In [2]:
from dataset import *
from train import *
from model import *
from features import *

import pickle as pkl

import numpy as np

import torch
from torch.optim import SGD, Adam
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

data_dir = "../data"




In [3]:
with open(f"{data_dir}/adata.npy", "rb") as f:
    data = np.load(f)[0].squeeze()
    print(data.shape)
with open(f"{data_dir}/adata.npy.meta.pickle", "rb") as f:
    meta = pkl.load(f)
    print(meta.keys())
    print(meta["cell_types"])

with open(f"{data_dir}/de_train.npy", "rb") as f:
    de_train = np.load(f)
    print(de_train.shape)
with open(f"{data_dir}/de_val.npy", "rb") as f:
    de_val = np.load(f)
    print(de_val.shape)

(21255, 147, 6)
dict_keys(['gene_names', 'mols', 'cell_types'])
['T cells CD4+', 'T regulatory cells', 'T cells CD8+', 'NK cells', 'B cells', 'Myeloid cells']
(18211, 146, 4)
(18211, 17, 2)


In [4]:
mtypes = meta["mols"]
mol_transforms = {
    "morgan2_fp": TransformList([Sm2Smiles("../config/sm_smiles.csv", mode="path"), Smiles2Mol(), Mol2Morgan(2048, 2)]),
    "morgan3_fp": TransformList([Sm2Smiles("../config/sm_smiles.csv", mode="path"), Smiles2Mol(), Mol2Morgan(2048, 3)]),
    "one_hot": TransformList([Type2OneHot(mtypes)])
}

ctypes = meta["cell_types"]

gene_num = len(meta["gene_names"])

In [5]:
class ExprAutoEncoder(nn.Module):

    def __init__(self, x_dim, bottle_dim):
        super().__init__()
        
        self.encoder = nn.Sequential(
            # nn.Dropout(0.5),
            nn.Linear(x_dim, 3000),
            nn.Tanh(),
            nn.Linear(3000, bottle_dim),
            nn.Tanh()
        )
        self.decoder = nn.Sequential(
            nn.Linear(bottle_dim, 3000),
            nn.Tanh(),
            nn.Linear(3000, x_dim)
        )

        self.float()
    
    def forward(self, x, device):
        x = x.to(device)
        self.to(device)
        # u = torch.rand(1)
        # drop = nn.Dropout(float(u))
        # x = drop(x)
        x = self.encoder(x)
        x = self.decoder(x)
        return x

class MolAutoEncoder(nn.Module):

    def __init__(self, x_dim, bottle_dim):
        super().__init__()
        
        self.encoder = nn.Sequential(
            nn.Linear(x_dim, 500),
            nn.Tanh(),
            nn.Linear(500, bottle_dim),
            nn.Tanh()
        )
        self.decoder = nn.Sequential(
            nn.Linear(bottle_dim, 500),
            nn.Tanh(),
            nn.Linear(500, x_dim),
            nn.Sigmoid()
        )

        self.float()
    
    def forward(self, x, device):
        x = x.to(device)
        self.to(device)
        u = float(torch.rand(1)) * 0.5 + 0.3
        drop = nn.Dropout(u)
        x = drop(x)
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [None]:
#### LOADERS, DATA ####
dataset_train = C2PDataset(data, meta, sm_transforms=mol_transforms)
dataset_val = C2PDataset(data, meta, sm_transforms=mol_transforms)
dataset_train.configure(sm_out_feature="morgan2_fp", mode="train")
dataset_val.configure(sm_out_feature="morgan2_fp", mode="validation")
train_dataloader = DataLoader(dataset_train, 32)
val_dataloader = DataLoader(dataset_val, 32)

#### MODEL ####
mol_ae = MolAutoEncoder(x_dim=2048,
                        bottle_dim=30)

#### TRAINING ####
lr = 0.01
epochs = 100
device = "cuda:0"

loss_fn = nn.BCELoss(reduction="mean") # loss_mrrmse
optimizer = Adam(mol_ae.parameters(), lr=lr, weight_decay=5e-3)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.8, patience=7)

#### BATCH PROCESSING ####
def process_batch(batch):

    x_batch, (y_batch, mask_batch) = batch
    y_batch = y_batch.to(device)
    mask_batch = mask_batch.to(device)
    x_batch[1] = x_batch[1].to(device)
    x_pred = mol_ae(x_batch[1].float(), device) # TODO: Send to device the x in model?

    loss = loss_fn(x_pred, x_batch[1])
    return loss

#### TENSORBOARD ####
writer = SummaryWriter("./runs/mol_ae_rand_final/bottle30_reg5e3")

#### RUN ####
train_many_epochs(mol_ae, train_dataloader, val_dataloader, epochs, 
                  process_batch, optimizer, scheduler, writer=writer, device=device)

torch.save(mol_ae, "mol_ae.pt")

100%|██████████| 100/100 [00:43<00:00,  2.28it/s]


In [6]:
#### LOADERS, DATA ####
dataset_train = C2PDataset(data, meta, sm_transforms=mol_transforms)
dataset_val = C2PDataset(data, meta, sm_transforms=mol_transforms)
dataset_train.configure(sm_out_feature="morgan2_fp", mode="train")
dataset_val.configure(sm_out_feature="morgan2_fp", mode="validation")
train_dataloader = DataLoader(dataset_train, 16)
val_dataloader = DataLoader(dataset_val, 16)

