In [1]:
from tdc.multi_pred import DTI
import re
import torch
import numpy as np
from PyDRP.Data.features.targets import MinMaxScaling
from PyDRP.Data.features.drugs import GraphCreator
from PyDRP.Data.features.proteins import BertProteinFeaturizer
from PyDRP.Data import DTIDatasetManager, TDCDTIWrapper
from PyDRP.Data.utils import TorchProteinsGraphsDataset
from PyDRP.Models.PairsNetwork import PairsNetwork
from PyDRP.Models.encoders.drugs import GATmannEncoder
from PyDRP.Models.encoders.cells import GeneExpEncoder
from PyDRP.Models.decoders import FCDecoder,  NonlinearDotDecoder
from PyDRP.Models.NNlayers import AttnDropout
from  PyDRP.Models.NNlayers import TransGAT, GatedGNNRes
import pandas as pd
import torch_geometric
from torch import nn
from torch_geometric import nn as gnn
import torchmetrics
from PyDRP.Models.metrics import ElementwiseMetric
from pprint import pprint

In [2]:
manager = DTIDatasetManager(TDCDTIWrapper(DTI(name = 'BindingDB_Kd')),
                 MinMaxScaling(),
                 GraphCreator(),
                 BertProteinFeaturizer(),
                 partition_column="PROTEIN_ID")

Found local copy...
Loading...
Done!
Some weights of the model checkpoint at Rostlab/prot_bert were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
protein_dict = manager.get_proteins()
drug_dict = manager.get_drugs()
train, test, val = manager.get_partition(0)

In [4]:
train_dataset = TorchProteinsGraphsDataset(train, drug_dict, protein_dict)
test_dataset = TorchProteinsGraphsDataset(val, drug_dict, protein_dict)

In [5]:
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size = 128,
                                               collate_fn = torch_geometric.data.Batch.from_data_list,
                                               shuffle=True,
                                              num_workers = 16)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size = 128,
                                               collate_fn = torch_geometric.data.Batch.from_data_list, num_workers = 16)

In [6]:
class ProteinConvPooling(nn.Module):
    def __init__(self,
                 init_dim = 1024,
                 hidden_dim = 512,
                 output_dim = 256,
                 n_groups = 8,
                 p_dropout_1 = 0.4,
                 p_dropout_2 = 0.4):
        super().__init__()
        self.conv_attn = nn.Sequential(nn.Conv1d(in_channels = init_dim,
                                                 out_channels=hidden_dim,
                                                 kernel_size = 6,
                                                 padding="same",
                                                 groups = n_groups),
              nn.ReLU(),
              nn.Dropout(p_dropout_1),
              nn.Conv1d(in_channels = hidden_dim,
                        out_channels=1,
                        kernel_size = 1,
                        stride = 1),)
        self.conv_seq = nn.Sequential(nn.Conv1d(in_channels = init_dim,
                                                out_channels=hidden_dim,
                                                kernel_size = 6,
                                                padding="same",
                                                groups = n_groups),
              nn.ReLU(),
              nn.Dropout(p_dropout_2),
              nn.Conv1d(in_channels = hidden_dim, out_channels=output_dim, kernel_size = 1, stride = 1),)
    def forward(self, x):
        a =  (self.conv_attn(x.transpose(1, 2)).squeeze()).softmax(-1)
        v = self.conv_seq(x.transpose(1, 2)).squeeze().transpose(1, 2)
        return a.unsqueeze(-1).mul(v).sum(axis=1)

In [7]:
class GNNAttnDrugPooling(nn.Module):
    def __init__(self,
                 embed_dim,
                 hidden_dim,
                 output_embed_dim,
                 p_dropout_attn = 0.0,
                 p_dropout_nodes = 0.0,
                 **kwargs):
        super().__init__()
        self.pool = gnn.GlobalAttention(nn.Sequential(nn.Linear(embed_dim, hidden_dim),
                                                             nn.ReLU(),
                                                             nn.Dropout(p_dropout_attn),
                                                             nn.Linear(hidden_dim, 1),
                                                             AttnDropout(p_dropout_nodes)),
                                               nn.Sequential(nn.Linear(embed_dim, hidden_dim),
                                                             nn.ReLU(),
                                                             nn.Dropout(p_dropout_attn),
                                                             nn.Linear(hidden_dim, output_embed_dim)))
    def set_cold(self):
        for p in self.parameters():
            p.requires_grad=False
    def forward(self, x, batch):
        return self.pool(x, batch)
