In [66]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as du
import torch.nn.functional as F
from torch.utils import data
from torch.utils.data import Dataset
from collections import defaultdict
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
import pandas as pd
import sidechainnet as scn
import random
import sklearn
import pytorch3d
from pytorch3d import transforms

In [122]:
class IPA(nn.Module):
    '''
    Invariant Point Attention
    '''
    def __init__(self, c_s = 384, c_z = 16, Nq=4, Nv=8, c=16, num_heads = 12):
        super(IPA, self).__init__()
        self.c = c
        self.num_heads = num_heads
        self.Nq = Nq
        self.Nv = Nv
        self.wl = np.sqrt(1/3)
        self.wc = np.sqrt(2/(9*Nq))
        
        self.q = nn.ModuleList([nn.Linear(c_s, c, bias = False) for i in range(num_heads)])
        self.k = nn.ModuleList([nn.Linear(c_s, c, bias = False) for i in range(num_heads)])
        self.v = nn.ModuleList([nn.Linear(c_s, c, bias = False) for i in range(num_heads)])
        self.qp = nn.ModuleList([nn.Linear(c_s, self.Nq*3, bias = False) for i in range(num_heads)])
        self.kp = nn.ModuleList([nn.Linear(c_s, self.Nq*3, bias = False) for i in range(num_heads)])
        self.vp = nn.ModuleList([nn.Linear(c_s, self.Nv*3, bias = False) for i in range(num_heads)])
        self.b = nn.ModuleList([nn.Linear(c_z, 1, bias = False) for i in range(num_heads)])
        self.gamma = nn.Parameter(torch.rand(num_heads))
        self.fc1 = nn.Linear(num_heads * c_z, c_s)
        self.fc2 = nn.Linear(num_heads * c_z, c_s)
        self.fc3 = nn.Linear(num_heads * self.Nv*3, c_s)
        
    def forward(self, s, z, t):
        o_hat_list = []
        o_list = []
        o_p_list = []
        g = F.softplus(self.gamma)
        for h in range(self.num_heads):
            query = self.q[h](s).unsqueeze(dim=2)
            key = self.k[h](s).unsqueeze(dim=2)
            value = self.v[h](s).unsqueeze(dim=2)
            query_p = self.qp[h](s)
            key_p = self.kp[h](s)
            value_p = self.vp[h](s)
            query_p = torch.reshape(query_p, (query_p.shape[0], query_p.shape[1], self.Nq, 3)).unsqueeze(dim=2)     
            key_p = torch.reshape(key_p, (key_p.shape[0], key_p.shape[1], self.Nq, 3)).unsqueeze(dim=2)
            value_p = torch.reshape(value_p, (value_p.shape[0], value_p.shape[1], self.Nv, 3)).unsqueeze(dim=2)
            bias = self.b[h](z)
            a = 1/np.sqrt(self.c) * torch.einsum("bihc,bjhc->bijh", query, key) + bias
            ti = torch.einsum("bilm,bihpl->bihpm", t[0], query_p) + t[1].unsqueeze(dim=2).unsqueeze(dim=2)
            tj = torch.einsum("bjlm,bjhpl->bjhpm", t[0], key_p) + t[1].unsqueeze(dim=2).unsqueeze(dim=2)
            sqrt_diff = torch.norm(ti.unsqueeze(dim=2) - tj.unsqueeze(dim=1), dim=-1) ** 2
            a -= g[h] * self.wc/2 * torch.sum(sqrt_diff, dim=-1)
            a = F.softmax(self.wl * a, dim=2)
            o_hat = torch.sum(torch.einsum("bmnh,booc->bmnhc", a, z), dim=2)
            o = torch.sum(torch.einsum("bmnh,bnhc->bmnhc", a, value), dim=2)
            t_v = torch.einsum("bjkl,bjhmk->bjhml", t[0], value_p) + t[1].unsqueeze(dim=2).unsqueeze(dim=2)
            a_t_v = torch.sum(torch.einsum("bijh,bjhpt->bijhpt", a, t_v), dim=2)
            o_p = torch.einsum("bmij,bmhpi->bmhpj", torch.linalg.inv(t[0]), a_t_v) + t[1].unsqueeze(dim=2).unsqueeze(dim=2)
            o_hat_list.append(o_hat)
            o_list.append(o)
            o_p_list.append(o_p)
        o_hat_tensor = torch.concat(o_hat_list, 2)
        o_tensor = torch.concat(o_list, 2)
        o_p_tensor = torch.concat(o_p_list, 2)
        o_hat_tensor = o_hat_tensor.reshape((o_hat_tensor.shape[0], o_hat_tensor.shape[1], o_hat_tensor.shape[2] * o_hat_tensor.shape[3]))
        o_tensor = o_tensor.reshape((o_tensor.shape[0], o_tensor.shape[1], o_tensor.shape[2] * o_tensor.shape[3]))
        o_p_tensor = o_p_tensor.reshape((o_p_tensor.shape[0], o_p_tensor.shape[1], o_p_tensor.shape[2] * o_p_tensor.shape[3] * o_p_tensor.shape[4]))
        return (self.fc1(o_hat_tensor) + self.fc2(o_tensor) + self.fc3(o_p_tensor))

