In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as du
from torchvision import transforms as T
import os
import numpy as np
import time
from tqdm import tqdm
from torch.utils import data
import sidechainnet as scn
import einops
import gc

## Build Model

In [2]:
class Dilation_block(nn.Module):
    def __init__(self, dilation):
        super(Dilation_block, self).__init__()

        self.bn1 = nn.BatchNorm2d(128)
        self.pd = nn.Conv2d(128, 64, 1)
        self.bn2 = nn.BatchNorm2d(64)
        self.d_conv = nn.Conv2d(64, 64, 1, dilation=dilation)
        self.bn3 = nn.BatchNorm2d(64)
        self.pu = nn.Conv2d(64, 128, 1)
    
    def forward(self, data):
        x = F.elu(self.bn1(data))
        x = self.pd(x)
        x = F.elu(self.bn2(x))
        x = self.d_conv(x)
        x = F.elu(self.bn3(x))
        x = self.pu(x)

        return x + data

In [3]:
class Dialation_blocks(nn.Module):
    def __init__(self):
        super(Dialation_blocks, self).__init__()

        self.db1 = Dilation_block(1)
        self.db2 = Dilation_block(2)
        self.db3 = Dilation_block(4)
        self.db4 = Dilation_block(8)
    
    def forward(self, data):
        x = self.db1(data)
        x = self.db2(x)
        x = self.db3(x)
        x = self.db4(x)

        return x

In [4]:
class Alphafold1(nn.Module):
    def __init__(self, num_blocks, device):
        super(Alphafold1, self).__init__()

        self.conv1 = nn.Conv2d(124, 128, 1)

        self.blocks = []
        for i in range(num_blocks):
            self.blocks.append(Dialation_blocks().to(device))

        self.conv2 = nn.Conv2d(128, 64, 1)

    def forward(self, data):
        x = self.conv1(data)
        
        for block in self.blocks:
            x = block(x)

        x = self.conv2(x)

        return x

## Define Bin

In [5]:
bin = []
n = 2
gap = 20 / 64
while n < 22:
    bin.append(n)
    n += gap

## Iterable Version of Training Dataset

In [6]:
from torch.utils.data import IterableDataset
import math
class Alphafold_iter_Dataset(IterableDataset):
    def __init__(self, data_block, dmats, dmat_masks, overlap=32):
        super(Alphafold_iter_Dataset).__init__()
        
        m = nn.ZeroPad2d(32)
        # change the data format to fit the input of conv2d
        self.data_block = einops.rearrange(data_block, 'b h w c-> b c h w')
        self.seq_len = self.data_block.size(dim=2)
        self.batch_size = self.data_block.size(dim=0)

        # zero padding the data block, distance map and mask
        self.data_block = m(self.data_block)
        self.dmats_p = m(dmats)
        self.dmat_masks_p = m(dmat_masks)
        
        self.overlap = overlap
        
    def __iter__(self):
        start_pos = torch.randint(32, (2,))
        start_pos = (int(start_pos[0]), int(start_pos[1]))
        pos_x, pos_y = start_pos
        if pos_x > pos_y:
            tmp = pos_x
            pos_x = pos_y
            pos_y = tmp
        for i in range(pos_x, self.seq_len, self.overlap):
            for j in range(pos_y, self.seq_len, self.overlap):
                seq_crop = T.functional.crop(self.data_block, i, j, 64, 64)
                dmat_crop = T.functional.crop(self.dmats_p, i, j, 64, 64)
                dmat_crop = np.searchsorted(bin, dmat_crop)
                dmat_crop[dmat_crop > 63] = 63
                dmat_crop = torch.tensor(dmat_crop.tolist())
                d_mask_crop = T.functional.crop(self.dmat_masks_p, i, j, 64, 64)
                for k in range(self.batch_size):
                    yield seq_crop[k], dmat_crop[k], d_mask_crop[k]
    
    def __len__(self):
        return (math.ceil(self.seq_len / self.overlap) ** 2) * self.batch_size

## Define Dataset for Validation And Testing

