In [1]:
import gym
import gym.spaces
import numpy as np
from pyquil import get_qc, Program
from pyquil.api import WavefunctionSimulator
from pyquil.gates import *
from scipy.interpolate import interp1d


# Do a sanity check first

# identify discrete gates on qubit 0
num_angles = 5
qubits = 2
angles = np.linspace(0.0, 2*np.pi, num_angles)
gates = [RY(theta, q) for theta in angles for q in range(qubits)]
gates += [RZ(theta, q) for theta in angles for q in range(qubits)]
gates += [CNOT(0, 1), CNOT(1, 0)] #, CNOT(1, 2), CNOT(2, 1), CNOT(0, 2), CNOT(2, 0)]

# wfn_sim = WavefunctionSimulator()
# for g in gates:
#     p = Program(g)
#     wfn = wfn_sim.wavefunction(p)
#     amps = wfn.amplitudes
#     if np.allclose(amps, np.sqrt(np.array([0.5, 0.5])), atol=1e-2):
#         print("Found |+> state!!")
#         print(p)
#         print("*" * 30)

In [2]:
gates

[<Gate RY(0) 0>,
 <Gate RY(0) 1>,
 <Gate RY(pi/2) 0>,
 <Gate RY(pi/2) 1>,
 <Gate RY(pi) 0>,
 <Gate RY(pi) 1>,
 <Gate RY(3*pi/2) 0>,
 <Gate RY(3*pi/2) 1>,
 <Gate RY(2*pi) 0>,
 <Gate RY(2*pi) 1>,
 <Gate RZ(0) 0>,
 <Gate RZ(0) 1>,
 <Gate RZ(pi/2) 0>,
 <Gate RZ(pi/2) 1>,
 <Gate RZ(pi) 0>,
 <Gate RZ(pi) 1>,
 <Gate RZ(3*pi/2) 0>,
 <Gate RZ(3*pi/2) 1>,
 <Gate RZ(2*pi) 0>,
 <Gate RZ(2*pi) 1>,
 <Gate CNOT 0 1>,
 <Gate CNOT 1 0>]

In [30]:
def bell_state(qb):
    wfn_sim = WavefunctionSimulator()
    _program = Program(RY(np.pi/2, 0))
    for i in range(qb-1):
        _program.inst(CNOT(i, i+1))
    _program.inst(RY(np.pi, 1))
    _wfn = wfn_sim.wavefunction(_program)
    dm = np.outer(_wfn.amplitudes, _wfn.amplitudes)
    state = np.moveaxis(np.stack([dm.real, dm.imag], axis=0), 0, 2)
    return _wfn

class OneQEnv(gym.Env):
    def __init__(self, gamma=0.8, max_steps=20, qubits=2):
        self.interp = interp1d([-1.001, 1.001], [0, 255])
        self.qubits = qubits
        self.goal = bell_state(self.qubits)
        # discount factor
        self.gamma = gamma
        # initialize a WavefunctionSimulator
        self.wfn_sim = WavefunctionSimulator()
        # identify the observation and action spaces
        self.observation_space = gym.spaces.Box(low=0, high=255, shape=(2**self.qubits, 2**self.qubits, 2), dtype=float)
        self._actions = gates
        self.action_space = gym.spaces.Discrete(len(self._actions))
        # the state will be the wavefunction probs
        p = Program()
        for i in range(self.qubits):
            p.inst(I(i))
        self._program = p
        self._wfn = self.wfn_sim.wavefunction(self._program)
