In [1]:
from dataset import *
from train import *
from model import *
from features import *

import pickle as pkl

import numpy as np

import torch
from torch.optim import SGD, Adam
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

data_dir = "../data"




In [2]:
with open(f"{data_dir}/adata.npy", "rb") as f:
    data = np.load(f)[0].squeeze()
    print(data.shape)
with open(f"{data_dir}/adata.npy.meta.pickle", "rb") as f:
    meta = pkl.load(f)
    print(meta.keys())
    print(meta["cell_types"])

(21255, 147, 6)
dict_keys(['gene_names', 'mols', 'cell_types'])
['T cells CD4+', 'T regulatory cells', 'T cells CD8+', 'NK cells', 'B cells', 'Myeloid cells']


In [3]:
mtypes = meta["mols"]
mol_transforms = {
    "morgan2_fp": TransformList([Sm2Smiles("../config/sm_smiles.csv", mode="path"), Smiles2Mol(), Mol2Morgan(2048, 2)]),
    "morgan3_fp": TransformList([Sm2Smiles("../config/sm_smiles.csv", mode="path"), Smiles2Mol(), Mol2Morgan(2048, 3)]),
    "one_hot": TransformList([Type2OneHot(mtypes)])
}

ctypes = meta["cell_types"]

gene_num = len(meta["gene_names"])

In [4]:
dataset_train = C2PDataset(data, meta, sm_transforms=mol_transforms)
dataset_val = C2PDataset(data, meta, sm_transforms=mol_transforms)

In [6]:
#### LOADERS, DATA ####
dataset_train.configure(sm_out_feature="morgan2_fp", mode="train")
dataset_val.configure(sm_out_feature="morgan2_fp", mode="validation")
train_dataloader = DataLoader(dataset_train, 32)
val_dataloader = DataLoader(dataset_val, 32)

#### MODEL ####
model = Control2Pert(in_dim=len(dataset_train.gene_names_control), pert_dim=2048, 
                     bottle_dim=500, out_dim=len(meta["gene_names"]))

#### TRAINING ####
lr = 0.01
epochs = 100
device = "cuda:0"

loss_fn = loss_mrrmse
optimizer = Adam(model.parameters(), lr=lr, weight_decay=2e-4)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.8, patience=7)

#### TENSORBOARD ####
writer = SummaryWriter("./runs/test13_plusmoldrop")

#### RUN ####
train_many_epochs(model, train_dataloader, val_dataloader, epochs, 
                  loss_fn, optimizer, scheduler, writer=writer, device=device)

100%|██████████| 100/100 [03:17<00:00,  1.98s/it]
