In [1]:
import gym
import gym.spaces
import numpy as np
from pyquil import get_qc, Program
from pyquil.api import WavefunctionSimulator
from pyquil.gates import *


# Do a sanity check first

# identify discrete gates on qubit 0
num_angles = 5
angles = np.linspace(0.0, 2 * np.pi, num_angles)
gates = [RY(theta, 0) for theta in angles]
gates += [RZ(theta, 0) for theta in angles]

wfn_sim = WavefunctionSimulator()
for g in gates:
    p = Program(g)
    wfn = wfn_sim.wavefunction(p)
    amps = wfn.amplitudes
    if np.allclose(amps, np.sqrt(np.array([0.5, 0.5])), atol=1e-2):
        print("Found |+> state!!")
        print(p)
        print("*" * 30)

Found |+> state!!
RY(pi/2) 0

******************************


In [4]:
gates

[<Gate RY(0) 0>,
 <Gate RY(pi/2) 0>,
 <Gate RY(pi) 0>,
 <Gate RY(3*pi/2) 0>,
 <Gate RY(2*pi) 0>,
 <Gate RZ(0) 0>,
 <Gate RZ(pi/2) 0>,
 <Gate RZ(pi) 0>,
 <Gate RZ(3*pi/2) 0>,
 <Gate RZ(2*pi) 0>]

In [27]:
class OneQEnv(gym.Env):
    
    def __init__(self, gamma=0.8, max_steps=20):
        # discount factor
        self.gamma = gamma
        # initialize a WavefunctionSimulator
        self.wfn_sim = WavefunctionSimulator()
        # identify the observation and action spaces
        self.observation_space = gym.spaces.Box(-1.0, 1.0, shape=(4,), dtype=float)
        self._actions = gates
        self.action_space = gym.spaces.Discrete(len(self._actions))
        # the state will be the wavefunction probs
        self._program = Program(I(0))
        self._wfn = self.wfn_sim.wavefunction(self._program)
#         self.state = self._wfn.amplitudes
        self.state = np.concatenate([self._wfn.amplitudes.real, self._wfn.amplitudes.imag])
        # self.state = self._wfn.amplitudes
        self.current_step = 0
        self.max_steps = max_steps
        self.info = {}
        
    def step(self, action):
        gate = self._actions[action]
        self._program += gate
        self._wfn = self.wfn_sim.wavefunction(self._program)
        self.state = np.concatenate([self._wfn.amplitudes.real, self._wfn.amplitudes.imag])
        self.current_step += 1
        # detect if found terminal state
        if np.allclose(self.state, np.array([0.5, 0.5, -0.5, 0.5]), atol=1e-2):
            reward = 1.0
            done = True

        elif self.current_step >= self.max_steps:
            reward = 0.0
            done = True
            
        else:
            reward = 0.0
            done = False
        
            
        return self.state, reward, done, self.info
    
    def reset(self):
        self._program = Program(I(0))
        self._wfn = self.wfn_sim.wavefunction(self._program)
        self.state = np.concatenate([self._wfn.amplitudes.real, self._wfn.amplitudes.imag])
        self.current_step = 0
        
        return self.state

In [34]:
# from stable_baselines.common.policies import MlpPolicy
from stable_baselines.deepq.policies import MlpPolicy
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines import PPO2, DQN

In [35]:
env = OneQEnv()
env_vec = DummyVecEnv([lambda: env])

model = DQN(MlpPolicy, env_vec, verbose=0)
model.learn(total_timesteps=20000)

# obs = env_vec.reset()
# for i in range(1000):
#     action, _states = model.predict(obs)
#     obs, rewards, dones, info = env_vec.step(action)



<stable_baselines.deepq.dqn.DQN at 0x24e902cdfd0>

In [36]:
wfn_sim = WavefunctionSimulator()

In [37]:
optimal_action, next_state = model.predict(np.array([1, 0, 0, 0]))
prog = Program(gates[optimal_action])
wfn = wfn_sim.wavefunction(prog)
print(wfn)

(0.7071067812+0j)|0> + (0.7071067812+0j)|1>


In [38]:

done = False
env.reset()
prog = Program()
obs = np.array([1, 0, 0, 0])

while not done:
    optimal_action, _ = model.predict(obs)
    print(gates[optimal_action])
    prog += gates[optimal_action]
    obs, rewards, done, info = env.step(optimal_action)
    print(obs, rewards, done)
    
wfn = wfn_sim.wavefunction(prog)
print(f"Wavefunction: {wfn}")

RY(pi/2) 0
[0.70710678 0.70710678 0.         0.        ] 0.0 False
RZ(pi/2) 0
[ 0.5  0.5 -0.5  0.5] 1.0 True
Wavefunction: (0.5-0.5j)|0> + (0.5+0.5j)|1>


In [8]:
wfn_sim = WavefunctionSimulator()
p = Program()
p.inst(H(0))
p.inst(CNOT(0,1))

wfn = wfn_sim.wavefunction(p)
print(wfn.amplitudes)

[0.70710678+0.j 0.        +0.j 0.        +0.j 0.70710678+0.j]


In [9]:
print(p)

H 0
CNOT 0 1

