In [30]:
%load_ext autoreload
%autoreload 2

import numpy as np
import gymnasium as gym
import os
import qiskit
from gymnasium import spaces
from stable_baselines3 import PPO, A2C, DQN, TD3, SAC
from stable_baselines3.common.env_util import DummyVecEnv
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.vec_env import VecNormalize

from qiskit.quantum_info import random_density_matrix, random_statevector, DensityMatrix
from adaptive_qst.plotting import PlotOneQubit
from adaptive_qst.max_info import Posterior, HiddenState
#from adaptive_qst.rl_qst import AQSTEnv
import matplotlib.pyplot as plt
from numpy import pi
from qiskit.quantum_info import state_fidelity

from numpy import sqrt
from numpy.linalg import cholesky

from sb3_contrib import RecurrentPPO

import os

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [51]:
class RAQST(gym.Env):
    
    def __init__(self, n_particles = 30, n_measurements = 1000, hidden_state = None, reward_start = 100):
        super(RAQST, self).__init__()
        
        self.n_particles = n_particles
        self.n_measurements = n_measurements
        self.posterior = Posterior(self.n_particles)
        
        self.change_state = (hidden_state is None)
        self.hidden_state = hidden_state
        if self.change_state:
            self.hidden_state = HiddenState()
    
        self.step_num = 0
        self.reward_start = reward_start
        
        ##Observations: Orientation of the measured axis
        self.observation_space = gym.spaces.Box(low = -1, high = 1, shape = (2,))
        
        ##Actions: Orientation of the measured axis
        self.action_space = gym.spaces.Box(low= -1, high = 1, shape = (2,))  ##Orientation of measurement

    def step(self, action):
        
        config = (action + 1) * pi / 2
        result = self.hidden_state.measure_along_axis(config)
        
        
        self.posterior.update(config, result)
        fidelity = state_fidelity(self.hidden_state.hidden_state, self.posterior.get_best_guess())
        
        reward = -np.log(1 - fidelity) / np.log(self.step_num + 8)

        self.step_num += 1
        truncated = (self.step_num >= self.n_measurements)
        terminated = False

        return (self.get_observation(config, result), 
                reward, 
                terminated, 
                truncated, 
                {})

    def reset(self, seed=None, options=None):
        super().reset(seed=seed, options=options)

        self.posterior = Posterior(self.n_particles)
        if not self.change_state:
            self.hidden_state = HiddenState()
            
        self.step_num = 0
        
        return np.array([0, 0]).astype(np.float32), {}
    
    #Package complex density matrix and weights into observation vector (x, y, z positions)
    def get_observation(self, config, result):
        
        if result == 0:
            obs = config
        
        ##Otherwise, measured in the opposite direction:
        else:
            obs = np.array([pi - config[0], config[1] - pi])
            
        #Normalize
        obs /= pi
        obs[0] = 2 * obs[0] - 1

        return obs.astype(np.float32)

In [52]:
env = RAQST()
check_env(env, warn=True)

In [53]:
##Test the Environment with random actions:
env = RAQST()
obs, _ = env.reset()

for _ in range(10):
    action = env.action_space.sample()
    obs, rewards, truncated, terminated, info = env.step(action)
    
    print(action)
    print(obs)
    print()

[0.37932843 0.6402268 ]
[0.3793285  0.82011336]

[-0.4464065  0.1054719]
[ 0.44640645 -0.44726405]

[ 0.49388504 -0.3730353 ]
[0.49388492 0.31348234]

[ 0.7541785 -0.7234208]
[0.7541784 0.1382896]

[0.15258603 0.17683956]
[0.15258598 0.5884198 ]

[0.92862636 0.00646803]
[-0.9286264  -0.49676597]

[0.75920683 0.97662836]
[0.75920665 0.98831415]

[ 0.4609135 -0.4513865]
[-0.46091357 -0.7256932 ]

[-0.58582515 -0.5354774 ]
[-0.58582515  0.2322613 ]

