In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from jax import jit, grad, vmap
import matplotlib.pyplot as plt
import qcsys as qs
import jax.numpy as jnp
from rar_helpers import *

In [3]:
import jaxquantum as jqt

# Relevant Hamiltonian Terms for a Time Dep Simulation

In [15]:
metrics, system, ϕ0, metrics0, system0 = get_metrics_normal_rar({
    "ATS__phi_sum_ext": 0.25,
    "ATS__phi_delta_ext": 0,
})

In [16]:
H_full = system.get_H()

phi_op = system.params["phi"] 
a_op = system.params["a"]
b_op = system.params["b"]
c_op = system.params["c"]

phi_sum = system.devices[1].params["phi_sum_ext"]
phi_delta = system.devices[1].params["phi_delta_ext"]
ATS_Ej = system.devices[1].params["Ej"]
ATS_dEj = system.devices[1].params["dEj"]
ATS_Ej2 = system.devices[1].params["Ej2"]

omega_a = metrics0["ω_ResonatorA"]
omega_b = metrics0["ω_ResonatorB"]
omega_c = metrics0["ω_ATS"]

H_rot = omega_a * a_op.dag() * a_op + omega_b * b_op.dag() * b_op 
H_static = H_full - qs.ATS.get_H_nonlinear_static(phi_op, ATS_Ej, ATS_dEj, ATS_Ej2, phi_sum, phi_delta)

In [17]:
jnp.max(jnp.abs((H_static - H_full).data ))

Array(0., dtype=float64)

In [21]:
# Sanity Check (at the right bias point)
H_static_expected = omega_a * a_op.dag() * a_op + omega_b * b_op.dag() * b_op + omega_c * c_op.dag() * c_op
constant_offset = (8.93783831) * jqt.identity_like(H_rot)
jnp.max(jnp.abs((H_static - H_static_expected - constant_offset).data))

Array(1.66802394e-09, dtype=float64)

In [22]:
phi = system.params["phi"]
Ejb = system.devices[1].params["Ej"]
H_drive_factor = +2*Ejb*jqt.cosm(phi) # only valid at the (phi_sum,phi_delta) = (0.25,0.25) bias point

In [23]:
# The full system Hamiltonian H = H_static + sin(eps(t)) * H_drive_factor
H_rot_qt = jqt.jqt2qt(H_rot)
H_static_qt = jqt.jqt2qt(H_static)
H_drive_factor_qt = jqt.jqt2qt(H_drive_factor) 

# Run Sweep

In [29]:
N_CONS = {
    "resonator_a": {
        "bare": 8,
        "normal": 8,
        "truncated": 8,
    },
    "ats": {
        "bare": 5,
        "normal": 100,
        "truncated": 10,
    },
    "resonator_b": {
        "bare": 8,
        "normal": 8,
        "truncated": 8,
    }
}

In [32]:
initial_state = jqt.basis(N_CONS["resonator_a"]["truncated"], 1) ^ jqt.basis(N_CONS["ats"]["truncated"], 0) ^ jqt.basis(N_CONS["resonator_b"]["truncated"], 0)

omega_d = omega_a - omega_b
epsilon_0 = 1

def Ht(t):
    return H_static + jnp.sin(epsilon_0 * omega_d * t) * H_drive_factor

g_3 = jnp.real(metrics0["g_3"])

t_BS = jnp.pi/g_3

ts = jnp.linspace(0, t_BS, 100)

In [33]:
res = jqt.sesolve(Ht, initial_state, ts)

100% |[35m██████████[0m| [04:43<00:00,  2.83s/%]


In [38]:
res[-1].ptrace(1)

Quantum array: dims = ((10,), (10,)), bdims = (), shape = (10, 10), type = oper
Qarray data =
[[ 0.15656616+0.j          0.        +0.j         -0.00294079-0.01599856j
   0.        +0.j         -0.00049961-0.01644442j  0.        +0.j
   0.00611516-0.01211252j  0.        +0.j          0.00536113+0.01785983j
   0.        +0.j        ]
 [ 0.        +0.j          0.12841971+0.j          0.        +0.j
  -0.02670501+0.00727236j  0.        +0.j         -0.02592368+0.0063989j
   0.        +0.j         -0.02127212-0.00192519j  0.        +0.j
  -0.00217783+0.00093789j]
 [-0.00294079+0.01599856j  0.        +0.j          0.10768361+0.j
   0.        +0.j          0.01950022-0.00076459j  0.        +0.j
  -0.00254493-0.01515785j  0.        +0.j          0.01216492+0.02843343j
   0.        +0.j        ]
 [ 0.        +0.j         -0.02670501-0.00727236j  0.        +0.j
   0.09105446+0.j          0.        +0.j          0.01502004-0.00285153j
   0.        +0.j         -0.00504622-0.00279811j  0.       