In [1]:
!pip install torch
!pip install einops
!pip install sidechainnet
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as du
from torchvision import transforms as T
import os
import numpy as np
import time
from tqdm import tqdm
from torch.utils import data
import sidechainnet as scn
import einops
import gc
from torch.utils.data import IterableDataset
from torch.utils.data import DataLoader



In [2]:
device = f'cuda:0' if torch.cuda.is_available() else 'cpu'
print(f"using device: {device}")

using device: cuda:0


## Load CASP7 data as pytorch tensors

In [3]:
data = scn.load(casp_version=7, with_pytorch="dataloaders", 
                seq_as_onehot=True, aggregate_model_input=False,
               batch_size=16)

SidechainNet was loaded from ./sidechainnet_data/sidechainnet_casp7_30.pkl.


## Creates features for a batch of sequences

In [4]:
def get_seq_features(batch):
    '''
    Take a batch of sequence info and return the sequence (one-hot),
    evolutionary info and (phi, psi, omega) angles per position, 
    as well as position mask.
    Also return the distance matrix, and distance mask.
    '''
    str_seqs = batch.str_seqs # seq in str format
    seqs = batch.seqs # seq in one-hot format
    int_seqs = batch.int_seqs # seq in int format
    masks = batch.msks # which positions are valid
    lengths = batch.lengths # seq length
    evos = batch.evos # PSSM / evolutionary info
    angs = batch.angs[:,:,0:2] # torsion angles: phi, psi
    
    # use coords to create distance matrix from c-beta
    # except use c-alpha for G
    # coords[:, 4, :] is c-beta, and coords[:, 1, :] is c-alpha
    coords = batch.crds # seq coord info (all-atom)
    batch_xyz = []
    for i in range(coords.shape[0]):
        xyz = []
        xyz = [coords[i][cpos+4,:] 
                if masks[i][cpos//14] and str_seqs[i][cpos//14] != 'G'
                else coords[i][cpos+1,:]
                for cpos in range(0, coords[i].shape[0]-1, 14)]
        batch_xyz.append(torch.stack(xyz))
    batch_xyz = torch.stack(batch_xyz)
    # now create pairwise distance matrix
    dmats = torch.cdist(batch_xyz, batch_xyz)
    # create matrix mask (0 means i,j invalid)
    dmat_masks = torch.einsum('bi,bj->bij', masks, masks)
    
    return seqs, evos, angs, masks, dmats, dmat_masks

## Pair Representation

In [5]:
class PairRep(nn.Module):
    def __init__(self):
        super(PairRep, self).__init__()

        self.fcA = nn.Linear(20, 128)
        self.fcB = nn.Linear(20, 128)
    
    def forward(self, seqs):
        seq_len = seqs.size(dim=1)
        A = self.fcA(seqs.to(torch.float))
        B = self.fcB(seqs.to(torch.float))
        As = A.repeat(seq_len, 1, 1, 1)
        pair_rep = torch.transpose(As, 0, 1)
        pair_rep = torch.transpose(pair_rep, 1, 2)

        Bs = B.repeat(seq_len, 1, 1, 1)
        pair_rep = pair_rep + torch.transpose(Bs, 0, 1)

        return pair_rep

## MSI resp

In [6]:
class MSA(nn.Module):
    def __init__(self, n_cluster):
        super(MSA, self).__init__()

        self.n_cluster = n_cluster
        self.fcA = nn.Linear(21, 256)
    
    def forward(self, evos):
        msa = self.fcA(evos)
        msa = torch.unsqueeze(msa, dim=1)
        msa = msa.repeat((1,self.n_cluster,1,1))

        return msa

## Row Attention

In [7]:
class RowAttention(nn.Module):
    '''Self Attention'''

    def __init__(self, d, dk,num_heads):
        '''define WQ, WK, WV projection matrices:
        d: d_model is the original model dimension
        dk: projection dimension for query, keys and values
        '''
        super(RowAttention, self).__init__()
        
        self.d = d  # d_model
        self.dk = dk  # d_k: projection dimension
        self.B_map = nn.Linear(128,1, bias=False)
        self.G_map = nn.Linear(256,1, bias=False)
        self.sigmoid = nn.Sigmoid()
        self.num_heads = num_heads  # number of attention heads
        self.WQ = nn.Linear(self.d, self.dk * self.num_heads, bias=False)
        self.WK = nn.Linear(self.d, self.dk * self.num_heads, bias=False)
        self.WV = nn.Linear(self.d, self.dk * self.num_heads, bias=False)
        self.WO = nn.Linear(self.dk * self.num_heads, self.d, bias=False)

    def forward(self, msa_rep, pair_rep):
        '''project the context onto key, query and value spaces and
        return the final value vectors
        '''
        #print(msa_rep.shape) # batch, cluter, seq, msa=256 
        #print(pair_rep.shape) # b,c, seq, z= 128
        
        # calculate b first
        B = self.B_map(pair_rep).squeeze(3).unsqueeze(1).unsqueeze(1) # b,1,1,seq,seq
        Gate = self.G_map(msa_rep)
        Gate = self.sigmoid(Gate).unsqueeze(2) # b,1,1,gate,seq

        # input shape: (batch_size, block_size, d)
        # let batch_size=b, block_size=l, num_heads=h
        Q = self.WQ(msa_rep)  # shape: b, c, seq, dk
        K = self.WK(msa_rep)  # shape: b, c, seq, dk
        V = self.WV(msa_rep)  # shape: b, c, seq, dk

        # split Q, K, V into heads and dk, move heads up front; KT is transpose of K
        Q = einops.rearrange(
            Q, 'b c s (h dk)-> b c h s dk', h=self.num_heads
        )  # size: (b c h s dk)
        KT = einops.rearrange(
            K, 'b c s (h dk)-> b c h dk s', h=self.num_heads
        )  # size: (b c h dk s)
        V = einops.rearrange(
            V, 'b c s (h dk)-> b c h s dk', h=self.num_heads
        )  # size: (b c h s dk)


        QKT = torch.einsum('bchsd,bchdt->bchst', Q, KT)
        # size: b, c, h, seq, seq
        # attention matrix
        # row specifies weights for the value vectors, row add up; to one
        A = F.softmax(QKT / np.sqrt(self.dk) + B, dim=4)  # shape: b, c, h, seq, seq
        V = torch.einsum('bchst,bchtd->bchsd', A, V) #shape: b, c, h, seq, dk
        V = V * Gate
        V = einops.rearrange(V, 'b c h s d-> b c s (h d)') #shape: b, c, seq, h*dk
        output = self.WO(V)
        return output

## Column Attention

In [8]:
class ColAttention(nn.Module):
    '''Self Attention'''

    def __init__(self, d, dk,num_heads):
        '''define WQ, WK, WV projection matrices:
        d: d_model is the original model dimension
        dk: projection dimension for query, keys and values
        '''
        super(ColAttention, self).__init__()
        
        self.d = d  # d_model
        self.dk = dk  # d_k: projection dimension
        self.num_heads = num_heads  # number of attention heads
        self.G_map = nn.Linear(256,1, bias=False)
        self.sigmoid = nn.Sigmoid()
        self.WQ = nn.Linear(self.d, self.dk * self.num_heads, bias=False)
        self.WK = nn.Linear(self.d, self.dk * self.num_heads, bias=False)
        self.WV = nn.Linear(self.d, self.dk * self.num_heads, bias=False)
        self.WO = nn.Linear(self.dk * self.num_heads, self.d, bias=False)

    def forward(self, msa_rep, pair_rep):
        '''project the context onto key, query and value spaces and
        return the final value vectors
        '''
        #print(msa_rep.shape) # batch, cluter, seq, msa=256 
        #print(pair_rep.shape) # b,c, seq, z= 128
        
        msa_col_rep = einops.rearrange(msa_rep, 'b c s m-> b s c m')
        # shape: b, seq, c, 256
        Gate = self.G_map(msa_col_rep)
        Gate = self.sigmoid(Gate).unsqueeze(2) # b,1,gate,seq

        # input shape: (batch_size, block_size, d)
        # let batch_size=b, block_size=l, num_heads=h
        Q = self.WQ(msa_col_rep)  # shape: b, seq, c, h*dk
        K = self.WK(msa_col_rep)  # shape: b, seq, c, h*dk
        V = self.WV(msa_col_rep)  # shape: b, seq, c, h*dk

        
        # split Q, K, V into heads and dk, move heads up front; KT is transpose of K
        Q = einops.rearrange(
            Q, 'b seq c (h dk)-> b seq h c dk', h=self.num_heads
        )  # size: (b seq h c dk)
        KT = einops.rearrange(
            K, 'b seq c (h dk)-> b seq h dk c', h=self.num_heads
        )  # size: (b seq h dk c)
        V = einops.rearrange(
            V, 'b seq c (h dk)-> b seq h c dk', h=self.num_heads
        )  # size: (b seq h c dk)

        QKT = torch.einsum('bshcd,bshdm->bshcm', Q, KT)
        # shape: b, seq, h, c, c
        # attention matrix
        # row specifies weights for the value vectors, row add up; to one
        A = F.softmax(QKT / np.sqrt(self.dk), dim=4)  # shape: b, seq, h, c, c
        # new value representation
        V = torch.einsum('bshcd,bshde->bshce', A, V) #shape: b, seq, h, c, dk
        V - V * Gate #shape: b, seq, h, c, dk 
        V = einops.rearrange(V, 'b s h c d-> b c s (h d)') #shape: b, c, seq, h*dk
        output = self.WO(V)
        return output


## MSA Information

In [9]:
class MSA_Information(nn.Module):
    '''Separate Headed Self Attention: List of Attention Heads
    This is a straightforward implementation of the multiple heads.
    We have separate WQ, WK and WV matrices, one per head.'''

    def __init__(self):
        '''create separate heads:
        d: d_model dimension
        dk: projection dimension for query, keys and values
        num_heads: number of attention heads
        '''
        super(MSA_Information, self).__init__()
        self.row_attention = RowAttention(256,32,8)
        self.col_attention = ColAttention(256,32,8)
        #trans
        self.project_up = nn.Linear(256,256*4, bias=False)
        self.relu = nn.ReLU()
        self.project_down = nn.Linear(256*4,256, bias=False)
        #temp_project_down
        self.project_c = nn.Linear(256,4, bias=False) 
        self.project_z = nn.Linear(4*4,128, bias=False) # use 8 instead of 32
        
    def forward(self, msa_rep, pair_rep):
        output_row = self.row_attention(msa_rep, pair_rep)
        output = msa_rep + output_row
        output_col = self.col_attention(output, pair_rep)
        output = output_col + output
        # trans
        output_trans = self.project_up(output)
        output_trans = self.relu(output_trans)
        output_trans = self.project_down(output_trans)
        output = output_trans + output
        output = self.project_c(output)  # project c to calculate outer product

        outer_product = torch.einsum('bcxm,bcyn->bcxymn', output, output) 
        #5*16*256*256*32*32
        outer_product = torch.mean(outer_product,axis=1)
        outer_product = einops.rearrange(outer_product, 'a b c i j-> a b c (i j)')
        #print(outer_product.shape)
        outer_product = self.project_z(outer_product)
        outer_product = outer_product + pair_rep
        return outer_product

## Triangular Multiplicative Update

In [10]:
class Tri_Multi(nn.Module):
    def __init__(self, c_z, c, mode):
        super(Tri_Multi, self).__init__()

        self.A = nn.Linear(c_z, c, bias=False)
        self.B = nn.Linear(c_z, c, bias=False)
        self.G_A = nn.Linear(c_z, c, bias=False)
        self.G_B = nn.Linear(c_z, c, bias=False)
        self.G = nn.Linear(c_z, c, bias=False)
        # project back to original dim
        self.pb = nn.Linear(c, c_z, bias=False)
        self.mode = mode

    def forward(self, x):
        # input dim: B x Nseq x Nseq x c_z
        new_z = torch.clone(x)

        A = self.A(x)
        B = self.B(x)
        if self.mode == 'in':
            A = torch.transpose(A, 1, 2)
            B = torch.transpose(B, 1, 2)

        G_A = torch.sigmoid(self.G_A(x))
        G_B = torch.sigmoid(self.G_B(x))
        G = torch.sigmoid(self.G(x))

        A = A * G_A
        B = B * G_B

        # compute the pair wise element wise
        ele_mult = torch.einsum('bijk,bajk->biajk', A, B)
        new_z = torch.sum(ele_mult, dim=3)

        new_z = self.pb(new_z)
        #print(new_z.shape)

        return new_z

## Triangular Attention (Starting)

In [11]:
class Tri_Attn(nn.Module):
    def __init__(self, c_z, c, mode):
        super(Tri_Attn, self).__init__()

        self.W_Q = nn.Linear(c_z, c, bias=False)
        self.W_K = nn.Linear(c_z, c, bias=False)
        self.W_V = nn.Linear(c_z, c, bias=False)
        self.W_G = nn.Linear(c_z, c, bias=False)
        self.W_B = nn.Linear(c_z, 1, bias=False)

        self.c = c
        self.mode = mode
    
    def forward(self, x):
        # input dim: B x Nseq x Nseq x c_z
        
        Q = self.W_Q(x)
        K = self.W_K(x)
        V = self.W_V(x)
        G = self.W_G(x)
        
        N_seq = x.size(dim=1)
        B = self.W_B(x)
        B = B.reshape((len(x), N_seq, N_seq))
        B = torch.transpose(B, 1, 2)

        # for each batch, pick one slice at a time
        new_z = torch.zeros(len(x), N_seq, N_seq, self.c).to("cuda")
        for i in range(N_seq):
            A_i = torch.matmul(Q[:,i,:,:], torch.transpose(K[:,i,:,:], 1, 2)) \
                / torch.sqrt(torch.tensor(self.c)) + B
            if self.mode == 'end':
                A_i = torch.matmul(Q[:,i,:,:], K[:,i,:,:]) \
                / torch.sqrt(torch.tensor(self.c)) + B
            A_I = F.softmax(A_i, dim=1)
            O_i = G[:,i,:,:] * torch.matmul(A_i, V[:,i,:,:])
            new_z[:,i,:,:] = O_i
        
        return new_z

class SepHead_Tri_Attn(nn.Module):
    def __init__(self, c_z, c, n_heads, mode):
        super(SepHead_Tri_Attn, self).__init__()

        self.n_heads = n_heads
        self.ta_layers = nn.ModuleList()
        for i in range(self.n_heads):
            self.ta_layers.append(Tri_Attn(c_z, c, mode))
        
        self.W_O = nn.Linear(c*n_heads, c_z, bias=False)

    def forward(self, x):
        V = []
        for i in range(self.n_heads):
            V.append(self.ta_layers[i](x))
        
        V = torch.cat(V, dim=3)

        x = self.W_O(V)

        return x

class MultiHead_Tri_Attn(nn.Module):
    def __init__(self, c_z, c, n_heads, mode):
        super(MultiHead_Tri_Attn, self).__init__()

        self.W_Q = nn.Linear(c_z, c*n_heads, bias=False)
        self.W_K = nn.Linear(c_z, c*n_heads, bias=False)
        self.W_V = nn.Linear(c_z, c*n_heads, bias=False)
        self.W_G = nn.Linear(c_z, c*n_heads, bias=False)
        self.W_O = nn.Linear(c*n_heads, c_z, bias=False)
        self.W_B = nn.Linear(c_z, 1*n_heads, bias=False)

        self.n_heads = n_heads
        self.c = c
        self.mode = mode

    def forward(self, x):
        # input dim: B x Nseq x Nseq x c_z
        Q = self.W_Q(x) # size: (B, Nseq, Nseq, n_heads*c)
        K = self.W_K(x) # size: (B, Nseq, Nseq, n_heads*c)
        V = self.W_V(x) # size: (B, Nseq, Nseq, n_heads*c)
        G = self.W_G(x) # size: (B, Nseq, Nseq, n_heads*c)
        B = self.W_B(x) # size: (B, Nseq, Nseq, n_heads*1)

        if self.mode == 'end':
            K = torch.transpose(K, 1, 2)

        Q = einops.rearrange(
            Q, 'b n s (h c) -> b n h s c', h = self.n_heads
        )
        KT = einops.rearrange(
            K, 'b n s (h c) -> b n h c s', h = self.n_heads
        )
        V = einops.rearrange(
            V, 'b n s (h c) -> b n h s c', h = self.n_heads
        )
        G = einops.rearrange(
            G, 'b n s (h c) -> b n h s c', h = self.n_heads
        )

        N_seq = x.size(dim=1)
        B = einops.rearrange(
            B, 'b n s (h c) -> b h n s c', h = self.n_heads
        )
        B = torch.squeeze(B)
        B = B.repeat(N_seq, 1, 1, 1, 1)
        B = einops.rearrange(
            B, 'x b h n s -> b x h n s', h = self.n_heads
        )

        QKT = torch.einsum('bnhij,bnhjk->bnhik', Q, KT)
        A = F.softmax((QKT + B) / np.sqrt(self.c), dim=4)

        V = torch.einsum('bnhij,bnhjk->bnhik', A, V)
        V = V * G
        V = einops.rearrange(V, 'b n h s c -> b n s (h c)')

        x = self.W_O(V)

        return x

## Triangular Attention Block

In [12]:
class Tri_Attn_Block(nn.Module):
    def __init__(self):
        super(Tri_Attn_Block, self).__init__()

        self.tri_in = Tri_Multi(128, 16, 'out')
        self.tri_out = Tri_Multi(128, 16, 'in')
        self.tri_start = MultiHead_Tri_Attn(128, 16, 4, 'start')
        self.tri_end = MultiHead_Tri_Attn(128, 16, 4, 'end')
        self.trans_up = nn.Linear(128, 128*4, bias=False)
        self.trans_down = nn.Linear(128*4, 128, bias=False)
    
    def forward(self, x):
        tri_in_out = self.tri_in(x)
        x = x + tri_in_out
        tri_out_out = self.tri_out(x)
        x = x + tri_out_out
        tri_start_out = self.tri_start(x)
        x = x + tri_start_out
        tri_end_out = self.tri_end(x)
        x = x + tri_end_out
        x = self.trans_up(x)
        x = self.trans_down(x)

        return x

## Alphafold 2

In [13]:
class Alphafold2(nn.Module):
    def __init__(self):
        super(Alphafold2, self).__init__()

        self.pr = PairRep()
        self.msa = MSA(16)
        self.msa_info = MSA_Information()
        self.tri_attn_b = Tri_Attn_Block()
        self.fc = nn.Linear(128, 64, bias=False)
        self.fc_angle = nn.Linear(128,1296,bias=False)
        self.max_pool_i = nn.MaxPool2d(kernel_size=(256,1))
        self.max_pool_j = nn.MaxPool2d(kernel_size=(1,256))

    def forward(self, seq, evo):
        pair_rep = self.pr(seq)
        msa_rep = self.msa(evo)
        Z = self.msa_info(msa_rep, pair_rep)
        Z = self.tri_attn_b(Z)
        Angle = self.fc_angle(Z)
        Angle = einops.rearrange(Angle, 'b s t c-> b c s t')
        Angle_i = self.max_pool_i(Angle).squeeze(2)
        Angle_j = self.max_pool_j(Angle).squeeze(3)
        Angle_out = torch.cat((Angle_i, Angle_j),dim=2)

        return Z, Angle_out

## Define Bin

In [14]:
bin = []
n = 2
gap = 20 / 64
while n < 22:
    bin.append(n)
    n += gap

angle_bin_w = 6.30/36
angle_bin = torch.arange(-3.15+angle_bin_w,3.15+angle_bin_w,angle_bin_w)

## Define Dataset Class

In [15]:
def get_angle_bin(input_angles):
  indices = torch.searchsorted(angle_bin,input_angles)
  angles_discrete = indices[:,:,0]*36 + indices[:,:,1]
  angles_discrete = angles_discrete
  return angles_discrete

def prepossess_data(seqs, evos, angs, masks, dmats, dmat_masks, overlap=32):
  # zero padding the data block, distance map and mask
  seq_len = seqs.size(dim=1)
  m = nn.ZeroPad2d((0, 0, 128, 128))
  seqs = m(seqs)
  evos = m(evos)
  angs = get_angle_bin(angs.contiguous())

  angs = F.pad(angs,(128,128), "constant", 0)
  masks = F.pad(masks,(128,128), "constant", 0)
  m = nn.ZeroPad2d(128)
  dmats_p = m(dmats)
  dmat_masks_p = m(dmat_masks)
  start_pos = torch.randint(32, (1,))
  start_pos = int(start_pos[0])
  pos_x = start_pos
  if seq_len < 40:
    pos_x = 0
  sc_list = list()
  ec_list = list()
  ang_list = list()
  mask_list = list()
  dmat_crop_list = list()
  d_mask_crop_list = list()
  for i in range(pos_x, seq_len, overlap):
    seq_crop = seqs[:,pos_x:pos_x+256,:]
    evo_crop = evos[:,pos_x:pos_x+256,:]
    ang_crop = angs[:,pos_x:pos_x+256]
    mask_crop = masks[:,pos_x:pos_x+256]
   
    #seq_crop = PairRep(seq_crop)
    dmat_crop = T.functional.crop(dmats_p, i, i, 256, 256)
    dmat_crop = np.searchsorted(bin, dmat_crop)
    dmat_crop[dmat_crop > 63] = 63
    dmat_crop = torch.tensor(dmat_crop.tolist())
    d_mask_crop = T.functional.crop(dmat_masks_p, i, i, 256, 256)
    for j in range(seq_crop.shape[0]):
      sc_list.append(seq_crop[j])
      ec_list.append(evo_crop[j])
      ang_list.append(ang_crop[j])
      mask_list.append(mask_crop[j])
      dmat_crop_list.append(dmat_crop[j])
      d_mask_crop_list.append(d_mask_crop[j])
  return sc_list, ec_list, ang_list, mask_list, dmat_crop_list, d_mask_crop_list
  

In [16]:

class AlphaFold_IterableDataset(torch.utils.data.IterableDataset):
  def __init__(self, raw_data, num_workers,overlap=32):
    super(AlphaFold_IterableDataset).__init__()
    self.raw_data = raw_data 
    #how many workers
    if num_workers <= 0:
      self.num_workers = 1
    else:
      self.num_workers = num_workers

    self.raw_data_list = [[]*4 for _ in range(self.num_workers)]
    self.overlap = overlap
 
    for idx, raw_data_info in enumerate(self.raw_data):
      parallel_idx = idx % self.num_workers
      self.raw_data_list[parallel_idx].append(raw_data_info)
  def __iter__(self):
    worker = torch.utils.data.get_worker_info()
    if worker is not None:
        worker_id = worker.id
        num_workers = worker.num_workers
    else:
        worker_id = 0
        num_workers = 1
    for m, batch in enumerate(self.raw_data_list[worker_id]):
      seqs, evos, angs, masks, dmats, dmat_masks = get_seq_features(batch)
      sc_list, ec_list, ang_list, mask_list, dmat_crop_list, d_mask_crop_list = prepossess_data(seqs, evos, angs, masks, dmats, dmat_masks, overlap=32)
      for i, data_info in enumerate(zip(sc_list, ec_list, ang_list, mask_list, dmat_crop_list, d_mask_crop_list)):
          yield data_info

  def __len__(self):
      return len(self.raw_data) * 26 * int(207 / self.overlap)

In [17]:
def save_model(model, optimizer):
    checkpoint = {
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict()}
    torch.save(checkpoint, 'checkpoint_2epoch.pth')

In [18]:
model = Alphafold2()
model = model.to(device)
checkpoint = torch.load('checkpoint_2epoch.pth')
model.load_state_dict(checkpoint['state_dict'])
model.train()
optimizer = optim.Adam(model.parameters(), lr=0.00008)
optimizer.load_state_dict(checkpoint['optimizer'])
count = 0
dataset = AlphaFold_IterableDataset(data['train'], num_workers=4)
batch_size = 3
train_loader = DataLoader(dataset, batch_size=batch_size,shuffle=False, num_workers=4)
epochs = 1
final_loss = None
for epoch in range(1, epochs + 1):
  sum_loss = 0
  for batch_idx, (seq, evo, ang,mask, dmat, dmats_mask) in enumerate(tqdm(train_loader)):
    seq, evo, ang, mask, dmat, dmats_mask = seq.to(device), evo.to(device), \
    ang.to(device), mask.to(device), dmat.to(device), dmats_mask.to(device)
    optimizer.zero_grad()
    output_dist, output_angle = model(seq, evo)
    
    dmat = dmat * dmats_mask
    ang = ang *mask

    ang_label = torch.cat((ang,ang),dim=1)
    pred = einops.rearrange(output_dist, 'b h w c -> b c h w')
    loss_dist = F.cross_entropy(pred, dmat)
    loss_angle = F.cross_entropy(output_angle,ang_label)
    loss = loss_dist + loss_angle * 0.001
    sum_loss += loss.item()
    loss.backward()
    optimizer.step()

    del seq
    del evo
    del ang 
    del mask
    del dmat
    del dmats_mask
    gc.collect()
    #print(loss.item())
  final_loss = sum_loss / len(train_loader)
  print('epoch: {}, loss: {:.3f}'.format(epoch, final_loss))
  save_model(model, optimizer)

print('Final loss {:.3f}'.format(final_loss))
print("Finish")

 65%|██████▌   | 21448/32864 [3:13:44<1:43:07,  1.85it/s]


epoch: 1, loss: 0.813
Final loss 0.813
Finish


## Dataset For Testing

In [19]:
class Alphafold_Dataset_Test(IterableDataset):
    def __init__(self, seqs, evos, overlap=32):
        super(Alphafold_Dataset_Test).__init__()

        self.seq_len = seqs.size(dim=1)
        self.batch_size = seqs.size(dim=0)

        m = nn.ZeroPad2d((0, 0, 128, 128))
        self.seqs = m(seqs)
        self.evos = m(evos)

        self.overlap = overlap
        
    def __iter__(self):
        start_pos = torch.randint(32, (1,))
        start_pos = int(start_pos[0])
        pos_x = start_pos
        if self.seq_len < 36:
            pos_x = 0
        for i in range(pos_x, self.seq_len, self.overlap):
            seq_crop = self.seqs[:,pos_x:pos_x+256,:]
            evo_crop = self.evos[:,pos_x:pos_x+256,:]
            for j in range(self.batch_size):
                yield seq_crop[j], evo_crop[j], i, i

## Function For Validation and Testing

In [20]:
from sklearn.metrics import confusion_matrix
from sklearn.metrics import recall_score
def get_recall(pred, true, seq_len):
    pred = np.array(pred)
    true = np.array(true)
    num_true = np.count_nonzero(true > 0.5)
    denom = min(num_true, seq_len)
    short_pred_sort = np.argsort(-1*pred)
    short_pred_top = pred[short_pred_sort][0:denom]
    short_pred_top[short_pred_top>0.5] = 1
    short_pred_top[short_pred_top!=1] = 0
    short_true_top = true[short_pred_sort][0:denom]
    
    # print(short_true_top)
    # print(short_pred_top)
    recall = recall_score(short_true_top, short_pred_top, average='binary', zero_division=0)
    
    return recall
    #return

In [23]:
recall_short_L = 0
recall_short_L_2 = 0
recall_short_L_5 = 0
recall_medium_L = 0
recall_medium_L_2 = 0
recall_medium_L_5 = 0
recall_long_L = 0
recall_long_L_2 = 0
recall_long_L_5 = 0
def valid_test(seqs, evos, dmats, dmat_masks, model):
    batch_size = 1
    seq_len = seqs.size(dim=1)
    #print(seq_len)
    dataset = Alphafold_Dataset_Test(seqs, evos)
    valid_loader = du.DataLoader(dataset=dataset, batch_size=batch_size)
    
    #contact_true = np.searchsorted(bin, dmats)
    contact_true = torch.clone(dmats)
    contact_true[contact_true <= 8.0] = 1
    contact_true[contact_true != 1] = 0
    contact_true = torch.tensor(contact_true.tolist()[0]).to(torch.int)
    
    global recall_short_L
    global recall_short_L_2
    global recall_short_L_5
    global recall_medium_L
    global recall_medium_L_2
    global recall_medium_L_5
    global recall_long_L
    global recall_long_L_2
    global recall_long_L_5
    
    # store the tensors for the position when overlapping
    pos_tensor = dict()
    with torch.no_grad():
        for batch_idx, (seq, evo, pos_i, pos_j) in enumerate(valid_loader):
            seq, evo = seq.to(device), evo.to(device)
            
            output, ang_output = model(seq, evo)
            #output = einops.rearrange(output, 'b h w c -> b c h w')
            output = F.softmax(output, dim=3)
            
            pos_i, pos_j = int(pos_i[0]), int(pos_j[0])
            
            for idx_i, (i) in enumerate(range(pos_i, pos_i+256)):
                for idx_j, (j) in enumerate(range(pos_j, pos_j+256)):
                    if (i, j) not in pos_tensor:
                        pos_tensor[(i, j)] = []
                        pos_tensor[(i, j)].append(output[0,idx_i,idx_j,:])
                    else:
                        pos_tensor[(i, j)].append(output[0,idx_i,idx_j,:])
    
    # compute the mean probability
    pred = []
    for i in range(128, 128+seq_len):
        row = []
        for j in range(128, 128+seq_len):
            if (i, j) not in pos_tensor:
                row.append(0)
            else:
                pos_tensor[(i, j)] = sum(pos_tensor[(i, j)]) / \
                                    len(pos_tensor[(i, j)])
                row.append(torch.sum(pos_tensor[(i, j)][0:20]))
        pred.append(row)
    
    pred = torch.tensor(pred)
    pred = pred * dmat_masks[0]
    
    #print(pred)
    short_pred = []
    short_true = []
    # short contact
    for i in range(6, 12):
        diag = torch.diagonal(pred, i, dim1=0, dim2=1)
        short_pred.extend(diag.tolist())
        diag = torch.diagonal(contact_true, i, dim1=0, dim2=1)
        short_true.extend(diag.tolist())
    
    # compute L size
    get_recall(short_pred, short_true, seq_len)
    res = get_recall(short_pred, short_true, seq_len)
    recall_short_L += res
    
    # L / 2
    res = get_recall(short_pred, short_true, int(seq_len/2))
    recall_short_L_2 += res
    
    # L / 5
    res = get_recall(short_pred, short_true, int(seq_len/5))
    recall_short_L_5 += res
    
    medium_pred = []
    medium_true = []
    # medium contact
    for i in range(12, 24):
        diag = torch.diagonal(pred, i, dim1=0, dim2=1)
        medium_pred.extend(diag.tolist())
        diag = torch.diagonal(contact_true, i, dim1=0, dim2=1)
        medium_true.extend(diag.tolist())
    
    # compute L size
    res = get_recall(medium_pred, medium_true, seq_len)
    recall_medium_L += res
    
    # L / 2
    res = get_recall(medium_pred, medium_true, int(seq_len/2))
    recall_medium_L_2 += res
    
    # L / 5
    res = get_recall(medium_pred, medium_true, int(seq_len/5))
    recall_medium_L_5 += res
    
    long_pred = []
    long_true = []
    # long contact
    for i in range(24, seq_len):
        diag = torch.diagonal(pred, i, dim1=0, dim2=1)
        long_pred.extend(diag.tolist())
        diag = torch.diagonal(contact_true, i, dim1=0, dim2=1)
        long_true.extend(diag.tolist())
    
    # compute L size
    res = get_recall(long_pred, long_true, seq_len)
    recall_long_L += res
    
    # L / 2
    res = get_recall(long_pred, long_true, int(seq_len/2))
    recall_long_L_2 += res
    
    # L / 5
    res = get_recall(long_pred, long_true, int(seq_len/5))
    recall_long_L_5 += res

    return

## Validation

In [24]:
data = scn.load(casp_version=7, with_pytorch="dataloaders", 
                seq_as_onehot=True, aggregate_model_input=False,
               batch_size=1)
checkpoint = torch.load('checkpoint_2epoch.pth')
model.load_state_dict(checkpoint['state_dict'])
model.eval()
count = 0
for batch in tqdm(data['valid-10']):
    seqs, evos, angs, masks, dmats, dmat_masks = get_seq_features(batch)
    valid_test(seqs, evos, dmats, dmat_masks, model)

    gc.collect()
    
    count += 1
    # if count == 1:
    #     break
print('Validation short accuracy L: {:.3f}'.format(recall_short_L/count))
print('Validation short accuracy L/2: {:.3f}'.format(recall_short_L_2/count))
print('Validation short accuracy L/5: {:.3f}'.format(recall_short_L_5/count))
print('Validation medium accuracy L: {:.3f}'.format(recall_medium_L/count))
print('Validation medium accuracy L/2: {:.3f}'.format(recall_medium_L_2/count))
print('Validation medium accuracy L/5: {:.3f}'.format(recall_medium_L_5/count))
print('Validation long accuracy L: {:.3f}'.format(recall_long_L/count))
print('Validation long accuracy L/2: {:.3f}'.format(recall_long_L_2/count))
print('Validation long accuracy L/5: {:.3f}'.format(recall_long_L_5/count))

SidechainNet was loaded from ./sidechainnet_data/sidechainnet_casp7_30.pkl.


100%|██████████| 32/32 [03:28<00:00,  6.52s/it]

Validation short accuracy L: 0.594
Validation short accuracy L/2: 0.594
Validation short accuracy L/5: 0.500
Validation medium accuracy L: 0.531
Validation medium accuracy L/2: 0.500
Validation medium accuracy L/5: 0.375
Validation long accuracy L: 0.469
Validation long accuracy L/2: 0.438
Validation long accuracy L/5: 0.281





## Test

In [26]:
count = 0
recall_short_L = 0
recall_short_L_2 = 0
recall_short_L_5 = 0
recall_medium_L = 0
recall_medium_L_2 = 0
recall_medium_L_5 = 0
recall_long_L = 0
recall_long_L_2 = 0
recall_long_L_5 = 0
for batch in tqdm(data['test']):
    seqs, evos, angs, masks, dmats, dmat_masks = get_seq_features(batch)
    valid_test(seqs, evos, dmats, dmat_masks, model)

    gc.collect()
    
    count += 1

print('Test short accuracy L: {:.3f}'.format(recall_short_L/count))
print('Test short accuracy L/2: {:.3f}'.format(recall_short_L_2/count))
print('Test short accuracy L/5: {:.3f}'.format(recall_short_L_5/count))
print('Test medium accuracy L: {:.3f}'.format(recall_medium_L/count))
print('Test medium accuracy L/2: {:.3f}'.format(recall_medium_L_2/count))
print('Test medium accuracy L/5: {:.3f}'.format(recall_medium_L_5/count))
print('Test long accuracy L: {:.3f}'.format(recall_long_L/count))
print('Test long accuracy L/2: {:.3f}'.format(recall_long_L_2/count))
print('Test long accuracy L/5: {:.3f}'.format(recall_long_L_5/count))

100%|██████████| 93/93 [09:58<00:00,  6.44s/it]

Test short accuracy L: 0.792
Test short accuracy L/2: 0.781
Test short accuracy L/5: 0.656
Test medium accuracy L: 0.710
Test medium accuracy L/2: 0.688
Test medium accuracy L/5: 0.548
Test long accuracy L: 0.774
Test long accuracy L/2: 0.581
Test long accuracy L/5: 0.376



