In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch

import torch.nn as nn
from torch.nn.utils import clip_grad_norm_
from torch.optim import SGD
from torch.utils.data import DataLoader
from tqdm import tqdm

from druxai.models.NN import Interaction_Model
from druxai.utils.data import DrugResponseDataset
from druxai.utils.dataframe_utils import split_data_by_cell_line_ids
from druxai.utils.dataframe_utils import standardize_molecular_data_inplace


file_path = "/Users/niklaskiermeyer/Desktop/Codespace/DruxAI/data/preprocessed"

In [None]:
# Load Data
data = DrugResponseDataset(file_path)

In [None]:
# Create splits
# Get split ids
train_id, val_id, test_id = split_data_by_cell_line_ids(data.targets)

In [None]:
# Important molecular data > cell_line data. since we do it over getitem
# only our train, val, and test cell lines we actually use are standardized
standardize_molecular_data_inplace(data, train_id, val_id, test_id)

In [None]:
# Dataloader
train_loader = DataLoader(data, sampler=train_id, batch_size=8, shuffle=False, pin_memory=True, num_workers=6)
val_loader = DataLoader(data, sampler=val_id, batch_size=8, shuffle=False, pin_memory=True, num_workers=6)

In [None]:
# Train Loop
model = Interaction_Model(data)
model.train().to(torch.device("mps"))
# Setup optimizers
optimizer1 = SGD(model.nn1.parameters(), momentum=0.9, lr=0.01, weight_decay=1e-5)
optimizer2 = SGD(model.nn2.parameters(), momentum=0.9, lr=0.01, weight_decay=1e-5)

epoch = 0
while epoch<1:
        model.train()
        total_loss = 0.0

        for X, y, _ in tqdm(train_loader):
            drug, molecular = X["drug_encoding"].to(torch.device("mps")), X["gene_expression"].to(torch.device("mps"))
            outcome = y.to(torch.device("mps"))
            optimizer1.zero_grad()
            optimizer2.zero_grad()

            prediction = model.forward(drug, molecular)
            loss = nn.HuberLoss()(outcome, prediction)
            total_loss += loss.item()

            loss.backward()

            clip_grad_norm_(model.parameters(), 1.0)

            # Randomly select optimizer
            selected_optimizer = optimizer1 if torch.rand(1) < 0.5 else optimizer2
            selected_optimizer.step()
            break

        epoch+=1
