In [1]:
from PyDRP.Data import DatasetManager, GDSC
from PyDRP.Data.features.drugs import GraphCreator
from PyDRP.Data.features.cells import TensorLineFeaturizer
from PyDRP.Data.features.targets import MinMaxScaling, IdentityPipeline
from PyDRP.Data.utils import TorchGraphsDataset
from PyDRP.Models.PairsNetwork import PairsNetwork, GNNCellDrugEncoderDecoder
from PyDRP.Models.encoders.drugs import GATmannEncoder, GNNAttnDrugPooling
from PyDRP.Models.encoders.cells import GeneExpEncoder
from PyDRP.Models.decoders import FCDecoder,  NonlinearDotDecoder
from PyDRP.Models.NNlayers import AttnDropout
from PyDRP.Models.metrics import ElementwiseMetric



In [2]:
import pandas as pd
from pprint import pprint
import os
import numpy as np
from torch import nn
from torch_geometric import nn as gnn
import torch
import torch_geometric
import torchmetrics

In [3]:
paccmann_genes = pd.read_csv("https://raw.githubusercontent.com/prassepaul/mlmed_ranking/main/data/gdsc_data/paccmann_gene_list.txt", index_col=None, header=None).to_numpy().squeeze().tolist()

In [4]:
manager = DatasetManager(processing_pipeline = GDSC(target = "LN_IC50",
                                                    gene_subset = paccmann_genes,
                                                    cell_lines = "expression"),
                        target_processor = IdentityPipeline(),
                        partition_column = "DRUG_ID",
                        k = 25,
                        drug_featurizer = GraphCreator(),
                        line_featurizer = TensorLineFeaturizer())

In [5]:
train, val, test = manager.get_partition(0)
line_dict = manager.get_cell_lines()
drug_dict = manager.get_drugs()

In [6]:
train_dataset = TorchGraphsDataset(data=train,
                   drug_dict = drug_dict,
                   line_dict = line_dict)
test_dataset = TorchGraphsDataset(data=test,
                   drug_dict = drug_dict,
                   line_dict = line_dict)
n_dim = train_dataset[0]["cell"].shape[1]

In [7]:
model = GNNCellDrugEncoderDecoder(line_encoder = GeneExpEncoder(n_dim, 1024, 256, genes_do = 0.4),
                         drug_encoder = GATmannEncoder(edge_features = drug_dict["5-Fluorouracil"]["edge_attr"].shape[1],embed_dim = 256),
                         line_adapter = nn.Identity(),
                         drug_adapter = GNNAttnDrugPooling(embed_dim = 256, hidden_dim = 1024, output_embed_dim=256),
                         decoder = FCDecoder(512, 1024, p_dropout_2 = 0.3))
                         #decoder = NonlinearDotDecoder(256, 1024, 64, p_dropout_1=0.3, p_dropout_2 = 0.3))
optim = torch.optim.Adam(model.parameters(), 0.0001)
device = torch.device("cuda")
model.to(device)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size = 512,
                                               collate_fn = torch_geometric.data.Batch.from_data_list, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size = 512,
                                               collate_fn = torch_geometric.data.Batch.from_data_list)

In [None]:
mse = nn.MSELoss()
metrics = torchmetrics.MetricCollection([torchmetrics.MeanSquaredError(), torchmetrics.PearsonCorrCoef()]).to(device)
elm = ElementwiseMetric(average="drugs")
for epoch in range(2):
    model.train()
    elm.reset()
    metrics.reset()
    for b in train_dataloader:
        b = b.to(device)
        y_pred = model(b)
        l = mse(y_pred.squeeze(), b["y"].squeeze())
        l.backward()
        optim.step()
        with torch.no_grad():
            metrics.update(y_pred.squeeze(), b["y"].squeeze())
            elm.update(y_pred.squeeze(), b["y"].squeeze(), b["DRUG_ID"], b["CELL_ID"])
        optim.zero_grad()
    metric_dict_train = {it[0] + "_train":it[1].cpu().item() for it in metrics.compute().items()}
    metric_dict_train["R_average_train"] = elm.compute().item()
    model.eval()
    metrics.reset()
    elm.reset()
    with torch.no_grad():
        for b in test_dataloader:
            b = b.to(device)
            y_pred = model(b)
            metrics.update(y_pred.squeeze(), b["y"].squeeze())
            elm.update(y_pred.squeeze(), b["y"].squeeze(), b["DRUG_ID"], b["CELL_ID"])
    metric_dict_test = {it[0] + "_test":it[1].cpu().item() for it in metrics.compute().items()}
    metric_dict_train["R_average_test"] = elm.compute().item()
    metric_dict = {**metric_dict_test, **metric_dict_train}
    print(f"epoch {epoch}")
    pprint(metric_dict)

epoch 0
{'MeanSquaredError_test': 14.591248512268066,
 'MeanSquaredError_train': 5.1743927001953125,
 'PearsonCorrCoef_test': 0.1492098569869995,
 'PearsonCorrCoef_train': 0.5079371333122253,
 'R_average_test': 0.33869847655296326,
 'R_average_train': 0.08139122277498245}
