In [39]:
import torch as t
from time import time

In [7]:
Id = t.diag(t.ones(5))
M1 = t.rand((5,5))
Id, M1

(tensor([[1., 0., 0., 0., 0.],
         [0., 1., 0., 0., 0.],
         [0., 0., 1., 0., 0.],
         [0., 0., 0., 1., 0.],
         [0., 0., 0., 0., 1.]]),
 tensor([[0.9488, 0.8391, 0.1563, 0.6694, 0.0181],
         [0.2093, 0.1539, 0.5345, 0.3523, 0.4059],
         [0.7389, 0.6839, 0.4605, 0.1577, 0.5186],
         [0.6262, 0.2533, 0.8873, 0.7578, 0.4532],
         [0.8747, 0.1859, 0.8035, 0.5447, 0.6156]]))

In [19]:
class Net(t.nn.Module):
    def __init__(self, Id = Id, M1 = M1):
        super(Net,self).__init__()
        self.Id = Id
        self.M1 = M1
    
    def forward(self,pulse):
        out = Id
        for i in range(100):
            out = t.matmul(Id,t.matrix_exp(-pulse[i]*self.M1))
        return out

In [20]:
model = Net()

In [21]:
init_pulse = t.rand(100)

In [22]:
traced_model = t.jit.trace(model,init_pulse)

In [46]:
start = time()
for i in range(1000):
    model(init_pulse)
finish = time()
print(finish - start)

31.851975679397583


In [49]:
start = time()
for i in range(1000):
    traced_model(init_pulse)
finish = time()
print(finish - start)

0.3480522632598877


# So great improvement!
# But can it handle autograd and complex values and so on?

In [60]:
Id = t.diag(t.ones((5),dtype=t.complex128))
M1 = t.rand((5,5)).type(t.complex128)
Id, M1

(tensor([[1.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j],
         [0.+0.j, 1.+0.j, 0.+0.j, 0.+0.j, 0.+0.j],
         [0.+0.j, 0.+0.j, 1.+0.j, 0.+0.j, 0.+0.j],
         [0.+0.j, 0.+0.j, 0.+0.j, 1.+0.j, 0.+0.j],
         [0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 1.+0.j]], dtype=torch.complex128),
 tensor([[0.6340+0.j, 0.5379+0.j, 0.9402+0.j, 0.2743+0.j, 0.6731+0.j],
         [0.6342+0.j, 0.2932+0.j, 0.1673+0.j, 0.5774+0.j, 0.3775+0.j],
         [0.6241+0.j, 0.5165+0.j, 0.4739+0.j, 0.4809+0.j, 0.8085+0.j],
         [0.2621+0.j, 0.4750+0.j, 0.9476+0.j, 0.3699+0.j, 0.9760+0.j],
         [0.9825+0.j, 0.3388+0.j, 0.9130+0.j, 0.2982+0.j, 0.1757+0.j]],
        dtype=torch.complex128))

In [113]:
class Net(t.nn.Module):
    def __init__(self, Id = Id, M1 = M1):
        super(Net,self).__init__()
        self.Id = Id
        self.M1 = M1
        self.pulse = t.nn.parameter.Parameter(t.rand(100).type(t.complex128))
    
    def forward(self,pulse):
        out = Id
        for i in range(100):
            out = t.matmul(out,t.matrix_exp(-1j*pulse[i]*self.M1))
        return out

In [114]:
model = Net()

In [115]:
traced_model = t.jit.trace(model,init_pulse)

In [116]:
new_pulse = t.rand(100)

In [117]:
start = time()
for i in range(1000):
    model(new_pulse)
finish = time()
print(finish - start)

36.5143039226532


In [118]:
start = time()
for i in range(1000):
    traced_model(new_pulse)
finish = time()
print(finish - start)

31.740001916885376


In [101]:
out = traced_model(init_pulse)
loss = t.trace(t.abs(M1-out))
optimizer = t.optim.Adam(traced_model.parameters(),lr=1e-2)
loss.backward()
optimizer.step()

# Seems to be working!!
# ... Damn it, I had made a mistake, such that the gain is, in fact, not very large