# RGN Modeling

The most recent possibly useful paper that I could find is: https://www.biorxiv.org/content/biorxiv/early/2018/02/14/265231.full.pdf

There are a lot of details missing so I expect that it will be difficult to implement, but the architecture is fairly simple. Feed sequence into an bi-LSTM and try to predict three bond characteristics (angle, extension and torsion). Pass the three predictions along with the current atoms for each residue into a "geometric unit", add each residue sequentially and deform the "nascent structure" appropriately. The last step is to calculate the loss, distance-based root mean square deviation (dRMSD), which accounts for global and local structural details and importantly does not require a specific orientation of the predicted structure since it only considers distance between one atom and all other atoms.

For training data the author uses targets from CASP 1-10 and tests results on CASP 11.

Task list:
<ul>
    <li>Create new bcolz array to attach sequence and structure together</li>
    <li>Pad structures to match length of sequences</li>
    <li>Handling of inaccurate PDB files</li>
</ul>

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import pandas as pd
import ipywidgets as ip
from matplotlib import pyplot as plt
import os
#import utils
from fastai import *
import matplotlib
import seaborn as sns
from tqdm import tqdm
import collections
from collections import Counter as cs
#import nglview as nv
import sys
import Bio.PDB as bio
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from keras.utils.np_utils import to_categorical
import torch.optim
import pdb
%matplotlib inline

Using TensorFlow backend.


In [85]:
import utils
from data import ProteinDataset, sequence_collate
from model import geometric_unit, pair_dist, dRMSD
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

In [86]:
plt.rcParams['figure.figsize'] = [16,6]

In [87]:
data_path = os.curdir + '/data/'
#load sequence file as a bcolz array for the sake of space and speed
#seq_file = load_array(data_path+'sequences.bc')
pdb_path = os.curdir + '/data/pdb/structures/pdb/'

In [88]:
#trn_samp = load_array(data_path+'train_samples.dat')
#val_samp = load_array(data_path+'validation_samples.dat')
#tst_samp = load_array(data_path+'test_samples.dat')
#keep = load_array(data_path+'keep.dat')

In [89]:
c1 = bcolz.carray(rootdir=data_path+'proteins_1.bc')

In [90]:
#for ix in range(len(c1)):
#    name, _, _ = c1[ix]
#    if name[0] == '2kw2':
#        print(ix)
#        break
#sd = np.setdiff1d(range(len(c1)), [1916])
#new_c1 = bcolz.carray(c1[sd], rootdir=data_path+'proteins_1.bc', mode='w')
#new_c1.flush()

In [549]:
shix = []
for ix in range(len(c1)):
    name, sequence, coords = c1[ix]
    length = len(sequence[0])
    if (length > 100) and (length < 150):
    #if length == 120:
        shix.append(ix)

In [550]:
shopro = bcolz.carray(c1[shix], rootdir=data_path+'proteins_short.bc', mode='w')
shopro.flush()

In [551]:
len(shopro)

354

## Pytorch Dataloader

First construct the dataloader for training the model

Know PDB file errors and issues:
<ul>
    <li>38 of 1992 chain_1 proteins have no coordinates, caused by weird files like pdb5da6.ent</li>
    <li>some chain_1 proteins have hetatms in the main coordinate section because the residues are special transformations of the standard residue (i.e. selenomethionone in pdb1rfe.ent)</li>
    <li>in 634 of 1992 chain_1 proteins the index of the last residue is greater than the number of residues in the sequence, because atoms in many files do not start at one (neither does sequence)</li>
</ul>

In [552]:
dataset = ProteinDataset(data_path, 'short', encoding='onehot')
trn_data = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=sequence_collate)

In [553]:
for i_batch, sample_batched in enumerate(trn_data):
    vec = sample_batched['sequence']
    print(i_batch, sample_batched['sequence'].size(),
         sample_batched['coords'].size())
    if i_batch == 3:
        break

(0, torch.Size([148, 16, 20]), torch.Size([444, 16, 3]))
(1, torch.Size([147, 16, 20]), torch.Size([441, 16, 3]))
(2, torch.Size([149, 16, 20]), torch.Size([447, 16, 3]))
(3, torch.Size([145, 16, 20]), torch.Size([435, 16, 3]))


