In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch

from torch.optim import SGD
from torch.utils.data import DataLoader
from tqdm import tqdm

from druxai.models.fusion_model import FusionModel

from druxai.utils.data import DrugResponseDataset
from druxai.utils.dataframe_utils import split_data_by_cell_line_ids
from druxai.utils.dataframe_utils import standardize_molecular_data_inplace
from druxai.utils.data import DataloaderSampler

file_path = "/Users/niklaskiermeyer/Desktop/Codespace/DruxAI/data/preprocessed"

In [3]:
# Load Data
data = DrugResponseDataset(file_path)

# Create splits for cell lines
train_id, val_id, test_id = split_data_by_cell_line_ids(data.targets)

# Standardize molecular data
standardize_molecular_data_inplace(data, train_id=train_id, val_id=val_id, test_id=test_id)

# Create dataloader sampler
train_sampler, _ = DataloaderSampler(train_id), DataloaderSampler(val_id)

# Create a train Dataloader
train_loader = DataLoader(data, sampler=train_id, batch_size=32, shuffle=False, pin_memory=True, num_workers=6)
val_loader = DataLoader(data, sampler=val_id, batch_size=32, shuffle=False, pin_memory=True, num_workers=6)

[34mINFO    [0m Loaded targets with shape: [1m([0m[1;36m556840[0m, [1;36m8[0m[1m)[0m                                                                    
[34mINFO    [0m Loaded molecular data with shape: [1m([0m[1;36m1479[0m, [1;36m19193[0m[1m)[0m                                                           


In [4]:
model = FusionModel(data, 64, 10, 1, 0.2, 0.2)

In [5]:
optimizer = SGD(model.parameters(), lr=0.01)
epoch = 0

model.train()
model.to(torch.device("mps"))
while epoch < 3:
    for X, y, _ in tqdm(train_loader):
        drug_data, molecular_data, outcome = (X["drug_encoding"].to(torch.device("mps")),
        X["gene_expression"].to(torch.device("mps")),
        y.to(torch.device("mps")))

        optimizer.zero_grad()

        outputs = model(drug_data, molecular_data)
        loss = torch.nn.functional.huber_loss(outputs, outcome)

        loss.backward()
        optimizer.step()

    epoch += 1
    print(f"Epoch {epoch} Loss: {loss.item()}")

100%|██████████| 12159/12159 [01:14<00:00, 163.61it/s]


Epoch 1 Loss: 0.009509323164820671


100%|██████████| 12159/12159 [01:13<00:00, 164.68it/s]


Epoch 2 Loss: 0.009250925853848457


100%|██████████| 12159/12159 [01:29<00:00, 136.48it/s]

Epoch 3 Loss: 0.007269097957760096





In [6]:
from scipy.stats import spearmanr

model.eval()
predictions = []
outcomes = []

for X, y, _ in tqdm(val_loader):
    drug_data, molecular_data, outcome = (X["drug_encoding"].to(torch.device("mps")),
        X["gene_expression"].to(torch.device("mps")),
        y.to(torch.device("mps")))
    with torch.no_grad():
        output = model(drug_data, molecular_data)
    predictions.extend(output.detach().cpu().numpy())
    outcomes.extend(outcome.detach().cpu().numpy())

spearman_corr, _ = spearmanr(predictions, outcomes)
print(f"Spearman correlation: {spearman_corr}")

100%|██████████| 2399/2399 [00:25<00:00, 93.79it/s] 

Spearman correlation: 0.5697544076124905





In [7]:
predictions[100:150]

[array([-0.13262278], dtype=float32),
 array([-0.13723084], dtype=float32),
 array([-0.1895616], dtype=float32),
 array([-0.21674111], dtype=float32),
 array([-0.11026599], dtype=float32),
 array([-0.15850471], dtype=float32),
 array([-0.12916355], dtype=float32),
 array([-0.12135497], dtype=float32),
 array([-0.15679702], dtype=float32),
 array([-0.11425522], dtype=float32),
 array([-0.15274216], dtype=float32),
 array([-0.24977687], dtype=float32),
 array([-0.11958277], dtype=float32),
 array([-0.26117608], dtype=float32),
 array([-0.17570165], dtype=float32),
 array([-0.12253809], dtype=float32),
 array([-0.15890926], dtype=float32),
 array([-0.16622847], dtype=float32),
 array([-0.16797934], dtype=float32),
 array([-0.15096883], dtype=float32),
 array([-0.15710549], dtype=float32),
 array([-0.17116013], dtype=float32),
 array([-0.13938446], dtype=float32),
 array([-0.09306007], dtype=float32),
 array([-0.1011591], dtype=float32),
 array([-0.12006402], dtype=float32),
 array([-0.087

In [8]:
outcomes

[array([0.11115613], dtype=float32),
 array([0.2633268], dtype=float32),
 array([-0.5258362], dtype=float32),
 array([-0.49935508], dtype=float32),
 array([0.11194454], dtype=float32),
 array([0.16227844], dtype=float32),
 array([0.18993562], dtype=float32),
 array([0.15420808], dtype=float32),
 array([-0.21554437], dtype=float32),
 array([-1.1675398], dtype=float32),
 array([0.3030128], dtype=float32),
 array([0.1439186], dtype=float32),
 array([0.19248554], dtype=float32),
 array([-0.16299194], dtype=float32),
 array([-0.40179667], dtype=float32),
 array([-0.69512033], dtype=float32),
 array([-0.3093679], dtype=float32),
 array([0.23811617], dtype=float32),
 array([0.11718196], dtype=float32),
 array([-0.01960478], dtype=float32),
 array([-0.06835456], dtype=float32),
 array([0.03726429], dtype=float32),
 array([-0.25503963], dtype=float32),
 array([0.33207047], dtype=float32),
 array([-0.56761575], dtype=float32),
 array([-0.9696719], dtype=float32),
 array([0.15501778], dtype=float