In [123]:
class BackUpdate(nn.Module):
    '''
    Given the single reprensation, returns a tuple of rotations and transitions.
    '''
    def __init__(self, c_s, device):
        super(BackUpdate, self).__init__()
        self.device = device
        
        #linear layer to project each row of the single rep
        self.fc1 = nn.Linear(c_s,6)
        
    def forward(self, s):
        #creates a list of rotations and transitions
        list_r = torch.empty((s.shape[0], s.shape[1], 3, 3))
        list_t = torch.empty((s.shape[0], s.shape[1], 3))
        
        for i in range(s.shape[1]):
            proj = self.fc1(s[:,i,:])
            
            #obtain b,c,d quaternion and the transition t
            b = proj[:,0]
            c = proj[:,1]
            d = proj[:,2]
            a = torch.ones(b.shape)
            t = proj[:,3:]
            
            #compute the quaternion
            a, b, c, d = torch.unsqueeze(a, dim = 0), torch.unsqueeze(b, dim = 0), torch.unsqueeze(c, dim = 0), torch.unsqueeze(d, dim = 0)
            total = torch.concat((a,b,c,d), dim = 0)
            
            #batch x normalized quaternion 
            q = F.normalize(total, dim = 0)
            
            #convert the quaternion into a rotation matrix
            r = pytorch3d.transforms.quaternion_to_matrix(q)
            
            list_r[:,i,:,:] = r
            list_t[:,i,:] = t
        #returns a tuple of all predicted rotations(b x n_res x 3 x3)
        #and transitions (b x n_res x 3)
        return (list_r, list_t)

In [124]:
class PredictAngleTorsion(nn.Module):
    '''
    Given the updated single representation and single representation,
    predicts phi and psi angles.
    '''
    def __init__(self, c_s, device):
        super(PredictAngleTorsion, self).__init__()
        self.device = device
        
        #project s to 128
        self.fc1 = nn.Linear(c_s, 128)
        #project s_i to 128
        self.fc2 = nn.Linear(c_s, 128)
        
        #projects a from 128 to ?
        self.layers1 = nn.Sequential(nn.Linear(128,128), nn.ReLU(), nn.Linear(128,128), nn.ReLU())
        self.layers2 = nn.Sequential(nn.Linear(128,128), nn.ReLU(), nn.Linear(128,128), nn.ReLU())
        
        #projects a into phi and psi
        self.phi = nn.Sequential(nn.Linear(128, 2), nn.ReLU())
        self.psi = nn.Sequential(nn.Linear(128, 2), nn.ReLU())
        
    def forward(self, s, s_i):
        proj_s = self.fc1(s)
        proj_si = self.fc2(s_i)
        
        a_i = proj_s + proj_si
        a_i = self.layers1(a_i) + a_i
        a_i = self.layers2(a_i) + a_i
        
        phi = self.phi(a_i)
        psi = self.psi(a_i)
        
        return phi, psi

In [125]:
class FAPE(nn.Module):
    '''
    Calculates Frame Aligned Point Error Loss.
    '''
    def __init__(self, device):
        super(FAPE, self).__init__()
        self.device = device
    
    def forward(self, T_pred, x_pred, T_true, x_true):
        p_rot = T_pred[0]
        p_trans = T_pred[1]
        t_rot = T_true[0]
        t_trans = T_true[1]
        
        all_dist = []
        #for all frames, align then true and predicted frames
        for i in range(p_rot.shape[1]):
            p_inv = torch.linalg.inv(p_rot[:,i,:,:])
            t_inv = torch.linalg.inv(t_rot[:,i,:,:])
            aligned_xp = torch.matmul(x_pred, p_rot[:,i,:,:])
            aligned_xt = torch.matmul(x_true, t_rot[:,i,:,:])
            #for all j, compute distances.
            for j in range(aligned_xp.shape[1]):
                dist = torch.linalg.norm(aligned_xp[:,j,:] - aligned_xt[:,j,:], dim = 1)
                eps = torch.ones(dist.shape)*(0.0001)
                dist = torch.sqrt(torch.square(dist) + eps)
                all_dist.append(dist)
        
        #concat all bx1 tensors and obtain a single mean loss value.
        all_dist = torch.concat(all_dist, dim = -1)
        f_loss = torch.mean(all_dist)/10
        
        return f_loss