#         self.state = self._wfn.amplitudes
        dm = np.outer(self._wfn.amplitudes, self._wfn.amplitudes)
        self.state = self.interp(np.moveaxis(np.stack([dm.real, dm.imag], axis=0), 0, 2))
        self.current_step = 0
        self.max_steps = max_steps
        self.info = {}
        
    def step(self, action):
        gate = self._actions[action]
        self._program += gate
        self._wfn = self.wfn_sim.wavefunction(self._program)
        dm = np.outer(self._wfn.amplitudes, self._wfn.amplitudes)
        self.state = self.interp(np.moveaxis(np.stack([dm.real, dm.imag], axis=0), 0, 2))
        self.current_step += 1
        # detect if found terminal state
        # if np.allclose(self.state, self.goal, atol=1e-2):
        #     reward = 1.0
        #     done = True

        # elif self.current_step >= self.max_steps:
        #     reward = 0.0
        #     done = True
            
        # else:
        #     reward = 0.0
        #     done = False
        reward = abs(self._wfn.amplitudes.T.conj() @ self.goal.amplitudes)**2
        if reward > 0.98:
            done = True
            reward = 1
        elif self.current_step >= self.max_steps:
            done = True
            reward = 0
        else:
            done = False
            reward = 0

        return self.state, reward, done, self.info
    
    def reset(self):
        p = Program()
        for i in range(self.qubits):
            p.inst(I(i))
        self._program = p
        self._wfn = self.wfn_sim.wavefunction(self._program)
        dm = np.outer(self._wfn.amplitudes, self._wfn.amplitudes)
        self.state = self.interp(np.moveaxis(np.stack([dm.real, dm.imag], axis=0), 0, 2))
        self.current_step = 0
        
        return self.state

In [4]:
from stable_baselines.deepq.policies import CnnPolicy
from stable_baselines.common.policies import ActorCriticPolicy, MlpPolicy
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines import PPO2, DQN
from stable_baselines.common.tf_layers import conv, linear, conv_to_fc, lstm
import tensorflow as tf

In [5]:
def custom_cnn(scaled_images, **kwargs):
    activ = tf.nn.relu
    layer_1 = activ(conv(scaled_images, 'c1', n_filters=8, filter_size=2, stride=2, init_scale=np.sqrt(2), **kwargs))
    # layer_2 = activ(conv(layer_1, 'c2', n_filters=16, filter_size=2, stride=2, init_scale=np.sqrt(2), **kwargs))
    # layer_3 = activ(conv(layer_2, 'c3', n_filters=64, filter_size=3, stride=1, init_scale=np.sqrt(2), **kwargs))
    layer_3 = conv_to_fc(layer_1)
    layer_4 = activ(linear(layer_3, 'fc1', n_hidden=128, init_scale=np.sqrt(2)))
    return activ(linear(layer_4, 'fc2', n_hidden=64, init_scale=np.sqrt(2)))

class CustomPolicy(ActorCriticPolicy):
    def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=False, layers=None, net_arch=None,
                 act_fun=tf.tanh, cnn_extractor=custom_cnn, feature_extraction="cnn", **kwargs):
        super(CustomPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=reuse,
                                                scale=(feature_extraction == "cnn"))

        self._kwargs_check(feature_extraction, kwargs)

        with tf.variable_scope("model", reuse=reuse):
            if feature_extraction == "cnn":
                pi_latent = vf_latent = cnn_extractor(self.processed_obs, **kwargs)

            self._value_fn = linear(vf_latent, 'vf', 1)

            self._proba_distribution, self._policy, self.q_value = \
                self.pdtype.proba_distribution_from_latent(pi_latent, vf_latent, init_scale=0.01)

        self._setup_init()

    def step(self, obs, state=None, mask=None, deterministic=False):
        if deterministic:
            action, value, neglogp = self.sess.run([self.deterministic_action, self.value_flat, self.neglogp],
                                                   {self.obs_ph: obs})
        else:
            action, value, neglogp = self.sess.run([self.action, self.value_flat, self.neglogp],
                                                   {self.obs_ph: obs})
        return action, value, self.initial_state, neglogp

    def proba_step(self, obs, state=None, mask=None):
        return self.sess.run(self.policy_proba, {self.obs_ph: obs})

    def value(self, obs, state=None, mask=None):
        return self.sess.run(self.value_flat, {self.obs_ph: obs})

In [33]:
env = OneQEnv()
env_vec = DummyVecEnv([lambda: env])

model = PPO2(CustomPolicy, env_vec, verbose=1)
model.learn(total_timesteps=60000)

# obs = env_vec.reset()
# for i in range(1000):
#     action, _states = model.predict(obs)
#     obs, rewards, dones, info = env_vec.step(action)