In [7]:
class Alphafold_Dataset_Test(IterableDataset):
    def __init__(self, data_block, overlap=32):
        super(Alphafold_Dataset_Test).__init__()

        m = nn.ZeroPad2d(32)
        # change the data format to fit the input of conv2d
        self.data_block = einops.rearrange(data_block, 'b h w c-> b c h w')
        self.seq_len = self.data_block.size(dim=2)
        self.batch_size = self.data_block.size(dim=0)

        # zero padding the data block, distance map and mask
        self.data_block = m(self.data_block)
        self.overlap = overlap
        
    def __iter__(self):
        start_pos = torch.randint(32, (2,))
        start_pos = (int(start_pos[0]), int(start_pos[1]))
        pos_x, pos_y = start_pos
        if pos_x > pos_y:
            tmp = pos_x
            pos_x = pos_y
            pos_y = tmp
        if self.seq_len < 32:
            pos_x, pos_y = 0, 0
        for i in range(pos_x, self.seq_len, self.overlap):
            for j in range(pos_y, self.seq_len, self.overlap):
                seq_crop = T.functional.crop(self.data_block, i, j, 64, 64)
                for k in range(self.batch_size):
                    yield seq_crop[k], i, j
    
    def __len__(self):
        return (math.ceil(self.seq_len / self.overlap) ** 2) * self.batch_size

## Load CASP7 data as pytorch tensors

In [8]:
data = scn.load(casp_version=7, with_pytorch="dataloaders", 
                seq_as_onehot=True, aggregate_model_input=False,
               batch_size=16)

SidechainNet was loaded from ./sidechainnet_data/sidechainnet_casp7_30.pkl.


## Creates features for a batch of sequences

