In [1]:
import torch as t
import matplotlib.pyplot as plt
from torchdiffeq import odeint_adjoint as odeint
# import torchdiffeq

In [2]:
class System():
    def __init__(self):
        self.T = 10
        self.NT = 11
        self.sigx = t.tensor([[0,1],[1,0]]).cdouble()
        self.sigy = t.tensor([[0,-1j],[1j,0]]).cdouble()
        self.sigz = t.tensor([[1,0],[0,-1]]).cdouble()
        self.ReLU = t.nn.ReLU()
    
    def _get_H(self,B = t.tensor([0,0,1.0])):
        return B[0]*self.sigx + B[1]*self.sigy + B[2]*self.sigz
    
    def pol_to_cart(self,R,theta,phi):
        theta = t.as_tensor(theta)
        phi = t.as_tensor(phi)
        out = t.zeros(3)
        out[0] = t.sin(theta)*t.cos(phi)
        out[1] = t.sin(theta)*t.sin(phi)
        out[2] = t.cos(phi)
        return R*out

In [3]:
B_heights = t.tensor([[0.0937],
        [0.3855],
        [0.8383],
        [0.8387],
        [1.0343],
        [0.8581],
        [0.3721],
        [0.6748],
        [0.9914],
        [0.2749],
        [0.9984]])
theta_heights = t.tensor([[1.3168],
        [1.5062],
        [1.5751],
        [1.6749],
        [0.7596],
        [1.8052],
        [1.3482],
        [0.7155],
        [1.6957],
        [1.7506],
        [1.8256]])
phi_heights = t.tensor([[-2.6207],
        [-2.7394],
        [-0.4964],
        [-0.8642],
        [-0.8476],
        [-0.0271],
        [ 1.6253],
        [ 1.1360],
        [ 0.4909],
        [ 1.8453],
        [ 2.3140]])

In [4]:
class Model(t.nn.Module,System):
    def __init__(self):
        super().__init__()
        super(t.nn.Module,self).__init__()
        self.time_places = t.linspace(0,self.T,self.NT).view(1,-1)
        self.dt = (self.time_places[0,1:] - self.time_places[0,:-1]).mean()
        # self.B_heights = t.nn.parameter.Parameter(t.rand(self.NT).view(-1,1).double())
        # self.theta_heights = t.nn.parameter.Parameter(t.rand(self.NT).view(-1,1).double())
        # self.phi_heights = t.nn.parameter.Parameter(t.rand(self.NT).view(-1,1).double())
        self.B_heights = t.nn.parameter.Parameter(B_heights.double())
        self.theta_heights = t.nn.parameter.Parameter(theta_heights.double())
        self.phi_heights = t.nn.parameter.Parameter(phi_heights.double())
        # self.B = t.nn.parameter.Parameter(self.init_B())

    def get_interpol_weights(self,times):
        times = t.as_tensor(times,dtype=t.double)
        dists = times.view(-1,1) - self.time_places
        weights = self.ReLU(1-dists.abs()/self.dt).square()
        weights = weights/weights.sum(1,keepdim=True)
        return weights
    
    def get_H(self, times):
        # print(times.shape)
        w = self.get_interpol_weights(times)
        B = w@self.B_heights
        theta = w@self.theta_heights
        phi = w@self.phi_heights
        
        B = self.pol_to_cart(B.squeeze(),theta.squeeze(),phi.squeeze())
        # print(B.shape)
        H = self._get_H(B)
        return H
    
    def get_jac(self, times, y=None):
        # print(y.shape)
        H = self.get_H(times)
        return t.block_diag(H,H)
    
    def forward(self, times: t.Tensor, y: t.Tensor) -> t.Tensor:
        # w = self.get_interpol_weights(times)
        # B = w@self.B_heights
        # theta = w@self.theta_heights
        # phi = w@self.phi_heights
        
        # B = self.pol_to_cart(B.squeeze(),theta,phi)
        # # print(B.shape)
        # H = self.get_H(B)
        H = self.get_H(times)
        U = y.view(2,2)
        return (-1j*H@U).flatten()

In [5]:
target_gate_adj = t.tensor([[0,1],[1,0]]).cdouble().adjoint()
def loss_func(U):
    return 1 - 0.25*t.square(t.abs(t.trace(target_gate_adj@U)))

In [6]:
obj = Model()
# obj = t.jit.script(Model())

In [20]:
y0 = 2*(t.rand(4) + 1j*t.rand(4) - 0.5*(1+1j)).cdouble()
traced_obj = t.jit.trace(obj, (t.rand(1)*10,y0))

torch.Size([1])
torch.Size([1])
torch.Size([1])


  times = t.as_tensor(times,dtype=t.double)
  theta = t.as_tensor(theta)
  phi = t.as_tensor(phi)


In [7]:
optimizer = t.optim.Adam(obj.parameters(),lr=1e-2)

