An excercise, recreating the results reported in [this paper](https://academic.oup.com/bioinformatics/article/39/4/btad187/7114029) using [the codes proviced by the authors](https://github.com/biomed-AI/GraphBepi)

importing necessary packages:

In [1]:
import torch
import torch.nn as nn
import torchmetrics as tm
import torch.nn.functional as F
import os
import pytorch_lightning as pl
from torch.nn.utils.rnn import pad_sequence,pack_sequence,pack_padded_sequence,pad_packed_sequence
import pandas as pd
import pickle as pk
from tqdm import tqdm,trange
import esm
import warnings
from torch.utils.data import DataLoader,Dataset
import numpy as np
import requests as rq
import time
import random
from collections import defaultdict
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import Callback,EarlyStopping,ModelCheckpoint
warnings.simplefilter('ignore')

  from .autonotebook import tqdm as notebook_tqdm


contents of tool.py, used to evaluate the perfomance:

In [2]:
class METRICS:
    def __init__(self,device='cpu'):
        self.device=device
        self.auroc=tm.AUROC(task='binary').to(device)
        self.auprc=tm.AveragePrecision(task='binary').to(device)
        self.roc=tm.ROC(task='binary').to(device)
        self.prc=tm.PrecisionRecallCurve(task='binary').to(device)
        self.rec=tm.Recall(task='binary').to(device)
        self.prec=tm.Precision(task='binary').to(device)
        self.f1=tm.F1Score(task='binary').to(device)
        self.mcc=tm.MatthewsCorrCoef(task='binary').to(device)
        f=lambda a,b,c,d,e:(a/(a+d)+c/(b+c))/2
        self.stat=tm.StatScores(task='binary').to(device)
        self.bacc=lambda x,y:f(*self.stat(x,y))

    def to(self,pred,y):
        return pred.to(self.device),y.to(self.device)
    def calc_thresh(self,pred,y):
        pred,y=self.to(pred,y)
        prec, rec, thresholds = self.prc(pred,y)
        f1=(2*prec*rec/(prec+rec)).nan_to_num(0)[:-1]
        threshold = thresholds[torch.argmax(f1)]
        return threshold
    def calc_prc(self,pred,y):
        pred,y=self.to(pred,y)
        auroc = self.auroc(pred,y)
        prec, rec, th1 = self.prc(pred,y)
        auprc = self.auprc(pred,y)
        fpr, tpr, th2 = self.roc(pred,y)
        return {
            'AUROC':auroc.cpu().item(),'AUPRC':auprc.cpu().item(),'prc':[rec[:-1],prec[:-1],th1],'roc':[fpr,tpr,th2]
        }
    def __call__(self,pred,y,threshold=None):
        pred,y=self.to(pred,y)
        auroc = self.auroc(pred,y)
        prec, rec, thresholds = self.prc(pred,y)
        auprc = self.auprc(pred,y)
        if threshold is None:
            f1=(2*prec*rec/(prec+rec)).nan_to_num(0)[:-1]
            threshold = thresholds[torch.argmax(f1)]
        threshold=torch.tensor(threshold)
        self.f1.threshold=threshold
        self.rec.threshold=threshold
        self.mcc.threshold=threshold
        self.stat.threshold=threshold
        self.prec.threshold=threshold
        f1 = self.f1(pred,y)
        rec = self.rec(pred,y)
        mcc = self.mcc(pred,y)
        bacc = self.bacc(pred,y)
        prec = self.prec(pred,y)
        return {
            'AUROC':auroc.cpu().item(),'AUPRC':auprc.cpu().item(),
            'RECALL':rec.cpu().item(),'PRECISION':prec.cpu().item(),
            'F1':f1.cpu().item(),'MCC':mcc.cpu().item(),
            'BACC':bacc.cpu().item(),'threshold':threshold.cpu().item(),
        }

contents of EGAT.py, the graph attention layer:

In [3]:
class AE(nn.Module):
    def __init__(self, dim_in, dim_out, hidden, dropout = 0., bias=True):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim_in, hidden, bias=bias),
            nn.LayerNorm(hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, dim_out, bias=bias),
            nn.LayerNorm(dim_out),
        )
    def forward(self, x):
        return self.net(x)