In [9]:
def get_seq_features(batch):
    '''
    Take a batch of sequence info and return the sequence (one-hot),
    evolutionary info and (phi, psi, omega) angles per position, 
    as well as position mask.
    Also return the distance matrix, and distance mask.
    '''
    str_seqs = batch.str_seqs # seq in str format
    seqs = batch.seqs # seq in one-hot format
    int_seqs = batch.int_seqs # seq in int format
    masks = batch.msks # which positions are valid
    lengths = batch.lengths # seq length
    evos = batch.evos # PSSM / evolutionary info
    angs = batch.angs[:,:,0:2] # torsion angles: phi, psi
    
    # use coords to create distance matrix from c-beta
    # except use c-alpha for G
    # coords[:, 4, :] is c-beta, and coords[:, 1, :] is c-alpha
    coords = batch.crds # seq coord info (all-atom)
    batch_xyz = []
    for i in range(coords.shape[0]):
        xyz = []
        xyz = [coords[i][cpos+4,:] 
                if masks[i][cpos//14] and str_seqs[i][cpos//14] != 'G'
                else coords[i][cpos+1,:]
                for cpos in range(0, coords[i].shape[0]-1, 14)]
        batch_xyz.append(torch.stack(xyz))
    batch_xyz = torch.stack(batch_xyz)
    # now create pairwise distance matrix
    dmats = torch.cdist(batch_xyz, batch_xyz)
    # create matrix mask (0 means i,j invalid)
    dmat_masks = torch.einsum('bi,bj->bij', masks, masks)
    
    return seqs, evos, angs, masks, dmats, dmat_masks

## Model Save

In [10]:
def save_model(count, final_loss, model, optimizer):
    checkpoint = {
        'protein_num': count,
        'loss': final_loss,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict()}
    torch.save(checkpoint, 'checkpoint.pth')


## Function To Set Up Training

In [11]:
def train(data_block, dmats, dmat_masks, device, model, optimizer):
    batch_size = 20
    epochs = 5
    final_loss = None
    train_loader = du.DataLoader(dataset=Alphafold_iter_Dataset(data_block, 
                                dmats, dmat_masks), batch_size=batch_size)
    for epoch in range(1, epochs + 1):
        sum_loss = 0
        for batch_idx, (data, dmat, dmats_mask) in enumerate(train_loader):
            
            data, dmat, dmats_mask = data.to(device), \
            dmat.to(device), dmats_mask.to(device)
            optimizer.zero_grad()

            output = model(data)

            dmat = dmat * dmats_mask
            loss = F.cross_entropy(output, dmat)
            sum_loss += loss.item()
            loss.backward()
            optimizer.step()

            del data
            del output
            del dmat
            del dmats_mask
            gc.collect()

            #break
        final_loss = sum_loss / len(train_loader)
        print('epoch: {}, loss: {:.3f}'.format(epoch, final_loss))
        #break
    print('Final loss {:.3f}'.format(final_loss))
    return final_loss


## Function For Validation and Testing

In [12]:
from sklearn.metrics import confusion_matrix
def get_recall(pred, true, seq_len):
    pred = np.array(pred)
    true = np.array(true)
    num_true = np.count_nonzero(pred > 0.5)
    denom = min(num_true, seq_len)
    short_pred_sort = np.argsort(-1*pred)
    short_pred_top = pred[short_pred_sort][0:denom]
    short_pred_top[short_pred_top>0.5] = 1
    short_pred_top[short_pred_top!=1] = 0
    short_true_top = true[short_pred_sort][0:denom]
    
    tn, fp, fn, tp = confusion_matrix(short_pred_top, short_true_top).ravel()
    res = None
    if tp+fp == 0:
        res = 0
    else:
        res = tp / (tp+fp)
    
    return res

In [13]:
recall_short_L = 0
recall_short_L_2 = 0
recall_short_L_5 = 0
recall_medium_L = 0
recall_medium_L_2 = 0
recall_medium_L_5 = 0
recall_long_L = 0
recall_long_L_2 = 0
recall_long_L_5 = 0
def valid_test(data_block, dmats, dmat_masks, device, model):
    batch_size = 1
    seq_len = data_block.size(dim=2)
    #print(seq_len)
    valid_loader = du.DataLoader(dataset=Alphafold_Dataset_Test(data_block), 
                                 batch_size=batch_size)
    
    #contact_true = np.searchsorted(bin, dmats)
    contact_true = torch.clone(dmats)
    contact_true[contact_true <= 8.0] = 1
    contact_true[contact_true != 1] = 0
    contact_true = torch.tensor(contact_true.tolist()[0]).to(torch.int)
    
    global recall_short_L
    global recall_short_L_2
    global recall_short_L_5
    global recall_medium_L
    global recall_medium_L_2
    global recall_medium_L_5
    global recall_long_L
    global recall_long_L_2
    global recall_long_L_5
    
    # store the tensors for the position when overlapping
    pos_tensor = dict()
    with torch.no_grad():
        for batch_idx, (data, pos_i, pos_j) in enumerate(valid_loader):
            data = data.to(device)
            
            output = model(data)
            output = F.softmax(output, dim=1)
            
            pos_i, pos_j = int(pos_i[0]), int(pos_j[0])
            
            for idx_i, (i) in enumerate(range(pos_i, pos_i+64)):
                for idx_j, (j) in enumerate(range(pos_j, pos_j+64)):
                    if (i, j) not in pos_tensor:
                        pos_tensor[(i, j)] = []
                        pos_tensor[(i, j)].append(output[0,:,idx_i,idx_j])
                    else:
                        pos_tensor[(i, j)].append(output[0,:,idx_i,idx_j])
    
    # compute the mean probability
    pred = []
    for i in range(32, 32+seq_len):
        row = []
        for j in range(32, 32+seq_len):
            pos_tensor[(i, j)] = sum(pos_tensor[(i, j)]) / \
                                len(pos_tensor[(i, j)])
            row.append(torch.sum(pos_tensor[(i, j)][0:20]))
        pred.append(row)
    
    pred = torch.tensor(pred)
    pred = pred * dmat_masks[0]
    
    short_pred = []
    short_true = []
    # short contact
    for i in range(6, 12):
        diag = torch.diagonal(pred, i, dim1=0, dim2=1)
        short_pred.extend(diag.tolist())
        diag = torch.diagonal(contact_true, i, dim1=0, dim2=1)
        short_true.extend(diag.tolist())
    
    # compute L size
    res = get_recall(short_pred, short_true, seq_len)
    recall_short_L += res
    
    # L / 2
    res = get_recall(short_pred, short_true, int(seq_len/2))
    recall_short_L_2 += res
    
    # L / 5
    res = get_recall(short_pred, short_true, int(seq_len/5))
    recall_short_L_5 += res
    
    medium_pred = []
    medium_true = []
    # medium contact
    for i in range(12, 24):
        diag = torch.diagonal(pred, i, dim1=0, dim2=1)
        medium_pred.extend(diag.tolist())
        diag = torch.diagonal(contact_true, i, dim1=0, dim2=1)
        medium_true.extend(diag.tolist())
    
    # compute L size
    res = get_recall(medium_pred, medium_true, seq_len)
    recall_medium_L += res
    
    # L / 2
    res = get_recall(medium_pred, medium_true, int(seq_len/2))
    recall_medium_L_2 += res
    
    # L / 5
    res = get_recall(medium_pred, medium_true, int(seq_len/5))
    recall_medium_L_5 += res
    
    long_pred = []
    long_true = []
    # long contact
    for i in range(24, seq_len):
        diag = torch.diagonal(pred, i, dim1=0, dim2=1)
        long_pred.extend(diag.tolist())
        diag = torch.diagonal(contact_true, i, dim1=0, dim2=1)
        long_true.extend(diag.tolist())
    
    # compute L size
    res = get_recall(long_pred, long_true, seq_len)
    recall_long_L += res
    
    # L / 2
    res = get_recall(long_pred, long_true, int(seq_len/2))
    recall_long_L_2 += res
    
    # L / 5
    res = get_recall(long_pred, long_true, int(seq_len/5))
    recall_long_L_5 += res

    return

## Now iterate through the train data

Do the same for data['test'] and data['valid-10']

In [14]:
count = 0
device = f'cuda:0' if torch.cuda.is_available() else 'cpu'
print(f"using device: {device}")
# number of blocks needed
num_blocks = 2
model = Alphafold1(num_blocks, device)
model = model.to(device)
model.train()
optimizer = optim.SGD(model.parameters(), lr=0.06, momentum=0.9)

using device: cuda:0


In [11]:
for batch in data['train']:
    seqs, evos, angs, masks, dmats, dmat_masks = get_seq_features(batch)
    # print(seqs.shape, evos.shape, angs.shape, masks.shape)
    # print(dmats.shape, dmat_masks.shape)
    
    '''
    Now write code to create the crops from each protein in the batch
    to create the final batch of crops along with true "labels" 
    for the distances and angles.
    
    dmat should be discretized into 64 bins, and angs into 1296 bins, 
    as described in lecture, to create true labels/bins.
    
    The input to you model should only be the seqs, and evos features.
    Read sidechainnet github page for more info.
    '''
    # generate sequence block with L×L×f
    seq_len = seqs.size(dim=1)
    seqs_evos = torch.cat((seqs, evos), 2)
    #print(seqs_evos.shape)
    seq_block = torch.tensor(seq_len*[seqs_evos.tolist()])
    #print(seq_block.shape)
    seq_block = torch.transpose(seq_block, 0, 1)
    #seq_block = einops.rearrange(seq_block, 'l b h f -> b l h f')
    seq_block_t = torch.transpose(seq_block, 1, 2)
    seq_block = torch.cat((seq_block_t, seq_block), 3)

    evo_block = torch.tensor(seq_len*[evos.tolist()])
    evo_block = torch.transpose(evo_block, 0, 1)
    #evo_block = einops.rearrange(evo_block, 'l b h f -> b l h f')
    evo_block_t = torch.transpose(evo_block, 1, 2)
    final_block = torch.cat((seq_block, evo_block-evo_block_t, 
                             evo_block * evo_block_t), 3)

    print('protein number {}, length {}'.format(count, seq_len))
    # if seq_len > 800:
    #     continue

    final_loss = train(final_block, dmats, dmat_masks, device, model, optimizer)
    save_model(count, final_loss, model, optimizer)

    del seq_block
    del seq_block_t
    del final_block
    del evo_block
    del evo_block_t
    gc.collect()

    # if count == 0:
    #     break
    
    count += 1
    

using device: cuda:0
protein number 0, length 498
epoch: 1, loss: 1.732
epoch: 2, loss: 1.641
epoch: 3, loss: 1.458
epoch: 4, loss: 1.508
epoch: 5, loss: 1.472
Final loss 1.472
protein number 1, length 143
epoch: 1, loss: 2.103
epoch: 2, loss: 1.993
epoch: 3, loss: 2.206
epoch: 4, loss: 1.870
epoch: 5, loss: 1.870
Final loss 1.870
protein number 2, length 497
epoch: 1, loss: 1.468
epoch: 2, loss: 1.448
epoch: 3, loss: 1.468
epoch: 4, loss: 1.435
epoch: 5, loss: 1.380
Final loss 1.380
protein number 3, length 66
epoch: 1, loss: 0.927
epoch: 2, loss: 0.875
epoch: 3, loss: 1.076
epoch: 4, loss: 1.025
epoch: 5, loss: 0.801
Final loss 0.801
protein number 4, length 266
epoch: 1, loss: 1.835
epoch: 2, loss: 1.704
epoch: 3, loss: 1.693
epoch: 4, loss: 1.703
epoch: 5, loss: 1.804
Final loss 1.804
protein number 5, length 219
epoch: 1, loss: 2.101
epoch: 2, loss: 2.100
epoch: 3, loss: 1.948
epoch: 4, loss: 1.929
epoch: 5, loss: 2.075
Final loss 2.075
protein number 6, length 128
epoch: 1, loss:

## Validation

In [15]:
data = scn.load(casp_version=7, with_pytorch="dataloaders", 
                seq_as_onehot=True, aggregate_model_input=False,
               batch_size=1)
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['state_dict'])
model.eval()
count = 0
for batch in tqdm(data['valid-10']):
    seqs, evos, angs, masks, dmats, dmat_masks = get_seq_features(batch)
    #print(seqs.shape, evos.shape, angs.shape, masks.shape)
    # print(dmats.shape, dmat_masks.shape)
    #print(dmats)
    # generate sequence block with L×L×f
    seq_len = seqs.size(dim=1)
    seqs_evos = torch.cat((seqs, evos), 2)
    #print(seqs_evos.shape)
    seq_block = torch.tensor(seq_len*[seqs_evos.tolist()])
    #print(seq_block.shape)
    seq_block = torch.transpose(seq_block, 0, 1)
    #seq_block = einops.rearrange(seq_block, 'l b h f -> b l h f')
    seq_block_t = torch.transpose(seq_block, 1, 2)
    seq_block = torch.cat((seq_block_t, seq_block), 3)

    evo_block = torch.tensor(seq_len*[evos.tolist()])
    evo_block = torch.transpose(evo_block, 0, 1)
    #evo_block = einops.rearrange(evo_block, 'l b h f -> b l h f')
    evo_block_t = torch.transpose(evo_block, 1, 2)
    final_block = torch.cat((seq_block, evo_block-evo_block_t, 
                             evo_block * evo_block_t), 3)
    #print(final_block.shape)
    valid_test(final_block, dmats, dmat_masks, device, model)
    # t_short_acc += short_acc
    # t_medium_acc += medium_acc
    # t_long_acc += long_acc
    #print('protein number {}, length {}'.format(count, seq_len))
    del seq_block
    del seq_block_t
    del final_block
    del evo_block
    del evo_block_t
    gc.collect()
    
    count += 1
    # if count == 5:
    #     break
    #break