In [8]:
# B0s = []
losss = []
def train(Nepochs):
    for i in range(Nepochs):
        optimizer.zero_grad()
        pred_y = odeint(obj,
                        y0=t.eye(2).cdouble().flatten(),
                        t=t.tensor([0.,10.]),
                        method='scipy_solver',
                        options={'solver':'BDF',
                                 'vectorized': True}, #BDF
                        adjoint_method="scipy_solver",
                        adjoint_options={'solver':'BDF'})
        print(pred_y.shape)
        # return pred_y
        U = pred_y.view(-1,2,2)
        loss = loss_func(U[-1])
        loss.backward()
        optimizer.step()

        with t.no_grad():
            losss.append(loss.item())
            # B0s.append(obj.B0.item())
            print(f"loss: {loss.item()}, step: {i}", end='\r')

In [60]:
from time import time
times = t.linspace(0,10,2000)

y0 = 2*(t.rand(4) + 1j*t.rand(4) - 0.5*(1+1j)).cdouble()
start = time()
for ti in times:
    obj(ti,y0)
print(time()-start)

1.163940191268921


In [9]:
train(1)

  times = t.as_tensor(times,dtype=t.double)


Perhaps convert_func_to_numpy is not needed
torch.Size([2, 4])
Perhaps convert_func_to_numpy is not needed
vjp_params:  tensor([[0.0000],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.0000],
        [1.1252]], dtype=torch.float64) tensor([[ 0.0000],
        [ 0.0000],
        [ 0.0000],
        [ 0.0000],
        [ 0.0000],
        [ 0.0000],
        [ 0.0000],
        [ 0.0000],
        [ 0.0000],
        [ 0.0000],
        [-0.2254]], dtype=torch.float64) tensor([[0.0000],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.4763]], dtype=torch.float64)
func_eval:  tensor([-0.4979-0.2814j, -0.5979+0.8398j,  0.5979+0.8398j, -0.4979+0.2814j],
       dtype=torch.complex128, grad_fn=<ReshapeAliasBackward0>)
vjp_params:  tensor([[0.0000e+00],
        [0.0000e+00],

In [87]:
from scipy.integrate import solve_ivp
def func_wrap(times,y):
    H = obj.get_H(times).detach()
    return -1j*H@y
def jac_wrap(times,y):
    H = obj.get_H(times).detach()
    return -1j*H
tmp = solve_ivp(func_wrap,y0=t.tensor([1,0]).cdouble(),t_span=(0,10),t_eval=[10],method='BDF',jac=jac_wrap)

In [88]:
tmp

  message: 'The solver successfully reached the end of the integration interval.'
     nfev: 244
     njev: 1
      nlu: 24
      sol: None
   status: 0
  success: True
        t: array([10])
 t_events: None
        y: array([[ 0.48241767+0.36998683j],
       [-0.27218211+0.74596508j]])
 y_events: None

In [None]:
fig, ax = plt.subplots()
ax.plot(losss)
ax.set_yscale('log')

In [None]:
fig, ax = plt.subplots()
ax.plot(losss)
ax.set_yscale('log')

In [None]:
def get_B(times, self = obj):
    with t.no_grad():
        w = self.get_interpol_weights(times)
        B = w@self.B_heights
        theta = w@self.theta_heights
        phi = w@self.phi_heights
    return B.detach().squeeze(), theta.detach().squeeze(), phi.detach().squeeze()
times = t.linspace(0,10,100).double()
B = get_B(times)

fig, ax = plt.subplots()
ax.plot(times,B[0])

In [None]:
def get_occs(self=obj):
    with t.no_grad():
        pred_ys = odeint(self,
                        y0=t.eye(2).cdouble().flatten(),
                        t=t.linspace(0,10,100),
                        method='scipy_solver',
                        options={'solver':'BDF'})
        Us = pred_ys.view(-1,2,2)
        occs = (Us@t.tensor([1,0]).cdouble()).abs().square()
    return occs
occs = get_occs()

In [None]:
fig, ax = plt.subplots()
ax.plot(occs);
print(occs[-1])

In [None]:
times = t.linspace(0,10,2).view(1,-1)
dt = (times[0,1:] - times[0,:-1]).mean()*1
vals = ((t.rand(times.shape[1])-0.5)*2).view(-1,1)

In [None]:
# def interpolate(time):
#     dists = time.view(-1,1) - times
#     return t.exp(-0.5*t.square(dists)/dt)@vals

def interpolate(time):
    dists = time.view(-1,1) - times
    weights = t.functional.F.relu(1-dists.abs()/dt)**(2)
    # weights = t.exp(-0.5*t.square(dists)/dt.square()*10)
    weights = weights/weights.sum(1,keepdim=True)
    return weights@vals

ts = t.linspace(0,10,300)
plt.figure(figsize=(15,5))
plt.plot(ts,interpolate(ts))
plt.vlines(times,0,vals)
plt.hlines(0,0,10)

In [None]:
t.functional.F.relu(1-t.linspace(0,10,11))