# Install and import packages
--------

In [None]:
# install/import quantum gym environments
!pip install git+https://github.com/qdevpsi3/quantum-arch-search.git

# install/import stable baselines 3
!pip install stable_baselines3

In [None]:
import gym
import torch.optim as optim
from stable_baselines3 import A2C, PPO
from stable_baselines3.common.evaluation import evaluate_policy

import qas_gym

# Basic Environment
------
Create your gym environment :

In [None]:
# Parameters 
env_name = 'BasicTwoQubit-v0'
fidelity_threshold = 0.95
reward_penalty = 0.01
max_timesteps = 20

# Environment
env = gym.make(env_name,
               fidelity_threshold=fidelity_threshold,
               reward_penalty=reward_penalty,
               max_timesteps=max_timesteps)

Diplay the action gates : 

In [None]:
for idx, gate in enumerate(env.action_gates):
    print('Action({:02d}) --> {}'.format(idx, gate))

Diplay the state observables : 

In [None]:
for idx, observable in enumerate(env.state_observables):
    print('State({:02d}) --> {}'.format(idx, observable))

# A2C Agent
------

In [None]:
# Parameters
gamma = 0.99
learning_rate = 0.0001
policy_kwargs = dict(optimizer_class=optim.Adam)

# Agent
a2c_model = A2C("MlpPolicy",
                env,
                gamma=gamma,
                learning_rate=learning_rate,
                policy_kwargs=policy_kwargs,
                tensorboard_log='logs/')

In [None]:
a2c_model.learn(total_timesteps=1000)

# PPO Model
------

In [None]:
# Parameters
gamma = 0.99
n_epochs = 4
clip_range = 0.2
learning_rate = 0.0001
policy_kwargs = dict(optimizer_class=optim.Adam)


# Agent
ppo_model = PPO("MlpPolicy",
                env,
                gamma=gamma,
                n_epochs=n_epochs,
                clip_range=clip_range,
                learning_rate=learning_rate,
                policy_kwargs=policy_kwargs,
                tensorboard_log='logs/')

In [None]:
ppo_model.learn(total_timesteps=20000)

# Results
------

In [None]:
%load_ext tensorboard
%tensorboard --logdir=logs/

# Predict
------

In [7]:
import time
from IPython.display import clear_output

state = env.reset()
done = False
while not done:
    action, _ = a2c_model.predict(state)
    state, reward, done, info = env.step(action)
    
    clear_output(wait=True)
    env.render()
    time.sleep(1)


0: ───I───Z───Z───Rz(0.25π)───@───Y───X───Y───Z───────
                              │       │
1: ───I───H───Z───────────────X───────@───Z───H───H───
