In [None]:
# Add this in a Google Colab cell to install the correct version of Pytorch Geometric.
import torch

def format_pytorch_version(version):
  return version.split('+')[0]

TORCH_version = torch.__version__
TORCH = format_pytorch_version(TORCH_version)

def format_cuda_version(version):
  return 'cu' + version.replace('.', '')

CUDA_version = torch.version.cuda
CUDA = format_cuda_version(CUDA_version)

!pip install torch-scatter     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-sparse      -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-cluster     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-geometric


In [None]:
import torch.nn as nn
import os
import copy
import random
import torch
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GINEConv,global_max_pool,global_add_pool
from torch_geometric.nn.norm import BatchNorm
import torch.nn.functional as F
import numpy as np
from scipy.stats import pearsonr
from torch.nn import Sequential, Linear, ReLU

indexFullSet = "INDEX_general_PL_data.2020" #set to paths to your index files from PDBbind database
indexRefinedSet = "INDEX_refined_data.2020"
indexCoreSet = "CoreSet.dat"

dataDir = "" # change to path containing graphs with names in the format {pdbID}_graph.pt
modelPath = "graphModel.pt" # path where the model is saved
pdbInfoPath = ""
kinaseFilter = False # change to True for training without kinases in the training data
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
random_seed = 46
random.seed(random_seed)



def parseIndexFile(indexFilePath):
    with open(indexFilePath, "r") as index_file:
            pdbIDs = []
            logKvalues = {}
            for line in index_file:
                if not line.startswith('#') and line.split()[4].startswith(("Kd=","Ki=")): #remove "line.split()[4].startswith(("Kd=","Ki="))" to train on filtered general set
                    pdbIDs.append(str(line.split()[0]))
                    logKvalues[str(line.split()[0])] = float(line.split()[3])
    return pdbIDs, logKvalues

refinedIndex,logK = parseIndexFile(indexRefinedSet)
coreIndex,logKcore = parseIndexFile(indexCoreSet)
fullIndex,logKFull = parseIndexFile(indexFullSet)
graphData = []
logKvalues = []

def filterKinases(PDBInfo):
  kinases = []
  with open(PDBInfo) as file:
    for line in file.readlines():
      pdbID = line.split("\t")[0]
      pdbinfo = line.split("\t")[1]
      if "kinase" in pdbinfo.lower() and pdbID not in kinases:
        kinases.append(pdbID)
  return kinases


class PDBData(torch.utils.data.Dataset):
    def __init__(self, data_list, transform=None):
        self.data_list = data_list
        self.transform = transform

    def __getitem__(self, index):
        data = self.data_list[index]
        if self.transform is not None:
            data = self.transform(data)
        return data

    def __len__(self):
        return len(self.data_list)


class GINENet(torch.nn.Module):
    def __init__(self):
        super(GINENet, self).__init__()
        torch.manual_seed(12345)
        dim = 64
        self.atomicNumEmb = nn.Embedding(10,5)
        self.formalChargeEmb = nn.Embedding(12,6)
        self.aromaticEmb = nn.Embedding(2,1)
        self.valenceEmb = nn.Embedding(9,4)
        self.hybridizationEmb = nn.Embedding(6,3)
        self.chiralityEmb = nn.Embedding(6,3)
        self.numHEmb = nn.Embedding(10,5)
        self.degreeEmb = nn.Embedding(12,6)
        self.typeEmb = nn.Embedding(2,1)
        self.residueEmb = nn.Embedding(31, 14)

        self.bondTypeEmb = nn.Embedding(7,3)
        self.bondDirEmb = nn.Embedding(5,2)
        self.stereoEmb = nn.Embedding(6,3)
        self.inRingEmb = nn.Embedding(2,1)
        self.innerEmb = nn.Embedding(2,1)

        self.conv1 = GINEConv(Sequential(Linear(49, dim), BatchNorm(dim), ReLU(),
                                         Linear(dim, dim), ReLU()),
                                         edge_dim = 11)

        self.conv2 = GINEConv(Sequential(Linear(dim, dim), BatchNorm(dim), ReLU(),
                                         Linear(dim, dim), ReLU()),
                                         edge_dim = 11)

        self.norm1 = BatchNorm(dim)
        self.norm2 = BatchNorm(dim)

        l1_size = 512
        input_size = 2*dim*2
        self.mlp = Sequential(Linear(input_size,l1_size),
                                     BatchNorm(l1_size),
                                     ReLU(),
                                     Linear(l1_size, l1_size),
                                     BatchNorm(l1_size),
                                     ReLU(),
                                     Linear(l1_size, l1_size),
                                     BatchNorm(l1_size),
                                     ReLU(),
                                     Linear(l1_size,1))



    def forward(self, x, edge_index, edge_attr, batch):

        atomicNums = self.atomicNumEmb(x[:,0].long())
        chirality = self.chiralityEmb(x[:,1].long())
        formalCharge = self.formalChargeEmb(x[:,2].long())
        hybridizations = self.hybridizationEmb(x[:,4].long())
        numHs = self.numHEmb(x[:,5].long())
        valences = self.valenceEmb(x[:,5].long())
        degrees = self.degreeEmb(x[:,6].long())
        aromatics = self.aromaticEmb(x[:,7].long())
        types = self.typeEmb(x[:,8].long())
        mass = x[:,9].view(-1,1)
        residues = self.residueEmb(x[:,10].long())


        dists = edge_attr[:,0].view(-1,1)
        bondTypes = self.bondTypeEmb(edge_attr[:,1].long())
        bondDirs = self.bondDirEmb(edge_attr[:,2].long())
        stereo = self.stereoEmb(edge_attr[:,3].long())
        inRing = self.inRingEmb(edge_attr[:,4].long())
        inner = self.innerEmb(edge_attr[:,5].long())

        x = torch.cat((atomicNums,chirality,formalCharge,hybridizations,numHs,valences,degrees,aromatics,types,mass, residues),dim=1)
        edge_attr = torch.cat((dists, bondTypes, bondDirs, stereo, inRing, inner), dim=1)
        x1 = self.conv1(x, edge_index, edge_attr)
        x1 = self.norm1(x1)
        x2 = self.conv2(x1, edge_index, edge_attr)
        x2 = self.norm2(x2)

        x = torch.cat((x1,x2), dim=1)
        x = torch.cat((global_add_pool(x,batch),global_max_pool(x,batch)),dim=1)
        x  = self.mlp(x)
        return x


