In [2]:
%load_ext autoreload
%autoreload 2

import numpy as np
import gymnasium as gym
import os
import qiskit
from gymnasium import spaces
from stable_baselines3 import PPO, A2C, DQN, TD3, SAC
from stable_baselines3.common.env_util import DummyVecEnv
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.vec_env import VecNormalize

from qiskit.quantum_info import random_density_matrix, random_statevector, DensityMatrix
from adaptive_qst.plotting import PlotOneQubit
from adaptive_qst.max_info import Posterior, HiddenState
#from adaptive_qst.rl_qst import AQSTEnv
import matplotlib.pyplot as plt
from numpy import pi
from qiskit.quantum_info import state_fidelity

from numpy import sqrt
from numpy.linalg import cholesky

from sb3_contrib import RecurrentPPO

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [13]:
class RAQST(gym.Env):
    
    def __init__(self, n_particles = 30, n_measurements = 1000, hidden_state = None, reward_start = 100):
        super(RAQST, self).__init__()
        
        self.n_particles = n_particles
        self.n_measurements = n_measurements
        self.posterior = Posterior(self.n_particles)
        
        self.change_state = (hidden_state is None)
        self.hidden_state = hidden_state
        if self.change_state:
            self.hidden_state = HiddenState()
    
        self.step_num = 0
        self.reward_start = reward_start
        
        ##Observations: Orientation of the measured axis
        self.observation_space = gym.spaces.Box(low = -1, high = 1, shape = (2,))
        
        ##Actions: Orientation of the measured axis
        self.action_space = gym.spaces.Box(low= -1, high = 1, shape = (2,))  ##Orientation of measurement

    def step(self, action):
        
        config = (action + 1) * pi / 2
        result = self.hidden_state.measure_along_axis(config)
        
        
        self.posterior.update(config, result)
        fidelity = state_fidelity(self.hidden_state.hidden_state, self.posterior.get_best_guess())
        
        reward = -np.log(1 - fidelity) / np.log(self.step_num + 8)

        self.step_num += 1
        truncated = (self.step_num >= self.n_measurements)
        terminated = False

        return (self.get_observation(action, result), 
                reward, 
                terminated, 
                truncated, 
                {})

    def reset(self, seed=None, options=None):
        super().reset(seed=seed, options=options)

        self.posterior = Posterior(self.n_particles)
        if not self.change_state:
            self.hidden_state = HiddenState()
            
        self.step_num = 0
        
        return np.array([0, 0]).astype(np.float32), {}
    
    #Package complex density matrix and weights into observation vector (x, y, z positions)
    def get_observation(self, action, result):
        
        if result == 0:
            obs = config
        
        ##Otherwise, measured in the opposite direction:
        else:
            obs = np.array([pi - config[0], config[1] - pi])
            
        #Normalize
        obs /= pi
        obs[0] = 2 * obs[0] - 1

        return obs.astype(np.float32)

In [14]:
env = RAQST()
check_env(env, warn=True)

In [20]:
##Test the Environment with random actions:
env = RAQST()
obs, _ = env.reset()

for _ in range(10):
    action = env.action_space.sample()
    obs, rewards, truncated, terminated, info = env.step(action)
    
    print(action)
    print(obs)
    print()

[0.4435885  0.07334452]
[0.4435885  0.53667223]

[-0.30095756  0.31652617]
[-0.30095756  0.6582631 ]

[-0.30076712  0.16702166]
[-0.30076712  0.5835108 ]

[0.9549845  0.23642556]
[-0.95498455 -0.3817872 ]

[-0.20831513  0.9967349 ]
[ 0.20831512 -0.00163254]

[ 0.17844771 -0.31732884]
[0.17844772 0.3413356 ]

[-0.46157548 -0.92410755]
[-0.4615755   0.03794622]

[0.56021947 0.78713393]
[-0.5602195  -0.10643299]

[0.89643925 0.88191444]
[-0.8964394  -0.05904279]

[0.91942364 0.41523916]
[-0.9194237  -0.29238042]



In [25]:
arr = np.array([0.56021947, 0.78713393]) * pi

(0.16702166 + 1) / 2

0.58351083