class GNNEncoderDecoder(PairsNetwork):
    def __init__(self,
                 protein_encoder,
                 drug_encoder,
                 protein_adapter,
                 drug_adapter,
                 decoder,
                 **kwargs):
        """
        Network consisting of two encoders, two adapters and a decoder.
        The forward method has to be reimplemented.
        """
        super().__init__()
        self.protein_encoder = protein_encoder
        self.drug_encoder = drug_encoder
        self.protein_adapter = protein_adapter
        self.drug_adapter = drug_adapter
        self.decoder = decoder
    def forward(self, data, *args, **kwargs):
        x_lines = self.protein_adapter(self.protein_encoder(data["protein"]))
        x_drugs = self.drug_adapter(self.drug_encoder(data["x"],
                                                      data["edge_index"],
                                                      data["edge_attr"],
                                                      data["batch"]),
                                    data["batch"])
        return self.decoder(x_lines, x_drugs)

In [8]:
class GTEncoder(nn.Module):
    def __init__(self, embed_dim, num_heads=1):
        super().__init__()
        self.init_gat = gnn.GATConv(79, embed_dim, edge_dim = 10)
        self.layers = _stack = GatedGNNRes(TransGAT,
                                           {"input_dim":embed_dim,
                                             "output_dim":embed_dim,
                                             "edge_dim":10,
                                             "num_heads":1,},
                                           n_layers = 2)
    def forward(self, x, edge_index, edge_attr, batch):
        x = self.init_gat(x, edge_index, edge_attr)
        return self.layers(x, edge_index, edge_attr, batch)

In [9]:
model = GNNEncoderDecoder(protein_encoder = nn.Identity(),
                         drug_encoder = GTEncoder(embed_dim = 256),
                         protein_adapter = ProteinConvPooling(hidden_dim = 512, output_dim=256),
                         drug_adapter = GNNAttnDrugPooling(embed_dim = 256, hidden_dim = 1024, output_embed_dim=256, p_dropout_attn=0.2),
                         decoder = FCDecoder(256+256, 2048, p_dropout_2 = 0.3))
                         #decoder = NonlinearDotDecoder(256, 1024, 64, p_dropout_1=0.3, p_dropout_2 = 0.3))
optim = torch.optim.Adam(model.parameters(), 0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, patience = 2, factor = 0.75)
device = torch.device("cuda")
model.to(device)
torch.backends.cudnn.benchmark = True

In [None]:
from tqdm.notebook import trange, tqdm
model.to(device)
mse = nn.MSELoss()
metrics = torchmetrics.MetricCollection([torchmetrics.MeanSquaredError(), torchmetrics.PearsonCorrCoef()]).to(device)
elm = ElementwiseMetric(average="drugs")
for epoch in range(50):
    model.train()
    elm.reset()
    metrics.reset()
    with tqdm(total=len(train_dataloader)) as pbar:
        for n_b, b in enumerate(train_dataloader):
            if ((n_b + 1) % 10) == 0:
                pbar.update(10)
            b = b.to(device)
            y_pred = model(b)
            l = mse(y_pred.squeeze(), b["y"].squeeze())
            l.backward()
            optim.step()
            with torch.no_grad():
                metrics.update(y_pred.squeeze(), b["y"].squeeze())
                elm.update(y_pred.squeeze(), b["y"].squeeze(), b["DRUG_ID"], b["PROTEIN_ID"])
            optim.zero_grad()
    metric_dict_train = {it[0] + "_train":it[1].cpu().item() for it in metrics.compute().items()}
    metric_dict_train["R_average_train"] = elm.compute().item()
    model.eval()
    metrics.reset()
    elm.reset()
    scheduler.step(metric_dict_train["MeanSquaredError_train"])
    with torch.no_grad():
        for b in test_dataloader:
            b = b.to(device)
            y_pred = model(b)
            metrics.update(y_pred.squeeze(), b["y"].squeeze())
            elm.update(y_pred.squeeze(), b["y"].squeeze(), b["DRUG_ID"], b["PROTEIN_ID"])
    metric_dict_test = {it[0] + "_test":it[1].cpu().item() for it in metrics.compute().items()}
    metric_dict_train["R_average_test"] = elm.compute().item()
    metric_dict = {**metric_dict_test, **metric_dict_train}
    print(epoch)
    pprint(metric_dict)

  0%|          | 0/381 [00:00<?, ?it/s]

0
{'MeanSquaredError_test': 0.00040473032277077436,
 'MeanSquaredError_train': 0.04489264637231827,
 'PearsonCorrCoef_test': -0.008448257111012936,
 'PearsonCorrCoef_train': 0.005974019877612591,
 'R_average_test': nan,
 'R_average_train': nan}


  0%|          | 0/381 [00:00<?, ?it/s]

1
{'MeanSquaredError_test': 0.0017297266749665141,
 'MeanSquaredError_train': 0.008432043716311455,
 'PearsonCorrCoef_test': 0.3168678283691406,
 'PearsonCorrCoef_train': 0.09814435988664627,
 'R_average_test': nan,
 'R_average_train': nan}


  0%|          | 0/381 [00:00<?, ?it/s]