Potential todos with PDB data because of exceptions and errors:
<ul>
    <li>Atoms with multiple possible positions (A, B)</li>
    <li>PDB files with multiple chains</li>
    <li>Masking to use chains with atoms that don't have position 1</li>
    <li>HETATMs like water can play a substantial role in the final folds</li>
    <li>Consider adjusting loss function to reduce penalty for atoms with multiple occupancy</li>
</ul>

NOTE: Always make input tensor a float and wrap the input as an autograd variable!!!

## RGN Model

In [554]:
aa2vec = bcolz.open(data_path + 'c3_embs.bc')

def create_emb_layer(aa2vec):
    aa2vec = torch.tensor(aa2vec, requires_grad=True)
    vocab_sz, embed_dim = aa2vec.size()
    emb_layer = nn.Embedding(vocab_sz, embed_dim)
    emb_layer.load_state_dict({'weight': aa2vec})

    return emb_layer, vocab_sz, embed_dim

In [567]:
class RGN(nn.Module):
    def __init__(self, hidden_size, num_layers, model_type='hardtanh', input_type='onehot', 
                 aa2vec=None, linear_units=None, input_size=21):
        super(RGN, self).__init__()

        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.input_type = input_type
        self.model_type = model_type
        self.grads = {}
        
        if self.input_type == 'onehot':
            self.input_size = input_size
        elif self.input_type == 'tokens':
            self.input_size = aa2vec.shape[1]
            self.embeds, vocab_sz, embed_dim = create_emb_layer(aa2vec)
        
        self.lstm = nn.LSTM(self.input_size, hidden_size, num_layers, bidirectional=True)
        
        if self.model_type == 'hardtanh':
            self.linear1 = nn.Linear(hidden_size*2, 3)
            self.linear2 = nn.Linear(hidden_size*2, 3)
            self.hardtanh = nn.Hardtanh()
        elif self.model_type == 'alphabet':
            u = torch.distributions.Uniform(-3.14, 3.14)
            self.alphabet = nn.Parameter(u.rsample(torch.Size([linear_units,3])))
            self.linear = nn.Linear(hidden_size*2, linear_units)
        
        #as per Mohammed, we simply use the identity matrix to define the first 3 residues
        self.A = torch.tensor([0., 0., 1.])
        self.B = torch.tensor([0., 1., 0.])
        self.C = torch.tensor([1., 0., 0.])

        #bond length vectors C-N, N-CA, CA-C
        self.avg_bond_lens = torch.tensor([1.329, 1.459, 1.525])
        #bond angle vector, in radians, CA-C-N, C-N-CA, N-CA-C
        self.avg_bond_angles = torch.tensor([2.034, 2.119, 1.937])

    
    def forward(self, sequences, lengths):
        max_len = sequences.size(0)
        batch_sz = sequences.size(1)
        lengths = torch.tensor(lengths, dtype=torch.long, requires_grad=False)
        order = [x for x,y in sorted(enumerate(lengths), key=lambda x: x[1], reverse=True)]
        
        abs_pos = torch.tensor(range(max_len), dtype=torch.float32).unsqueeze(1)
        abs_pos = (abs_pos * torch.ones((1, batch_sz))).unsqueeze(2)
        
        h0 = Variable(torch.zeros((self.num_layers*2, batch_sz, self.hidden_size)))
        c0 = Variable(torch.zeros((self.num_layers*2, batch_sz, self.hidden_size)))
        
        #set sequence input type
        if self.input_type == 'onehot':
            sequences = torch.tensor(sequences, dtype=torch.float32, requires_grad=True)
            pad_seq = torch.cat([sequences, abs_pos], 2)
        elif self.input_type == 'tokens':
            sequences = torch.tensor(sequences, dtype=torch.long, requires_grad=False)
            pad_seq = self.embeds(sequences)
            pad_seq = torch.cat([pad_seq, abs_pos], 2)
    
        packed = pack_padded_sequence(pad_seq[:, order], lengths[order], batch_first=False)
        
        lstm_out, _ = self.lstm(packed, (h0,c0))
        unpacked, _ = pad_packed_sequence(lstm_out, batch_first=False, padding_value=0.0)
        unpacked = unpacked[:, range(batch_sz)] #reorder to match target
        #unpacked.register_hook(self.save_grad('unpacked'))

        if self.model_type == 'hardtanh':
            sin_out = self.hardtanh(self.linear1(unpacked))
            cos_out = self.hardtanh(self.linear2(unpacked))
            #sin_out.register_hook(self.save_grad('sin_out'))
            out = torch.atan2(sin_out, cos_out)
            #out.register_hook(self.save_grad('out'))
        elif self.model_type == 'alphabet':
            softmax_out = F.softmax(self.linear(unpacked), dim=2)
            sine = torch.matmul(softmax_out, torch.sin(self.alphabet))
            cosine = torch.matmul(softmax_out, torch.cos(self.alphabet))
            out = torch.atan2(sine, cosine)
        
        #create as many copies of first residue as there are samples in the batch
        #broadcast = torch.ones((batch_sz, 3))
        #pred_coords = torch.stack([self.A*broadcast, self.B*broadcast, self.C*broadcast])
        
        out = out.permute(dims=(1,0,2))
        for num, seq in enumerate(out):
            A, B, C = self.A, self.B, self.C
            pcs = torch.stack([A, B, C])
            for ix,triplet in enumerate(seq[1:]):
                for i in range(3):
                    #A, B, C = pred_coords[-3], pred_coords[-2], pred_coords[-1]
                    T = self.avg_bond_angles[i] #angle_BCD
                    R = self.avg_bond_lens[i] #bond_CD
                    P = triplet[i] #torsionBC

                    D2 = torch.stack([-R*torch.cos(T),
                                      R*torch.cos(P)*torch.sin(T),
                                      R*torch.sin(P)*torch.sin(T)])

                    BC = C - B
                    bc = BC/torch.norm(C - B, 2)

                    AB = B - A

                    N = torch.cross(AB, bc)
                    n = N/torch.norm(torch.cross(AB, bc), 2)

                    M = torch.stack([bc, torch.cross(n, bc), n], dim=1)

                    D = torch.mm(M, D2.view(-1,1)).squeeze() + C
                    pcs = torch.cat([pcs, D.view(1,3)])
                    
                    A = pcs[-3]
                    B = pcs[-2]
                    C = pcs[-1]
            if num == 0:        
                pred_coords = pcs
            elif num == 1:
                pred_coords = torch.stack([pred_coords, pcs], 1)
            else:
                pred_coords = torch.cat([pred_coords, pcs.unsqueeze(1)], 1)
        #TODO: find bug in bmm (or maybe user error)
        #for ix, triplet in enumerate(out[1:]):
            #triplet.register_hook(self.save_grad('tr{}'.format(ix)))
        #    pred_coords = geometric_unit(pred_coords, triplet, 
        #                                 self.avg_bond_angles, 
        #                                 self.avg_bond_lens)
            #pred_coords.register_hook(self.save_grad('pc{}'.format(ix)))
            
        #pred_coords.register_hook(self.save_grad('pc'))
            
        #pdb.set_trace()
        return pred_coords
    
    def save_grad(self, name):
        def hook(grad): self.grads[name] = grad
        return hook

