In [1]:
import torch
from torch import nn
import torch.functional as F
from torch.utils.data import random_split, DataLoader
import pandas as pd
import scanpy as sc
import numpy as np
import os
import sys
import matplotlib.pyplot as plt
import seaborn as sns
import lightning.pytorch as pl

from pytorch_lightning.loggers import WandbLogger
import wandb

In [18]:
class ObsDataset(pl.LightningDataModule):
    def __init__(self, file, batch_size=32, num_workers=10, train_split=0.9, toy=False):

        super().__init__()
        
        self.file = file
        self.batch_size = batch_size
        self.train_split = train_split
        self.num_workers = num_workers
        self.toy = toy
        
        self._load_data()
        
        self.n_genes = len(self.adata.var)
        self.n_cells = len(self.adata.obs)
        
    def _load_data(self):
        self.adata = sc.read_h5ad(self.file)
        if self.toy:
            sc.pp.subsample(self.adata, n_obs=1000)
            self.adata = self.adata[:, :100]
        
    def prepare_data(self, normalize=False):
        
        max_val = np.apply_along_axis(lambda x: np.percentile(x, 0.99), 0, self.adata.X)
        for i in range(self.n_genes):
            self.adata.X[self.adata.X[:, i] > max_val[i], i] = max_val[i]
            
        if normalize:
            pass
        
    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            train_size = int(self.n_cells * self.train_split)
            val_size = self.n_cells - int(self.n_cells * self.train_split)
            self.train, self.val = random_split(self.adata.X, [train_size, val_size])

    def train_dataloader(self):
        return DataLoader(self.train,
                          batch_size=self.batch_size,
                          num_workers=self.num_workers,
                          shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val,
                          batch_size=self.batch_size,
                          num_workers=self.num_workers,
                          shuffle=False)


In [19]:
class Model(pl.LightningModule):
    def __init__(self, n_genes, n_genes_in_minibatch, embed_size=10, lr=0.001, perc_hidden=0.25):
        super().__init__()
        self.n_genes = n_genes
        self.n_genes_in_minibatch = n_genes_in_minibatch
        self.embed_size = embed_size
        self.lr = lr
        self.perc_hidden = perc_hidden
        
        self.act = nn.ReLU()
        self.embedding = nn.Embedding(n_genes, embed_size)
        self.attn = nn.MultiheadAttention(self.embed_size + 1, 1, batch_first=True)
        self.ll1 = nn.Linear(self.embed_size + 1, 11)
        self.bn_attn1 = nn.BatchNorm1d(self.embed_size + 1)
        self.bn1 = nn.BatchNorm1d(self.embed_size + 1)
        self.bn_attn2 = nn.BatchNorm1d(self.embed_size + 1)
        self.ll2 = nn.Linear(self.embed_size + 1, self.embed_size + 1)
        self.bn2 = nn.BatchNorm1d(self.embed_size + 1)
        self.ll3 = nn.Linear((self.embed_size + 1) * 3, 1)
        
        self.model = nn.ModuleList(
                             [self.embedding, self.attn, self.bn_attn1,
                             self.ll1, self.bn1,
                             self.bn_attn2, self.ll2, self.bn2,
                             self.ll3]
                    )
        
        self.loss_fn = nn.MSELoss()
        
        
    def forward(self, x, gene_idxs):

        x_emb = self.embedding(gene_idxs)
        x_emb = x_emb[None, :, :].repeat((x.shape[0], 1, 1))
        x = x[:, :, None]
        x = torch.concat((x, x_emb), 2)

        attn_output1, attn_output_weights1 = self.attn(x, x, x)
        attn_output1 = self.act(attn_output1)
        attn_output1 = attn_output1.permute(0, 2, 1)
        attn_output1 = self.bn_attn1(attn_output1)
        attn_output1 = attn_output1.permute(0, 2, 1)

        attn_output1 = self.ll1(attn_output1)
        attn_output1 = self.act(attn_output1)
        attn_output1 = attn_output1.permute(0, 2, 1)
        attn_output1 = self.bn1(attn_output1)
        attn_output1 = attn_output1.permute(0, 2, 1)

        attn_output2 = torch.einsum('ijk, ijl -> ijk', attn_output1, attn_output_weights1)
        attn_output2 = self.act(attn_output2)
        attn_output2 = attn_output2.permute(0, 2, 1)
        attn_output2 = self.bn_attn2(attn_output2)
        attn_output2 = attn_output2.permute(0, 2, 1)

        attn_output2 = self.ll2(attn_output2)
        attn_output2 = self.act(attn_output2)
        attn_output2 = attn_output2.permute(0, 2, 1)
        attn_output2 = self.bn2(attn_output2)
        attn_output2 = attn_output2.permute(0, 2, 1)
        
        combined = torch.concat((x, attn_output1, attn_output2), axis=2)

        res = self.ll3(combined)
        res = res.squeeze(2)
        
        return res
        
    def training_step(self, batch, batch_idx):
        loss, scores, y = self._common_step(batch, batch_idx)
        self.log('train_loss', loss)
        return loss
        
    def validation_step(self, batch, batch_idx):
        loss, scores, y = self._common_step(batch, batch_idx)
        self.log('val_loss', loss)
        return loss
        
    def test_step(self, batch, batch_idx):
        loss, scores, y = self._common_step(batch, batch_idx)
        self.log('test_loss', loss)
        return loss
    
    def _common_step(self, batch, batch_idx):
        x = batch 
        gene_mask = self._generate_random_mask(self.n_genes, self.n_genes_in_minibatch)
        gene_idxs = torch.where(gene_mask)[0]
        x = x[:, gene_mask]

        
        hidden_mask = self._generate_random_mask(self.n_genes_in_minibatch, 
                                    int(self.n_genes_in_minibatch * self.perc_hidden))
        x_corrupted = torch.clone(x)
        x_corrupted[:, hidden_mask] = 0.
        
        preds = self.forward(x_corrupted, gene_idxs)
        loss = self.loss_fn(preds, x)
        return loss, preds, x
    
    def predict_step(self, batch, batch_idx):
        x, y = batch 
        scores = self.forward(x)
        preds = torch.argmax(scores, dim=1)
        return preds
    
    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.lr)
    
    def _generate_random_mask(self, total_size, sample_size):
        mask = torch.full((total_size,), False, dtype=bool)
        mask[:sample_size] = True
        mask = mask[torch.randperm(total_size)].to(self.device)
        return mask
    