--------------------------------------
| approxkl           | 2.3025823e-05 |
| clipfrac           | 0.0           |
| explained_variance | -0.0257       |
| fps                | 39            |
| n_updates          | 1             |
| policy_entropy     | 3.0910146     |
| policy_loss        | -0.0024622604 |
| serial_timesteps   | 128           |
| time_elapsed       | 0             |
| total_timesteps    | 128           |
| value_loss         | 0.053438015   |
--------------------------------------
---------------------------------------
| approxkl           | 1.359204e-05   |
| clipfrac           | 0.0            |
| explained_variance | -0.174         |
| fps                | 48             |
| n_updates          | 2              |
| policy_entropy     | 3.0909023      |
| policy_loss        | -0.00096567185 |
| serial_timesteps   | 256            |
| time_elapsed       | 3.26           |
| total_timesteps    | 256            |
| value_loss         | 0.00332289     |
-------------

<stable_baselines.ppo2.ppo2.PPO2 at 0x24887b51860>

In [8]:
wfn_sim = WavefunctionSimulator()

In [20]:
def wfn_to_dm(wfn):
    dm = np.outer(wfn.amplitudes, wfn.amplitudes)
    state = np.moveaxis(np.stack([dm.real, dm.imag], axis=0), 0, 2)
    interp = interp1d([-1.001, 1.001], [0, 255])
    return interp(state)

In [91]:
# program = Program(I(0)).inst(I(1))
# init_wfn = wfn_sim.wavefunction(program)
# dm = np.outer(init_wfn, init_wfn)
# state = wfn_to_dm(init_wfn)
# print(state)
# optimal_action, next_state = model.predict(state)
# prog = Program(gates[optimal_action])
# wfn = wfn_sim.wavefunction(prog)
# print(wfn)
# print(prog)

In [34]:
done = False
env.reset()
prog = Program(I(0)).inst(I(1))
wfn = wfn_sim.wavefunction(prog)
obs = wfn_to_dm(wfn)

while not done:
    optimal_action, _ = model.predict(obs)
    print(gates[optimal_action])
    prog += gates[optimal_action]
    obs, rewards, done, info = env.step(optimal_action)
    
wfn = wfn_sim.wavefunction(prog)
print(f"Wavefunction: {wfn}")
print(f"Density Matrix: {wfn_to_dm(wfn)}")

RY(pi/2) 0
CNOT 0 1
RY(pi/2) 0
RY(pi/2) 0
Wavefunction: (0.7071067812+0j)|01> + (-0.7071067812+0j)|10>
Density Matrix: [[[127.5        127.5       ]
  [127.5        127.5       ]
  [127.5        127.5       ]
  [127.5        127.5       ]]

 [[127.5        127.5       ]
  [191.18631369 127.5       ]
  [ 63.81368631 127.5       ]
  [127.5        127.5       ]]

 [[127.5        127.5       ]
  [ 63.81368631 127.5       ]
  [191.18631369 127.5       ]
  [127.5        127.5       ]]

 [[127.5        127.5       ]
  [127.5        127.5       ]
  [127.5        127.5       ]
  [127.5        127.5       ]]]


In [55]:
wfn_sim = WavefunctionSimulator()
_program = Program(RY(np.pi/2, 0)).inst(CNOT(0,1))
_wfn = wfn_sim.wavefunction(_program)
#         self.state = self._wfn.amplitudes
dm = np.outer(_wfn.amplitudes, _wfn.amplitudes)
state = np.moveaxis(np.stack([dm.real, dm.imag], axis=0), 0, 2)

In [56]:
state

array([[[0.5, 0. ],
        [0. , 0. ],
        [0. , 0. ],
        [0.5, 0. ]],

       [[0. , 0. ],
        [0. , 0. ],
        [0. , 0. ],
        [0. , 0. ]],

       [[0. , 0. ],
        [0. , 0. ],
        [0. , 0. ],
        [0. , 0. ]],

       [[0.5, 0. ],
        [0. , 0. ],
        [0. , 0. ],
        [0.5, 0. ]]])

In [68]:
abs(wfn.amplitudes.T.conj() @ bell_state(2))

0.7071067811865476