In [None]:
import sys
import time
import argparse
from collections import defaultdict
from tqdm.notebook import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import sidechainnet as scn
import einops as ein

## Creates features for a batch of sequences

In [None]:
def get_seq_features(batch):
    '''
    Take a batch of sequence info and return the sequence (one-hot),
    evolutionary info and (phi, psi, omega) angles per position, 
    as well as position mask.
    Also return the distance matrix, distance mask, pids and lengths
    '''
    pids = batch.pids # protein ids
    str_seqs = batch.str_seqs # seq in str format
    seqs = batch.seqs # seq in one-hot format
    int_seqs = batch.int_seqs # seq in int format
    masks = batch.msks # which positions are valid
    lengths = batch.lengths # seq length
    evos = batch.evos # PSSM / evolutionary info
    angs = batch.angs[:,:,0:2] # torsion angles: phi, psi
    
    # use coords to create distance matrix from c-beta
    # except use c-alpha for G
    # coords[:, 4, :] is c-beta, and coords[:, 1, :] is c-alpha
    coords = batch.crds # seq coord info (all-atom)
    batch_xyz = []
    for i in range(coords.shape[0]):
        xyz = [coords[i][cpos+4,:] 
                if masks[i][cpos//14] and str_seqs[i][cpos//14] != 'G'
                else coords[i][cpos+1,:]
                for cpos in range(0, coords[i].shape[0]-1, 14)]
        batch_xyz.append(torch.stack(xyz))
    batch_xyz = torch.stack(batch_xyz)
    # now create pairwise distance matrix
    dmats = torch.cdist(batch_xyz, batch_xyz)
    # create matrix mask (0 means i,j invalid)
    dmat_masks = torch.einsum('bi,bj->bij', masks, masks)
    
    return seqs, evos, angs, masks, dmats, dmat_masks, pids, lengths

## Alphafold 

In [None]:
class residual_block(nn.Module):
    def __init__(self, dilation, dmodel, dproj, drop_prob):
        super(residual_block, self).__init__()
        self.dilation = dilation
        self.dmodel = dmodel
        self.dproj = dproj
        self.drop_prob = drop_prob
        self.dilation_kernel_size = 3

        # from dmodel project down to dproj, then do dilated conv,
        # followed by project up to dmodel
        self.res_block = nn.Sequential(
            nn.BatchNorm2d(self.dmodel),
            nn.ELU(),
            nn.Conv2d(self.dmodel, self.dproj, kernel_size=1), #proj down
            nn.BatchNorm2d(self.dproj),
            nn.ELU(),
            nn.Dropout2d(p=self.drop_prob),
            nn.Conv2d(self.dproj, self.dproj,
                      kernel_size=self.dilation_kernel_size, 
                      padding='same', dilation=self.dilation), #dilated Conv
            nn.BatchNorm2d(self.dproj),
            nn.ELU(),
            nn.Conv2d(self.dproj, self.dmodel, kernel_size=1) # proj up
        )

    def forward(self, x):
        # x.shape: b, f, n, n
        x = x + self.res_block(x)
        return x


class AlphaFold(nn.Module):
    def __init__(self, in_channels, out_channels, dmodel, dproj, 
                 num_dilations, num_blocks, drop_prob):
        super(AlphaFold, self).__init__()
        self.dmodel = dmodel
        self.dproj = dproj
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_dilations = num_dilations
        self.num_blocks = num_blocks
        self.drop_prob = drop_prob

        # project to dmodel
        self.proj_in = nn.Conv2d(self.in_channels, self.dmodel, kernel_size=1)
        
        # add num_groups of residual dilation blocks
        # each block has dilations of 1, 2, 4, 8, i.e. 2^0, 2^1, 2^2, 2^3
        res_layers = []            
        for _ in range(self.num_blocks):
            res_layers.extend([
            residual_block(2**di, self.dmodel, self.dproj, self.drop_prob)
            for di in range(self.num_dilations)])
        self.res_blocks = nn.ModuleList(res_layers)

        # compute logits for distance bins
        self.dist_out = nn.Conv2d(self.dmodel, self.out_channels, kernel_size=1)

    def forward(self, x):
        # x.shape: b, f, n, n
        x = self.proj_in(x)
        for (dil, res_block) in enumerate(self.res_blocks):
            x = res_block(x)
        x = self.dist_out(x)
        return x

## Dataset of Crops

In [None]:
class Crops_Dataset(torch.utils.data.IterableDataset):
    '''Assumes that the batch from sidechainnet is already on the GPU
       Creates a batch of crops across the entire batch of sequences 
    '''
    
    def __init__(self, seqs, evos, angs, masks, dmats, dmat_masks,
                 crop_sz, stride_sz, pad_sz, num_dist_bins, device, test=False):
        self.seqs = seqs
        self.evos = evos
        self.angs = angs
        self.masks = masks
        self.dmats = dmats
        self.dmat_masks = dmat_masks
        self.crop_sz = crop_sz
        self.stride_sz = stride_sz
        self.pad_sz = pad_sz
        self.num_dist_bins = num_dist_bins
        self.device = device
        self.test = test
        
        # create distance bins
        bin_w = (22-2)/self.num_dist_bins
        self.dist_bins = torch.arange(2+bin_w,22,bin_w).to(self.device)
        
        # if test set, then create place holders for the entire NxN contact map
        if self.test:
            B, N, _ = self.seqs.shape
            self.cmaps = torch.zeros((B, num_dist_bins, N, N)).to(self.device)
            self.cnts = torch.zeros((B, N, N)).to(self.device)

        # pad all tensors by pad_size
        self.pad_tensors()
        
        # create the list of crop start positions (si, sj)
        self.get_start_positions()
    
    def pad_tensors(self):
        '''pad all tensors by pad_sz in each relevant dim'''
        pad_sea = (0, 0, self.pad_sz, self.pad_sz) # don't pad dim 2, pad dim 1
        self.seqs = F.pad(self.seqs,  pad_sea)
        self.evos = F.pad(self.evos,  pad_sea)
        self.angs = F.pad(self.angs,  pad_sea)
        
        pad_masks = (self.pad_sz, self.pad_sz) # pad dim 1
        self.masks = F.pad(self.masks,  pad_masks)
        
        pad_dmat = (self.pad_sz, self.pad_sz, 
                    self.pad_sz, self.pad_sz) # pad both dim 2 and 1
        self.dmats = F.pad(self.dmats, pad_dmat)
        self.dmat_masks = F.pad(self.dmat_masks, pad_dmat)

        # also pad the contact/dist maps for test batch
        if self.test:
            self.cmaps = F.pad(self.cmaps, pad_dmat)
            self.cnts = F.pad(self.cnts, pad_dmat)
    
    def get_start_positions(self):
        '''Create a set of start positions si, sj for the batch
           During training, si is chosen in (0, pad_sz) with sj > si,
              followed by strides in both si and sj directions
           During testing, si, sj = 0, 1
        '''
        N = self.seqs.shape[1]
        self.start_pos = []
        
        # cover the test proteins from the start (this includes padding)
        if self.test:
            si, sj = 0, 1
        else:
            # chose si, sj in the pad_sz x pad_sz top left corner
            si = np.random.randint(0, self.pad_sz)
            sj = np.random.randint(si+1, self.pad_sz+1)
        
        # now generate other start pos si, and sj using strides 
        PI = np.arange(si, N-self.crop_sz+1, self.stride_sz)
        PJ = np.arange(sj, N-self.crop_sz+1, self.stride_sz)

        # list of all si, sj pairs
        self.start_pos = [(i,j) for i in PI for j in PJ if i<j]
            
        
    def add_pred_crop(self, crop_preds, ij_pos):
        '''Used during testing to convert the predicted output from 
        each crop into probabilties for the 64 distance bins, and
        adding this to the appropriate "tile" in the predicted cmap
        '''
        si, sj = ij_pos
        B, nbins, n, n = crop_preds.shape
        # compute softmax along the distance bins axis: dim 1
        predicted_probs = F.softmax(crop_preds, dim=1)
        
        # update the cmaps info and counts
        self.cmaps[:, :, si:si+n, sj:sj+n] += predicted_probs.detach()
        self.cnts[:, si:si+n, sj:sj+n] += torch.ones(B, n, n).to(self.device)        
        
    
    def get_cmap_data(self):
        '''return cmap/cnt info during testing'''
        return self.cmaps, self.cnts
    
    def get_dmat_data(self):
        '''convert the true distance map into categorical bins'''
        dmats_discrete = torch.searchsorted(self.dist_bins, 
                        self.dmats)
        return dmats_discrete, self.dmat_masks
    
    def get_dist_bins(self):
        '''return the distance bins'''
        return self.dist_bins
        
    def __len__(self):
        '''how many si, sj pairs are there?'''
        return len(self.start_pos)
    
    def __iter__(self):
        '''return one crop starting at si, sj *for all* sequences in a batch
           so the crop will the B x f x 64 x 64, with features in the 2nd dim
        '''
        for si, sj in self.start_pos:
            si_crop = self.seqs[:, si:si+self.crop_sz]
            sj_crop = self.seqs[:, sj:sj+self.crop_sz]
            ei_crop = self.evos[:, si:si+self.crop_sz]
            ej_crop = self.evos[:, sj:sj+self.crop_sz]

            # tile si_crop
            si_ary = ein.repeat(si_crop, 'b n f -> b (n repeat) f', repeat=self.crop_sz)
            ei_ary = ein.repeat(ei_crop, 'b n f -> b (n repeat) f', repeat=self.crop_sz)
            # repeat sj_crop
            sj_ary = ein.repeat(sj_crop, 'b n f -> b (repeat n) f', repeat=self.crop_sz) 
            ej_ary = ein.repeat(ej_crop, 'b n f -> b (repeat n) f', repeat=self.crop_sz)

            # cat along last dim
            seqs_crop = ein.rearrange([si_ary, sj_ary], 'a b n f -> b n (a f)')
            evos_crop = ein.rearrange([ei_ary, ej_ary], 'a b n f -> b n (a f)')
            evos_diff = ei_ary - ej_ary # diff
            evos_mul = ei_ary * ej_ary # elementwise
            
            # make it b, crop_sz x crop_sz, f
            seqs_crop = ein.rearrange(seqs_crop, 'a (n c) f -> a n c f', c=self.crop_sz)
            evos_crop = ein.rearrange(evos_crop, 'a (n c) f -> a n c f', c=self.crop_sz)
            evos_diff = ein.rearrange(evos_diff, 'a (n c) f -> a n c f', c=self.crop_sz)
            evos_mul = ein.rearrange(evos_mul, 'a (n c) f -> a n c f', c=self.crop_sz)

            # cat all together along feature dim (last)
            crop = torch.cat([seqs_crop, evos_crop, evos_diff, evos_mul], dim=3)
            # move features front
            crop = ein.rearrange(crop, 'b i j f -> b f i j') 

            # create discretized dmat
            crop_dmat = torch.searchsorted(self.dist_bins, 
                            self.dmats[:, si:si+self.crop_sz, sj:sj+self.crop_sz])
            # extract the crop mask (for valid pairs)
            crop_mask = self.dmat_masks[:, si:si+self.crop_sz, sj:sj+self.crop_sz]

            # return the crop, crop dist map, crop mask, and si, sj start pos
            yield (crop, crop_dmat, crop_mask, (si, sj))

In [None]:
def save_checkpoint(args, bidx, e, running_loss, model, optimizer):
    ckpt_fname = f'ckpt_J{args.jobid}_e{e}.pth'
    checkpoint = {
        'batch': bidx,
        'epoch': e,
        'loss': running_loss,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict()
    }
    torch.save(checkpoint, ckpt_fname)


def load_checkpoint(args, model, optimizer, device):
    checkpoint = torch.load(args.ckpt_fname)
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    running_loss = checkpoint['loss']
    bidx = checkpoint['batch']
    e = checkpoint['epoch']
    return e, bidx, running_loss

In [None]:
def parse_args():
    parser = argparse.ArgumentParser(description='alphafold.py')
    # checkpoint file
    parser.add_argument('-cf', dest='ckpt_fname', default=None)

    parser.add_argument('-e', dest='epochs', default=10, type=int)
    parser.add_argument('-b', dest='batch_size',
                        default=32, type=int)
    parser.add_argument('-lr', dest='learning_rate',
                        default=1e-5, type=float)

    # model hyperparams
    parser.add_argument('-nb', dest='num_blocks', default=1, type=int)
    parser.add_argument('-dmodel', dest='dmodel', default=128, type=int)
    parser.add_argument('-dproj', dest='dproj', default=64, type=int)
    parser.add_argument('-dilations', dest='num_dilations', default=4, type=int)
    parser.add_argument('-dbins', dest='num_dist_bins', default=64, type=int)
    parser.add_argument('-c', dest='crop_sz', default=64, type=int)
    parser.add_argument('-s', dest='stride_sz', default=16, type=int)
    parser.add_argument('-p', dest='pad_sz', default=32, type=int)
    parser.add_argument('-dropout', dest='dropout',
                        default=0.15, type=float)
    parser.add_argument('-num_workers', dest='num_workers',
                        default=1, type=int)
    parser.add_argument('-j', dest='jobid', default='1')

    cmd = '-e 2 -b 16'
    args = parser.parse_args(cmd.split())

    return args

## Training 

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
args = parse_args()
print("ARGS", args)
in_channels = 124
out_channels = args.num_dist_bins

# Load CASP7 data as pytorch tensors
data = scn.load(casp_version=7, with_pytorch="dataloaders", 
                seq_as_onehot=True, aggregate_model_input=False,
                batch_size=args.batch_size, dynamic_batching=True)

model = AlphaFold(in_channels, out_channels, args.dmodel, args.dproj, 
             args.num_dilations, args.num_blocks, args.dropout)

optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)

