In [None]:
import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import lightning as L

import os, sys

import pandas as pd
import numpy as np
import pickle
import gzip
import sklearn.model_selection
import matplotlib.pyplot as plt

In [None]:
df_grountruth_score = pd.read_csv("/gpfs/gibbs/pi/zhao/tl688/synergy_prediction/labels_synergy_value.csv")
df_grountruth_score.head()

In [None]:
with open("/gpfs/gibbs/pi/zhao/tl688/cpsc_finalproject/genept_data/GenePT/ensem_emb_deepsynergycellline.pickle", 'rb') as f:
    cellline_name_getembedding = pickle.load(f)
with open("/gpfs/gibbs/pi/zhao/tl688/cpsc_finalproject/genept_data/GenePT/ensem_emb_deepsynergydrug.pickle", 'rb') as f:
    drug_name_getembedding = pickle.load(f)


In [None]:
test_fold = 0
train_index = df_grountruth_score[df_grountruth_score['fold'] != test_fold].index

train_list = {}
for item in train_index:
    d1, d2, cl = df_grountruth_score.loc[item]['Unnamed: 0'].split('_')
    value_list = (drug_name_getembedding[d1] + drug_name_getembedding[d2]) / 2 
    value_list = np.hstack([value_list,cellline_name_getembedding[cl]])
    train_list[item] = value_list

X_train = np.array(list(train_list.values()))
y_train = df_grountruth_score.loc[train_index]['synergy'].values
y_train = (y_train > 30)*1

# X_train, X_val, y_train, y_val = sklearn.model_selection.train_test_split(X_train, y_train, random_state=2023)

layers = [10240,4096,1] 
epochs = 1000 
act_func = 'GELU'
dropout = 0.5
input_dropout = 0.2
eta = 1e-5 
norm = 'tanh' 
drug_num_dim = 16

df_test = df_grountruth_score[df_grountruth_score['fold'] == test_fold]

test_list = {}
for item in df_test.index.values:
    d1, d2, cl = df_grountruth_score.loc[item]['Unnamed: 0'].split('_')
    value_list = (drug_name_getembedding[d1] + drug_name_getembedding[d2]) / 2 
    value_list = np.hstack([value_list,cellline_name_getembedding[cl]])
    test_list[item] = value_list

X_test = np.array(list(test_list.values()))
y_test = df_grountruth_score.loc[df_test.index.values]['synergy'].values
y_test = (y_test > 30)*1

# X_tr, X_test, y_tr, y_test = sklearn.model_selection.train_test_split(X_train, y_train, test_size=0.1, random_state = 2023)

X_tr, X_val, y_tr, y_val = sklearn.model_selection.train_test_split(X_train, y_train, random_state = 2023)


In [None]:
class Encoder(nn.Module):

    def __init__(self):

        super().__init__()

        self.l1 = nn.Sequential(nn.Linear(X_tr.shape[1], layers[0]), 
                                nn.BatchNorm1d(layers[0], momentum=0.1),
                                nn.ReLU(), 
                                nn.Dropout(input_dropout),
                                nn.Linear(layers[0], layers[1]),
                                nn.BatchNorm1d(layers[1], momentum=0.1),
                                nn.ReLU(), 
                                nn.Dropout(input_dropout),
                                nn.Linear(layers[1], layers[2]),
                                nn.Sigmoid()
                               )

        

        self.drug_num_emb = nn.Embedding(10, drug_num_dim)



    def forward(self, x):

        return self.l1(x)



    def predict(self, x):

        z = self.l1(x)

        return z

class LitAutoEncoder(L.LightningModule):
    def __init__(self, encoder):
        super().__init__()
        self.encoder = encoder
#         self.encoder.apply(init_weights)

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        x, y = batch
#         print(self.encoder.drug_num_emb(x[:,-1].long()).shape)
        x = x.view(x.size(0), -1)
        x_new = x
#         x_new = torch.cat([x[:,:-1], self.encoder.drug_num_emb(x[:,-1].long())], axis=1)
        z = self.encoder(x_new)
        loss = F.binary_cross_entropy(z,y.view(x.size(0), -1))
        return loss
    
    def validation_step(self, batch, batch_idx):
        # this is the validation loop
        x, y = batch
        x = x.view(x.size(0), -1)
        x_new = x
#         x_new = torch.cat([x[:,:-1], self.encoder.drug_num_emb(x[:,-1].long())], axis=1)
        z = self.encoder(x_new)
        val_loss = F.binary_cross_entropy(z,y.view(x.size(0), -1))
        self.log("val_loss", val_loss)
        return val_loss

        
    def test_step(self, batch, batch_idx):
        # this is the test loop
        x, y = batch
        x = x.view(x.size(0), -1)
        x_new = x
#         x_new = torch.cat([x[:,:-1], self.encoder.drug_num_emb(x[:,-1].long())], axis=1)
        z = self.encoder(x_new)
        test_loss = F.binary_cross_entropy(z,y.view(x.size(0), -1))
        self.log("test_loss", test_loss)
        return test_loss
        
    def forward(self, x):
        return self.encoder(x)

    def configure_optimizers(self):
            optimizer = torch.optim.Adam(
                params=self.parameters(), 
                lr=eta,
#                 weight_decay=1e-4
            )
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer,
                patience=10,
                verbose=True
            )
            return {
               'optimizer': optimizer,
               'lr_scheduler': scheduler, # Changed scheduler to lr_scheduler
               'monitor': 'val_loss'
           }


In [None]:
X_tr, X_val, X_train, X_test, y_tr, y_val, y_train, y_test =torch.FloatTensor(X_tr),torch.FloatTensor(X_val),torch.FloatTensor(X_train),torch.FloatTensor(X_test),torch.FloatTensor(y_tr), torch.FloatTensor(y_val), torch.FloatTensor(y_train), torch.FloatTensor(y_test)

train_dataset = torch.utils.data.TensorDataset(X_tr, y_tr)
valid_dataset = torch.utils.data.TensorDataset(X_val, y_val)
test_dataset = torch.utils.data.TensorDataset(X_test, y_test)

# layers = [10240,4096,1] 
# epochs = 1000 
# act_func = 'GELU'
# dropout = 0.5 
# input_dropout = 0.2
# eta = 1e-4 
# norm = 'tanh' 
# drug_num_dim = 16

model = LitAutoEncoder(Encoder())

from lightning.pytorch.callbacks import LearningRateMonitor
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
lr_monitor = LearningRateMonitor(logging_interval='step')

model

model.encoder.l1[0]

train_loader = DataLoader(train_dataset, batch_size=1024, num_workers=5, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=1024, num_workers=5)

# train with both splits
trainer = L.Trainer(callbacks=[EarlyStopping(monitor="val_loss", mode="min", patience=100)], max_epochs=1000)
trainer.fit(model, train_loader, valid_loader)

In [None]:
with torch.no_grad():
    y_pred = model.encoder.predict(X_test).detach()

import scipy.stats
import sklearn.metrics

print("We use the fold with number:", test_fold)
print(sklearn.metrics.roc_auc_score(y_test, y_pred.t()[0].cpu().numpy()), sklearn.metrics.accuracy_score(y_test, (y_pred.t()[0].cpu().numpy()>0.5)*1))