In [568]:
#for i_batch, sampled_batch in enumerate(trn_data):
#    inp_seq = sampled_batch['sequence']
#    inp_lens = sampled_batch['length']
#    rgn = RGN(20, 1, 'hardtanh', 'onehot')
#    out = rgn(inp_seq, inp_lens)
#    print(i_batch, inp_seq.size(), sampled_batch['coords'].size(), out.size())
    
#    if i_batch == 1:
#        break

In [569]:
def adaptive_lr(optimizer, step_size):
    #for now just linear scaling
    for param_group in optimizer.param_groups:
        param_group['lr'] += step_size
        new_lr = param_group['lr']
    
    return optimizer

In [574]:
rgn = RGN(200, 3, 'hardtanh', 'onehot')
drmsd = dRMSD()

In [575]:
optimizer = torch.optim.SGD(rgn.parameters(), lr=1e-3)
#optimizer = torch.optim.Adam(rgn.parameters(), lr=1e-3)

Next steps, try debugging gradient using https://gist.github.com/apaszke/f93a377244be9bfcb96d3547b9bc424d.

In [None]:
loss_history=[]
running_loss = 0.0
last_batch = len(trn_data) - 1
c = 0
for epoch in range(20):
    c = 0
    #for i, data in tqdm(enumerate(trn_data)):
    for i, data in enumerate(trn_data):
        try:
            names = data['name']
            coords = data['coords']

            optimizer.zero_grad()
            outputs = rgn(data['sequence'], data['length'])

            loss = drmsd(outputs, coords)

            #print(i, loss.item(), rgn.embeds.state_dict()['weight'][0][0])
            loss.backward()
            nn.utils.clip_grad_norm_(rgn.parameters(), max_norm=1)
            optimizer.step()

            running_loss += loss.item()
            if (i != 0) and (i % last_batch == 0):
                print('Epoch {}, Loss {}'.format(epoch, running_loss/(i-c)))
                running_loss = 0.0
        except KeyboardInterrupt:
            raise
        except:
            c += 1
            pass

