In [28]:
import torch 
import numpy as np
import scanpy as sc
from torch.utils.data import random_split
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from scvi.distributions import NegativeBinomial

from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer
import torch.nn.functional as F

import sys
sys.path.insert(0, "/home/icb/alessandro.palma/environment/scportrait_ot/src")
from dataloader import EmbeddingDecoderDataset
from decoding_modules import DecoderFromHarmony
from pathlib import Path

### Initialize dataset

Torch dataset 

In [2]:
dataset = EmbeddingDecoderDataset(adata_path="/lustre/groups/ml01/workspace/alessandro.palma/scportrait/data/scrnaseq/sce_converted_processed_discovery.h5ad", 
                                     count_label="X_counts", 
                                     embedding_label="HARMONY",
                                     batch_label="donor_id")

Initialize dataloaders

In [3]:
train_data, valid_data = random_split(dataset,
                                      lengths=[0.80, 0.20])   
        
train_dataloader = torch.utils.data.DataLoader(train_data,
                                                batch_size=256,
                                                shuffle=True,
                                                num_workers=4)
        
valid_dataloader = torch.utils.data.DataLoader(valid_data,
                                                batch_size=256,
                                                shuffle=False,
                                                num_workers=4)

### Initialize model 

In [4]:
decoder_model = DecoderFromHarmony(input_dim=50, 
                                   output_dim=55,
                                   dims=[64, 64],
                                   batch_norm=False, 
                                   dropout=False,
                                   dropout_p=0.0, 
                                   batch_encoding=False, 
                                   batch_encoding_dim=None,
                                   learning_rate=1e-3
                                  )

In [5]:
decoder_model

DecoderFromHarmony(
  (decoder): MLP(
    (net): Sequential(
      (0): Linear(in_features=50, out_features=64, bias=True)
      (1): ELU(alpha=1.0)
      (2): Linear(in_features=64, out_features=64, bias=True)
      (3): ELU(alpha=1.0)
      (4): Linear(in_features=64, out_features=55, bias=True)
    )
  )
)

### Training setup

In [6]:
training_dir = "/lustre/groups/ml01/workspace/alessandro.palma/scportrait/experiements"

logger = WandbLogger(offline=False,
                     anonymous=None,
                     project="harmony_decoder",
                     log_model=False,
                     save_dir=training_dir
                    )

# Use wandb run name to create a subfolder
run_name = "harmony_decoder"
run_dir = Path("/lustre/groups/ml01/workspace/alessandro.palma/scportrait/experiements") / run_name
run_dir.mkdir(parents=True, exist_ok=True)

# Callbacks for saving checkpoints in the run-specific folder
checkpoint_callback = ModelCheckpoint(dirpath=run_dir / "checkpoints",
                                      filename="epoch_{epoch:01d}",
                                      monitor="valid/loss",
                                      mode="min",               
                                      every_n_epochs=50,
                                      save_last=True,
                                      auto_insert_metric_name=False
                                     )
callbacks = [checkpoint_callback]

# Initialize trainer with custom dir
trainer = Trainer(
    callbacks=callbacks,
    default_root_dir=run_dir,
    logger=logger,
    max_epochs=20,
    accelerator="gpu",
    devices=1,
    check_val_every_n_epoch=1,
    log_every_n_steps=1,
    detect_anomaly=True,
    deterministic=False,
    gradient_clip_val=1)

/home/icb/alessandro.palma/miniconda3/envs/sc_exp_design/lib/python3.10/site-packages/lightning_fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/alessandro.palma/miniconda3/envs/sc_exp_de ...
You have turned on `Trainer(detect_anomaly=True)`. This will significantly slow down compute speed and is recommended only for model debugging.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


## Training 

In [7]:
trainer.fit(decoder_model,
            train_dataloaders=train_dataloader,
            val_dataloaders=valid_dataloader)

You are using a CUDA device ('NVIDIA A100-PCIE-40GB MIG 3g.20gb') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mallepalma[0m ([33minverse-perturbation-models[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


/home/icb/alessandro.palma/miniconda3/envs/sc_exp_design/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /ictstr01/groups/ml01/workspace/alessandro.palma/scportrait/experiements/harmony_decoder/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [MIG-cc9abd83-6b36-5bab-9380-fc37aeddff07]

  | Name         | Type | Params | Mode 
----------------------------------------------
0 | decoder      | MLP  | 11.0 K | train
  | other params | n/a  | 55     | n/a  
----------------------------------------------
11.1 K    Trainable params
0         Non-trainable params
11.1 K    Total params
0.044     Total estimated model params size (MB)
7         Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=20` reached.


## Check reconstruction  

In [25]:
Xs = []
X_hat = []

with torch.no_grad():
    for batch in train_dataloader:
        Xs.append(batch["X"])
        size_factor = batch["X"].sum(1, keepdim=True)
        mu_hat = decoder_model.decoder(batch["X_emb"])
        mu_hat = F.softmax(mu_hat, dim=1)
        px = NegativeBinomial(mu=mu_hat * size_factor, theta=torch.exp(decoder_model.theta))
        X_hat.append(px.sample())

In [26]:
X_hat = torch.cat(X_hat, dim=0).cpu().detach().numpy()
Xs = torch.cat(Xs, dim=0).cpu().detach().numpy()

In [29]:
adata_generated = sc.AnnData(X=np.concatenate([Xs, X_hat]))

In [None]:
sc.tl.pca(adata_generated)
sc.pp.neighbors(adata_generated)
sc.tl.umap(adata_generated)

In [None]:
sc.pl.umap(adata_generated)

In [9]:
# adata = sc.read_h5ad("/lustre/groups/ml01/workspace/alessandro.palma/scportrait/data/scrnaseq/sce_converted_processed_discovery.h5ad")

In [None]:
# sc.tl.pca(adata)
# sc.pp.neighbors(adata)
# sc.tl.umap(adata)

In [None]:
# sc.pl.umap(adata, color="donor_id")

In [12]:
adata.obsm["HARMONY"].shape

(263286, 50)