# Robust quantum gates using smooth pulses and physics-informed neural networks

In [4]:
import jax
import jax.numpy as jnp

from flax import nnx
import optax

import hydra
from omegaconf import DictConfig
'''
%load_ext autoreload
%autoreload 2
import operators as qo
from network import MLP
'''

'\n%load_ext autoreload\n%autoreload 2\nimport operators as qo\nfrom network import MLP\n'

## Examples : Transmon qubit

### utils.py

In [5]:
def dagger(A: jax.Array) -> jax.Array:
    """
    Return Hermitian transpose matrix
    Args:
        A (jnp.ndarray): input matrix of shape (m, n).

    Returns:
        jnp.ndarray: Hermitian transpose matrix
    """
    return jax.lax.conj(jax.lax.transpose(A, (1, 0)))

### operators.py

In [6]:
def identity(dim: int) -> jax.Array:
    '''
    Returns the identity operator.

    Args:
      dim (int): The dimension of Hilbert space.
      
    Returns: 
      jnp.array: A 2x2 matrix of the identity operator with complex64 precision.
    '''
    return jnp.identity(dim, dtype=jnp.complex64)

def num(dim: int) -> jax.Array:
    '''
    Returns the number operator.
    
    Args:
      dim (int): The dimension of Hilbert space.

    Returns:
      jnp.array: A 'dim'x'dim' matrix of the number operator with complex64 precision.
    '''
    return jnp.diag(jnp.arange(0, dim, dtype=jnp.complex64))

### envelope.py

In [7]:
@hydra.main(config_name="system_params", version_base=None, config_path="./config")
def smoothed_square(t: float, cfg: DictConfig):
    '''
    Smoothed unit square pulse expressed as follows:
      A(t) = coth(kappa*T){tanh(kappa*T) - tanh(kappa(t-T))} - 1
    
    Args:
      t (float): Current time (in seconds) at which to evaluate the control Hamiltonian.
      cfg (DictConfig): Hydra‐loaded configuration object containing system parameters, 
        for example:
        - cfg.drive.kappa: the degree of the smoothing
        - cfg.drive.duration: anharmonicity of the transmon  
      
    Returns:
      jnp.array: Matrix representation of the envelope (see system_params.yaml for kappa and duration)
    '''
    term_1 = (1/jax.lax.tanh(cfg.drive.kappa*cfg.drive.duration)) * jax.lax.tanh(cfg.drive.kappa*t)
    term_2 = (1/jax.lax.tanh(cfg.drive.kappa*cfg.drive.duration)) * jax.lax.tanh(cfg.drive.kappa*(t-cfg.drive.duration))
    return term_1 - term_2 - 1

### network.py

In [8]:
class Linear(nnx.Module):
  def __init__(self, num_input: int, num_output: int, *, rngs: nnx.Rngs):
    key = rngs.params()
    self.w = nnx.Param(jax.random.uniform(key, (num_input, num_output)))
    self.b = nnx.Param(jnp.zeros((num_output,)))
    self.num_input, self.num_output = num_input, num_output

  def __call__(self, x: jax.Array):
    return x @ self.w + self.b

class MLP(nnx.Module):
    def __init__(self, num_input: int, num_output: int, rngs: nnx.Rngs):
        self.input_layer = Linear(num_input, 32, rngs=rngs)
        self.hidden_layers1 = Linear(32, 32, rngs=rngs)
        self.hidden_layers2 = Linear(32, 32, rngs=rngs)
        self.output_layer = Linear(32, num_output, rngs=rngs)

    def __call__(self, x: jax.Array):
        x = nnx.tanh(self.input_layer(x))
        x = nnx.tanh(self.hidden_layers1(x))
        x = nnx.tanh(self.hidden_layers2(x))
        return self.output_layer(x)

### train.py

In [9]:
model = MLP(num_input=1, num_output=6, rngs=nnx.Rngs(0))
optimizer = nnx.optimizer(model, optax.lbfgs(0.001))

In [None]:
@hydra.main(config_name="train_params", version_base=None, config_path="./config")
@hydra.main(config_name="system_params", version_base=None, config_path="./config")
def loss(model, cfg: DictConfig):
    # infidelity
    U_0 = (1/jnp.sqrt(2)) * jnp.array([[1, -1j], [-1j, 1]])
    P = jnp.array([[1, 0, 0, 0], [0, 1, 0, 0]])
    U_c = 
    term_1 = 1 - ((jnp.abs())**)/()

    # noise sensitivity
    term_2 = 

@hydra.main(config_name="train_params", version_base=None, config_path="./config")
def train_loop()

@hydra.main(config_name="system_params", version_base=None, config_path="./config")
def compute_loss()

### Settings

In [26]:
# 
sys_dim = 4 # maximum energy level
qubit_dim = 2**1 # logical subspace

# target unitary gate
theta = -jnp.pi/2
U_tar = jnp.cos(theta)*qo.identity(2) + 1.0j*jnp.sin(theta)*qo.pauli_x()

# define system parameters
'''
w_r = 2*jnp.pi * 
w_ge = 2*jnp.pi *
w_ef = 2*jnp.pi *
w_d = 2*jnp.pi *
'''
alpha = -200*10^6

$$
\begin{align}
\hat{H}_{\rm c} \approx \delta \left( t \right) \hat{a}^\dagger \hat{a} + \frac{\Delta}{2} \hat{a}^\dagger \hat{a} \left( \hat{a}^\dagger \hat{a} - \hat{I} \right) + \frac{\Omega \left( t \right) \hat{a} + \Omega^* \left( t \right) \hat{a}^\dagger}{2}
\end{align}
$$

