# (L)BFGS

In [1]:
import torch
from get_coords import get_coords
import numpy as np
import matplotlib.pyplot as plt
from structure import Structure
from G_phi_psi import make_distmap
from optimize import loss
from G_phi_psi import G

coords = get_coords('1a02F00')

backbone = torch.from_numpy(coords[:, 4:][coords[:, 3] != 'CB'].astype(np.float))
cbeta = torch.from_numpy(coords[:, 4:][coords[:, 3] == 'CB'].astype(np.float))

phi = torch.load('1a02F00_phi.pt')
psi = torch.load('1a02F00_psi.pt')

phi = np.radians(phi[1:])
psi = np.radians(psi[:-1])

with open('1a02F00.fasta') as f:
    f.readline()
    seq = f.readline()
    
# real distmap
c = []
for i in coords:
    if i[3] == 'CA' and i[2] == 'G':
        c.append(i[4:])
    elif i[3] == 'CB':
        c.append(i[4:])
        
c = torch.tensor(c, dtype=torch.float)

dist_map_real = make_distmap(c)

G_1a02 = Structure(phi, psi, dist_map_real, seq)



In [2]:
torsion = torch.cat((phi, psi)).requires_grad_()

In [3]:
# initialize B (inverse Hessian approximate), xk, gk

B = torch.diag(torch.ones(len(torsion)))
xk = torch.zeros(len(torsion))
gk = torch.zeros(len(torsion))

In [4]:
pred = G(torsion, seq)

In [5]:
L = loss(pred, dist_map_real)

**Calculate $x_{k + 1}$ and $g_{k + 1}$**

In [6]:
L.backward()
xk1 = torsion.detach()
gk1 = torsion.grad

**Calculate $p_k$ and $q_k$**

In [37]:
pk = (xk1 - xk).to(torch.float)
qk = (gk1 - gk).to(torch.float)

**Calculate $V$ nad $B_{k+1}$**

In [44]:
V = (pk - B.mv(qk.t())).view(1, -1)
#B = B -  

In [47]:
V.mv(qk)

tensor([-6.7386e+10])

In [48]:
V.mm(V.t())

tensor([[6.7386e+10]])

In [7]:
dist_map_real

tensor([[ 0.0000,  5.5745,  5.2658,  ..., 74.2745, 76.5421, 79.8635],
        [ 5.5745,  0.0000,  5.3023,  ..., 72.9742, 74.9657, 78.2487],
        [ 5.2658,  5.3023,  0.0000,  ..., 71.2431, 73.5717, 76.6548],
        ...,
        [74.2745, 72.9742, 71.2431,  ...,  0.0000,  5.3974,  7.1309],
        [76.5421, 74.9657, 73.5717,  ...,  5.3974,  0.0000,  5.5673],
        [79.8635, 78.2487, 76.6548,  ...,  7.1309,  5.5673,  0.0000]])

In [10]:
B.mv(qk)

tensor([-1.1488e+05,  2.1635e+03, -7.4802e+04,  2.1432e+04, -3.8222e+04,
         1.4692e+04, -1.6435e+04,  6.5156e+03, -8.9798e+01,  1.0685e+04,
        -5.8823e+03,  2.0399e+03, -5.9118e+02,  3.1984e+03, -9.7868e+02,
         1.2700e+03, -3.7121e+02,  1.2240e+03, -2.1223e+02,  3.8118e+02,
        -2.8026e+02,  1.0990e+02, -4.3805e+02,  1.3075e+02, -1.2684e+02,
         1.1224e+02, -8.0126e+01,  2.2132e+01, -1.8380e+01,  6.6903e+01,
        -3.3274e+01,  1.7839e+01, -1.6132e+01,  1.1425e+01, -6.1699e+00,
         9.5007e+00, -2.2296e+00,  6.7154e+00,  1.8912e+00,  4.8588e+00,
         1.5443e+00,  2.7634e+00,  1.8943e+00,  3.0088e+00,  1.2432e+00,
         1.3803e+00,  1.5451e+00,  7.6785e-01,  7.4039e-01,  9.3091e-01,
         3.3175e-01,  7.9112e-02,  1.7935e+04, -1.6824e+05,  2.5172e+04,
        -9.7472e+04,  4.3732e+04, -5.2542e+04,  2.8668e+04, -2.8468e+04,
         2.0907e+04, -4.5536e+03,  1.5193e+04, -1.0680e+04,  6.8835e+03,
        -2.8369e+03,  5.2156e+03, -2.5516e+03,  2.6

In [14]:
gk.dtype

torch.float32

In [12]:
torsion = torch.cat((phi, psi)).requires_grad_()

In [13]:
opt = torch.optim.LBFGS([torsion])

In [14]:
for i in range(10):
    def closure():
        opt.zero_grad()
        p = G(torsion, seq)
        L = loss(p, dist_map_real)
        L.backward()
        print(L.item())
        return L
    
    opt.step(closure)

1480.7359619140625
1481.500732421875
2070.434814453125
1521.37109375
1521.7982177734375
3680.231689453125
1556.179443359375
1580.0145263671875
1533.39794921875
1554.70703125
1770.52734375
1815.6947021484375
1798.40673828125
1814.0816650390625
1867.806640625
1963.1470947265625
1947.73974609375
1795.809814453125
1569.416259765625
2153.6708984375
2748.1171875
4525.71728515625
5257.833984375
3411.46240234375
7224.27880859375
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan


KeyboardInterrupt: 

In [6]:
def closure():
    opt.zero_grad()
    p = G(torsion, seq)
    L = loss(p, dist_map_real)
    L.backward()
    return L

In [7]:
opt.step(closure)

tensor(1480.7360, grad_fn=<NegBackward>)