trainData = []
testData = []

if not kinaseFilter:
    for i,pdbID in enumerate(fullIndex):
        if not os.path.isfile(f"{dataDir}/{pdbID}_graph.pt"):
            continue
        with open(f"{dataDir}/{pdbID}_graph.pt","rb") as file:
                graph = torch.load(file)
                graph.y = torch.tensor([logKFull[pdbID]])
                graph.y = graph.y.unsqueeze(1)
                if pdbID in coreIndex:
                    testData.append(graph)
                    continue
                elif pdbID in fullIndex: #change to refinedIndex for refined set
                    trainData.append(graph)

else:
    kinases = filterKinases(pdbInfoPath)

    for i,pdbID in enumerate(refinedIndex):
        if not os.path.isfile(f"{dataDir}/{pdbID}_graph.pt"):
            continue
        with open(f"{dataDir}/{pdbID}_graph.pt","rb") as file:
                graph = torch.load(file)
                graph.y = torch.tensor([logKFull[pdbID]])
                graph.y = graph.y.unsqueeze(1)
                if pdbID in kinases:
                    testData.append(graph)
                    continue
                elif pdbID in refinedIndex and pdbID not in coreIndex: 
                    trainData.append(graph)






random.shuffle(trainData)
splitSize = int(len(trainData)/10)
validData = trainData[:splitSize]
trainData = trainData[splitSize:]



trainSet = PDBData(data_list = trainData)
validSet = PDBData(data_list = validData)
testSet = PDBData(data_list = testData)




model = GINENet().to(device)
#print(sum(p.numel() for p in model.parameters() if p.requires_grad))
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)
trainloader = DataLoader(trainSet,batch_size = 64,shuffle=True)
testloader = DataLoader(testSet,batch_size=1,shuffle=False)
validloader = DataLoader(validSet,batch_size=1,shuffle=False)


validValues = [data.y.item() for data in validloader]
testValues = [data.y.item() for data in testloader]
def train():
    model.train()
    running_loss = 0
    for data in trainloader:
         data.to(device)
         out = model(data.x, data.edge_index, data.edge_attr, data.batch)
         loss = F.mse_loss(out.squeeze(), data.y.squeeze(), reduction = 'mean')
         running_loss += loss.item()*len(data.y.squeeze())
         optimizer.zero_grad()
         loss.backward()
         optimizer.step()
    print(running_loss/len(trainloader.dataset))

def test(loader):
     model.eval()
     predictions = []
     all_loss = 0
     for data in loader:
         data.to(device)
         out = model(data.x, data.edge_index, data.edge_attr, data.batch)
         loss = F.mse_loss(out.squeeze(), data.y.squeeze(),reduction = "mean")
         all_loss += loss.item()
         predictions.append(out)
     loss = all_loss / len(loader.dataset)

     return predictions,loss

print("start training")
mse= nn.MSELoss()
best_valid = (0,0,np.inf,0,0)
for epoch in range(1, 1000):
    train()
    valid_preds,valid_loss = test(validloader)
    test_preds, test_loss = test(testloader)
    predictions = [float(t.item()) for t in test_preds]
    valid_preds = [float(t.item()) for t in valid_preds]
    pearson, pvalue = pearsonr(predictions,testValues)
    validpearson, pvalid = pearsonr(valid_preds,validValues)
    if round(np.sqrt(valid_loss),3) < best_valid[2]:
      best_model_state = copy.deepcopy(model.state_dict())
      best_valid = (round(pearson,3), round(np.sqrt(test_loss),3), round(np.sqrt(valid_loss),3),round(validpearson,3), epoch)
    print(f" Epoch: {epoch} pearson: {round(pearson,3)} test loss: {round(np.sqrt(test_loss),3)} valid loss: {round(np.sqrt(valid_loss),3)} best: {best_valid}")

torch.save(best_model_state,modelPath)
