In [1]:
import pandas as pd
from PyDRP.Data import DatasetManager, GDSC
from PyDRP.Data.features.drugs import GraphCreator
from PyDRP.Data.features.cells import TensorLineFeaturizer
from PyDRP.Data.features.targets import MinMaxScaling, IdentityPipeline
from PyDRP.Data.utils import TorchGraphsDataset
from PyDRP.Models.PairsNetwork import PairsNetwork
from PyDRP.Models.encoders.drugs import GATmannEncoder
from PyDRP.Models.encoders.cells import GeneExpEncoder
from PyDRP.Models.decoders import FCDecoder,  NonlinearDotDecoder
from PyDRP.Models.NNlayers import AttnDropout



In [2]:
import os
import numpy as np
from torch import nn
from torch_geometric import nn as gnn
import torch
import torch_geometric
import torchmetrics

In [3]:
paccmann_genes = pd.read_csv("https://raw.githubusercontent.com/prassepaul/mlmed_ranking/main/data/gdsc_data/paccmann_gene_list.txt", index_col=None, header=None).to_numpy().squeeze().tolist()

In [4]:
manager = DatasetManager(processing_pipeline = GDSC(target = "LN_IC50",
                                                    gene_subset = paccmann_genes,
                                                    cell_lines = "expression&mutation"),
                        target_processor = IdentityPipeline(),
                        partition_column = "CELL_ID",
                        k = 25,
                        drug_featurizer = GraphCreator(),
                        line_featurizer = TensorLineFeaturizer())

In [5]:
train, val, test = manager.get_partition(0)

In [6]:
line_dict = manager.get_cell_lines()
line_dict[683667]

tensor([3.8235, 4.7562, 3.2011,  ..., 0.0000, 0.0000, 0.0000])

In [7]:
drug_dict = manager.get_drugs()

In [8]:
train_dataset = TorchGraphsDataset(data=train,
                   drug_dict = drug_dict,
                   line_dict = line_dict)
test_dataset = TorchGraphsDataset(data=test,
                   drug_dict = drug_dict,
                   line_dict = line_dict)
n_dim = train_dataset[0]["cell"].shape[1]

In [9]:
class GNNDrugPooling(nn.Module):
    def __init__(self,
                 embed_dim,
                 hidden_dim,
                 output_embed_dim,
                 p_dropout_attn = 0.0,
                 p_dropout_nodes = 0.0,
                 **kwargs):
        super().__init__()
        self.pool = gnn.GlobalAttention(nn.Sequential(nn.Linear(embed_dim, hidden_dim),
                                                             nn.ReLU(),
                                                             nn.Dropout(p_dropout_attn),
                                                             nn.Linear(hidden_dim, 1),
                                                             AttnDropout(p_dropout_nodes)),
                                               nn.Sequential(nn.Linear(embed_dim, hidden_dim),
                                                             nn.ReLU(),
                                                             nn.Dropout(p_dropout_attn),
                                                             nn.Linear(hidden_dim, output_embed_dim)))
    def forward(self, x, batch):
        return self.pool(x, batch)
class GNNEncoderDecoder(PairsNetwork):
    def __init__(self,
                 line_encoder,
                 drug_encoder,
                 line_adapter,
                 drug_adapter,
                 decoder,
                 **kwargs):
        """
        Network consisting of two encoders, two adapters and a decoder.
        The forward method has to be reimplemented.
        """
        super().__init__()
        self.line_encoder = line_encoder
        self.drug_encoder = drug_encoder
        self.line_adapter = line_adapter
        self.drug_adapter = drug_adapter
        self.decoder = decoder
    def forward(self, data, *args, **kwargs):
        x_lines = self.line_adapter(self.line_encoder(data["cell"]))
        x_drugs = self.drug_adapter(self.drug_encoder(data["x"],
                                                      data["edge_index"],
                                                      data["edge_attr"]),
                                    data["batch"])
        return self.decoder(x_lines, x_drugs)

