In [1]:
import torch
from torch import nn
from torchsummary import summary

import random

import sys
sys.path.append("../..")

data_dir = "../data"

import pandas as pd
import numpy as np
from pyarrow.parquet import ParquetFile

from tqdm import tqdm

In [3]:
class Gene2Gene(nn.Module):
    def __init__(self, gene_num, cell_num, d_embed_gene, d_embed_cell):
        super().__init__()

        self.gene_num = gene_num
        self.cell_num = cell_num
        self.d_gene = d_embed_gene
        self.d_cell = d_embed_cell
        self.cell_emb = nn.Embedding(cell_num, d_embed_cell)
        self.gene_emb = nn.Embedding(gene_num, d_embed_gene)
        
        self.layer_1 = nn.Linear(gene_num*2+cell_num+3, 20)
        self.layer_2 = nn.Linear(20, 1)
    
    def __init__(self, gene1_idx, gene2_idx, cell_idx, x):
        
        gene1_embed = self.gene_emb(gene1_idx)
        gene2_embed = self.gene_emb(gene2_idx)
        cell_embed = self.cell_emb(cell_idx)
        x = torch.concat([gene1_embed, gene2_embed, cell_embed, x], dim=1)

        x = self.layer_2(nn.ReLU(self.layer_1(x)))
        return x

class MultiomeDataset(torch.utils.data.Dataset):
    def __init__(self, multiome_pf, multiome_var, multiome_obs):
        super().__init__()

        self.raw_pf = multiome_pf
        self.var_df = multiome_var
        self.gene_names = self.var_df[self.var_df["feature_type"]=="Gene Expression"]["location"].unique()
        self.obs_df = multiome_obs
    
    def load(self, cell_names="all", batch_num=10):

        if cell_names == "all":
            cell_names = self.obs_df["cell_type"].unique()

        multiome_obs_meta = self.obs_df
        wanted_obs_ids = multiome_obs_meta[
            multiome_obs_meta["cell_type"].isin(cell_names)
        ]["obs_id"].to_list()

        self.multiome_df = pd.DataFrame()

        multiome_file = self.raw_pf
        batch_group_size = multiome_file.metadata.num_rows // batch_num
        for batch in tqdm(multiome_file.iter_batches(batch_size=batch_group_size), total=batch_num+1):
            multiome_batch = batch.to_pandas()
            multiome_batch = multiome_batch[multiome_batch["obs_id"].isin(wanted_obs_ids)]
            multiome_batch = multiome_batch[multiome_batch["location"].isin(self.gene_names)]
            self.multiome_df = pd.concat([self.multiome_df, multiome_batch], axis=0, ignore_index=True)
        
        self.multiome_df.reset_index(inplace=True)
    
    def combine(self):

        self.combined_idx = np.arange(len(self.gene_names)**2)
        np.random.shuffle(self.combined_idx)

    def choose(self):

        chosen = np.random.choice(self.gene_names, 2)
        return chosen[0], chosen[1]
    
    def __len__(self):
        return self.combined_idx.shape[0]
    
    def __getitem__(self, idx):
        out_idx = self.combined_idx[idx]
        if isinstance(out_idx, np.array):
            for i in range(out_idx.shape[0]):
                idx = out_idx[i]
                idx1 = idx // len(self.gene_names)
                idx2 = idx % len(self.gene_names)

        elif isinstance(idx, int):
            idx1 = out_idx // len(self.gene_names)
            idx2 = out_idx % len(self.gene_names)
            gene1 = self.gene_names[idx1]
            gene2 = self.gene_names[idx2]
            obs1 = self.multiome_df[self.multiome_df["location"] == gene1].set_index("obs_id")
            obs2 = self.multiome_df[self.multiome_df["location"] == gene2].set_index("obs_id")
            obs_common = list(set(obs1["obs_id"]).intersection(set(obs2["obs_id"])))
            for j, obs in enumerate(obs_common):
                exp1 = obs1.at[obs, gene1]
                exp

In [4]:
multiome_obs_meta = pd.read_csv("../../data/multiome_obs_meta.csv")
multiome_var_meta = pd.read_csv("../../data/multiome_var_meta.csv")
multiome_file = ParquetFile("../../data/multiome_train.parquet")

data = MultiomeDataset(multiome_file, multiome_var_meta, multiome_obs_meta)
data.load()

100%|██████████| 11/11 [02:03<00:00, 11.26s/it]


In [9]:
l = [0, 1, 2, 3]
set(l).intersection(set(l))

{0, 1, 2, 3}

In [None]:
import pandas as pd

from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.tensorboard import SummaryWriter

from dataset import *
from features import *
from train import *
from model import *

array([0, 1])

In [None]:
os.chdir("..")

#### LOADERS, DATA ####
de_dataset_train.configure(cell_out_feature="one_hot", sm_out_feature="morgan2_fp")
de_dataset_val.configure(cell_out_feature="one_hot", sm_out_feature="morgan2_fp")
train_dataloader = DataLoader(de_dataset_train, 32)
val_dataloader = DataLoader(de_dataset_val, 32)

#### MODEL ####
model = BaselineModel(cell_in=len(ctypes), mol_in=2048, out_size=len(de_df.columns)-5)

#### TRAINING ####
lr = 0.02
epochs = 500
device = "cuda:0"

loss_fn = loss_mrrmse
optimizer = Adam(model.parameters(), lr=lr, weight_decay=2e-5)
scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.8, patience=7)

#### TENSORBOARD ####
writer = SummaryWriter("./runs/trying_out2/2")

#### RUN ####
train_many_epochs(model, train_dataloader, val_dataloader, epochs, 
                  loss_fn, optimizer, scheduler, writer=writer, device=device)

os.chdir("./notebooks")