print('Finished Training')

In [435]:
#rgn.embeds.state_dict()['weight'][0]

In [None]:
torch.save(rgn, data_path+'models/rgn1.pt')
#rgn = torch.load(data_path+'models/rgn1.pt')

In [None]:
#plt.plot(np.array(loss_history)[:, 0], np.array(loss_history)[:, 1])

In [None]:
torch.__version__()

## Validation

To actually reproduce the results from the RGN paper, I need to use the proteinnet dataset, https://github.com/aqlaboratory/proteinnet. In particular, Mohammed used the CASP 11 data to test his model. The full dataset may be too large for my memory without deleting all the hard work I did with the pdb files. However, if I delete all the PDB files are currently have, I at least still have the tools to reproduce the datasets if necessary.


## Geometric Units

Some basic information about bond angles and lengths can be found here: https://www.ruppweb.org/Xray/tutorial/protein_structure.htm

I'll use this as my primary source, but it may be somewhat inaccurate (I have since found a more reliable source, saved in my Dropbox).

To validate that my implementation of the NERF algorithm is correct, I want to get pdb file, use BioPDB to calculate the torsion angles, and then use the ground truth torsion angles to reconstruct the coordinates. The goal is for the dRMSD between the rendered structure and the gt structure to be zero. This would imply that if my LSTM model can correctly predict the torsion angles the calculated coordinates should match the gt PDB file.

In [443]:
#First find a pdb file with no missing coordinates
chain_1 = load_array(data_path+'proteins_1.bc')

In [444]:
for ix, chain in enumerate(chain_1[:20]):
    msk = chain[2].sum(1) == 0
    if np.any(msk) == False:
        print(ix)

2
15
19


In [445]:
chain_1[2][0]

['1zur']

Protein at index 2 in the proteins_1.bc dataset has no missing atoms, so we can use it for testing

In [464]:
t_angles, b_angles, b_len = utils.gt_dihedral_angles(pdb_path+'pdb1zur.ent')

Note that angles are in radians, whereas my implementation assumes degrees (can remove the 180 muliplication). Angles in omega are all roughly equal to $\pi$ in accordance with literatue I've read

In [465]:
A = torch.tensor(chain_1[2][2][0], dtype=torch.float)
B = torch.tensor(chain_1[2][2][1], dtype=torch.float)
C = torch.tensor(chain_1[2][2][2], dtype=torch.float)

#A = torch.tensor([0., 0., 1.], dtype=torch.float)
#B = torch.tensor([0., 1., 0.], dtype=torch.float)
#C = torch.tensor([1., 0., 0.], dtype=torch.float)

#avg_bond_lens = torch.tensor([1.329, 1.459, 1.525])
#avg_bond_angles = torch.tensor([2.034, 2.119, 1.937])

pred_coords = torch.stack([A, B, C])