In [126]:
class AngleTorsionLoss(nn.Module):
    '''
    Computes angular torsion loss given predicted angles and true angles psi and phi.
    '''
    def __init__(self):
        super(AngleTorsionLoss, self).__init__()
    
    def forward(self, p_psi, p_phi, t_psi, t_phi):
        norms = []
        dists = []
        #iterate through all angles and calculate angle norm and distance.
        for i in range(p_psi.shape[1]):
            #calculate angle norms for psi and phi
            l_psi = torch.norm(p_psi[:,i,:], dim = 1)
            l_phi = torch.norm(p_phi[:,i,:], dim = 1)
            
            lt_psi = torch.tile(l_psi.unsqueeze(dim = -1), (1, 2))
            lt_phi = torch.tile(l_phi.unsqueeze(dim = -1), (1, 2))
            
            #make unit vectors for psi and phi, WILL GET NANS WHEN li_psi has 0s. The model should not be predicting 0,0 for angles as it is not a valid angle.
            u_psi = torch.div(p_psi[:,i,:],lt_psi)
            u_phi = torch.div(p_phi[:,i,:],lt_phi)
            
            #calculate distance between prediction and true values
            dist_psi = torch.square(torch.linalg.norm(u_psi - t_psi[:,i,:]))
            dist_phi = torch.square(torch.linalg.norm(u_phi - t_phi[:,i,:]))
            
            #append angle norm and distance to their respective lists
            norms.append(abs(l_psi-1))
            norms.append(abs(l_phi-1))
            dists.append(dist_psi)
            dists.append(dist_phi)
        
        #concatenate all norms and all dists and find respective means.
        all_norms = torch.mean(torch.stack(norms, dim = 0))
        all_dists = torch.mean(torch.stack(dists, dim = 0))
        
        #calculate loss
        torsion_angular_loss = all_norms + all_dists
        
        return torsion_angular_loss

In [127]:
class StructureModel(nn.Module):
    '''
    Predicts 3D protein structure coordinate.
    '''
    def __init__(self, c_s, layers, device):
        super(StructureModel, self).__init__()
        self.layers = layers
        self.device = device
        
        #Define a layer in structure prediction.
        self.transition = nn.Sequential(nn.Linear(c_s, c_s), nn.ReLU(), nn.Linear(c_s, c_s), nn.ReLU(), nn.Linear(c_s, c_s))
        self.ipa = IPA(c_s)
        self.update_b = BackUpdate(c_s, device)
        self.torsion = PredictAngleTorsion(c_s, device)
        self.fape = FAPE(device)
        self.t_loss = AngleTorsionLoss()
        
    def forward(self, s_i, pair_rep):
        #Initialize a reference frame T
        identity = iden = torch.tile(torch.eye(3).unsqueeze(dim=0).unsqueeze(dim=0), (s_i.shape[0], s_i.shape[1], 1, 1))
        translation = torch.zeros(s_i.shape[0], s_i.shape[1], 3)
        t_i = (identity, translation)
        
        #store all calculated losses
        total_l_aux = []
        for layer in range(self.layers):
            s = self.ipa(s_i, pair_rep, t_i) + s_i
            s = self.transition(s) + s

            t_new = self.update_b(s)
            
            #multiply all respective rotation matricies, add all respective transitions
            new_trans = t_new[1] + t_i[1]
            new_rot = torch.matmul(t_new[0], t_i[0])
            
            #update t_i
            t_i = (new_rot, new_trans)
                
            #the predicted x is just the transition
            x_pred = t_i[1]
            
            #predict phi psi angles
            p_phi, p_psi = self.torsion(s, s_i)
            
            #calculate loss
            fape_loss = self.fape(t_i, x_pred, t_true, x_true)
            torsion_loss = self.t_loss(p_psi, p_phi, t_psi, t_phi)
            loss_aux = fape_loss + torsion_loss
            
            total_l_aux.append(loss_aux)
        mean_loss_aux = torch.mean(torch.stack(loss_aux, dim = 0))
        return mean_loss_aux

In [128]:
device = f'cuda:0' if torch.cuda.is_available() else 'cpu'
print(f"using device: {device}")

#Generate test tensors
t_s_i = torch.rand(4, 64, 384)
pair_rep = torch.rand(4, 64, 64, 16)

#create test model
struct = StructureModel(384, 1, device)
out = struct(t_s_i, pair_rep)

using device: cpu


NameError: name 't_true' is not defined