In [35]:
import gym
import gym.spaces
import numpy as np
from pyquil import get_qc, Program
from pyquil.api import WavefunctionSimulator
from pyquil.gates import *
from scipy.interpolate import interp1d


# Do a sanity check first

# identify discrete gates on qubit 0
num_angles = 5
qubits = 2
angles = np.linspace(0.0, 2*np.pi, num_angles)
gates = [H(q) for q in range(qubits)]
gates.extend([I(q) for q in range(qubits)])
gates.extend([T(q) for q in range(qubits)])
# gates = [RY(theta, q) for theta in angles for q in range(qubits)]
# gates += [RZ(theta, q) for theta in angles for q in range(qubits)]
gates.extend([CNOT(0, 1), CNOT(1, 0)]) #, CNOT(1, 2), CNOT(2, 1), CNOT(0, 2), CNOT(2, 0)]

# wfn_sim = WavefunctionSimulator()
# for g in gates:
#     p = Program(g)
#     wfn = wfn_sim.wavefunction(p)
#     amps = wfn.amplitudes
#     if np.allclose(amps, np.sqrt(np.array([0.5, 0.5])), atol=1e-2):
#         print("Found |+> state!!")
#         print(p)
#         print("*" * 30)

In [36]:
gates

[Gate { name: "H", parameters: [], qubits: [Fixed(0)], modifiers: [] },
 Gate { name: "H", parameters: [], qubits: [Fixed(1)], modifiers: [] },
 Gate { name: "I", parameters: [], qubits: [Fixed(0)], modifiers: [] },
 Gate { name: "I", parameters: [], qubits: [Fixed(1)], modifiers: [] },
 Gate { name: "T", parameters: [], qubits: [Fixed(0)], modifiers: [] },
 Gate { name: "T", parameters: [], qubits: [Fixed(1)], modifiers: [] },
 Gate { name: "CNOT", parameters: [], qubits: [Fixed(0), Fixed(1)], modifiers: [] },
 Gate { name: "CNOT", parameters: [], qubits: [Fixed(1), Fixed(0)], modifiers: [] }]

In [65]:
def bell_state(qb):
    wfn_sim = WavefunctionSimulator()
    _program = Program(RY(np.pi/2, 0))
    for i in range(qb-1):
        _program.inst(CNOT(i, i+1))
    # _program.inst(RY(np.pi, 1))
    _wfn = wfn_sim.wavefunction(_program)
    dm = np.outer(_wfn.amplitudes, _wfn.amplitudes)
    state = np.moveaxis(np.stack([dm.real, dm.imag], axis=0), 0, 2)
    return _wfn

def swap_gate(qb):
    wfn_sim = WavefunctionSimulator()
    _program = Program()
    _program += CNOT(0, 1)
    _program += CNOT(1, 0)
    _program += CNOT(0, 1)
    _wfn = wfn_sim.wavefunction(_program)
    return _wfn

class OneQEnv(gym.Env):
    def __init__(self, gamma=0.8, max_steps=20, qubits=2):
        self.interp = interp1d([-1.001, 1.001], [0, 255])
        self.qubits = qubits
        # self.goal = bell_state(self.qubits)
        self.goal = swap_gate(self.qubits)
        # discount factor
        self.gamma = gamma
        # initialize a WavefunctionSimulator
        self.wfn_sim = WavefunctionSimulator()
        # identify the observation and action spaces
        self.observation_space = gym.spaces.Box(low=0, high=255, shape=(2**self.qubits, 2**self.qubits, 2), dtype=float)
        self._actions = gates
        self.action_space = gym.spaces.Discrete(len(self._actions))
        # the state will be the wavefunction probs
        p = Program()
        for i in range(self.qubits):
            p.inst(I(i))
        self._program = p
        self._wfn = self.wfn_sim.wavefunction(self._program)
        print("initial state", self._wfn)
#         self.state = self._wfn.amplitudes
        dm = np.outer(self._wfn.amplitudes, self._wfn.amplitudes)
        # print(dm)
        self.state = self.interp(np.moveaxis(np.stack([dm.real, dm.imag], axis=0), 0, 2))
        # print(self.state)
        self.current_step = 0
        self.max_steps = max_steps
        self.info = {}

    def step(self, action):
        gate = self._actions[action]
        self._program += gate
        self._wfn = self.wfn_sim.wavefunction(self._program)
        dm = np.outer(self._wfn.amplitudes, self._wfn.amplitudes)
        self.state = self.interp(np.moveaxis(np.stack([dm.real, dm.imag], axis=0), 0, 2))
        self.current_step += 1
        # detect if found terminal state
        # if np.allclose(self.state, self.goal, atol=1e-2):
        #     reward = 1.0
        #     done = True

        # elif self.current_step >= self.max_steps:
        #     reward = 0.0
        #     done = True

        # else:
        #     reward = 0.0
        #     done = False
        reward = abs(self._wfn.amplitudes.T.conj() @ self.goal.amplitudes)**2
        if reward > 0.98:
            done = True
            reward = 1
        elif self.current_step >= self.max_steps:
            done = True
            reward = 0
        else:
            done = False
            reward = 0

        return self.state, reward, done, self.info

    def reset(self):
        p = Program()
        for i in range(self.qubits):
            p.inst(I(i))
        self._program = p
        self._wfn = self.wfn_sim.wavefunction(self._program)
        dm = np.outer(self._wfn.amplitudes, self._wfn.amplitudes)
        self.state = self.interp(np.moveaxis(np.stack([dm.real, dm.imag], axis=0), 0, 2))
        self.current_step = 0

        return self.state