start_epoch = 0 # epoch number
if args.ckpt_fname is not None:
    saveinfo = torch.load(args.ckpt_fname)
    start_epoch = saveinfo['epoch']+1
    model.load_state_dict(saveinfo['state_dict'])
    optimizer.load_state_dict(saveinfo['optimizer'])
    print("resume from epoch:", start_epoch) # resume from epoch

model = model.to(device)
model.train()

st = time.time()
train = 'train'
for e in range(start_epoch, args.epochs):
    running_loss = 0
    running_len = 0
    for bidx, batch in tqdm(enumerate(data[train]), total=len(data[train])):
        # get the batch info and send to GPU
        seqs, evos, angs, masks, dmats, dmat_masks, pids, lengths = get_seq_features(batch)
        seqs, evos, masks, dmats, dmat_masks =\
            seqs.to(device), evos.to(device), masks.to(device),\
            dmats.to(device), dmat_masks.to(device)

        # dataset of crops for this batch
        train_data = Crops_Dataset(seqs, evos, angs, masks, dmats, dmat_masks,
            args.crop_sz, args.stride_sz, args.pad_sz, args.num_dist_bins, device)

        # now iterate over the crops, which are already on GPU
        for cidx, (crop, crop_dmat, crop_mask, ij_pos) in enumerate(train_data):
            output = model(crop)
            # compute entropy loss per crop element
            loss = F.cross_entropy(output, crop_dmat, reduction='none')
            # consider only valid positions based on crop_mask
            loss = torch.sum(loss * crop_mask)

            running_loss += loss.item() # total loss
            running_len += torch.sum(crop_mask).item() # how many valid pos

            loss.backward()
            optimizer.step()

        # save after each batch
        save_checkpoint(args, bidx, e, running_loss, model, optimizer)

    print("epoch", e, running_loss, running_len, running_loss/running_len)