for ix,triplet in enumerate(t_angles):
    for i in range(3):
        T = b_angles[ix][i] #avg_bond_angles[i] #angle_BCD
        R = b_len[ix][i] #avg_bond_lens[i] #bond_CD
        P = triplet[i] #torsionBC
        
        D2 = torch.stack([-R*torch.cos(T),
                          R*torch.cos(P)*torch.sin(T),
                          R*torch.sin(P)*torch.sin(T)])

        BC = C - B
        bc = BC/torch.norm(C - B, 2)

        AB = B - A

        N = torch.cross(AB, bc)
        n = N/torch.norm(torch.cross(AB, bc), 2)

        M = torch.stack([bc, torch.cross(n, bc), n], dim=1)

        D = torch.mm(M, D2.view(-1,1)).squeeze() + C
        pred_coords = torch.cat([pred_coords, D.view(1,3)])
        
        A = pred_coords[-3]
        B = pred_coords[-2]
        C = pred_coords[-1]

In [466]:
pair_dist(pred_coords[:5])

tensor([[ 0.0000,  1.4874,  2.5037,  3.6471,  4.9417],
        [ 1.4874,  0.0000,  1.5289,  2.4205,  3.8290],
        [ 2.5037,  1.5289,  0.0000,  1.3304,  2.4853],
        [ 3.6471,  2.4205,  1.3304,  0.0000,  1.4620],
        [ 4.9417,  3.8290,  2.4853,  1.4620,  0.0000]])

In [462]:
gt_coords = utils.create_targets(pdb_path+'pdb1zur.ent')
print(gt_coords)
#pair_dist(gt_coords[:5])

[[]]


Notice that there is a considerable amount of error injected into the geometric units when just using average bond lengths and angles. In particular, since bond lengths are fixed, it is actually impossible to train any model that can achieve zero loss (dRMSD is directly affected by the bond lengths). Using the identity matrix as Mohammed suggested also leads to larger errors even when using the gt torsions. I dislike the idea of lazily allowing these sources of loss to remain in the model, but I want to see if it is possible to reproduce the paper's results before jiggering with the architecture. At the very least, it seems like these parameters should be learnable.

## Parse PDB files

The target for each protein structure is the distance between one atom and all the others. The only atoms that are used are the  N, Cα, and C' for each residue. So if a protein sequence has 100 residues the target should be a 300x300 matrix. 

To start I think that I need to restrict myself to single chain proteins because it seems like multi-chain proteins need to be post-processed.

Alphabet matrix: http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.122.1578&rep=rep1&type=pdf

Torsion space conversion: https://pdfs.semanticscholar.org/6310/0da463862f35f4188754bdeb7f41e24dbfb7.pdf

In [9]:
#import warnings
#warnings.filterwarnings("ignore")

#keep = []
#for ix,s in tqdm(enumerate(trn_samp)):
#p = bio.PDBParser()
#structure = p.get_structure('X', pdb_path+'pdb1rfe.ent')

#for model in structure:
#    for chain in model:
#        for ix, residue in enumerate(chain):
#            if residue.get_id()[0] == ' ':
                #print(residue.get_id()[1])
#                if ix == 0:
#                    ex = residue.get_id()[1]  
#                elif residue.get_id()[1] == ex+1:
#                    ex = residue.get_id()[1]
#                else:
#                    print("error at {}".format(residue.get_id()[1]))

## LSTM example

Try the example from http://pytorch.org/tutorials/beginner/nlp/sequence_models_tutorial.html.

In [101]:
import torch.optim as optim
torch.manual_seed(1)

<torch._C.Generator at 0x7f6ad4d41b70>

In [102]:
lstm = BiLSTM(3,3,1)
inputs = [Variable(torch.randn((1,3))) for _ in range(5)]
inputs

[Variable containing:
 -0.1186  0.4903  0.8349
 [torch.FloatTensor of size 1x3], Variable containing:
  0.8894  0.4148  0.0507
 [torch.FloatTensor of size 1x3], Variable containing:
 -0.9644 -2.0111  0.5245
 [torch.FloatTensor of size 1x3], Variable containing:
  2.1332 -0.0822  0.8388
 [torch.FloatTensor of size 1x3], Variable containing:
 -1.3233  0.0701  1.2200
 [torch.FloatTensor of size 1x3]]

