In [1]:
import sys, os
sys.path.append("..")

import pandas as pd

from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.tensorboard import SummaryWriter

from dataset import *
from features import *
from train import *
from model import *

In [2]:
de_df = pd.read_parquet("../data/de_train.parquet")

train_index, val_index = stratified_split(de_df["cell_type"], 0.2, 45)
de_df_dataset_train = DataFrameDataset(de_df.iloc[train_index], mode="df")
de_df_dataset_val = DataFrameDataset(de_df.iloc[val_index], mode="df")

In [3]:
mtypes = list(set(de_df["sm_name"].to_list()))
mol_transforms = {
    "morgan2_fp": TransformList([Sm2Smiles("../config/sm_smiles.csv", mode="path"), Smiles2Mol(), Mol2Morgan(2048, 2)]),
    "morgan3_fp": TransformList([Sm2Smiles("../config/sm_smiles.csv", mode="path"), Smiles2Mol(), Mol2Morgan(2048, 3)]),
    "one_hot": TransformList([Type2OneHot(mtypes)])
}

ctypes = list(set(de_df["cell_type"].to_list()))

file_names = ["../data/temp/"+name.replace(" ", "_").replace("+", "")+"_control_mean.csv"
              for name in ctypes]
gene_num = len(pd.read_csv(file_names[0]))

cell_transforms = {
    "one_hot": TransformList([Type2OneHot(ctypes)]),
    "gene_exp": TransformList([CType2CSVEncoding(ctypes, file_names)])
}

In [4]:
de_dataset_train = DEDataset(de_df_dataset_train, mol_transforms, cell_transforms)
de_dataset_val = DEDataset(de_df_dataset_val, mol_transforms, cell_transforms)

In [6]:
os.chdir("..")

#### LOADERS, DATA ####
de_dataset_train.configure(cell_out_feature="one_hot", sm_out_feature="morgan2_fp")
de_dataset_val.configure(cell_out_feature="one_hot", sm_out_feature="morgan2_fp")
train_dataloader = DataLoader(de_dataset_train, 32)
val_dataloader = DataLoader(de_dataset_val, 32)

#### MODEL ####
model = BaselineModel(cell_in=len(ctypes), mol_in=2048, out_size=len(de_df.columns)-5)

#### TRAINING ####
lr = 0.02
epochs = 500
device = "cuda:0"

loss_fn = loss_mrrmse
optimizer = Adam(model.parameters(), lr=lr, weight_decay=2e-5)
scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.8, patience=7)

#### TENSORBOARD ####
writer = SummaryWriter("./runs/trying_out2/2")

#### RUN ####
train_many_epochs(model, train_dataloader, val_dataloader, epochs, 
                  loss_fn, optimizer, scheduler, writer=writer, device=device)

os.chdir("./notebooks")

100%|██████████| 500/500 [05:24<00:00,  1.54it/s]
