In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch

from torch.optim import SGD
from torch.utils.data import DataLoader
from tqdm import tqdm

from druxai.models.fusion_model import FusionModel

from druxai.utils.data import DrugResponseDataset
from druxai.utils.dataframe_utils import split_data_by_cell_line_ids
from druxai.utils.dataframe_utils import standardize_molecular_data_inplace
from druxai.utils.data import DataloaderSampler

file_path = "/Users/niklaskiermeyer/Desktop/Codespace/DruxAI/data/preprocessed"

In [3]:
# Load Data
data = DrugResponseDataset(file_path)

# Create splits for cell lines
train_id, val_id, test_id = split_data_by_cell_line_ids(data.targets)

# Standardize molecular data
standardize_molecular_data_inplace(data, train_id=train_id, val_id=val_id, test_id=test_id)

# Create dataloader sampler
train_sampler, _ = DataloaderSampler(train_id), DataloaderSampler(val_id)

# Create a train Dataloader
train_loader = DataLoader(data, sampler=train_id, batch_size=32, shuffle=False, pin_memory=True, num_workers=6)
val_loader = DataLoader(data, sampler=val_id, batch_size=32, shuffle=False, pin_memory=True, num_workers=6)

[34mINFO    [0m Loaded targets with shape: [1m([0m[1;36m556840[0m, [1;36m8[0m[1m)[0m                                                                    
[34mINFO    [0m Loaded molecular data with shape: [1m([0m[1;36m1479[0m, [1;36m19193[0m[1m)[0m                                                           


In [4]:
model = FusionModel(data, 64, 10, 0.2, 0.2)

In [5]:
optimizer = SGD(model.parameters(), lr=0.01)
epoch = 0

model.train()
model.to(torch.device("mps"))
while epoch < 3:
    for X, y, _ in tqdm(train_loader):
        drug_data, molecular_data, outcome = (X["drug_encoding"].to(torch.device("mps")),
        X["gene_expression"].to(torch.device("mps")),
        y.to(torch.device("mps")))

        optimizer.zero_grad()

        outputs = model(drug_data, molecular_data)
        loss = torch.nn.functional.huber_loss(outputs, outcome)

        loss.backward()
        optimizer.step()

    epoch += 1
    print(f"Epoch {epoch} Loss: {loss.item()}")

100%|██████████| 12159/12159 [01:01<00:00, 198.42it/s]


Epoch 1 Loss: 0.33440861105918884


100%|██████████| 12159/12159 [00:54<00:00, 222.51it/s]


Epoch 2 Loss: 0.3225771188735962


100%|██████████| 12159/12159 [00:55<00:00, 217.73it/s]

Epoch 3 Loss: 0.32297563552856445





In [7]:
from scipy.stats import spearmanr

model.eval()
predictions = []
outcomes = []

for X, y, _ in tqdm(val_loader):
    drug_data, molecular_data, outcome = (X["drug_encoding"].to(torch.device("mps")),
        X["gene_expression"].to(torch.device("mps")),
        y.to(torch.device("mps")))
    with torch.no_grad():
        output = model(drug_data, molecular_data)
    predictions.extend(output.squeeze(1).cpu().tolist())
    outcomes.extend(outcome.squeeze(1).cpu().tolist())

spearman_corr, _ = spearmanr(predictions, outcomes)
print(f"Spearman correlation: {spearman_corr}")

100%|██████████| 2399/2399 [00:12<00:00, 186.58it/s]

Spearman correlation: 0.1150701892489979