#### MODEL ####
cell_ae = ExprAutoEncoder(x_dim=len(dataset_train.gene_names_control),
                          bottle_dim=100)

#### TRAINING ####
lr = 0.005
epochs = 100
device = "cuda:0"

loss_fn = loss_mrrmse
optimizer = Adam(cell_ae.parameters(), lr=lr, weight_decay=1e-4)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.8, patience=7)

#### BATCH PROCESSING ####
gi = dataset_train.gene_idx
print(gi.shape)
def process_batch(batch):

    x_batch, (y_batch, mask_batch) = batch
    y_batch = y_batch.to(device)
    mask_batch = mask_batch.to(device)
    x_batch[0] = x_batch[0].to(device)

    u = float(torch.rand(1))
    # if u > 0.9:
    #     x_pred = cell_ae(x_batch[0].float(), device) # TODO: Send to device the x in model?
    #     loss = loss_fn(x_pred, x_batch[0])
    #     print(loss)
    #     return loss
    # else: 
    y_pred = cell_ae(y_batch.float(), device)
    mnotnan = mask_batch.sum(1) > 0
    loss = loss_fn(y_pred[mnotnan], y_batch[mnotnan], mask_batch[mnotnan])
    if torch.isnan(loss):
        print(mask_batch[mnotnan])
    return loss

#### TENSORBOARD ####
writer = SummaryWriter("./runs/expr_ae_rand_newnew/bottle100_reg1e4_almosttfinal")

#### RUN ####
train_many_epochs(cell_ae, train_dataloader, val_dataloader, epochs, 
                  process_batch, optimizer, scheduler, writer=writer, device=device)
torch.save(cell_ae, "cell_ae.pt")

(21255,)


  1%|          | 1/100 [00:04<07:15,  4.39s/it]

tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)


  2%|▏         | 2/100 [00:07<05:32,  3.40s/it]

tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)


  3%|▎         | 3/100 [00:09<04:57,  3.07s/it]

tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)


  4%|▍         | 4/100 [00:12<04:39,  2.91s/it]

tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)


  5%|▌         | 5/100 [00:15<04:29,  2.83s/it]

tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)


  6%|▌         | 6/100 [00:17<04:21,  2.78s/it]

tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)


  7%|▋         | 7/100 [00:20<04:16,  2.76s/it]

tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)


  8%|▊         | 8/100 [00:23<04:11,  2.74s/it]

tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)


  9%|▉         | 9/100 [00:25<04:07,  2.72s/it]

tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)


 10%|█         | 10/100 [00:28<04:05,  2.72s/it]

tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)


 11%|█         | 11/100 [00:31<04:01,  2.71s/it]

tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)


 12%|█▏        | 12/100 [00:34<03:58,  2.71s/it]

tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)


 13%|█▎        | 13/100 [00:36<03:55,  2.71s/it]

tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)
tensor([], device='cuda:0', size=(0, 16049), dtype=torch.bool)


 13%|█▎        | 13/100 [00:38<04:18,  2.97s/it]


KeyboardInterrupt: 

In [11]:
#### LOADERS, DATA ####
dataset_train = C2PDataset(data, meta, sm_transforms=mol_transforms)
dataset_val = C2PDataset(data, meta, sm_transforms=mol_transforms)
dataset_train.configure(sm_out_feature="morgan2_fp", mode="train")
dataset_val.configure(sm_out_feature="morgan2_fp", mode="validation")
dataset_train_pval = CP2DEDataset(de_train, dataset_train)
dataset_val_pval = CP2DEDataset(de_val, dataset_val)
train_dataloader = DataLoader(dataset_train_pval, 32)
val_dataloader = DataLoader(dataset_val_pval, 32)

#### MODEL ####
pval_model = nn.Sequential(
    nn.Linear(45, 250),
    nn.Tanh(),
    nn.Linear(250, 250),
    nn.Tanh(),
    nn.Linear(250, 500),
    nn.Tanh(),
    nn.Linear(500, 2500),
    nn.Tanh(),
    nn.Linear(2500, 18211)
)
pval_model.to("cuda:0")

#### TRAINING ####
lr = 0.1
epochs = 100
device = "cuda:0"

loss_fn = loss_mrrmse
optimizer = Adam(pval_model.parameters(), lr=lr, weight_decay=5e-5)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.8, patience=7)

#### BATCH PROCESSING ####
def process_batch(batch):

    (x_batch, (y_batch, mask_batch)), final_y = batch
    y_batch = y_batch.to(device)
    final_y = final_y.to(device)
    mask_batch = mask_batch.to(device)
    x_batch[0], x_batch[1] = x_batch[0].to(device), x_batch[1].to(device)
    with torch.no_grad():
        enc_cell = cell_ae.encoder(x_batch[0].float())
        enc_mol = mol_ae.encoder(x_batch[1].float())
    y_pred = pval_model(torch.concat([enc_cell, enc_mol], dim=1)) # TODO: Send to device the x in model?

    loss = loss_fn(y_pred, final_y)
    return loss

#### TENSORBOARD ####
writer = SummaryWriter("./runs/pval_model_final/cell15_mol30_reg5e3_nplus_try4")

#### RUN ####
train_many_epochs(pval_model, train_dataloader, val_dataloader, epochs, 
                  process_batch, optimizer, scheduler, writer=writer, device=device)

100%|██████████| 100/100 [01:51<00:00,  1.11s/it]