class EGraphAttentionLayer(nn.Module):
    def __init__(self, in_features, out_features, dropout, alpha, concat=True):
        super().__init__()
        self.dropout = dropout
        self.in_features = in_features
        self.out_features = out_features
        self.alpha = alpha
        self.concat = concat

        self.W = nn.Parameter(torch.empty(size=(in_features, out_features)))
        nn.init.xavier_uniform_(self.W.data, gain=1.414)
        self.a = nn.Parameter(torch.empty(size=(2*out_features, 1)))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)
        self.leakyrelu = nn.LeakyReLU(self.alpha)

    def forward(self, h, edge_attr):
        Wh = torch.mm(h, self.W) # h.shape: (N, in_features), Wh.shape: (N, out_features)
        e = self._prepare_attentional_mechanism_input(Wh)
        e = e*edge_attr
        zero_vec = -9e15*torch.ones_like(e)
        e = torch.where(edge_attr > 0, e, zero_vec)
        e = F.softmax(e, dim=1)
        e = F.dropout(e, self.dropout, training=self.training)

        h_prime=[]
        for i in range(edge_attr.shape[0]):
            h_prime.append(torch.matmul(e[i],Wh))

        if self.concat:
            h_prime = torch.cat(h_prime,dim=1)
        else:
            h_prime = torch.stack(h_prime,dim=0).mean(0)
        return F.elu(h_prime),e

    #compute attention coefficient
    def _prepare_attentional_mechanism_input(self, Wh):
        # Wh.shape (N, out_feature)
        # self.a.shape (2 * out_feature, 1)
        # Wh1&2.shape (N, 1)
        # e.shape (N, N)
        Wh1 = torch.matmul(Wh, self.a[:self.out_features, :])
        Wh2 = torch.matmul(Wh, self.a[self.out_features:, :])
        # broadcast add
        e = Wh1 + Wh2.T
        return self.leakyrelu(e)

    def __repr__(self):
        return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'
class EGAT(nn.Module):
    def __init__(self, nfeat, nhid, efeat, dropout=0.2, alpha=0.2):
        super().__init__()
        self.dropout = dropout
        self.in_att = EGraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True)
        self.out_att = EGraphAttentionLayer(nhid*efeat, nfeat, dropout=dropout, alpha=alpha, concat=False)
    def forward(self, x, edge_attr):
        x_cut=x
        x = F.dropout(x, self.dropout, training=self.training)
        x, edge_attr=self.in_att(x, edge_attr)
        x, edge_attr=self.out_att(x, edge_attr)
        return x+x_cut, edge_attr

contents of model.py, full GraphBepi model:

