In [1]:
import os
import sys
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable, axes_size
import matplotlib.colors as mcolors
from matplotlib.transforms import Bbox
from matplotlib.colors import to_rgba
from matplotlib.colors import LinearSegmentedColormap
import seaborn as sns
import scvi
import scanpy as sc
import anndata as an
import scanpy.external as sce
import scipy
import scipy.sparse as sp
import time
import sklearn
import torch
from scipy.sparse import csr_matrix

from importlib import reload

import ray
from ray import tune
from scvi import autotune

In [2]:
fpath = "/scratch/indikar_root/indikar1/shared_data/sc_HSC/SCANVI/full_data.h5ad"
adata = sc.read_h5ad(fpath)
adata.X = adata.layers['counts'].copy()
sc.logging.print_memory_usage()
print(adata)

Memory usage: current 6.01 GB, difference +6.01 GB
AnnData object with n_obs × n_vars = 89821 × 17397
    obs: 'n_genes', 'dataset', 'n_genes_by_counts', 'total_counts', 'obs_index', 'cell_type', 'standard_cell_type', 'cell_label', 'batch'
    layers: 'counts'


In [3]:
rdata = adata[adata.obs['batch'] == "0", :].copy()
qdata = adata[adata.obs['batch'] == "1", :].copy()

rdata

AnnData object with n_obs × n_vars = 81442 × 17397
    obs: 'n_genes', 'dataset', 'n_genes_by_counts', 'total_counts', 'obs_index', 'cell_type', 'standard_cell_type', 'cell_label', 'batch'
    layers: 'counts'

In [4]:
sc.pp.highly_variable_genes(rdata, n_top_genes=2000, flavor="seurat_v3", subset=True)
rdata

AnnData object with n_obs × n_vars = 81442 × 2000
    obs: 'n_genes', 'dataset', 'n_genes_by_counts', 'total_counts', 'obs_index', 'cell_type', 'standard_cell_type', 'cell_label', 'batch'
    var: 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm'
    uns: 'hvg'
    layers: 'counts'

In [5]:
scvi.model.SCVI.setup_anndata(
  adata, 
  batch_key='dataset', 
  layer='counts', 
  labels_key='cell_label'
)

torch.cuda.empty_cache()

scvi.model.SCVI.setup_anndata(
  adata, 
  batch_key='dataset', 
  layer='counts', 
  labels_key='cell_label'
)

torch.cuda.empty_cache()

model = scvi.model.SCVI(
  adata,
  use_layer_norm="both",
  use_batch_norm="none",
  n_latent=10,
  n_hidden=64,
  encode_covariates=True,
  dropout_rate=0.25,
  n_layers=2,
)

model






In [9]:
?scvi.train.SaveCheckpoint

[0;31mInit signature:[0m
[0mscvi[0m[0;34m.[0m[0mtrain[0m[0;34m.[0m[0mSaveCheckpoint[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mdirpath[0m[0;34m:[0m [0;34m'str | None'[0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mfilename[0m[0;34m:[0m [0;34m'str | None'[0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmonitor[0m[0;34m:[0m [0;34m'str'[0m [0;34m=[0m [0;34m'validation_loss'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mload_best_on_end[0m[0;34m:[0m [0;34m'bool'[0m [0;34m=[0m [0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0;34m**[0m[0mkwargs[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m     
``BETA`` Saves model checkpoints based on a monitored metric.

Inherits from :class:`~lightning.pytorch.callbacks.ModelCheckpoint` and modifies the default
behavior to save the full model state instead of just the state dict. This is necessary for


In [10]:
plan_kwargs = {
      "lr": 0.001,
      "n_epochs_kl_warmup": 10,
      "reduce_lr_on_plateau": True,
      "lr_patience": 8,
      "lr_factor": 0.1,
}

checkpointer = scvi.train.SaveCheckpoint(
    dirpath='./test/',
    monitor='validation_loss',
    load_best_on_end=True,
)

model.train(
    max_epochs=2,
    accelerator="gpu",
    devices="auto",
    enable_model_summary=True,
    batch_size=2500,
    load_sparse_tensor=True,
    plan_kwargs=plan_kwargs,
    early_stopping=True,
    early_stopping_patience=5,
    early_stopping_monitor='elbo_validation',
    enable_checkpointing=True,
    callbacks=[checkpointer],
)

/home/cstansbu/miniconda3/envs/scanpy/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/cstansbu/miniconda3/envs/scanpy/lib/python3.12 ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/cstansbu/miniconda3/envs/scanpy/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/cstansbu/miniconda3/envs/scanpy/lib/python3.12 ...
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [MIG-847aa0ee-3fcc-58ba-8b00-ce2921c0a71c]

  | Name            | Type                | Params | Mode 
---------------------------------

Training:   0%|          | 0/2 [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=2` reached.


[34mINFO    [0m File                                                                                                      
         [35m/home/cstansbu/git_repositories/scVI-trainer/notebooks/test/[0m[95mepoch[0m=[1;36m1[0m-[33mstep[0m=[1;36m66[0m-[33mvalidation_loss[0m=[1;36m8226[0m[1;36m.172851562[0m
         [1;36m5[0m/model.pt already downloaded                                                                             


  model = torch.load(model_path, map_location=map_location)


In [None]:
break