In [1]:
from PyDRP.Benchmarks import BenchCanc
import os
from torch import nn
import torch
import torch_geometric
import numpy as np

In [2]:
# Defining a "Dummy model class" that returns the output multiplied by a scalar
class DummyClass(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.lin = nn.Linear(1, 1)
    def forward(self, data):
        return self.lin(data["y"])

In [3]:
# Defining the configuration
config = {"optimizer":{"batch_size":256,
                        "learning_rate":0.001,
                         "max_epochs":2,
                         "patience":10,
                         "kwargs":{},
                         "clip_norm":1,},
            "env":{"device":"cuda:7",
                   "mixed_precision":True},
            "model":{}}

In [4]:
# Instantiating the benchmark
n_folds = 3
benchmark = BenchCanc(config = config,
                        n_folds=n_folds,
                        dataset = "GDSC1", 
                        line_features="expression+mutations+cnv")

In [5]:
# Instantiating the model at each fold and passing it to the train_model method 
performance = []
for i in range(n_folds):
    print(f"fold {i}")
    model =  DummyClass()
    _, _, test_metrics = benchmark.train_model(model, fold=i)
    performance += [test_metrics.compute()["MeanSquaredError"].cpu().detach().numpy()]
    
print(f" MSE: {np.mean(performance)} +- {np.std(performance)/np.sqrt(n_folds)}")

fold 0
epoch : 0, test_metrics: {'MeanSquaredError': 9.587281227111816, 'R_cellwise': -0.9999995827674866, 'R_cellwise_residuals': -0.9999978542327881}
epoch : 1, test_metrics: {'MeanSquaredError': 1.7074977159500122, 'R_cellwise': 1.0, 'R_cellwise_residuals': 0.9999994039535522}
fold 1
epoch : 0, test_metrics: {'MeanSquaredError': 14.385035514831543, 'R_cellwise': -1.0, 'R_cellwise_residuals': -0.9999997019767761}
epoch : 1, test_metrics: {'MeanSquaredError': 2.952277660369873, 'R_cellwise': 0.9999998807907104, 'R_cellwise_residuals': 0.9999992847442627}
fold 2
epoch : 0, test_metrics: {'MeanSquaredError': 0.19636879861354828, 'R_cellwise': 1.0, 'R_cellwise_residuals': 0.9999995827674866}
epoch : 1, test_metrics: {'MeanSquaredError': 6.656610639765859e-05, 'R_cellwise': 1.0, 'R_cellwise_residuals': 0.9999995827674866}
 MSE: 1.5532807111740112 +- 0.6986852480280793