In [10]:
model = GNNEncoderDecoder(line_encoder = GeneExpEncoder(n_dim, 1024, 256, genes_do = 0.4),
                         drug_encoder = GATmannEncoder(edge_features = drug_dict["5-Fluorouracil"]["edge_attr"].shape[1],embed_dim = 256),
                         line_adapter = nn.Identity(),
                         drug_adapter = GNNDrugPooling(embed_dim = 256, hidden_dim = 1024, output_embed_dim=256),
                         #decoder = FCDecoder(512, 1024, p_dropout_2 = 0.3))
                         decoder = NonlinearDotDecoder(256, 1024, 64, p_dropout_1=0.3, p_dropout_2 = 0.3))
optim = torch.optim.Adam(model.parameters(), 0.0001)
device = torch.device("cuda")
model.to(device)

GNNEncoderDecoder(
  (line_encoder): GeneExpEncoder(
    (do): Dropout(p=0.4, inplace=False)
    (net): Sequential(
      (0): Linear(in_features=3163, out_features=1024, bias=True)
      (1): ReLU()
      (2): Linear(in_features=1024, out_features=256, bias=True)
    )
  )
  (drug_encoder): GATmannEncoder(
    (gat_init): GATv2Conv(79, 256, heads=1)
    (gat_layers): ModuleList(
      (0): GATv2Conv(256, 256, heads=1)
      (1): GATv2Conv(256, 256, heads=1)
    )
  )
  (line_adapter): Identity()
  (drug_adapter): GNNDrugPooling(
    (pool): GlobalAttention(gate_nn=Sequential(
      (0): Linear(in_features=256, out_features=1024, bias=True)
      (1): ReLU()
      (2): Dropout(p=0.0, inplace=False)
      (3): Linear(in_features=1024, out_features=1, bias=True)
      (4): AttnDropout(
        (id): Identity()
      )
    ), nn=Sequential(
      (0): Linear(in_features=256, out_features=1024, bias=True)
      (1): ReLU()
      (2): Dropout(p=0.0, inplace=False)
      (3): Linear(in_featu

In [11]:
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size = 512,
                                               collate_fn = torch_geometric.data.Batch.from_data_list, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size = 512,
                                               collate_fn = torch_geometric.data.Batch.from_data_list)

In [12]:
from PyDRP.Models.metrics import ElementwiseMetric

In [13]:
mse = nn.MSELoss()
metrics = torchmetrics.MetricCollection([torchmetrics.MeanSquaredError(), torchmetrics.PearsonCorrCoef()]).to(device)
elm = ElementwiseMetric(average="drugs")
for epoch in range(10):
    model.train()
    elm.reset()
    metrics.reset()
    for b in train_dataloader:
        b = b.to(device)
        y_pred = model(b)
        l = mse(y_pred.squeeze(), b["y"].squeeze())
        l.backward()
        optim.step()
        with torch.no_grad():
            metrics.update(y_pred.squeeze(), b["y"].squeeze())
            elm.update(y_pred.squeeze(), b["y"].squeeze(), b["DRUG_ID"], b["CELL_ID"])
        optim.zero_grad()
    metric_dict_train = {it[0] + "_train":it[1].cpu().item() for it in metrics.compute().items()}
    metric_dict_train["R_average_train"] = elm.compute().item()
    model.eval()
    metrics.reset()
    elm.reset()
    with torch.no_grad():
        for b in test_dataloader:
            b = b.to(device)
            y_pred = model(b)
            metrics.update(y_pred.squeeze(), b["y"].squeeze())
            elm.update(y_pred.squeeze(), b["y"].squeeze(), b["DRUG_ID"], b["CELL_ID"])
    metric_dict_test = {it[0] + "_test":it[1].cpu().item() for it in metrics.compute().items()}
    metric_dict_train["R_average_test"] = elm.compute().item()
    metric_dict = {**metric_dict_test, **metric_dict_train}
    print(f"epoch {epoch}", **metric_dict)

TypeError: print() takes at most 4 keyword arguments (6 given)