In [103]:
lstm

BiLSTM (
  (lstm): LSTM(3, 3)
)

In [104]:
hidden = (Variable(torch.randn((1,1,3))), 
         Variable(torch.randn((1,1,3))))
hidden

(Variable containing:
 (0 ,.,.) = 
   0.4251 -1.2328 -0.6195
 [torch.FloatTensor of size 1x1x3], Variable containing:
 (0 ,.,.) = 
   1.5133  1.9954 -0.6585
 [torch.FloatTensor of size 1x1x3])

In [108]:
#for i in inputs:
#    out, h = lstm.forward(i.view(1,1,-1), hidden)

In [109]:
hidden

(Variable containing:
 (0 ,.,.) = 
   0.4251 -1.2328 -0.6195
 [torch.FloatTensor of size 1x1x3], Variable containing:
 (0 ,.,.) = 
   1.5133  1.9954 -0.6585
 [torch.FloatTensor of size 1x1x3])

In [116]:
inputs = torch.cat(inputs).view(len(inputs), 1, 3)
hidden = (Variable(torch.randn((1,1,3))), 
         Variable(torch.randn((1,1,3))))
#out, hidden = lstm(inputs, hidden)
inputs

Variable containing:
(0 ,.,.) = 
 -0.1186  0.4903  0.8349

(1 ,.,.) = 
  0.8894  0.4148  0.0507

(2 ,.,.) = 
 -0.9644 -2.0111  0.5245

(3 ,.,.) = 
  2.1332 -0.0822  0.8388

(4 ,.,.) = 
 -1.3233  0.0701  1.2200
[torch.FloatTensor of size 5x1x3]

In [26]:
inputs

Variable containing:
(0 ,.,.) = 
  0.6133 -0.2240  1.8343

(1 ,.,.) = 
 -0.1765  0.6837  1.2409

(2 ,.,.) = 
 -0.3073 -1.0962  1.6789

(3 ,.,.) = 
  0.2860 -0.4774 -0.1175

(4 ,.,.) = 
  0.1739 -0.1030 -0.5680
[torch.FloatTensor of size 5x1x3]

In [23]:
hidden

(Variable containing:
 (0 ,.,.) = 
  -0.3807 -0.2034 -0.2926
 [torch.FloatTensor of size 1x1x3], Variable containing:
 (0 ,.,.) = 
  -0.7260 -0.3826 -0.8743
 [torch.FloatTensor of size 1x1x3])

In [21]:
out

Variable containing:
(0 ,.,.) = 
 -0.0823  0.0944 -0.1291

(1 ,.,.) = 
 -0.0851 -0.1749 -0.0420

(2 ,.,.) = 
 -0.1065 -0.3703 -0.0488

(3 ,.,.) = 
 -0.2781 -0.2692 -0.2353

(4 ,.,.) = 
 -0.3807 -0.2034 -0.2926
[torch.FloatTensor of size 5x1x3]

## Pad Packed Example
Try the example from https://discuss.pytorch.org/t/simple-working-example-how-to-use-packing-for-variable-length-sequence-inputs-for-rnn/2120

In [54]:
batch_size = 3
max_length = 3
hidden_size = 2
n_layers = 1

batch_in = torch.zeros((batch_size, 1, max_length))

vec_1 = torch.FloatTensor([[1,2,3]])
vec_2 = torch.FloatTensor([[1,2,0]])
vec_3 = torch.FloatTensor([[1,0,0]])

batch_in[0] = vec_1
batch_in[1] = vec_2
batch_in[2] = vec_3

batch_in = Variable(batch_in)

seq_lengths = [3,2,1]

pack = nn.utils.rnn.pack_padded_sequence(batch_in, seq_lengths, batch_first=True)

pack

PackedSequence(data=Variable containing:
 1  2  3
 1  2  0
 1  0  0
[torch.FloatTensor of size 3x3]
, batch_sizes=[3])