print('Validation short accuracy L: {:.3f}'.format(recall_short_L/count))
print('Validation short accuracy L/2: {:.3f}'.format(recall_short_L_2/count))
print('Validation short accuracy L/5: {:.3f}'.format(recall_short_L_5/count))
print('Validation medium accuracy L: {:.3f}'.format(recall_medium_L/count))
print('Validation medium accuracy L/2: {:.3f}'.format(recall_medium_L_2/count))
print('Validation medium accuracy L/5: {:.3f}'.format(recall_medium_L_5/count))
print('Validation long accuracy L: {:.3f}'.format(recall_long_L/count))
print('Validation long accuracy L/2: {:.3f}'.format(recall_long_L_2/count))
print('Validation long accuracy L/5: {:.3f}'.format(recall_long_L_5/count))

SidechainNet was loaded from ./sidechainnet_data/sidechainnet_casp7_30.pkl.


100%|██████████| 32/32 [03:11<00:00,  5.97s/it]

Validation short accuracy L: 0.938
Validation short accuracy L/2: 0.938
Validation short accuracy L/5: 0.719
Validation medium accuracy L: 0.875
Validation medium accuracy L/2: 0.781
Validation medium accuracy L/5: 0.562
Validation long accuracy L: 0.875
Validation long accuracy L/2: 0.844
Validation long accuracy L/5: 0.594





