# Model-Free Quantum Control with Reinforcement Learning

Author: R.Maekura (ryomaekura@g.ecc.ut-kyo.ac.jp) \
This notebook is based on the following paper. It provides a `QAX` implementation of the part of the Fock state preparation that is optimized using reinforcement learning. \
V.V.Sivak et.al., "Model-Free Quantum Control with Reinforcement Learning", Phys.Rev.X 12, 011059 – Published 28 March, 2022 \
https://journals.aps.org/prx/abstract/10.1103/PhysRevX.12.011059

## Import package

In [1]:
import sys
sys.path.append("/home/users/u0001529/ondemand/qax-project")

import jax
import jax.numpy as jnp
from flax import nnx
import optax

import qax
from qax import state
from qax import operator as op
from qax.utils import linalg
from qax.utils import device

## Environment

In [None]:
class FockEnv:
    def __init__(self, sys_dim: int, target_state: StateVector) -> None: 
        '''

        '''
        self.sys_dim = sys_dim
        self.target_state = target_state

    def reset(self):
        """
        Reset at the beginning of an episode.
        - Return to the vacuum state
        - Set the step counter to 0
        Returns:
            observation: Observation returned to the agent
        """
        # initialization
        self._state = op.vacuum(self.sys_dim)
        self._step_count = 0

        # Return the observation
        observation = jnp.array([self._step_count], dtype=jnp.float32)
        return observation

    def step(self, action):
        """
        Apply a gate for one step.
        Args:
            action: Action output by the agent (parameters for Displacement/SNAP gates)

        - Create a Displacement gate using parameters alpha_real, alpha_imag
        - Create a SNAP gate using the parameters phi

        action format:
        [alpha_real, alpha_imag, phi_0, phi_1, ..., phi_(N-1)]
        Total (N + 2) dimensions

        Apply the unitary operator:
        U = D^\dagger(alpha) * SNAP(phi_0,...,phi_{N-1}) * D(alpha)

        Returns:
            observation: Observation of the next state
            reward: The reward
            done: Episode termination flag
            info: Debugging information (dict)
        """
        # Create Displacement and SNAP gates from alpha, phi and apply the composite gate
        alpha_real = action[0]
        alpha_imag = action[1]
        alpha = alpha_real + 1j * alpha_imag

        # Create the Displacement gate
        D_op = qt.displace(self.N, alpha)

        # Create the SNAP gate: apply independent parameters for each energy level
        phi_list = action[2:]  # Assumes shape is (N,)
        snap_diag = np.exp(1j * phi_list)    # 1D array of shape (N,)
        snap_mat = np.diag(snap_diag)        # Diagonal matrix of shape (N,N)

        SNAP_op = qt.Qobj(snap_mat, dims=[[self.N], [self.N]])

        # --- Composite gate U = D^\dagger * SNAP * D ---
        # Since D^\dagger(alpha) = D(-alpha), in qutip this can be written as:
        #   D_op.dag() or displace(N, -alpha)
        D_dagger = D_op.dag()  # In qutip, displace(N, alpha).dag() == displace(N, -alpha)
        U_step = D_dagger * SNAP_op * D_op

        # Apply the gate to the current state
        self._state = U_step * self._state

        # Advance the step
        self._step_count += 1

        # Termination check
        done = (self._step_count >= self.T)

        if done:
            # Calculate reward at the end of the episode: overlap (fidelity) with the Fock state
            fidelity = qt.fidelity(self._state, self._target_state)
            if np.random.rand() < fidelity:
                reward = 1.0
            else:
                reward = -1.0
        else:
            # Reward is 0 while the episode continues
            reward = 0

        # Return the observation
        observation = np.array([self._step_count], dtype=np.float32)  # Return the step count as the observation

        # info returns debugging information
        info = {}

        return observation, reward, done, info    

In [None]:
class PolicyNetwork(nnx.Module):
    '''

    '''
    def __init__(self, obs_dim: int, action_dim: int, hidden_dim: int) -> None:
        """
        Args:
            obs_dim (int): Dimension of the observation.
            action_dim (int): Dimension of the action.
            hidden_dim (int): Dimension of the hidden layers and LSTM state.
        """
        super().__init__()
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.hidden_dim = hidden_dim

        # --- Network Layers ---
        # LSTM layer to process sequences of observations
        self.lstm = nn.LSTM(obs_dim, hidden_dim, batch_first=True)

        # Fully connected layers to process the LSTM's output
        self.fc1 = nn.Linear(hidden_dim, hidden_dim)
        self.mean_layer = nn.Linear(hidden_dim, action_dim)

        # Learn a single log_std parameter for each action dimension
        self.log_std = nn.Parameter(torch.zeros(action_dim))