In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as du
import torch.nn.functional as F
from torch.utils import data
from torch.utils.data import Dataset
from torchvision import datasets, transforms
from collections import defaultdict
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
import pandas as pd
import sidechainnet as scn
import random
import sklearn

In [None]:
data = scn.load(casp_version=7, with_pytorch="dataloaders", 
                seq_as_onehot=True, aggregate_model_input=False,
               batch_size=16)

In [None]:
def get_seq_features(batch):
    '''
    Take a batch of sequence info and return the sequence (one-hot),
    evolutionary info and (phi, psi, omega) angles per position, 
    as well as position mask.
    Also return the distance matrix, and distance mask.
    '''
    str_seqs = batch.str_seqs # seq in str format
    seqs = batch.seqs # seq in one-hot format
    int_seqs = batch.int_seqs # seq in int format
    masks = batch.msks # which positions are valid
    lengths = batch.lengths # seq length
    evos = batch.evos # PSSM / evolutionary info
    angs = batch.angs[:,:,0:2] # torsion angles: phi, psi
    
    # use coords to create distance matrix from c-beta
    # except use c-alpha for G
    # coords[:, 4, :] is c-beta, and coords[:, 1, :] is c-alpha
    coords = batch.crds # seq coord info (all-atom)
    batch_xyz = []
    for i in range(coords.shape[0]):
        xyz = []
        xyz = [coords[i][cpos+4,:] 
                if masks[i][cpos//14] and str_seqs[i][cpos//14] != 'G'
                else coords[i][cpos+1,:]
                for cpos in range(0, coords[i].shape[0]-1, 14)]
        batch_xyz.append(torch.stack(xyz))
    batch_xyz = torch.stack(batch_xyz)
    # now create pairwise distance matrix
    dmats = torch.cdist(batch_xyz, batch_xyz)
    # create matrix mask (0 means i,j invalid)
    dmat_masks = torch.einsum('bi,bj->bij', masks, masks)
    
    return seqs, evos, angs, masks, dmats, dmat_masks

In [None]:
for epoch in range(1,epochs+1):
    for batch in data['train']:
        seqs, evos, angs, masks, dmats, dmat_masks = get_seq_features(batch)
        seqs, evos, angs, masks, dmats, dmat_masks = seqs.to(device), evos.to(device), angs.to(device), masks.to(device), dmats.to(device), dmat_masks.to(device)
        
        print(seqs.shape, evos.shape, angs.shape)
        print(masks.shape, dmats.shape, dmat_mask.shape)
        
        break

In [3]:
class AttentionHead(nn.Module):
    def __init__(self, in_dim = 256, d_k = 32):
        '''
        Represents an attention head for multihead attention,
        d_k is 32 by default.
        in_dim is 256 by default.
        '''
        super(AttentionHead, self).__init__()
        
        self.d_k = d_k
        #create query, key, and values
        self.q = nn.Linear(in_dim, d_k)
        self.k = nn.Linear(in_dim, d_k)
        self.v = nn.Linear(in_dim, d_k)
        
    def forward(self, sequence, bias, row_or_col):
        '''
        Given a sequence in MSA_rep of size n_res x 256, calculate attention.
        Depending on row_or_col, bias is either added or excluded.
        '''
        query = self.q(sequence)
        key = self.k(sequence)
        value = self.v(sequence)
        
        A_sh = torch.matmul(query, torch.transpose(key, 1, 2)) / torch.sqrt(self.d_k)
        
        if row_or_col == "row":
            A_sh += bias
        
        #take softmax with respect to the rows
        A_sh = F.softmax(A_sh, dim = 0)
        A_sh = torch.matmul(A_sh, value)
        
        return A_sh

In [None]:
class Evoformer(nn.Module):
    def __init__(self, n_clust, num_heads = 8):
        '''
        Creates the MSA_representation and the Z(pairwise) matrix given a PSSM and a sequence.
        n_clust: number of PSSMs.
        num_heads: number of attention heads(8 by default)
        '''
        super(Evoformer, self).__init__()
        
        self.n_clust = n_clust
        self.num_heads = num_heads
        
        #linear layer to project evos into n_clust x n_res x 256
        self.fc0 = nn.Linear(21, 256)
        #linear layer to project targets to n_res x 256
        self.fc1 = nn.Linear(21, 256)
        #linear layer to project target to n_res x 128
        self.fc2 = nn.Linear(21, 128)
        #Linear layer to project distances into 128 space
        self.fc3 = nn.Linear(21, 128)
        #linear layer to project pair_rep to bias
        self.fc4 = nn.Linear(128, 1)
        #linear layer to project the new msa_rep into 256 dim
        self.fc5 = nn.Linear(32, 256)
        #linear layer to project i[s] and j[s] to 32 dim
        self.fc6 = nn.Linear(256, 32)
        #flatten the mean outer product to C*C
        self.flatten = nn.Flatten()
        #linear layer to project the outer product mean to 128 dim
        self.fc7 = nn.Linear(32, 128)
        
        #define attention heads
        self.mha = nn.ModuleList([AttentionHead() for i in range(num_heads)])
        
        #create a gate for each head, corresponding to each index.
        #a gate maps msa_rep to 1 and sigmoids it to determine how much information is kept from a head.
        self.gates = nn.ModuleList[nn.Sequential(nn.Linear(256, 1), nn.Sigmoid()) for i in range(num_heads)]
        
        #define the transitional layers to pass the new msa_rep through
        self.transition = nn.Sequential(nn.Linear(256, 1024), F.relu(), nn.Linear(1024, 256))
    
    def create_msa_rep(self, evos, target):
        '''
        Create the msa_representation given evolutionary data evos
        and the targets, both are n_res x 21.
        '''
        #obtain n_clust layers of PSSM(evos); stack them into a (n_clust x n_res x 256) matrix
        clusters = [self.fc0(evos) for i in range(self.n_clust)]
        msa_rep = torch.vstack(clusters)
        
        #project the targets from n_res x 21 to n_res x 256 and tile it.
        new_target = self.fc1(target)
        new_target = torch.tile(target, (1, 1, self.n_clust))
        
        #add the targets to the msa_rep
        msa_rep += new_target
        
        return msa_rep
    
    def create_pair_rep(self, target):
        '''
        Create pair_wise representations given targets.
        '''
        #create the pairwise rep matrix
        a_i = self.fc2(target)
        b_j = self.fc2(target)
        pair_rep = torch.outer(a_i, b_j)
        
        #add the relative position rel_pos
        idx_j = torch.arange(0, dmats.shape[1])
        idx_j = torch.tile(idx_j, (idx_i.shape[1], 1))
        idx_i = torch.transpose(idx_j, 0, 1)
        idx_i , idx_j = idx_i.to(device), idx_j.to(device)
        dist_ij = idx_i - idx_j   
        bins = torch.linspace(-32, 32, 64)
        dist_ij = torch.bucketize(dist_ij, bins)
        rel_pos = self.fc3(F.one_hot(dist_ij))
        
        pair_rep += rel_pos
        return pair_rep
    
    def create_bias(self, pair_rep):
        '''
        given the pairwise representation create the bias
        '''
        bias = self.fc4(pair_rep)
        return bias
        
    def compute_attention(self, row_or_col, msa_rep, bias):
        '''
        compute either row-wise or column-wise attention depending on the given argument.
        '''
        t_msa_rep = msa_rep
        if row_or_col == "col":
            #transpose the msa_rep if we are doing column wise attention
            t_msa_rep = torch.transpose(t_msa_rep, 0, 1)
        
        #calculate all the respective gates dot attention head outputs.
        outputs = [torch.dot(self.mha[i](msa_rep[s], bias, row_or_col), self.gates[i](msa_rep))
                   for s in range(msa_rep.shape[0]) for i in range(self.num_heads)]
        
        #concatenate them to form O_sh
        O_sh = torch.concat(outputs, dim = 2)
        new_msa_rep = self.fc5(O_sh)
        
        return new_msa_rep
    
    def outer_product_mean(self, msa_rep):
        '''
        Finds the outer product mean between the pair-wise representation
        and the msa representation.
        The output is a n_res x n_res x 128 pair_rep
        '''
        #iterate through clusters, pick slices i and j and project them into 32 dim and gather their outer products
        for i in range(msa_rep.shape[1]):
            outer_prods = [torch.outer(self.fc6(msa_rep[i][s]), self.fc6(msa_rep[j][s]))
                           for j in range(msa_rep.shape[1]) for s in range(msa_rep.shape[0])]
        
        #concatenate all o_ij to make the output and take the mean
        new_pair_rep = torch.mean(torch.concat(outer_prods, dim = 2))
        new_pair_rep = self.flatten(new_pair_rep)
        
        #project to n_res x n_res x 128 dim
        new_pair_rep = self.fc7(new_pair_rep)
        
        #make sure to do residual connection after calling the function
        return new_pair_rep
        
    def mult_attention(self, out = False):
        '''
        Does incoming(default) multiplicative attention on a given pair_rep.
        out: set to true to do outgoing attention
        '''
        
        
    def tri_attention(self, ending = False):
        '''
        Does starting triangular attention by default.
        ending: set to true to do ending triangular attention
        '''
        
    def single_rep(self):
        '''
        Find the singular representation of M
        Should only be done on the last block.
        '''
    
    
    def forward(self, seqs, evos, dmats):
        return

In [14]:
#TODO: In the end divide evoformer into separate models so that the code becomes more clear