[0.21020266 0.6387054 ]
[0.2102027 0.8193527]



In [54]:
model_save_dir = "models/lstm_qst"
tb_log_dir = "tb_logs/lstm_qst/1_qubit"

os.makedirs(model_save_dir, exist_ok = True)
os.makedirs(tb_log_dir, exist_ok = True)

In [50]:
run_name = "PPO_default"
model_save_path = f"{model_save_dir}/{run_name}"

batch_size = 32
n_measurements = 1000
n_train_episodes = 10
eval_episode_freq = 1
n_eval_episodes = 10

train_env = DummyVecEnv([lambda: Monitor(RAQST(n_measurements = n_measurements)) for _ in range(batch_size)])
#train_env = VecNormalize(train_env, norm_obs=True, norm_reward=True)

eval_env = DummyVecEnv([lambda: Monitor(RAQST(n_measurements = n_measurements))])
#eval_env = VecNormalize(eval_env, norm_obs=True, norm_reward=False)

eval_callback = EvalCallback(eval_env, best_model_save_path=model_save_path,
                             deterministic=True, render=False, n_eval_episodes = n_eval_episodes,
                             eval_freq = eval_episode_freq * n_measurements)


# model = RecurrentPPO("MlpLstmPolicy", "CartPole-v1", verbose=1)

model = RecurrentPPO("MlpLstmPolicy", train_env, tensorboard_log = tb_log_dir)
model.learn(batch_size * n_measurements * n_train_episodes, tb_log_name=run_name, callback = eval_callback)

KeyboardInterrupt: 

In [48]:
##Load best model and evaluate
model = RecurrentPPO.load(f"{model_save_path}/best_model")

model = RecurrentPPO("MlpLstmPolicy", train_env, tensorboard_log = tb_log_dir)

obs = eval_env.reset()
# cell and hidden state of the LSTM
lstm_states = None
num_envs = 1
# Episode start signals are used to reset the lstm states
episode_starts = np.ones((num_envs,), dtype=bool)

for i in range(100):
    action, lstm_states = model.predict(obs, state=lstm_states, episode_start=episode_starts, deterministic=True)
    obs, rewards, dones, info = eval_env.step(action)
    episode_starts = dones

    print(action)
    

[[ 1.3609309e-04 -8.0317521e-05]]
[[ 0.00027938 -0.00026112]]
[[ 0.00036952 -0.00035382]]
[[ 0.00042778 -0.00039793]]
[[ 0.00046694 -0.00041662]]
[[ 0.0004484  -0.00013724]]
[[4.631229e-04 2.123799e-05]]
[[0.0004875  0.00011202]]
[[0.0005083  0.00016208]]
[[ 5.654755e-04 -9.257115e-05]]
[[ 0.00057907 -0.00023797]]
[[ 0.00057369 -0.00032093]]
[[ 0.00056534 -0.00036646]]
[[ 0.00055851 -0.0003909 ]]
[[ 0.0005085  -0.00011904]]
[[ 0.00054603 -0.00025027]]
[[ 5.151750e-04 -3.922511e-05]]
[[ 0.00055505 -0.00020613]]
[[ 5.2276300e-04 -1.5825279e-05]]
[[5.1656814e-04 8.9213645e-05]]
[[0.00052323 0.00014943]]
[[0.00053125 0.00018251]]
[[ 5.795658e-04 -8.094103e-05]]
[[ 0.00058766 -0.00023133]]
[[ 0.00057888 -0.00031713]]
[[ 5.2402326e-04 -8.0160062e-05]]
[[5.122069e-04 5.369253e-05]]
[[ 0.00056172 -0.00015171]]
[[5.3061242e-04 1.4013854e-05]]
[[ 0.00056579 -0.00017721]]
[[ 0.00057425 -0.00028414]]
[[ 5.2448490e-04 -6.1479535e-05]]
[[ 0.00055732 -0.00021922]]
[[ 5.2304112e-04 -2.3022756e-05]]
[[