In [66]:
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.ppo import PPO
import torch
import torch.nn as nn

In [67]:
class CustomPolicy(ActorCriticPolicy):
    def __init__(self, ob_space, ac_space, lr, **kwargs):
        super().__init__(ob_space, ac_space, lr, **kwargs)

    def step(self, obs, deterministic=False):
        return self(obs, deterministic)

    def value(self, obs):
        return self.predict_values(obs)

In [68]:
env = OneQEnv()
env_vec = DummyVecEnv([lambda: env])

model = PPO(CustomPolicy, env_vec, verbose=1, tensorboard_log="./circuit_rl_tensorboard/")
model.learn(total_timesteps=60000)

# obs = env_vec.reset()
# for i in range(1000):
#     action, _states = model.predict(obs)
#     obs, rewards, dones, info = env_vec.step(action)

initial state (1+0j)|00>
[[1.+0.j 0.+0.j 0.+0.j 0.+0.j]
 [0.+0.j 0.+0.j 0.+0.j 0.+0.j]
 [0.+0.j 0.+0.j 0.+0.j 0.+0.j]
 [0.+0.j 0.+0.j 0.+0.j 0.+0.j]]
[[[254.87262737 127.5       ]
  [127.5        127.5       ]
  [127.5        127.5       ]
  [127.5        127.5       ]]

 [[127.5        127.5       ]
  [127.5        127.5       ]
  [127.5        127.5       ]
  [127.5        127.5       ]]

 [[127.5        127.5       ]
  [127.5        127.5       ]
  [127.5        127.5       ]
  [127.5        127.5       ]]

 [[127.5        127.5       ]
  [127.5        127.5       ]
  [127.5        127.5       ]
  [127.5        127.5       ]]]
Using cpu device
Logging to ./circuit_rl_tensorboard/PPO_5




KeyboardInterrupt: 

In [52]:
wfn_sim = WavefunctionSimulator()

In [53]:
def wfn_to_dm(wfn):
    dm = np.outer(wfn.amplitudes, wfn.amplitudes)
    state = np.moveaxis(np.stack([dm.real, dm.imag], axis=0), 0, 2)
    interp = interp1d([-1.001, 1.001], [0, 255])
    return interp(state)

In [54]:
# program = Program(I(0)).inst(I(1))
# init_wfn = wfn_sim.wavefunction(program)
# dm = np.outer(init_wfn, init_wfn)
# state = wfn_to_dm(init_wfn)
# print(state)
# optimal_action, next_state = model.predict(state)
# prog = Program(gates[optimal_action])
# wfn = wfn_sim.wavefunction(prog)
# print(wfn)
# print(prog)

In [55]:
done = False
env.reset()
prog = Program(I(0)).inst(I(1))
wfn = wfn_sim.wavefunction(prog)
obs = wfn_to_dm(wfn)

while not done:
    optimal_action, _ = model.predict(obs)
    print(gates[optimal_action])
    prog += gates[optimal_action]
    obs, rewards, done, info = env.step(optimal_action)

wfn = wfn_sim.wavefunction(prog)
print(f"Wavefunction: {wfn}")
print(f"Density Matrix: {wfn_to_dm(wfn)}")

T 0
Wavefunction: (1+0j)|00>
Density Matrix: [[[254.87262737 127.5       ]
  [127.5        127.5       ]
  [127.5        127.5       ]
  [127.5        127.5       ]]

 [[127.5        127.5       ]
  [127.5        127.5       ]
  [127.5        127.5       ]
  [127.5        127.5       ]]

 [[127.5        127.5       ]
  [127.5        127.5       ]
  [127.5        127.5       ]
  [127.5        127.5       ]]

 [[127.5        127.5       ]
  [127.5        127.5       ]
  [127.5        127.5       ]
  [127.5        127.5       ]]]


In [56]:
print(swap_gate(0))

(1+0j)|00>


In [45]:
wfn_sim = WavefunctionSimulator()
_program = Program(RY(np.pi/2, 0)).inst(CNOT(0,1))
_wfn = wfn_sim.wavefunction(_program)
#         self.state = self._wfn.amplitudes
dm = np.outer(_wfn.amplitudes, _wfn.amplitudes)
state = np.moveaxis(np.stack([dm.real, dm.imag], axis=0), 0, 2)

In [46]:
state

array([[[0.5, 0. ],
        [0. , 0. ],
        [0. , 0. ],
        [0.5, 0. ]],

       [[0. , 0. ],
        [0. , 0. ],
        [0. , 0. ],
        [0. , 0. ]],

       [[0. , 0. ],
        [0. , 0. ],
        [0. , 0. ],
        [0. , 0. ]],

       [[0.5, 0. ],
        [0. , 0. ],
        [0. , 0. ],
        [0.5, 0. ]]])

In [47]:
abs(wfn.amplitudes.T.conj() @ bell_state(2))

0.9999999999999999