en = time.time()
print("tot time", en-st)

## Class to compute prediction statistics

In [None]:
class testStats:
    '''Aggregate the statistics for all the test proteins
       Print out overall stats for top L, L/2, L/5 predictions,
       which are denoted as k=1, k=2, k=5.
       Also print prec/recall for all predictions (above dist 1), dentoed k=0
    '''
    
    def __init__(self):
        '''precision by pid, length, Lk'''
        self.prec = defaultdict(lambda: defaultdict(dict)) #nested dict
        
    def print_stats(self):
        '''print per protein precision, and compute stats over all pids'''
        prec_Lk = defaultdict(lambda: defaultdict(list))
        for pid in self.prec.keys():
            for len_lbl in self.prec[pid].keys():
                for k in self.prec[pid][len_lbl].keys():
                    prec_Lk[len_lbl][k].append(self.prec[pid][len_lbl][k])
                    print("Protein:", pid, len_lbl, k, self.prec[pid][len_lbl][k])
                    
        print("\nOverall Stats:")
        for len_lbl in prec_Lk:
            for k in prec_Lk[len_lbl]:
                if k >= 1:
                    precs = [p for (p, n, l) in prec_Lk[len_lbl][k]]
                    print(len_lbl, k, np.array(precs).mean())
                else:
                    precs = [p for (p, r, nc, np, nt) in prec_Lk[len_lbl][k]]
                    recs = [r for (p, r, nc, np, nt) in prec_Lk[len_lbl][k]]
                    print(len_lbl, k, np.array(precs).mean(), np.array(recs).mean())
                
   
    def compute_lk(self, len_lbl, cmaps_pred, cmaps_true, lb, ub, pids, lengths):
        '''Compute the precision (and recall) for given length range between
        [lb, ub], for the top L/k predictions.
        '''
        
        # create a mask for only those i,j pairs that are within [lb, ub] band
        N = cmaps_pred.shape[1]
        ones = torch.ones((N, N))
        sep_mask = torch.triu(ones, diagonal=lb) * torch.tril(ones, diagonal=ub)
        sep_mask = sep_mask.to(device)
        
        # filter out elements not in the lb, ub diagional band
        cmaps_pred = cmaps_pred * sep_mask
        cmaps_true = cmaps_true * sep_mask
        
        # compute stats for each protein in the batch
        for pi in range(len(pids)):
            pid = pids[pi]
            L = lengths[pi]
            cmap_p = cmaps_pred[pi] # predicted probabilities
            cmap_t = cmaps_true[pi] # true binary cmap
            cmap_b = torch.where(cmap_p >= 0.5, 1, 0) # predicted binary cmap
            
            for k in [1, 2, 5]:
                # choose smaller of L/k or # true contacts
                Lk = min(L//k, int(cmap_t.sum()))
                
                # next, find the probabilities over 0.5, extract topk of those
                # and return a tuple that contains the positions of the topk vals
                idxs = torch.where(cmap_p >= 0.5)
                probs = cmap_p[idxs]
                topk_probs = torch.topk(probs, min(Lk, len(probs)))
                top_Lk = topk_probs.indices
                top_tup = (idxs[0][top_Lk], idxs[1][top_Lk])

                # the number of contacts common to both pred and true
                num_correct = torch.sum(cmap_b[top_tup] * cmap_t[top_tup])
                
                # compute precision @ L/k
                prec_Lk = num_correct/Lk
                self.prec[pid][len_lbl][k] = (float(prec_Lk.cpu()), 
                                              int(num_correct.cpu()), Lk)
            
            # for the full protein, compute precision and recall
            if len_lbl == 'all':
                k = 0 # let k =0 mean no topk restriction
                num_correct = torch.sum(cmap_b * cmap_t)
                num_true = torch.sum(cmap_t)
                num_preds = torch.sum(cmap_b)
                self.prec[pid][len_lbl][k] = (float(num_correct/num_preds), 
                                              float(num_correct/num_true), 
                                              int(num_correct), int(num_preds), 
                                              int(num_true))
            
                
    def update(self, cmaps, cnts, dmaps, dmap_masks, pids, lengths, dist_bins):
        '''Take the predicted contact map probabilities for all pids in a 
        batch, and compute the precision (and recall) for various length
        thresholds -- short, medium, long -- and overall
        '''
        # move dist bins (f) last
        cmaps = ein.rearrange(cmaps, 'b f i j -> b i j f')
        
        cnts = cnts * dmap_masks # ignore any invalid ij pair
        cmaps = cmaps * dmap_masks.unsqueeze(dim=3) # ignore the preds at invalid pos
        
        # divide cmaps by cnts, but make sure to not divide by 0
        cnts_mask = torch.where(cnts > 0)
        cmaps[cnts_mask] = cmaps[cnts_mask] / cnts[cnts_mask].unsqueeze(dim=1)
        
        # now make dist bins sum up to one (prob vector) by dividing the 
        # bins with the sum of all bins
        cmaps_sum = cmaps.sum(dim=3, keepdim=True)
        cmaps[cnts_mask] = cmaps[cnts_mask]/cmaps_sum[cnts_mask]

        #now create cmap prob for all dists under 8A
        bin8 = torch.searchsorted(dist_bins, 8).item()  # which bin for 8A
        cmaps_pred = cmaps[:, :, :, 0:bin8].sum(dim=3)
        
        #create true cmap, and retain only valid pos
        cmaps_true = torch.where(dmaps <= bin8, 1, 0) * dmap_masks
    
        # short, medium, long range contacts
        N = cmaps_true.shape[1]
        ranges = {'short':(6,11), 'medium':(12,23), 'long':(24, N), 'all':(2,N)}
        for len_lbl in ['short', 'medium', 'long', 'all']:
            lb, ub = ranges[len_lbl]
            self.compute_lk(len_lbl, cmaps_pred, cmaps_true, lb, ub, pids, lengths)

## Load model from checkpoint (if needed)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
args = parse_args()
print("ARGS", args)
in_channels = 124
out_channels = args.num_dist_bins

# Load CASP7 data as pytorch tensors
data = scn.load(casp_version=7, with_pytorch="dataloaders", 
                seq_as_onehot=True, aggregate_model_input=False,
                batch_size=args.batch_size, dynamic_batching=True)

model = AlphaFold(in_channels, out_channels, args.dmodel, args.dproj, 
             args.num_dilations, args.num_blocks, args.dropout)

optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)