In [4]:
wandb_logger = WandbLogger(name='Adam-32-0.001', project='causal_dev')

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mbakulin[0m ([33mml-enthusiast[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [22]:
dataset = '/home/artemy/causal_proj/data/competition_data/datasets/k562.h5ad'
data = ObsDataset(dataset, toy=True)

In [23]:
N_SAMPLE_GENES = 100
LR = 0.001
PERC_HIDDEN = 0.25

model = Model(data.n_genes, N_SAMPLE_GENES, lr=LR, perc_hidden=PERC_HIDDEN)

In [24]:
ACCELERATOR = 'gpu'
DEVICES = [0]
NUM_EPOCHS = 200

torch.set_float32_matmul_precision('medium')


trainer = pl.Trainer(
        logger=wandb_logger,
        accelerator=ACCELERATOR,
        devices=DEVICES,
        min_epochs=1,
        max_epochs=NUM_EPOCHS,
        log_every_n_steps=29
    )

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [25]:
trainer.fit(model, data)

  self.adata.X[self.adata.X[:, i] > max_val[i], i] = max_val[i]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name      | Type               | Params
--------------------------------------------------
0  | act       | ReLU               | 0     
1  | embedding | Embedding          | 1.0 K 
2  | attn      | MultiheadAttention | 528   
3  | ll1       | Linear             | 132   
4  | bn_attn1  | BatchNorm1d        | 22    
5  | bn1       | BatchNorm1d        | 22    
6  | bn_attn2  | BatchNorm1d        | 22    
7  | ll2       | Linear             | 132   
8  | bn2       | BatchNorm1d        | 22    
9  | ll3       | Linear             | 34    
10 | model     | ModuleList         | 1.9 K 
11 | loss_fn   | MSELoss            | 0     
--------------------------------------------------
1.9 K     Trainable params
0         Non-trainable params
1.9 K     Total params
0.008     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=200` reached.