In [58]:
lstm = nn.RNN(max_length, hidden_size, n_layers, batch_first=True)
h0 = Variable(torch.randn(n_layers, batch_size, hidden_size))
out, _ = lstm(pack, h0)
unpacked, unpacked_len = nn.utils.rnn.pad_packed_sequence(out)
unpacked

Variable containing:
(0 ,.,.) = 
 -0.7852 -0.9670
  0.1932 -0.7175
 -0.8542 -0.4358
[torch.FloatTensor of size 1x3x2]

## Loss Testing

In [203]:
class _Loss(nn.Module):
    def __init__(self, size_average=True, reduce=True):
        super(_Loss, self).__init__()
        self.size_average = size_average
        self.reduce = reduce

In [204]:
def _assert_no_grad(tensor):
    assert not tensor.requires_grad

In [193]:
def _pairwise_loss(lambd, input, target, size_average=True, reduce=True):
    #if target.requires_grad:
    d = lambd(input, target)
    return d
        #if not reduce:
        #    return d
        #return torch.mean(d) if size_average else torch.sum(d)

In [194]:
def batch_pair_dist(x, y=None):
    x = x.permute(dims=(1,0,2))
    x_norm = (x**2).sum(2).view(x.size(0), -1, 1)
    
    y_t = x.permute(0,2,1)
    y_norm = x_norm.view(x.size(0), 1, -1)
    
    dist = x_norm + y_norm - 2*torch.bmm(x, y_t)
    dist = torch.clamp(dist, 0.0, np.inf)
    
    return torch.pow(dist, 0.5)

In [195]:
def drmsd_loss(input, target, size_average=True, reduce=True):
    return _pairwise_loss(lambda a,b: batch_pair_dist(a,b), input, target, size_average, reduce)

In [196]:
class dRMSDLoss(_Loss):
    def __init__(self, size_average=True, reduce=True):
        super(dRMSDLoss, self).__init__(size_average, reduce)

    def forward(self, input, target):
        _assert_no_grad(target)
        return drmsd_loss(input, target, size_average=self.size_average, reduce=self.reduce)

In [197]:
drmsd_loss(out, Variable(sampled_batch['coords'], requires_grad=True))

RuntimeError: $ Torch: not enough memory: you tried to allocate 1GB. Buy new RAM! at /opt/conda/conda-bld/pytorch-cpu_1524577316810/work/aten/src/TH/THGeneral.c:218

In [202]:
loss = dRMSDLoss()
loss(out, sampled_batch['coords'])

RuntimeError: $ Torch: not enough memory: you tried to allocate 0GB. Buy new RAM! at /opt/conda/conda-bld/pytorch-cpu_1524577316810/work/aten/src/TH/THGeneral.c:218

## Folding Layer

In [46]:
from torch.autograd import gradcheck, Function

In [None]:
class Fold(Function):
    
    @staticmethod
    def forward(ctx, pred_coords, pred_torsions, bond_angles, bond_lens):
        
    @staticmethod    
    def backward(ctx, grad_output):

def geometric_unit(pred_coords, pred_torsions, bond_angles, bond_lens):
    for i in range(3):
        #coordinates of last three atoms
        A, B, C = pred_coords[-3], pred_coords[-2], pred_coords[-1]

        #internal coordinates
        T = bond_angles[i]
        R = bond_lens[i]
        P = pred_torsions[:, i]

        #6x3 one triplet for each sample in the batch
        D2 = torch.stack([-R*torch.ones(P.size())*torch.cos(T), 
                          R*torch.cos(P)*torch.sin(T),
                          R*torch.sin(P)*torch.sin(T)], dim=1)

        #6x3 one triplet for each sample in the batch
        BC = C - B
        bc = BC/torch.norm(BC, 2)

        AB = B - A

        N = torch.cross(AB, bc)
        n = N/torch.norm(N, 2)

        M = torch.stack([bc, torch.cross(n, bc), n], dim=2)

        D = torch.bmm(M, D2.view(-1,3,1)).squeeze() + C
        pred_coords = torch.cat([pred_coords, D.view(1,-1,3)])
    
    return pred_coords