saveinfo = torch.load("ckpt_aws_e2.pth")
model.load_state_dict(saveinfo['state_dict'])
optimizer.load_state_dict(saveinfo['optimizer'])

model = model.to(device)

## Testing

In [None]:
test_stride = 16

model.eval()
st = time.time()
test_stats = testStats()
with torch.no_grad():
    for bidx, batch in tqdm(enumerate(data['test']), total=len(data['test'])):
        running_loss = 0
        running_len = 0
        seqs, evos, angs, masks, dmats, dmat_masks, pids, lengths = get_seq_features(batch)
        seqs, evos, masks, dmats, dmat_masks = seqs.to(device),\
            evos.to(device), masks.to(device), dmats.to(device), dmat_masks.to(device)

        test_data = Crops_Dataset(seqs, evos, angs, masks, dmats, dmat_masks,
            args.crop_sz, test_stride, args.pad_sz, args.num_dist_bins, device, test=True)

        for cidx, (crop, crop_dmat, crop_mask, ij_pos) in enumerate(test_data):
            output = model(crop)
            
            # update the cmap data for the test batch proteins for this crop
            test_data.add_pred_crop(output, ij_pos)
            
            loss = F.cross_entropy(output, crop_dmat, reduction='none')
            loss = torch.sum(loss * crop_mask)

            running_loss += loss.item()
            running_len += torch.sum(crop_mask).item()            
            
        print("batch", bidx, running_loss, running_len, running_loss/running_len)
        
        # for each batch, retrieve the predicted and true data and update stats
        cmaps, cnts = test_data.get_cmap_data()
        dmaps, dmap_masks = test_data.get_dmat_data()
        dist_bins = test_data.get_dist_bins()
        test_stats.update(cmaps, cnts, dmaps, dmap_masks, pids, lengths, dist_bins)
        
en = time.time()
print("test time", en-st)

## Print the stats

In [None]:
test_stats.print_stats()