In [1]:
from tdc.multi_pred import DTI
import re
import torch
import numpy as np
from PyDRP.Data.features.targets import LogMinMaxScaling
from PyDRP.Data.features.drugs import GraphCreator
from PyDRP.Data.features.proteins import BertProteinFeaturizer
from PyDRP.Data import DTIDatasetManager, TDCDTIWrapper
from PyDRP.Data.utils import TorchProteinsGraphsDataset
from PyDRP.Models import GNNProteinDrugEncoderDecoder
from PyDRP.Models.encoders.drugs import GTEncoder, GNNAttnDrugPooling
from PyDRP.Models.encoders.proteins import ProteinConvPooling
from PyDRP.Models.decoders import FCDecoder,  NonlinearDotDecoder
from PyDRP.Models.NNlayers import AttnDropout
import pandas as pd
import torch_geometric
from torch import nn
from torch_geometric import nn as gnn
import torchmetrics
from PyDRP.Models.metrics import ElementwiseMetric
from pprint import pprint

In [2]:
manager = DTIDatasetManager(TDCDTIWrapper(DTI(name = 'BindingDB_Kd')),
                 LogMinMaxScaling(),
                 GraphCreator(),
                 BertProteinFeaturizer(),
                 partition_column="PROTEIN_ID")

Found local copy...
Loading...
Done!
Some weights of the model checkpoint at Rostlab/prot_bert were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
protein_dict = manager.get_proteins()
drug_dict = manager.get_drugs()
train, test, val = manager.get_partition(0)

In [4]:
train_dataset = TorchProteinsGraphsDataset(train, drug_dict, protein_dict)
test_dataset = TorchProteinsGraphsDataset(val, drug_dict, protein_dict)

In [5]:
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size = 128,
                                               collate_fn = torch_geometric.data.Batch.from_data_list,
                                               shuffle=True,
                                              num_workers = 16)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size = 128,
                                               collate_fn = torch_geometric.data.Batch.from_data_list, num_workers = 16)

In [6]:
model = GNNProteinDrugEncoderDecoder(protein_encoder = nn.Identity(),
                         drug_encoder = GTEncoder(embed_dim = 256),
                         protein_adapter = ProteinConvPooling(hidden_dim = 512, output_dim=256),
                         drug_adapter = GNNAttnDrugPooling(embed_dim = 256, hidden_dim = 1024, output_embed_dim=256, p_dropout_attn=0.2),
                         decoder = FCDecoder(256+256, 2048, p_dropout_2 = 0.3))
                         #decoder = NonlinearDotDecoder(256, 1024, 64, p_dropout_1=0.3, p_dropout_2 = 0.3))
optim = torch.optim.Adam(model.parameters(), 0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, patience = 2, factor = 0.75)
device = torch.device("cuda")
model.to(device)
torch.backends.cudnn.benchmark = True

In [7]:
from tqdm.notebook import trange, tqdm
model.to(device)
mse = nn.MSELoss()
metrics = torchmetrics.MetricCollection([torchmetrics.MeanSquaredError(), torchmetrics.PearsonCorrCoef()]).to(device)
elm = ElementwiseMetric(average="lines")
for epoch in range(2):
    model.train()
    elm.reset()
    metrics.reset()
    with tqdm(total=len(train_dataloader)) as pbar:
        for n_b, b in enumerate(train_dataloader):
            if ((n_b + 1) % 10) == 0:
                pbar.update(10)
            b = b.to(device)
            y_pred = model(b)
            l = mse(y_pred.squeeze(), b["y"].squeeze())
            l.backward()
            optim.step()
            with torch.no_grad():
                metrics.update(y_pred.squeeze(), b["y"].squeeze())
                elm.update(y_pred.squeeze(), b["y"].squeeze(), b["DRUG_ID"], b["PROTEIN_ID"])
            optim.zero_grad()
    metric_dict_train = {it[0] + "_train":it[1].cpu().item() for it in metrics.compute().items()}
    metric_dict_train["R_average_train"] = elm.compute().item()
    model.eval()
    metrics.reset()
    elm.reset()
    scheduler.step(metric_dict_train["MeanSquaredError_train"])
    with torch.no_grad():
        for b in test_dataloader:
            b = b.to(device)
            y_pred = model(b)
            metrics.update(y_pred.squeeze(), b["y"].squeeze())
            elm.update(y_pred.squeeze(), b["y"].squeeze(), b["DRUG_ID"], b["PROTEIN_ID"])
    metric_dict_test = {it[0] + "_test":it[1].cpu().item() for it in metrics.compute().items()}
    metric_dict_train["R_average_test"] = elm.compute().item()
    metric_dict = {**metric_dict_test, **metric_dict_train}
    print(epoch)
    pprint(metric_dict)

  0%|          | 0/381 [00:00<?, ?it/s]

0
{'MeanSquaredError_test': 0.0889548659324646,
 'MeanSquaredError_train': 0.15374061465263367,
 'PearsonCorrCoef_test': 0.5811170339584351,
 'PearsonCorrCoef_train': 0.2567463219165802,
 'R_average_test': 0.14781658351421356,
 'R_average_train': 0.06337512284517288}


  0%|          | 0/381 [00:00<?, ?it/s]

1
{'MeanSquaredError_test': 0.09694498032331467,
 'MeanSquaredError_train': 0.1000029668211937,
 'PearsonCorrCoef_test': 0.5284720659255981,
 'PearsonCorrCoef_train': 0.5473455786705017,
 'R_average_test': 0.14151397347450256,
 'R_average_train': 0.14765235781669617}