In [2]:
@hydra.main(config_name="control_params", version_base=None, config_path="./config")
def smooth_square_envelope(t: float, cfg: DictConfig):
    '''
    
    '''
    term_1 = (1/jax.lax.tanh(cfg.drive.kappa*cfg.drive.duration)) * jax.lax.tanh(cfg.drive.kappa*t)
    term_2 = (1/jax.lax.tanh(cfg.drive.kappa*cfg.drive.duration)) * jax.lax.tanh(cfg.drive.kappa*(t-cfg.drive.duration))
    return term_1 - term_2 - 1

@hydra.main(config_name="system_params", version_base=None, config_path="./config")
def pulse(t: float, model_outputs: jax.Array, cfg: DictConfig):
    '''

    '''
    outputs = model(t)
    Omega_x = 4*cfg.transmon.anharmonicity*smooth_square_envelope(t)*(2/jnp.pi) * jax.lax.atan(outputs[0]) * jax.lax.sin(outputs[1])
    Omega_y = 4*cfg.transmon.anharmonicity*smooth_square_envelope(t)*(2/jnp.pi) * jax.lax.atan(outputs[2]) * jax.lax.sin(outputs[3])
    Omega = omega_x + 1.0j*omega_y
    
    delta = 2*cfg.transmon.anharmonicity*(2/jnp.pi) * jax.lax.atan(outputs[4]) * jax.lax.sin(outputs[5])

    return Omega, delta

@hydra.main(config_name="system_params", version_base=None, config_path="./config")
def control_Hamiltonian(t: float, cfg: DictConfig):
    '''
    Control Hamiltonian for the single transmon in Eq.(7)
    
    Args:
      t (float): Current time (in seconds) at which to evaluate the control Hamiltonian.
      cfg (DictConfig): Hydra‐loaded configuration object containing system parameters, 
        for example:
        - cfg.transmon.dim: dimension of the Hilbert space  
        - cfg.transmon.anharmonicity: anharmonicity of the transmon  
        - any other nested fields defining pulse characteristics or coupling strengths
      
    Returns:
      jnp.array: Matrix representation of the Hamiltonian (see system_params.yaml for matrix size)
    '''
    I = qo.identity(cfg.transmon.dim)
    num = qo.num(cfg.transmon.dim)
    Omega, delta = pulse(t, model)
    
    return delta*num + (cfg.transmon.anharmonicity/2)*jnp.dot(num, num-I) + (Omega*a+jax.lax.conj(Omega)*a_dag)/2

def error_Hamiltonian(cfg: DictConfig):
    '''

    '''
    return cfg.system.error_strength * qo.num(cfg.transmon.dim)

def loss()

SyntaxError: incomplete input (3406110079.py, line 45)

In [10]:
# define operators
I = qo.identity(sys_dim)
a = qo.annihilate(sys_dim)
a_dag = qo.create(sys_dim)

# define control Hamiltonian & error Hamiltonian
#H_c = 
#H_e = 

NameError: name 'sys_dim' is not defined

Array([[0.       -0.j, 0.       -0.j, 0.       -0.j, 0.       -0.j],
       [1.       -0.j, 0.       -0.j, 0.       -0.j, 0.       -0.j],
       [0.       -0.j, 1.4142135-0.j, 0.       -0.j, 0.       -0.j],
       [0.       -0.j, 0.       -0.j, 1.7320508-0.j, 0.       -0.j]],      dtype=complex64)

### neural network setting

In [15]:
num_input = 1
num_output = 6
#lr = 

model = MLP(num_input, num_output, rngs=nnx.Rngs(0))
#optimizer = nnx.Optimizer(model, optax.lbfgs(lr))
y = model(x=jnp.ones((5, 1)))

In [16]:
y

Array([[15.629814, 15.024139, 16.238647, 16.590088, 15.046219, 14.143402],
       [15.629814, 15.024139, 16.238647, 16.590088, 15.046219, 14.143402],
       [15.629814, 15.024139, 16.238647, 16.590088, 15.046219, 14.143402],
       [15.629814, 15.024139, 16.238647, 16.590088, 15.046219, 14.143402],
       [15.629814, 15.024139, 16.238647, 16.590088, 15.046219, 14.143402]],      dtype=float32)

In [12]:
nnx.display(model)

MLP(
  input_layer=Linear(
    w=Param(
      value=Array(shape=(1, 32), dtype=float32)
    ),
    b=Param(
      value=Array(shape=(32,), dtype=float32)
    ),
    num_input=1,
    num_output=32
  ),
  hidden_layers1=Linear(
    w=Param(
      value=Array(shape=(32, 32), dtype=float32)
    ),
    b=Param(
      value=Array(shape=(32,), dtype=float32)
    ),
    num_input=32,
    num_output=32
  ),
  hidden_layers2=Linear(
    w=Param(
      value=Array(shape=(32, 32), dtype=float32)
    ),
    b=Param(
      value=Array(shape=(32,), dtype=float32)
    ),
    num_input=32,
    num_output=32
  ),
  output_layer=Linear(
    w=Param(
      value=Array(shape=(32, 6), dtype=float32)
    ),
    b=Param(
      value=Array(shape=(6,), dtype=float32)
    ),
    num_input=32,
    num_output=6
  )
)
