In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as du
import torch.nn.functional as F
from torch.utils import data
from torch.utils.data import Dataset
from torchvision import datasets, transforms
from collections import defaultdict
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
import pandas as pd
import sidechainnet as scn
import random
import sklearn

In [2]:
data = scn.load(casp_version=7, with_pytorch="dataloaders", 
                seq_as_onehot=True, aggregate_model_input=False,
               batch_size=1)

SidechainNet was loaded from ./sidechainnet_data/sidechainnet_casp7_30.pkl.


In [3]:
def get_seq_features(batch):
    '''
    Take a batch of sequence info and return the sequence (one-hot),
    evolutionary info and (phi, psi, omega) angles per position, 
    as well as position mask.
    Also return the distance matrix, and distance mask.
    '''
    str_seqs = batch.str_seqs # seq in str format
    seqs = batch.seqs # seq in one-hot format
    int_seqs = batch.int_seqs # seq in int format
    masks = batch.msks # which positions are valid
    lengths = batch.lengths # seq length
    evos = batch.evos # PSSM / evolutionary info
    angs = batch.angs[:,:,0:2] # torsion angles: phi, psi
    
    # use coords to create distance matrix from c-beta
    # except use c-alpha for G
    # coords[:, 4, :] is c-beta, and coords[:, 1, :] is c-alpha
    coords = batch.crds # seq coord info (all-atom)
    batch_xyz = []
    for i in range(coords.shape[0]):
        xyz = []
        xyz = [coords[i][cpos+4,:] 
                if masks[i][cpos//14] and str_seqs[i][cpos//14] != 'G'
                else coords[i][cpos+1,:]
                for cpos in range(0, coords[i].shape[0]-1, 14)]
        batch_xyz.append(torch.stack(xyz))
    batch_xyz = torch.stack(batch_xyz)
    # now create pairwise distance matrix
    dmats = torch.cdist(batch_xyz, batch_xyz)
    # create matrix mask (0 means i,j invalid)
    dmat_masks = torch.einsum('bi,bj->bij', masks, masks)
    
    return seqs, evos, angs, masks, dmats, dmat_masks

In [4]:
class AttentionHead(nn.Module):
    def __init__(self, num_heads, c, c_z, c_m):
        super(AttentionHead, self).__init__()
        
        self.num_heads = num_heads
        self.c = c
        self.c_z = c_z
        
        #query key value
        self.q = nn.Linear(c_m, self.c, bias = False)
        self.k = nn.Linear(c_m, self.c, bias = False)
        self.v = nn.Linear(c_m, self.c, bias = False)
        
        #bias projects z from 128 to 1
        self.bias = nn.Linear(self.c_z, 1, bias = False)
        
    def forward(self, msa_rep, pair_rep, row):
        #get query key value
        query = self.q(msa_rep)
        key = self.k(msa_rep)
        value = self.v(msa_rep)
        
        out = torch.matmul(query,torch.transpose(key,1,2))/np.sqrt(self.c)
        if row: 
            b = self.bias(pair_rep).squeeze()
            out += b
        
        #softmax with respect to rows
        out = F.softmax(out, dim = 1)
        out = torch.matmul(out, value)
        
        return out

In [5]:
class rowColAtt(nn.Module):
    def __init__(self, c, c_m, c_z, num_heads, row, device):
        super(rowColAtt, self).__init__()
        
        self.row = row
        self.num_heads = num_heads
        self.device = device
        
        #attention heads
        self.mhsa = nn.ModuleList([AttentionHead(num_heads, c, c_z, c_m) for i in range(self.num_heads)])
        
        #project to c_msa
        self.fc1 = nn.Linear(c_z, c_m)
        
        self.gate = nn.ModuleList([nn.Sequential(nn.Linear(c_m, 1), nn.Sigmoid()) for i in range(self.num_heads)])
        
    def forward(self, msa_rep, pair_rep):
        new_msa_rep = torch.empty(msa_rep.shape).to(self.device)
        for s in range(msa_rep.shape[1]):
            s_o = []
            for i, head in enumerate(self.mhsa):
                out = head(msa_rep[:,s,:,:], pair_rep, self.row)
                gate = self.gate[i](msa_rep).squeeze(dim=-1)
                final = torch.transpose(gate, 1, 2) * out       #gate is b x n_clust x n_res, out is b x n_res x c, cannot be dotted, might skew results\n",
                s_o.append(final)
                
            new_slice = torch.concat(s_o, dim = 2)
            new_slice = self.fc1(new_slice)
            new_msa_rep[:,s,:,:] = new_slice
            
        return new_msa_rep

In [6]:
class Outer_Prod_Mean(nn.Module):
    '''
    Finds the outer product mean between the pair-wise representation
    and the msa representation.
    The output is a n_res x n_res x 128 pair_rep
    '''
    def __init__(self, c, c_m, c_z, device):
        super(Outer_Prod_Mean, self).__init__()
        self.device = device
        
        #linear layer to project i[s] and j[s] to c dim
        self.fc1 = nn.Linear(c_m, c)
        self.fc2 = nn.Linear(c_m, c)
        
        #flatten the mean outer product to C*C
        self.flatten = nn.Flatten()
        
        #linear layer to project the outer product mean to 128 dim
        self.fc3 = nn.Linear(c_m, c_z)
        
    def forward(self, msa_rep, pair_rep):
        #project m to A and B
        new_pair_rep = torch.empty(pair_rep.shape).to(self.device)
        
        for i in range(msa_rep.shape[1]):
            for j in range(msa_rep.shape[1]):
                a = self.fc1(msa_rep[:,i,:,:])
                b = self.fc2(msa_rep[:,j,:,:])
                outer =torch.einsum('bij,bik->bijk',a,b)
                out_mean = torch.mean(outer, dim = 1)
                out = self.flatten(out_mean)
                out = self.fc3(out)
                new_pair_rep[:,i,j,:] = out

        return new_pair_rep

In [7]:
class Mult_Attention(nn.Module):
    def __init__(self, c_z, c_m, out, device):
        '''
        Does incoming(default) multiplicative attention on a given pair_rep.
        out: set to False to do incoming attention
        '''
        super(Mult_Attention, self).__init__()
        self.device = device
        self.out = out
        self.ln = nn.LayerNorm(c_z)
        self.fc1 = nn.Linear(c_z, c_z)
        self.fc2 = nn.Linear(c_z, c_z)
        self.fc3 = nn.Linear(c_m, c_z)

        self.gate1 = nn.Sequential(nn.Linear(c_z, c_z), nn.Sigmoid())
        self.gate2 = nn.Sequential(nn.Linear(c_z, c_z), nn.Sigmoid())
        self.gate3 = nn.Sequential(nn.Linear(c_z, c_m), nn.Sigmoid())

    def forward(self, pair_rep):
        # print(pair_rep.shape)
        #Do a layer norm on pair_rep
        pair_rep = self.ln(pair_rep)
        Z = torch.zeros((pair_rep.shape[0], pair_rep.shape[1], pair_rep.shape[2], c_m)).to(self.device)
        #make A and B
        A = self.fc1(pair_rep)
        B = self.fc2(pair_rep)

        #Make gates for A and B
        gate_A = self.gate1(pair_rep)
        gate_B = self.gate2(pair_rep)
        gate_Z = self.gate3(pair_rep)

        #take dot product of A, B and their gates
        new_A = A * gate_A
        new_B = B * gate_B

        #transpose a and b if we are doing incoming attention
        if not self.out:
            new_A = torch.transpose(new_A, 1, 2)
            new_B = torch.transpose(new_B, 1, 2)

        for i in range(new_A.shape[1]):
            for j in range(new_B.shape[2]):
                Z[:,i,j] = gate_Z[:,i,j] * (torch.sum((new_A[:,i,:] * new_B[:,j,:]), dim = -1))

        return self.fc3(Z)

In [8]:
class Tri_Attention(nn.Module):
    '''
    Does starting triangular attention by default.
    ending: set to true to do ending triangular attention
    '''
    def __init__(self, c, c_z, ending = False, num_heads = 4):
        super(Tri_Attention, self).__init__()
        self.ending = ending
        self.num_heads = num_heads
        self.c = c
        
        self.q = nn.ModuleList([nn.Linear(c_z, c) for i in range(num_heads)])
        self.k = nn.ModuleList([nn.Linear(c_z, c) for i in range(num_heads)])
        self.v = nn.ModuleList([nn.Linear(c_z, c) for i in range(num_heads)])
        self.b = nn.ModuleList([nn.Linear(c_z, 1) for i in range(num_heads)])
        self.g = nn.ModuleList([nn.Sequential(nn.Linear(c_z,c), nn.Sigmoid()) for i in range(num_heads)])
        
        self.fc1 = nn.Linear(64, c_z)
        
    def forward(self, pair_rep):
        output = []
        for h in range(self.num_heads):
            query = self.q[h](pair_rep)
            key = self.k[h](pair_rep)
            value = self.v[h](pair_rep)
            bias = self.b[h](pair_rep)
            gate = self.g[h](pair_rep)
            
            #find attention
            a = (query * key)/np.sqrt(self.c) + bias
            a = F.softmax(a, dim = -1)
            a = a * value
            out = a * gate
            output.append(out)
        
        #concat all outputs
        output = torch.concat(output, -1)
        output = self.fc1(output)
        return output

In [9]:
class Evoformer(nn.Module):
    def __init__(self, c, c_m, c_z, n_clust, num_heads, device):
        '''
        Creates the MSA_representation and the Z(pairwise) matrix given a PSSM and a sequence.
        n_clust: number of PSSMs.
        num_heads: number of attention heads(8 by default)
        '''
        super(Evoformer, self).__init__()
        
        self.n_clust = n_clust
        self.num_heads = num_heads
        self.device = device
        
        #linear layers to project evos into n_clust x n_res x 256
        self.fc0 = nn.ModuleList([nn.Linear(21, c_m) for i in range(n_clust)])
        #linear layer to project seqs to n_res x 256
        self.fc1 = nn.Linear(20, c_m)
        #linear layer to project seqs to n_res x 128
        self.fc2 = nn.Linear(20, c_z)
        self.fc3 = nn.Linear(20, c_z)
        #Linear layer to project distances into 128 space
        self.fc4 = nn.Linear(64, c_z)
        #linear layer to project pair_rep to bias
        self.fc5 = nn.Linear(c_z, 1)
        #linear layer to project the single representation to 256 dim
        self.fc6 = nn.Linear(c_m, c_m)
        #linear layer to project the single representation to 384 dim
        self.fc7 = nn.Linear(c_m, 384)
        
        #define the transitional layers to pass the new msa_rep through
        self.transition1 = nn.Sequential(nn.Linear(c_m, 4*c_m), nn.ReLU(), nn.Linear(4*c_m, c_m))
        self.transition2 = nn.Sequential(nn.Linear(c_z, 4*c_z), nn.ReLU(), nn.Linear(4*c_z, c_z))
        
        #define all attentions
        self.row_att = rowColAtt(c, c_m, c_z, self.num_heads, True, self.device)
        self.col_att = rowColAtt(c, c_m, c_z, self.num_heads, False, self.device)
        self.mul_att_in = Mult_Attention(c_z, c_m, False, device)
        self.mul_att_out = Mult_Attention(c_z, c_m, True, device)
        self.tri_att_start = Tri_Attention(c, c_z, ending = False)
        self.tri_att_end = Tri_Attention(c, c_z, ending = True)
        
        #define outer_product_mean
        self.out_prod_mean = Outer_Prod_Mean(c, c_m, c_z, self.device)
        
    
    def create_msa_rep(self, evos, seqs):
        '''
        Create the msa_representation given evolutionary data evos
        and the seqs, both are n_res x 21.
        '''
        #obtain n_clust layers of PSSM(evos); stack them into a (n_clust x n_res x 256) matrix
        clusters = [self.fc0[i](evos) for i in range(self.n_clust)]
        msa_rep = torch.stack(clusters, dim=1)
        
        #project the seqs from n_res x 21 to n_res x 256 and tile it.
        new_seqs = self.fc1(seqs)
        new_seqs = new_seqs.unsqueeze(dim=1)
        new_seqs = torch.tile(new_seqs, (1, self.n_clust, 1, 1))
        
        #add the seqs to the msa_rep
        msa_rep += new_seqs
        
        return msa_rep
    
    def create_pair_rep(self, seqs):
        '''
        Create pair_wise representations given seqs.
        '''
        #create the pairwise rep matrix
        a_i = self.fc2(seqs).unsqueeze(dim=2)
        b_j = self.fc3(seqs).unsqueeze(dim=2)
        a_i = torch.tile(a_i, (1, 1, a_i.shape[1], 1))
        b_j = torch.tile(b_j, (1, 1, b_j.shape[1], 1))
        pair_rep = a_i + torch.transpose(b_j, 1, 2)
        
        #add the relative position rel_pos
        idx_j = torch.arange(0, seqs.shape[1]).unsqueeze(dim=1)
        idx_j = torch.tile(idx_j, (1, idx_j.shape[1]))
        idx_i = torch.transpose(idx_j, 0, 1)
        # idx_i , idx_j = idx_i.to(device), idx_j.to(device)
        dist_ij = idx_i - idx_j   
        bins = torch.linspace(-32, 32, 64)
        dist_ij = torch.bucketize(dist_ij, bins)
        dist_ij[dist_ij>=64] = 63
        dist_ij = dist_ij.unsqueeze(dim=0)
        dist_ij = torch.tile(dist_ij, (pair_rep.shape[0], 1, 1))
        dist_ij = F.one_hot(dist_ij).type(torch.float)
        dist_ij = dist_ij.to(self.device)
        rel_pos = self.fc4(dist_ij)
        pair_rep += rel_pos
        return pair_rep
    
    def create_bias(self, pair_rep):
        '''
        given the pairwise representation create the bias
        '''
        bias = self.fc5(pair_rep)
        return bias
        
    def single_rep(self, msa_rep):
        '''
        Find the singular representation of M
        Should only be done on the last block.
        '''
        single_rep = self.fc6(msa_rep[:,1,:,:])
        single_rep = self.fc7(single_rep)
        return single_rep  
    
    def forward(self, seqs, evos):
        #create msa_rep, pair_rep, bias
        msa_rep = self.create_msa_rep(evos, seqs)
        pair_rep = self.create_pair_rep(seqs)
        # bias = self.create_bias(pair_rep)
        
        # #feed msa_rep into row -> col -> transition
        msa_rep = msa_rep + self.row_att(msa_rep, pair_rep) 
        msa_rep = msa_rep + self.col_att(msa_rep, pair_rep)
        msa_rep = msa_rep + self.transition1(msa_rep) #output of evoformer for msa_rep
        
        #do the outer product mean
        pair_rep = pair_rep + self.out_prod_mean(msa_rep, pair_rep)
        
        #do triangular attention
        pair_rep = pair_rep + self.mul_att_out(pair_rep) 
        pair_rep = pair_rep + self.mul_att_in(pair_rep)
        pair_rep = pair_rep + self.tri_att_start(pair_rep)
        pair_rep = pair_rep + self.tri_att_end(pair_rep)
        
        #do the transition
        pair_rep = pair_rep + self.transition2(pair_rep) #output of evoformer for pair_rep
        
        single_rep = self.single_rep(msa_rep)
        return msa_rep, pair_rep, single_rep

In [10]:
class EvoformerBlock(nn.Module):
    def __init__(self, c, c_m, c_z, num_blocks, n_clust, num_heads, device):
        super(EvoformerBlock, self).__init__()
        
        self.num_blocks = num_blocks
        self.evo_blocks = nn.ModuleList([Evoformer(c, c_m, c_z, n_clust, num_heads, device) for i in range(num_blocks)])
        
        #dmat
        self.conv1 = nn.Conv2d(c_z, 256, 1)
        
        #angle 
        self.maxpool = nn.MaxPool2d((1,c_m))
        self.conv2 = nn.Conv2d(c_z, 1296, 1)
        
    def forward(self, seqs, evos):
        
        output = self.evo_blocks[0](seqs, evos)
        
        #single rep is calculated but ignored until the last output
        for i in range(1, self.num_blocks):
            output = self.evo_blocks[0](output[0], output[1])
        
        #do convolutions to obtain the angle prediction
        pred_dmat = self.conv1(torch.transpose(output[1], 1, 3))
        
        #obtain the angle predictions by maxpooling
        pred_angles = self.maxpool(torch.transpose(output[1], 1, 3))  #shapes are b x cm x cm x c
        pred_angles = self.conv2(pred_angles)
        
        return pred_dmat, pred_angles, output[2]

In [11]:
device = f'cuda:0' if torch.cuda.is_available() else 'cpu'
print(f"using device: {device}")

epochs = 1
learning_rate = 0.1

num_heads = 8
n_clust = 16
num_blocks = 1
c = 16
c_m = 256
c_z = 128
c_s = 256

model = EvoformerBlock(c, c_m, c_z, num_blocks, n_clust, num_heads, device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
model = model.to(device)
model.train()

using device: cuda:0


EvoformerBlock(
  (evo_blocks): ModuleList(
    (0): Evoformer(
      (fc0): ModuleList(
        (0): Linear(in_features=21, out_features=256, bias=True)
        (1): Linear(in_features=21, out_features=256, bias=True)
        (2): Linear(in_features=21, out_features=256, bias=True)
        (3): Linear(in_features=21, out_features=256, bias=True)
        (4): Linear(in_features=21, out_features=256, bias=True)
        (5): Linear(in_features=21, out_features=256, bias=True)
        (6): Linear(in_features=21, out_features=256, bias=True)
        (7): Linear(in_features=21, out_features=256, bias=True)
        (8): Linear(in_features=21, out_features=256, bias=True)
        (9): Linear(in_features=21, out_features=256, bias=True)
        (10): Linear(in_features=21, out_features=256, bias=True)
        (11): Linear(in_features=21, out_features=256, bias=True)
        (12): Linear(in_features=21, out_features=256, bias=True)
        (13): Linear(in_features=21, out_features=256, bias=Tru

In [12]:
loss_func = nn.CrossEntropyLoss(reduction = 'none')
training_loss = 0.
num_crops = 0
for epoch in range(1,epochs+1):
    for bidx, (batch) in enumerate(tqdm(data['train'])):
        seqs, evos, angs, masks, dmats, dmat_masks = get_seq_features(batch)
        seqs, evos, angs, masks, dmats, dmat_masks = seqs.to(device), evos.to(device), angs.to(device), masks.to(device), dmats.to(device), dmat_masks.to(device)
        
        #generate a random starting index
        start_idx = random.randint(1,16)
        
        
        original_shape = seqs.shape
        seqs = F.pad(seqs, (0, 0, 0, seqs.shape[1] + c_s), 'constant', 0)
        evos = F.pad(evos, (0, 0, 0, evos.shape[1] + c_s), 'constant', 0)
        
        #discretize the matrix
        bins = torch.linspace(2,22, 64)
        bins = bins.to(device)
        discretized = torch.clamp(dmats, min = 2, max = 22)
        discretized = torch.bucketize(discretized, bins, right = True)
        discretized = F.pad(discretized, (0, discretized.shape[1] + c_s, 0, discretized.shape[1] + c_s, 0, 0), 'constant', 0)
        
        #discretize the angles
        bins = torch.linspace(0, 36, 1296)
        bins = bins.to(device)
        d_angs = torch.clamp(angs, min = 0 , max = 36)
        d_angs = d_angs.to(device)
        d_angs = torch.bucketize(d_angs, bins, right = True)
        d_angs = 36 * d_angs[:,:,0] + d_angs[:,:,1]
        d_angs = F.pad(d_angs, (0, d_angs.shape[1] + c_s, 0, 0), 'constant', 0)
        
        for i in range(start_idx, original_shape[1], 128):
            seq_crop = seqs[:,i:i+c_s,:]
            evo_crop = evos[:,i:i+c_s,:]
            ddmat = discretized[:,i:i+c_s, i:i+c_s]
            new_angs = d_angs[:,i:i+c_s]
            
            #zero out previous gradients
            model.zero_grad()
            
            #forward pass
            dmat_pred, ang_pred, single_rep = model(seq_crop.type(torch.float), evo_crop)
            #calculate loss
            dmat_loss = loss_func(dmat_pred, ddmat.long())
            #angs_loss = loss_func(ang_pred.squeeze(dim=-1), new_angs.long())
            loss = torch.mean(dmat_loss) #+ torch.mean(angs_loss)
            training_loss += loss.item()
            num_crops += 1
            #backpropagate and step
            loss.backward()
            optimizer.step()
        if bidx % 100 == 0:
            checkpoint = {
                'batch': bidx,
                'epoch': epoch,
                'loss': training_loss,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict()
            }
            torch.save(checkpoint, "alphafold2.pth")
    training_loss /= num_crops
    with open('output.txt', 'a') as fp:
        fp.write(f"Epoch: {epoch} Loss: {training_loss}\n")
        print(f"Epoch: {epoch} Loss: {training_loss}")
    fp.close

  0%|          | 0/10110 [00:08<?, ?it/s]


KeyboardInterrupt: 