In [4]:
class GraphBepi(pl.LightningModule):
    def __init__(
        self,
        feat_dim=2560, hidden_dim=256,
        exfeat_dim=13, edge_dim=51,
        augment_eps=0.05, dropout=0.2,
        lr=1e-6, metrics=None, result_path=None
    ):
        super().__init__()
        self.metrics=metrics
        self.path=result_path
        # loss function
        self.loss_fn=nn.BCELoss()
        # Hyperparameters
        self.exfeat_dim=exfeat_dim
        self.augment_eps = augment_eps
        self.lr = lr
        self.cls = 1
        bias=False
        self.W_v = nn.Linear(feat_dim, hidden_dim, bias=bias)
        self.W_u1 = AE(exfeat_dim,hidden_dim,hidden_dim, bias=bias)
        self.edge_linear=nn.Sequential(
            nn.Linear(edge_dim,hidden_dim//4, bias=True),
            nn.ELU(),
        )
        self.gat=EGAT(2*hidden_dim,hidden_dim,hidden_dim//4,dropout)
        self.lstm1 = nn.LSTM(hidden_dim,hidden_dim//2,3,batch_first=True,bidirectional=True,dropout=dropout)
        self.lstm2 = nn.LSTM(hidden_dim,hidden_dim//2,3,batch_first=True,bidirectional=True,dropout=dropout)
        # output
        self.mlp=nn.Sequential(
            nn.Linear(4*hidden_dim,hidden_dim,bias=True),
            nn.ReLU(),
            nn.Linear(hidden_dim,1,bias=True),
            nn.Sigmoid()
        )
        # Initialization
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, V, edge):
        h=[]
        V = pad_sequence(V, batch_first=True, padding_value=0).float()
        mask=V.sum(-1)!=0
        if self.training and self.augment_eps > 0:
            aug=torch.randn_like(V)
            aug[~mask]=0
            V = V+self.augment_eps * aug
        mask=mask.sum(1)
        feats,exfeats=self.W_v(V[:,:,:-self.exfeat_dim]),self.W_u1(V[:,:,-self.exfeat_dim:])
        x_gcns=[]
        for i in range(len(V)):
            E=self.edge_linear(edge[i]).permute(2,0,1)
            x1,x2=feats[i,:mask[i]],exfeats[i,:mask[i]]
            x_gcn=torch.cat([x1,x2],-1)
            x_gcn,E=self.gat(x_gcn,E)
            x_gcns.append(x_gcn)
        feats=pack_padded_sequence(feats,mask.cpu(),True,False)
        exfeats=pack_padded_sequence(exfeats,mask.cpu(),True,False)
        feats=pad_packed_sequence(self.lstm1(feats)[0],True)[0]
        exfeats=pad_packed_sequence(self.lstm2(exfeats)[0],True)[0]
        x_attns=torch.cat([feats,exfeats],-1)

        x_attns=[x_attns[i,:mask[i]] for i in range(len(x_attns))]
        h=[torch.cat([x_attn,x_gcn],-1) for x_attn,x_gcn in zip(x_attns,x_gcns)]
        h=torch.cat(h,0)
        return self.mlp(h)
    def training_step(self, batch, batch_idx):
        feat, edge, y = batch
        pred = self(feat, edge).squeeze(-1)
        loss=self.loss_fn(pred,y.float())
        self.log('train_loss', loss.cpu().item(), on_step=False, on_epoch=True, prog_bar=True, logger=True)
        if self.metrics is not None:
            result=self.metrics.calc_prc(pred.detach().clone(),y.detach().clone())
            self.log('train_auc', result['AUROC'], on_epoch=True, prog_bar=True, logger=True)
            self.log('train_prc', result['AUPRC'], on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        feat, edge, y = batch
        pred = self(feat, edge).squeeze(-1)
        return pred,y
    def validation_epoch_end(self,outputs):
        pred,y=[],[]
        for i,j in outputs:
            pred.append(i)
            y.append(j)
        pred=torch.cat(pred,0)
        y=torch.cat(y,0)
        loss=self.loss_fn(pred,y.float())
        self.log('val_loss', loss.cpu().item(), on_epoch=True, prog_bar=True, logger=True)
        if self.metrics is not None:
            result=self.metrics(pred.detach().clone(),y.detach().clone())
            self.log('val_AUROC', result['AUROC'], on_epoch=True, prog_bar=True, logger=True)
            self.log('val_AUPRC', result['AUPRC'], on_epoch=True, prog_bar=True, logger=True)
            self.log('val_mcc', result['MCC'], on_epoch=True, prog_bar=True, logger=True)
            self.log('val_f1', result['F1'], on_epoch=True, prog_bar=True, logger=True)

    def test_step(self, batch, batch_idx):
        feat, edge, y = batch
        pred = self(feat, edge).squeeze(-1)
        return pred,y
    def test_epoch_end(self,outputs):
        pred,y=[],[]
        for i,j in outputs:
            pred.append(i)
            y.append(j)
        pred=torch.cat(pred,0)
        y=torch.cat(y,0)
        loss=self.loss_fn(pred,y.float())
        if self.path:
            if not os.path.exists(self.path):
                os.system(f'mkdir -p {self.path}')
            torch.save({'pred':pred.cpu(),'gt':y.cpu()},f'{self.path}/result.pkl')
        if self.metrics is not None:
            result=self.metrics(pred.detach().clone(),y.detach().clone())
            self.log('test_loss', loss.cpu().item(), on_epoch=True, prog_bar=True, logger=True)
            self.log('test_AUROC', result['AUROC'], on_epoch=True, prog_bar=True, logger=True)
            self.log('test_AUPRC', result['AUPRC'], on_epoch=True, prog_bar=True, logger=True)
            self.log('test_recall', result['RECALL'], on_epoch=True, prog_bar=True, logger=True)
            self.log('test_precision', result['PRECISION'], on_epoch=True, prog_bar=True, logger=True)
            self.log('test_f1', result['F1'], on_epoch=True, prog_bar=True, logger=True)
            self.log('test_mcc', result['MCC'], on_epoch=True, prog_bar=True, logger=True)
            self.log('test_bacc', result['BACC'], on_epoch=True, prog_bar=True, logger=True)
            self.log('test_threshold', result['threshold'], on_epoch=True, prog_bar=True, logger=True)
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), betas=(0.9, 0.99), lr=self.lr, weight_decay=1e-5, eps=1e-5)

contents of preprocess.py, used to preprocess the inputs:

In [5]:
DICT={
    'ALA': 'A', 'CYS': 'C', 'CCS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F',
    'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L',
    'MET': 'M', 'MSE': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R',
    'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y',
}
def pdb_split(line):
    order=int(line[6:11].strip())
    atom=line[11:16].strip()
    amino=line[16:21].strip()
    chain=line[21]
    site=line[22:28].strip()
    x=line[28:38].strip()
    y=line[38:46].strip()
    z=line[46:54].strip()
    return order,atom,amino,chain,site,x,y,z
def judge(line,filt_atom='CA'):
    kind=line[:6].strip()
    if kind not in ['HETATM','ATOM']:
        return None
    order,atom,amino,chain,site,x,y,z=pdb_split(line)
    if filt_atom is not None and atom!=filt_atom:
        return None
    prefix=''
    if len(amino)>3:
        prefix=amino[0]
        amino=amino[-3:]
    if amino=='MSE':
        amino='MET'
    elif amino=='CCS' or amino[:-1]=='CS':
        amino='CYS'
    elif amino not in DICT.keys():
        return None
    return prefix+amino,chain,site,float(x),float(y),float(z)
def process_dssp(dssp_file):
    aa_type = "ACDEFGHIKLMNPQRSTVWY"
    SS_type = "HBEGITSC"
    rASA_std = [115, 135, 150, 190, 210, 75, 195, 175, 200, 170,
                185, 160, 145, 180, 225, 115, 140, 155, 255, 230]
    with open(dssp_file, "r") as f:
        lines = f.readlines()
    seq = ""
    dssp_feature = []
    position = []
    p = 0
    while lines[p].strip()[0] != "#":
        p += 1
    for i in range(p + 1, len(lines)):
        aa = lines[i][13]
        if aa == "!" or aa == "*":
            continue
        seq += aa
        POS = lines[i][5:11].strip()
        position.append(POS)
        SS = lines[i][16]
        if SS == " ":
            SS = "C"
        SS_vec = np.zeros(8)
        SS_vec[SS_type.find(SS)] = 1
        PHI = float(lines[i][103:109].strip())
        PSI = float(lines[i][109:115].strip())
        ACC = float(lines[i][34:38].strip())
        ASA = min(100, round(ACC / rASA_std[aa_type.find(aa)] * 100)) / 100
        dssp_feature.append(np.concatenate((np.array([PHI, PSI, ASA]), SS_vec)))

    return seq, dssp_feature,position
def transform_dssp(dssp_feature):
    dssp_feature = np.array(dssp_feature)
    angle = dssp_feature[:,0:2]
    ASA_SS = dssp_feature[:,2:]
    radian = angle * (np.pi / 180)
    dssp_feature = np.concatenate([np.sin(radian), np.cos(radian), ASA_SS], axis = 1)
    return dssp_feature
def get_dssp(ID,root):
    if not os.path.exists(f"{root}/dssp/"):
        os.mkdir(f"{root}/dssp/")
    os.system(f"./mkdssp/mkdssp -i {root}/purePDB/{ID}.pdb -o {root}/dssp/{ID}.dssp")
    if not os.path.exists(f"{root}/dssp/" + ID + ".dssp"):
        return None
    dssp_seq, dssp_matrix,position = process_dssp(f"{root}/dssp/" + ID + ".dssp")
    np.save(f"{root}/dssp/" + ID, transform_dssp(dssp_matrix))
    np.save(f"{root}/dssp/"+ID+"_pos",position)

contents of graph_construction.py, used to create a representation of the structural information:

In [6]:
ID={
    'A': 0, 'C': 1, 'D': 2, 'E': 3, 'F': 4,
    'G': 5, 'H': 6, 'I': 7, 'K': 8, 'L': 9,
    'M': 10, 'N': 11, 'P': 12, 'Q': 13, 'R': 14,
    'S': 15, 'T': 16, 'V': 17, 'W': 18, 'Y': 19
}

def calcPROgraph(seq,coord,dseq=3,dr=10,dlong=5,k=10):
    nodes=coord.shape[0]
    adj=torch.zeros((nodes,nodes))
    E=torch.zeros((nodes,nodes,21*2+2*dseq+3))
    # C=coord.to('cuda:1')
    dist=torch.cdist(coord,coord,2)
    knn=dist.argsort(1)[:,1:k+1]
    for i in range(nodes):
        # knn=dist[i].argsort()[1:k+1]
        for j in range(nodes):
            not_edge=True
            dij_seq=abs(i-j)
            if dij_seq<dseq:
                E[i][j][41+i-j+dseq]=1
                not_edge=False
            if dist[i][j]<dr and dij_seq>=dlong:
                E[i][j][41+2*dseq]=1
                not_edge=False
            if j in knn[i] and dij_seq>=dlong:
                E[i][j][42+2*dseq]=1
                not_edge=False
            if not_edge:
                continue
            adj[i][j]=1
            E[i][j][ID.get(seq[i],20)]=1
            E[i][j][21+ID.get(seq[j],20)]=1
            E[i][j][43+2*dseq]=dij_seq
            E[i][j][44+2*dseq]=dist[i][j]
    idx=adj.nonzero().T
    data=adj[idx[0],idx[1]]
    adj=torch.sparse.FloatTensor(idx,data,adj.shape)
    idx=E.nonzero().T
    data=E[idx[0],idx[1],idx[2]]
    E=torch.sparse.FloatTensor(idx,data,E.shape)
    return {'adj':adj,'edge':E}


contents of utils.py, methods for loading and working with the data:

In [7]:
amino2id={
    '<null_0>': 0, '<pad>': 1, '<eos>': 2, '<unk>': 3,
    'L': 4, 'A': 5, 'G': 6, 'V': 7, 'S': 8, 'E': 9, 'R': 10,
    'T': 11, 'I': 12, 'D': 13, 'P': 14, 'K': 15, 'Q': 16,
    'N': 17, 'F': 18, 'Y': 19, 'M': 20, 'H': 21, 'W': 22,
    'C': 23, 'X': 24, 'B': 25, 'U': 26, 'Z': 27, 'O': 28,
    '.': 29, '-': 30, '<null_1>': 31, '<mask>': 32, '<cath>': 33, '<af2>': 34
}
class chain:
    def __init__(self):
        self.sequence=[]
        self.amino=[]
        self.coord=[]
        self.site={}
        self.date=''
        self.length=0
        self.adj=None
        self.edge=None
        self.feat=None
        self.dssp=None
        self.name=''
        self.chain_name=''
        self.protein_name=''
    def add(self,amino,pos,coord):
        self.sequence.append(DICT[amino])
        self.amino.append(amino2id[DICT[amino]])
        self.coord.append(coord)
        self.site[pos]=self.length
        self.length+=1
    def process(self):
        self.amino=torch.LongTensor(self.amino)
        self.coord=torch.FloatTensor(self.coord)
        self.label=torch.zeros_like(self.amino)
        self.sequence=''.join(self.sequence)
    def extract(self,model,device,path):
        if len(self)>1024 or model is None:
            return
        f=lambda x:model(x.to(device).unsqueeze(0),[36])['representations'][36].squeeze(0).cpu()
        with torch.no_grad():
            feat=f(self.amino)
        torch.save(feat,f'{path}/feat/{self.name}_esm2.ts')
    def load_dssp(self,path):
        dssp=torch.Tensor(np.load(f'{path}/dssp/{self.name}.npy'))
        pos=np.load(f'{path}/dssp/{self.name}_pos.npy')
        self.dssp=torch.Tensor([
            -2.4492936e-16, -2.4492936e-16,
            1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0
        ]).repeat(self.length,1)
        self.rsa=torch.zeros(self.length)
        for i in range(len(dssp)):
            self.dssp[self.site[pos[i]]]=dssp[i]
            if dssp[i][4]>0.15:
                self.rsa[i]=1
        self.rsa=self.rsa.bool()
    def load_feat(self,path):
        self.feat=torch.load(f'{path}/feat/{self.name}_esm2.ts')
    def load_adj(self,path,self_cycle=False):
        graph=torch.load(f'{path}/graph/{self.name}.graph')
        self.adj=graph['adj'].to_dense()
        self.edge=graph['edge'].to_dense()
        if not self_cycle:
            self.adj[range(len(self)),range(len(self))]=0
            self.edge[range(len(self)),range(len(self))]=0
    def get_adj(self,path,dseq=3,dr=10,dlong=5,k=10):
        graph=calcPROgraph(self.sequence,self.coord,dseq,dr,dlong,k)
        torch.save(graph,f'{path}/graph/{self.name}.graph')
    def update(self,pos,amino):
        if amino not in DICT.keys():
            return
        amino_id=amino2id[DICT[amino]]
        idx=self.site.get(pos,None)
        if idx is None:
            for i in self.site.keys():
                # print(i,pos)
                if i[:len(pos)]==pos:
                    idx=self.site.get(i)
                    if amino_id==self.amino[idx]:
                        self.label[idx]=1
                        return
        elif amino_id!=self.amino[idx]:
            for i in self.site.keys():
                if i[:len(pos)]==pos:
                    idx=self.site.get(i)
                    if amino_id==self.amino[idx]:
                        self.label[idx]=1
                        return
        else:
            self.label[idx]=1
    def __len__(self):
        return self.length
    def __getitem__(self,idx):
        return self.amino[idx],self.coord[idx],self.label[idx]
def collate_fn(batch):
    edges = [item['edge'] for item in batch]
    feats = [item['feat'] for item in batch]
    labels = torch.cat([item['label'] for item in batch],0)
    return feats,edges,labels

def extract_chain(root,pid,chain,force=False):
    if not force and os.path.exists(f'{root}/purePDB/{pid}_{chain}.pdb'):
        return True
    if not os.path.exists(f'{root}/PDB/{pid}.pdb'):
        retry=5
        pdb=None
        with rq.get(f'https://files.rcsb.org/download/{pid}.pdb') as f:
            if f.status_code==200:
                pdb=f.content
        while retry>0:
            try:
                with rq.get(f'https://files.rcsb.org/download/{pid}.pdb') as f:
                    if f.status_code==200:
                        pdb=f.content
                        break
            except:
                retry-=1
                continue
        if pdb is None:
            print(f'PDB file {pid} failed to download')
            return False
        with open(f'{root}/PDB/{pid}.pdb','wb') as f:
            f.write(pdb)
    lines=[]
    with open(f'{root}/PDB/{pid}.pdb','r') as f:
        for line in f:
            if line[:6]=='HEADER':
                lines.append(line)
            if line[:6].strip()=='TER' and line[21]==chain:
                lines.append(line)
                break
            feats=judge(line,None)
            if feats is not None and feats[1]==chain:
                lines.append(line)
    with open(f'{root}/purePDB/{pid}_{chain}.pdb','w') as f:
        for i in lines:
            f.write(i)
    return True
def process_chain(data,root,pid,model,device):
    get_dssp(pid,root)
    same={}
    with open(f'{root}/purePDB/{pid}.pdb','r') as f:
        for line in f:
            if line[:6]=='HEADER':
                date=line[50:59].strip()
                data.date=date
                continue
            feats=judge(line,'CA')
            if feats is None:
                continue
            amino,_,site,x,y,z=feats
            if len(amino)>3:
                if same.get(site) is None:
                    same[site]=amino[0]
                if same[site]!=amino[0]:
                    continue
                amino=amino[-3:]
            data.add(amino,site,[x,y,z])
    data.process()
    data.get_adj(root)
    data.extract(model,device,root)
    return data
def initial(file,root,model=None,device='cpu',from_native_pdb=True):
    df=pd.read_csv(f'{root}/{file}',header=0,index_col=0)
    prefix=df.index
    labels=df['Epitopes (resi_resn)']
    samples=[]
    with tqdm(prefix) as tbar:
        for i in tbar:
            tbar.set_postfix(protein=i)
            if from_native_pdb:
                state=extract_chain(root,i[:4],i[-1])
                if not state:
                    continue
            data=chain()
            p,c=i.split('_')
            data.protein_name=p
            data.chain_name=c
            data.name=f"{p}_{c}"
            process_chain(data,root,i,model,device)
            label=labels.loc[i].split(', ')
            for j in label:
                site,amino=j.split('_')
                data.update(site,amino)
            samples.append(data)
    with open(f'{root}/total.pkl','wb') as f:
        pk.dump(samples,f)

first part of the contents of dataset.py, a class for handling the protein data:

In [8]:
class PDB(Dataset):
    def __init__(
        self,mode='train',fold=-1,root='.',self_cycle=False
    ):
        self.root=root
        assert mode in ['train','val','test']
        if mode in ['train','val']:
            with open(f'{self.root}/train.pkl','rb') as f:
                self.samples=pk.load(f)
        else:
            with open(f'{self.root}/test.pkl','rb') as f:
                self.samples=pk.load(f)
        self.data=[]
        idx=np.load(f'{self.root}/cross-validation.npy')
        cv=10
        inter=len(idx)//cv
        ex=len(idx)%cv
        if mode=='train':
            order=[]
            for i in range(cv):
                if i==fold:
                    continue
                order+=list(idx[i*inter:(i+1)*inter+ex*(i==cv-1)])
        elif mode=='val':
            order=list(idx[fold*inter:(fold+1)*inter+ex*(fold==cv-1)])
        else:
            order=list(range(len(self.samples)))
        order.sort()
        tbar=tqdm(order)
        for i in tbar:
            tbar.set_postfix(chain=f'{self.samples[i].name}')
            self.samples[i].load_feat(self.root)
            self.samples[i].load_dssp(self.root)
            self.samples[i].load_adj(self.root,self_cycle)
            self.data.append(self.samples[i])
    def __len__(self):
        return len(self.data)
    def __getitem__(self,idx):
        seq=self.data[idx]
        feat=torch.cat([seq.feat,seq.dssp],1)
        return {
            'feat':feat,
            'label':seq.label,
            'adj':seq.adj,
            'edge':seq.edge,
        }


second part of dataset.py, downloads and creates the necessary datasets (needs to be ran just once to create the needed files):

In [None]:
'''
parser = argparse.ArgumentParser()
parser.add_argument('--root', type=str, default='./data/BCE_633', help='dataset path')
parser.add_argument('--gpu', type=int, default=0, help='gpu.')
args = parser.parse_args()
root = args.root
device='cpu' if args.gpu==-1 else f'cuda:{args.gpu}'
'''
root = '.'
device = 'cuda'

os.system(f'cd {root} && mkdir PDB purePDB feat dssp graph')
model,_=esm.pretrained.esm2_t36_3B_UR50D()
model=model.to(device)
model.eval()
train='DATA/total.csv'
initial(train,root,model,device)

third part of dataset.py:

In [9]:
root = '.'
device = 'cuda'
with open(f'{root}/total.pkl','rb') as f:
    dataset=pk.load(f)
    dates={i.name:i.date for i in dataset}
    filt_data=[]
    for i in dataset:
        if len(i)<1024 and i.label.sum()>0:
            filt_data.append(i)
    month={'JAN':1,'FEB':2,'MAR':3,'APR':4,'MAY':5,'JUN':6,'JUL':7,'AUG':8,'SEP':9,'OCT':10,'NOV':11,'DEC':12}
    trainset,valset,testset=[],[],[]
    D,M,Y=[],[],[]
    test=20210401
    dates_=[]
    for i in filt_data:
        d,m,y=dates[i.name].split('-')
        d,m,y=int(d),month[m],int(y)
        if y<23:
            y+=2000
        else:
            y+=1900
        date=y*10000+m*100+d
        if date<test:
            dates_.append(date)
            trainset.append(i)
        else:
            testset.append(i)
    with open(f'{root}/train.pkl','wb') as f:
        pk.dump(trainset,f)
    with open(f'{root}/test.pkl','wb') as f:
        pk.dump(testset,f)
    idx=np.array(dates_).argsort()
    np.save(f'{root}/cross-validation.npy',idx)

seed everything:

In [10]:
def seed_everything(seed=2022):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
seed_everything(2022)

set parameters:

In [11]:
gpu=0
fold=-1
lr=1e-6
batch=4
epochs=300
root='.'
log_name=f'BCE_633_GraphBepi'

load data:

In [12]:
trainset=PDB(mode='train',fold=fold,root=root)
valset=PDB(mode='val',fold=fold,root=root)
testset=PDB(mode='test',root=root)
train_loader=DataLoader(trainset, batch, shuffle=True, collate_fn=collate_fn, drop_last=True)
val_loader=DataLoader(valset, batch, shuffle=False, collate_fn=collate_fn)
test_loader=DataLoader(testset, batch, shuffle=False, collate_fn=collate_fn)

100%|███████████████████████████████████████████████████████████████████| 577/577 [00:11<00:00, 49.40it/s, chain=3lh2_V]
0it [00:00, ?it/s]
100%|█████████████████████████████████████████████████████████████████████| 56/56 [00:01<00:00, 40.86it/s, chain=7ue9_C]


create models:

In [13]:
device='cpu' if gpu==-1 else f'cuda:{gpu}'
metrics=METRICS(device)
es=EarlyStopping('val_AUPRC',patience=40,mode='max')
mc=ModelCheckpoint(
    f'./model/{log_name}/',f'model_{fold}',
    'val_AUPRC',
    mode='max',
    save_weights_only=True, 
)
logger = TensorBoardLogger(
    './log', 
    name=f'{log_name}_{fold}'
)
cb=[mc,es]
trainer = pl.Trainer( 
    max_epochs=epochs, callbacks=cb,
    logger=logger,check_val_every_n_epoch=1,
)
model=GraphBepi(
    feat_dim=2560,                     # esm2 representation dim
    hidden_dim=256,                    # hidden representation dim
    exfeat_dim=13,                     # dssp feature dim
    edge_dim=51,                       # edge feature dim
    augment_eps=0.05,                  # random noise rate
    dropout=0.2,
    lr=lr,                             # learning rate
    metrics=metrics,                   # an implement to compute performance
    result_path=f'./model/{log_name}', # path to save temporary result file of testset
)

GPU available: True, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


train:

In [None]:
trainer.fit(model, train_loader, test_loader)


  | Name        | Type       | Params
-------------------------------------------
0 | loss_fn     | BCELoss    | 0     
1 | W_v         | Linear     | 655 K 
2 | W_u1        | AE         | 69.9 K
3 | edge_linear | Sequential | 3.3 K 
4 | gat         | EGAT       | 8.5 M 
5 | lstm1       | LSTM       | 1.2 M 
6 | lstm2       | LSTM       | 1.2 M 
7 | mlp         | Sequential | 262 K 
-------------------------------------------
11.9 M    Trainable params
0         Non-trainable params
11.9 M    Total params
47.536    Total estimated model params size (MB)


Epoch 0:  91%|██████▍| 144/158 [32:13<03:07, 13.42s/it, loss=0.384, v_num=1, train_auc_step=0.563, train_prc_step=0.132]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                                                | 0/14 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                                   | 0/14 [00:00<?, ?it/s][A
Validation DataLoader 0:   7%|████▏                                                      | 1/14 [00:00<00:09,  1.44it/s][A
Epoch 0:  92%|██████▍| 145/158 [32:13<02:53, 13.34s/it, loss=0.384, v_num=1, train_auc_step=0.563, train_prc_step=0.132][A
Validation DataLoader 0:  14%|████████▍                                                  | 2/14 [00:02<00:15,  1.26s/it][A
Epoch 0:  92%|██████▍| 146/158 [32:15<02:39, 13.26s/it, loss=0.384, v_num=1, train_auc_step=0.563, train_prc_step=0.132][A
Validation DataLoader 0:  21%|████████████▋                                              | 3/14 [00:0

Epoch 4:  92%|▉| 146/158 [31:40<02:36, 13.02s/it, loss=0.374, v_num=1, train_auc_step=0.737, train_prc_step=0.154, val_l[A
Validation DataLoader 0:  21%|████████████▋                                              | 3/14 [00:05<00:21,  1.99s/it][A
Epoch 4:  93%|▉| 147/158 [31:42<02:22, 12.94s/it, loss=0.374, v_num=1, train_auc_step=0.737, train_prc_step=0.154, val_l[A
Validation DataLoader 0:  29%|████████████████▊                                          | 4/14 [00:07<00:19,  1.99s/it][A
Epoch 4:  94%|▉| 148/158 [31:44<02:08, 12.87s/it, loss=0.374, v_num=1, train_auc_step=0.737, train_prc_step=0.154, val_l[A
Validation DataLoader 0:  36%|█████████████████████                                      | 5/14 [00:08<00:14,  1.57s/it][A
Epoch 4:  94%|▉| 149/158 [31:45<01:55, 12.79s/it, loss=0.374, v_num=1, train_auc_step=0.737, train_prc_step=0.154, val_l[A
Validation DataLoader 0:  43%|█████████████████████████▎                                 | 6/14 [00:10<00:13,  1.68s/it][A
Epoch 4:

Epoch 8:  94%|▉| 149/158 [32:54<01:59, 13.25s/it, loss=0.345, v_num=1, train_auc_step=0.622, train_prc_step=0.224, val_l[A
Validation DataLoader 0:  43%|█████████████████████████▎                                 | 6/14 [00:11<00:14,  1.85s/it][A
Epoch 8:  95%|▉| 150/158 [32:56<01:45, 13.17s/it, loss=0.345, v_num=1, train_auc_step=0.622, train_prc_step=0.224, val_l[A
Validation DataLoader 0:  50%|█████████████████████████████▌                             | 7/14 [00:13<00:12,  1.84s/it][A
Epoch 8:  96%|▉| 151/158 [32:57<01:31, 13.10s/it, loss=0.345, v_num=1, train_auc_step=0.622, train_prc_step=0.224, val_l[A
Validation DataLoader 0:  57%|█████████████████████████████████▋                         | 8/14 [00:16<00:14,  2.39s/it][A
Epoch 8:  96%|▉| 152/158 [33:01<01:18, 13.04s/it, loss=0.345, v_num=1, train_auc_step=0.622, train_prc_step=0.224, val_l[A
Validation DataLoader 0:  64%|█████████████████████████████████████▉                     | 9/14 [00:17<00:10,  2.03s/it][A
Epoch 8:

TBC