## Testing

In [16]:
count = 0
recall_short_L = 0
recall_short_L_2 = 0
recall_short_L_5 = 0
recall_medium_L = 0
recall_medium_L_2 = 0
recall_medium_L_5 = 0
recall_long_L = 0
recall_long_L_2 = 0
recall_long_L_5 = 0
for batch in tqdm(data['test']):
    seqs, evos, angs, masks, dmats, dmat_masks = get_seq_features(batch)
    # print(seqs.shape, evos.shape, angs.shape, masks.shape)
    # print(dmats.shape, dmat_masks.shape)

    # generate sequence block with L×L×f
    seq_len = seqs.size(dim=1)
    seqs_evos = torch.cat((seqs, evos), 2)
    #print(seqs_evos.shape)
    seq_block = torch.tensor(seq_len*[seqs_evos.tolist()])
    #print(seq_block.shape)
    seq_block = torch.transpose(seq_block, 0, 1)
    #seq_block = einops.rearrange(seq_block, 'l b h f -> b l h f')
    seq_block_t = torch.transpose(seq_block, 1, 2)
    seq_block = torch.cat((seq_block_t, seq_block), 3)

    evo_block = torch.tensor(seq_len*[evos.tolist()])
    evo_block = torch.transpose(evo_block, 0, 1)
    #evo_block = einops.rearrange(evo_block, 'l b h f -> b l h f')
    evo_block_t = torch.transpose(evo_block, 1, 2)
    final_block = torch.cat((seq_block, evo_block-evo_block_t, 
                             evo_block * evo_block_t), 3)
    #print(final_block.shape)
    valid_test(final_block, dmats, dmat_masks, device, model)
    
    #print('protein number {}, length {}'.format(count, seq_len))
    del seq_block
    del seq_block_t
    del final_block
    del evo_block
    del evo_block_t
    gc.collect()
    
    count += 1
    #break

print('Test short accuracy L: {:.3f}'.format(recall_short_L/count))
print('Test short accuracy L/2: {:.3f}'.format(recall_short_L_2/count))
print('Test short accuracy L/5: {:.3f}'.format(recall_short_L_5/count))
print('Test medium accuracy L: {:.3f}'.format(recall_medium_L/count))
print('Test medium accuracy L/2: {:.3f}'.format(recall_medium_L_2/count))
print('Test medium accuracy L/5: {:.3f}'.format(recall_medium_L_5/count))
print('Test long accuracy L: {:.3f}'.format(recall_long_L/count))
print('Test long accuracy L/2: {:.3f}'.format(recall_long_L_2/count))
print('Test long accuracy L/5: {:.3f}'.format(recall_long_L_5/count))

100%|██████████| 93/93 [08:32<00:00,  5.52s/it]

Test short accuracy L: 0.989
Test short accuracy L/2: 0.989
Test short accuracy L/5: 0.828
Test medium accuracy L: 0.989
Test medium accuracy L/2: 0.968
Test medium accuracy L/5: 0.710
Test long accuracy L: 0.989
Test long accuracy L/2: 0.892
Test long accuracy L/5: 0.645



