In [1]:
from dataset import *
from train import *
from model import *
from features import *

import pickle as pkl

import numpy as np

import torch
from torch.optim import SGD, Adam
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

data_dir = "../data"




In [2]:
with open(f"{data_dir}/adata.npy", "rb") as f:
    data = np.load(f)[0].squeeze()
    print(data.shape)
with open(f"{data_dir}/adata.npy.meta.pickle", "rb") as f:
    meta = pkl.load(f)
    print(meta.keys())
    print(meta["cell_types"])

with open(f"{data_dir}/de_train.npy", "rb") as f:
    de_train = np.load(f)
    print(de_train.shape)
with open(f"{data_dir}/de_val.npy", "rb") as f:
    de_val = np.load(f)
    print(de_val.shape)

(21255, 147, 6)
dict_keys(['gene_names', 'mols', 'cell_types'])
['T cells CD4+', 'T regulatory cells', 'T cells CD8+', 'NK cells', 'B cells', 'Myeloid cells']
(18211, 146, 4)
(18211, 17, 2)


In [3]:
mtypes = meta["mols"]
mol_transforms = {
    "morgan2_fp": TransformList([Sm2Smiles("../config/sm_smiles.csv", mode="path"), Smiles2Mol(), Mol2Morgan(2048, 2)]),
    "morgan3_fp": TransformList([Sm2Smiles("../config/sm_smiles.csv", mode="path"), Smiles2Mol(), Mol2Morgan(2048, 3)]),
    "one_hot": TransformList([Type2OneHot(mtypes)])
}

ctypes = meta["cell_types"]

gene_num = len(meta["gene_names"])

In [4]:
# dataset_train = C2PDataset(data, meta, sm_transforms=mol_transforms)
# dataset_val = C2PDataset(data, meta, sm_transforms=mol_transforms)
# dataset_train.configure(sm_out_feature="morgan2_fp", mode="train", return_y=False)
# dataset_val.configure(sm_out_feature="morgan2_fp", mode="validation")
# d = CP2DEDataset(de_train, dataset_train)

In [5]:
#### LOADERS, DATA ####
dataset_train = C2PDataset(data, meta, sm_transforms=mol_transforms)
dataset_val = C2PDataset(data, meta, sm_transforms=mol_transforms)
dataset_train.configure(sm_out_feature="morgan2_fp", mode="train")
dataset_val.configure(sm_out_feature="morgan2_fp", mode="validation")
train_dataloader = DataLoader(dataset_train, 32)
val_dataloader = DataLoader(dataset_val, 32)

#### MODEL ####
model = Control2Pert(in_dim=len(dataset_train.gene_names_control), pert_dim=2048, 
                     bottle_dim=500, out_dim=len(dataset_train.gene_names_control))

#### TRAINING ####
lr = 0.01
epochs = 40
device = "cuda:0"

loss_fn = loss_mrrmse
optimizer = Adam(model.parameters(), lr=lr, weight_decay=2e-4)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.8, patience=7)

#### BATCH PROCESSING ####
def process_batch(batch):

    x_batch, (y_batch, mask_batch) = batch
    y_batch = y_batch.to(device)
    mask_batch = mask_batch.to(device)
    y_pred = model(*x_batch, device) # TODO: Send to device the x in model?

    loss = loss_fn(y_pred, y_batch, mask_batch)
    return loss

#### TENSORBOARD ####
writer = SummaryWriter("./runs/final_test7/initial_model")

#### RUN ####
train_many_epochs(model, train_dataloader, val_dataloader, epochs, 
                  process_batch, optimizer, scheduler, writer=writer, device=device)

100%|██████████| 40/40 [01:07<00:00,  1.68s/it]


In [6]:
#### LOADERS, DATA ####
dataset_final_train = CP2DEDataset(de_train, dataset_train)
dataset_final_val = CP2DEDataset(de_val, dataset_val)
train_dataloader = DataLoader(dataset_final_train, 32)
val_dataloader = DataLoader(dataset_final_val, 32)

#### MODEL ####
model_final = ContPert2DE(model, 500+2048, # len(dataset_train.gene_names_control)
                          layers_sizes=[1000, 1000], out_dim=de_train.shape[0])
model.float()

#### TRAINING ####
lr = 0.05
epochs = 100
device = "cuda:0"

loss_fn = loss_mrrmse
optimizer = Adam(model_final.parameters(), lr=lr, weight_decay=2e-3)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.7, patience=5)

#### BATCH PROCESSING ####
def process_batch(batch):

    ((x_batch, pert_batch), _), y_batch = batch
    y_batch = y_batch.to(device)
    y_pred = model_final(x_batch, pert_batch, device) # TODO: Send to device the x in model?

    loss = loss_fn(y_pred, y_batch)
    return loss

#### TENSORBOARD ####
writer = SummaryWriter("./runs/final_test7/final_model")

#### RUN ####
train_many_epochs(model_final, train_dataloader, val_dataloader, epochs, 
                  process_batch, optimizer, scheduler, writer=writer, device=device)

100%|██████████| 100/100 [01:13<00:00,  1.36it/s]


In [7]:
p=model_final.parameters()

In [8]:
next(p).shape

torch.Size([2000, 16049])