In [1]:
import pandas as pd
import requests
import io
import numpy as np
import pickle
import os
import torch
import torch_geometric
from torch import nn
from torch.nn import functional as F
from torch_geometric import nn as gnn
import torchmetrics
from tdc import single_pred



In [2]:
from PyDRP.Data import PreprocessingPipeline
from PyDRP.src import Splitter
from PyDRP.Data.features.targets import MultitargetMinMaxScaling
from PyDRP.Data.features.drugs import GraphCreator
from PyDRP.Data.transfer.drugs import ToxRicPreprocessingPipeline, TransferDrugsDatasetManager,  TDCSingleInstanceWrapper, MultiTaskPreprocessingPipeline, MakeDrugwise
from PyDRP.Data.utils import TorchGraphsTransferDataset
from PyDRP.Models.PairsNetwork import GNNDrugEncoderDecoder
from PyDRP.Models.layers import FCBlock
from PyDRP.Models.encoders.drugs import GTEncoder, GNNAttnDrugPooling
from PyDRP.Models.layers import AttnDropout
from PyDRP.Data import NI60

In [3]:
ppls = [TDCSingleInstanceWrapper(single_pred.Tox(name="Tox21", label_name = "NR-AR")),
       TDCSingleInstanceWrapper(single_pred.Tox(name="Tox21", label_name = "NR-AhR")),
       MakeDrugwise(NI60()),
       ToxRicPreprocessingPipeline()]

Found local copy...
Loading...
Done!
Found local copy...
Loading...
Done!


In [4]:
manager = TransferDrugsDatasetManager(drugs_processing_pipeline =  MultiTaskPreprocessingPipeline(ppls),
                           target_processor = MultitargetMinMaxScaling(),
                           drug_featurizer = GraphCreator())

In [5]:
drug_dict = manager.get_drugs()
train, test, val = manager.get_partition(1)
train_dataset = TorchGraphsTransferDataset(train, drug_dict)
val_dataset = TorchGraphsTransferDataset(val, drug_dict)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size = 512,
                                               collate_fn = torch_geometric.data.Batch.from_data_list, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size = 512,
                                               collate_fn = torch_geometric.data.Batch.from_data_list)

In [6]:
model = GNNDrugEncoderDecoder(drug_encoder = GTEncoder(embed_dim = 256),
                         drug_adapter = GNNAttnDrugPooling(embed_dim = 256, hidden_dim = 1024, output_embed_dim=256),
                         decoder = FCBlock(input_dim = 256, hidden_dim = 2048, output_dim = 94, outp_dropout = 0.4))
optim = torch.optim.Adam(model.parameters(), 0.0005)
device = torch.device("cuda")
model.to(device)

GNNDrugEncoderDecoder(
  (drug_encoder): GTEncoder(
    (init_gat): GATConv(79, 256, heads=1)
    (layers): GatedGNNRes(
      (layers): ModuleList(
        (0): TransGAT(
          (k): GATv2Conv(512, 256, heads=1)
          (v): GATv2Conv(256, 256, heads=1)
          (q): GATv2Conv(512, 256, heads=1)
          (pool): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
          )
        )
        (1): TransGAT(
          (k): GATv2Conv(512, 256, heads=1)
          (v): GATv2Conv(256, 256, heads=1)
          (q): GATv2Conv(512, 256, heads=1)
          (pool): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
          )
        )
      )
    )
  )
  (drug_adapter): GNNAttnDrugPooling(
    (pool): GlobalAttention(gate_nn=Sequential(
      (0): Linear(in_features=256, out_features=1024, bias=True)
      (1): ReLU()
      (2): Dropout(p=0.0, inpl

In [7]:
mse = nn.MSELoss()
metrics = torchmetrics.MetricCollection([torchmetrics.MeanSquaredError(), torchmetrics.PearsonCorrCoef()]).to(device)
for epoch in range(2):
    model.train()
    metrics.reset()
    for b in train_dataloader:
        b = b.to(device)
        y_pred = model(b)
        is_data = ~(b["y"].squeeze().isnan())
        l = mse(y_pred.squeeze()[is_data], b["y"].squeeze()[is_data])
        l.backward()
        optim.step()
        with torch.no_grad():
            metrics.update(y_pred.squeeze()[is_data], b["y"].squeeze()[is_data])
        optim.zero_grad()
    metric_dict_train = {it[0] + "_train":it[1].cpu().item() for it in metrics.compute().items()}
    model.eval()
    metrics.reset()
    with torch.no_grad():
        for b in test_dataloader:
            b = b.to(device)
            y_pred = model(b)
            is_data = ~(b["y"].squeeze().isnan())
            metrics.update(y_pred.squeeze()[is_data], b["y"].squeeze()[is_data])
    metric_dict_test = {it[0] + "_test":it[1].cpu().item() for it in metrics.compute().items()}
    metric_dict = {**metric_dict_test, **metric_dict_train}
    print(f"epoch {epoch}")
    print(metric_dict)

epoch 0
{'MeanSquaredError_test': 0.04631167650222778, 'PearsonCorrCoef_test': 0.8108320236206055, 'MeanSquaredError_train': 0.05116969347000122, 'PearsonCorrCoef_train': 0.7881336212158203}
epoch 1
{'MeanSquaredError_test': 0.04392028972506523, 'PearsonCorrCoef_test': 0.8226858377456665, 'MeanSquaredError_train': 0.04532036930322647, 'PearsonCorrCoef_train': 0.8130165934562683}


In [8]:
torch.save(model.drug_encoder.state_dict(), "saved_weights/drug_encoder_multitask.pt")

In [9]:
torch.save(model.drug_adapter.state_dict(), "saved_weights/drug_adapter_multitask.pt")