In [180]:
import numpy as np
import qutip as qt
from qutip import tensor, basis, identity, projection
from scipy.linalg import expm

"""
a simple 3 qubit environment that encourages driving towards a specific entangled state
"""

class DickeStateEnv:
    def __init__(self, target_state=None, action_max_val=0.2, timestep_duration = 1):
        self.act_max = action_max_val
        self.duration = timestep_duration
        if target_state==None:
            self.targ = 1/np.sqrt(3) * (tensor(basis(2,1), basis(2,0), basis(2,0)) + \
                                       tensor(basis(2,0), basis(2,1),  basis(2,0)) + \
                                       tensor(basis(2,0), basis(2,0),  basis(2,1)))
        else:
            self.targ = target_state
        self.q0_00 = tensor(projection(2,0,0), identity(2), identity(2))
        self.q0_11 = tensor(projection(2,1,1), identity(2), identity(2))
        self.q0_10 = tensor(projection(2,1,0), identity(2), identity(2))
        
        self.q1_00 = tensor(identity(2), projection(2,0,0),  identity(2))
        self.q1_11 = tensor(identity(2), projection(2,1,1),  identity(2))
        self.q1_10 = tensor(identity(2), projection(2,1,0),  identity(2))
        
        self.q2_00 = tensor(identity(2), identity(2), projection(2,0,0))
        self.q2_11 = tensor(identity(2), identity(2), projection(2,1,1))
        self.q2_10 = tensor(identity(2), identity(2), projection(2,1,0))
        
        
        
    def format_local_obs(self, reduced_state):
        np_array = reduced_state.full()
        real = np_array.real.flatten()
        imag = np_array.imag.flatten()
        real_imag_concat = np.concatenate((real,imag), axis=0)
        return real_imag_concat
    
    def get_reward(self, input_state):
        fid = qt.fidelity(input_state, self.targ)
        return fid
    
    def reset(self):
        excitations = np.zeros(3, dtype=np.int32)
        idx = np.random.randint(0,3)
        excitations[idx]=1
        self.state = tensor(basis(2,excitations[0]), basis(2,excitations[1]),basis(2,excitations[2]))
        observations = []
        for val in excitations:
            flattened_obs = self.format_local_obs(basis(2, val)*basis(2,val).dag())
            observations.append(flattened_obs)
        return observations
    
    
    
    def step(self, actions):
        coupling_01 = actions[0]*self.act_max
        coupling_12 = actions[1]*self.act_max
        coupling_02 = actions[3]*self.act_max
        
        H = (coupling_01 * (self.q0_10*self.q1_10.dag() + self.q1_10*self.q0_10.dag()) +
             coupling_12 * (self.q2_10*self.q1_10.dag() + self.q2_10*self.q1_10.dag()) +
             coupling_02 * (self.q0_10*self.q2_10.dag() + self.q2_10*self.q0_10.dag()))
        U = qt.Qobj(expm(-1j * H * self.duration))
        new_state = U * self.state * U.dag()
        
        reward = self.get_reward(new_state)
        rewards = [reward]*3
        
        self.state = new_state
        observations = []
        idxs = [0,1,2]
        for i in range(3):
            reduced_state = new_state.ptrace(i)
            formatted = self.format_local_obs(reduced_state)
            observations.append(formatted)
        if reward>0.99:
            dones = [True]*3
        else:
            dones = [False]*3
            
        info = {}
        return observations, rewards, dones, info
    
    

In [181]:
obs = env.reset

[array([0.+0.j, 0.+0.j, 0.+0.j, 1.+0.j]),
 array([1.+0.j, 0.+0.j, 0.+0.j, 0.+0.j]),
 array([1.+0.j, 0.+0.j, 0.+0.j, 0.+0.j])]

In [178]:
qt.fock_dm(2,0,0).full().imag

array([[0., 0.],
       [0., 0.]])