In [1]:
import os
import sys
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable, axes_size
import matplotlib.colors as mcolors
from matplotlib.transforms import Bbox
from matplotlib.colors import to_rgba
from matplotlib.colors import LinearSegmentedColormap
import seaborn as sns
import scvi
import scanpy as sc
import anndata as an
import scanpy.external as sce
import scipy
import scipy.sparse as sp
import time
import sklearn
import torch
from scipy.sparse import csr_matrix

from importlib import reload

# local imports
import utils as ut
import plotting as plt2

sc.settings.verbosity = 3 
torch.set_float32_matmul_precision("high")

In [2]:
print(torch.version.cuda) 

12.0


In [3]:
num_processors = os.cpu_count()
print(f"Number of processors: {num_processors}") 

Number of processors: 64


In [4]:
# Check CUDA availability
cuda_available = torch.cuda.is_available()
print(f"CUDA available: {cuda_available}")

if cuda_available:
    num_gpus = torch.cuda.device_count()
    print(f"Number of GPUs: {num_gpus}")

    for i in range(num_gpus):
        gpu_name = torch.cuda.get_device_name(i)
        print(f"GPU {i}: {gpu_name}")

        # Additional information (compute capability, memory)
        gpu_props = torch.cuda.get_device_properties(i)
        print(f"  Compute Capability: {gpu_props.major}.{gpu_props.minor}")
        print(f"  Total Memory: {gpu_props.total_memory / 1024**3:.2f} GB")
else:
    print("CUDA not available. Running on CPU.")

CUDA available: True
Number of GPUs: 1
GPU 0: NVIDIA A100 80GB PCIe MIG 3g.40gb
  Compute Capability: 8.0
  Total Memory: 39.25 GB


In [5]:
# Check JAX
import jax
print(jax.devices())

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


[CpuDevice(id=0)]


# Load data

In [6]:
"""
DATA
"""
fpath = "/scratch/indikar_root/indikar1/shared_data/sc_HSC/SCANVI/raw_data.h5ad"
adata = sc.read_h5ad(fpath)
adata.X = adata.layers['counts'].copy()
sc.logging.print_memory_usage()
print(adata)

Memory usage: current 8.75 GB, difference +8.75 GB
AnnData object with n_obs × n_vars = 171498 × 18867
    obs: 'n_genes', 'dataset', 'n_genes_by_counts', 'total_counts', 'obs_index', 'cell_type', 'standard_cell_type'
    var: 'n_cells_by_counts', 'mean_counts', 'pct_dropout_by_counts', 'total_counts', 'gene_id', 'token_id', 'gene_biotype', 'Chromosome', 'Start', 'End'
    layers: 'counts'


In [7]:
adata.obs['standard_cell_type'].value_counts()

standard_cell_type
Fib               34289
HSC               20789
T_cell            19303
LinNeg            13396
B_cell            10734
EryP               9902
MPP                9391
iHSC               8379
MEP                7433
Mono               6988
GMP                4882
NK                 4189
MDP                3989
MKP                3980
CLP                3640
CMP                2804
Dendritic_cell     2694
LMPP                805
PreBNK              554
MLP                 123
Name: count, dtype: int64

In [8]:
adata.obs['standard_cell_type'].unique()

['PreBNK', 'LinNeg', 'HSC', 'CMP', 'MEP', ..., NaN, 'LMPP', 'Dendritic_cell', 'Fib', 'iHSC']
Length: 21
Categories (20, object): ['B_cell', 'CLP', 'CMP', 'Dendritic_cell', ..., 'NK', 'PreBNK', 'T_cell', 'iHSC']

# Cell type filtering

In [9]:
cell_types = [
    'HSC',
    'CMP',
    'MEP',
    'MPP',
    'GMP',
    'EryP',
    'MDP',
    'MKP',
    'CLP',
    'LMPP',
    'Fib',
    'iHSC',
]

adata = adata[adata.obs['standard_cell_type'].notna(), :].copy()
adata = adata[adata.obs['standard_cell_type'].isin(cell_types), :].copy()

adata.obs['standard_cell_type'].value_counts()

standard_cell_type
Fib     34289
HSC     20789
EryP     9902
MPP      9391
iHSC     8379
MEP      7433
GMP      4882
MDP      3989
MKP      3980
CLP      3640
CMP      2804
LMPP      805
Name: count, dtype: int64

# Preprocessing

In [10]:
sc.pp.filter_cells(adata, min_genes=500)
sc.pp.filter_genes(adata, min_cells=250)

adata

filtered out 1470 genes that are detected in less than 250 cells


AnnData object with n_obs × n_vars = 110283 × 17397
    obs: 'n_genes', 'dataset', 'n_genes_by_counts', 'total_counts', 'obs_index', 'cell_type', 'standard_cell_type'
    var: 'n_cells_by_counts', 'mean_counts', 'pct_dropout_by_counts', 'total_counts', 'gene_id', 'token_id', 'gene_biotype', 'Chromosome', 'Start', 'End', 'n_cells'
    layers: 'counts'

# Make reference and query data

In [11]:
reference_data = [
    'tabula_sapiens',
    'weng_young2_all',
    'sc_fib',
    'weng_young1_all_t2',
    'weng_young1_all_t1',
]

query_data = [
    'iHSC',
]

# define data sets
rdata = adata[adata.obs['dataset'].isin(reference_data), :].copy()
print(rdata)
print()
qdata = adata[adata.obs['dataset'].isin(query_data), :].copy()
print(qdata)

AnnData object with n_obs × n_vars = 81442 × 17397
    obs: 'n_genes', 'dataset', 'n_genes_by_counts', 'total_counts', 'obs_index', 'cell_type', 'standard_cell_type'
    var: 'n_cells_by_counts', 'mean_counts', 'pct_dropout_by_counts', 'total_counts', 'gene_id', 'token_id', 'gene_biotype', 'Chromosome', 'Start', 'End', 'n_cells'
    layers: 'counts'

AnnData object with n_obs × n_vars = 8379 × 17397
    obs: 'n_genes', 'dataset', 'n_genes_by_counts', 'total_counts', 'obs_index', 'cell_type', 'standard_cell_type'
    var: 'n_cells_by_counts', 'mean_counts', 'pct_dropout_by_counts', 'total_counts', 'gene_id', 'token_id', 'gene_biotype', 'Chromosome', 'Start', 'End', 'n_cells'
    layers: 'counts'


In [12]:
rdata.obs['cell_label'] = rdata.obs['standard_cell_type'].apply(lambda x: str(x).strip())
rdata.obs['cell_label'].value_counts()

cell_label
Fib     34289
HSC     15110
MPP      7683
EryP     5889
MEP      5234
CLP      3345
MKP      3147
GMP      2656
MDP      2091
CMP      1193
LMPP      805
Name: count, dtype: int64

In [13]:
qdata.obs['cell_label'] = 'Unknown'
qdata.obs['cell_label'].value_counts()

cell_label
Unknown    8379
Name: count, dtype: int64

# Store the Data

In [14]:
fpath = "/scratch/indikar_root/indikar1/shared_data/sc_HSC/SCANVI/full_data.h5ad"
fdata = an.concat([rdata, qdata], label="batch")
fdata.write(fpath)
fdata

AnnData object with n_obs × n_vars = 89821 × 17397
    obs: 'n_genes', 'dataset', 'n_genes_by_counts', 'total_counts', 'obs_index', 'cell_type', 'standard_cell_type', 'cell_label', 'batch'
    layers: 'counts'

# Feature Selection

In [15]:
n_genes = 5000

sc.pp.highly_variable_genes(
    rdata, 
    n_top_genes=n_genes,
    flavor="seurat_v3",
    subset=True,
)


rdata

extracting highly variable genes
--> added
    'highly_variable', boolean vector (adata.var)
    'highly_variable_rank', float vector (adata.var)
    'means', float vector (adata.var)
    'variances', float vector (adata.var)
    'variances_norm', float vector (adata.var)


AnnData object with n_obs × n_vars = 81442 × 5000
    obs: 'n_genes', 'dataset', 'n_genes_by_counts', 'total_counts', 'obs_index', 'cell_type', 'standard_cell_type', 'cell_label'
    var: 'n_cells_by_counts', 'mean_counts', 'pct_dropout_by_counts', 'total_counts', 'gene_id', 'token_id', 'gene_biotype', 'Chromosome', 'Start', 'End', 'n_cells', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm'
    uns: 'hvg'
    layers: 'counts'

In [16]:
# n_genes = 3000

# rdata.raw = rdata
# sc.pp.highly_variable_genes(
#     rdata, 
#     n_top_genes=n_genes,
#     batch_key="dataset", 
# )

# # actually subset the data
# rdata = rdata[:, rdata.var['highly_variable']].copy()

# rdata

# Batch Correction (scVI)

In [17]:
scvi.model.SCVI.setup_anndata(
    rdata, 
    batch_key="dataset",
    layer="counts",
    labels_key='cell_label',
)

In [18]:
torch.cuda.empty_cache()

epochs = 400

model = scvi.model.SCVI(
    rdata,
    use_layer_norm="both",
    use_batch_norm="none",
    n_latent=24,
    encode_covariates=True,
    dropout_rate=0.3,
    n_layers=2,
)

start_time = time.time()  # Record the start time

plan_kwargs = {
    'lr': 0.001, 
    'n_epochs_kl_warmup': 10, 
    'reduce_lr_on_plateau': True,
    'lr_patience': 8, 
    'lr_factor': 0.1   
}

model.train(
    max_epochs=epochs,
    accelerator='gpu',
    devices='auto',
    enable_model_summary=True,
    early_stopping=True,
    batch_size=5000,
    load_sparse_tensor=True,
    plan_kwargs=plan_kwargs,
    early_stopping_patience=5,
)

end_time = time.time()  # Record the end time
total_time = end_time - start_time  # Calculate total execution time

print(f"Training completed in {total_time:.2f} seconds")

# Optional: More detailed timing information
minutes = int(total_time // 60)
seconds = int(total_time % 60)
print(f"Training time: {minutes} minutes {seconds} seconds")

/home/cstansbu/miniconda3/envs/scanpy/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/cstansbu/miniconda3/envs/scanpy/lib/python3.12 ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/cstansbu/miniconda3/envs/scanpy/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/cstansbu/miniconda3/envs/scanpy/lib/python3.12 ...
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [MIG-cfb2a8ae-864b-50df-94a5-98983023f29d]

  | Name            | Type                | Params | Mode 
---------------------------------

Training:   0%|          | 0/400 [00:00<?, ?it/s]

  return sparse_csr_tensor(

Detected KeyboardInterrupt, attempting graceful shutdown ...

KeyboardInterrupt



In [None]:
metrics = pd.concat(
    model.history.values(), 
    ignore_index=False,
    axis=1,
)

metrics = metrics[metrics['validation_loss'].notna()]
metrics = metrics.reset_index(drop=False, names='epoch')

metrics = pd.melt(
    metrics,
    id_vars='epoch',
    value_vars=['train_loss_epoch', 'validation_loss'],
)

plt.rcParams['figure.dpi'] = 200
plt.rcParams['figure.figsize'] = 4, 3

sns.lineplot(
    data=metrics,
    x='epoch',
    y='value',
    hue='variable',
    style='variable',
    linewidth=1.5,                        
    markersize=5
)

plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.xticks()
plt.yticks()
plt.grid(True, linestyle='--', alpha=0.5)  
sns.move_legend(
    plt.gca(), 
    title="",
    loc='best',
)

sns.despine()

# SCVI Latent Space

In [None]:
SCVI_LATENT_KEY = "X_scVI"
rdata.obsm[SCVI_LATENT_KEY] = model.get_latent_representation()

sc.pp.neighbors(
    rdata, 
    use_rep=SCVI_LATENT_KEY,
)

sc.tl.leiden(
    rdata, 
    resolution=0.3,
    key_added='scvi_clusters',
)

sc.tl.umap(
    rdata,
    min_dist=0.25,
    method='rapids',
)

plt.rcParams['figure.dpi'] = 200
plt.rcParams['figure.figsize'] = 5, 5

sc.pl.umap(
    rdata, 
    color=[
        'dataset',
        'standard_cell_type',
        'scvi_clusters',
    ],
    ncols=1,
)

# SCANVI Model

In [None]:
torch.cuda.empty_cache()

# train the scanvi model
scanvi_model = scvi.model.SCANVI.from_scvi_model(
    model, 
    unlabeled_category="Unknown",
)

start_time = time.time()  # Record the start time

scanvi_model.train(
    max_epochs=epochs,
    accelerator='gpu',
    devices='auto',
    enable_model_summary=True,
    early_stopping=True,
    batch_size=5000,
    plan_kwargs=plan_kwargs,
    early_stopping_patience=5,
)

end_time = time.time()  # Record the end time
total_time = end_time - start_time  # Calculate total execution time

print(f"Training completed in {total_time:.2f} seconds")

# Optional: More detailed timing information
minutes = int(total_time // 60)
seconds = int(total_time % 60)
print(f"Training time: {minutes} minutes {seconds} seconds")

In [None]:
metrics = pd.concat(
    scanvi_model.history.values(), 
    ignore_index=False,
    axis=1,
)

metrics = metrics[metrics['validation_loss'].notna()]
metrics = metrics.reset_index(drop=False, names='epoch')

metrics = pd.melt(
    metrics,
    id_vars='epoch',
    value_vars=['train_loss_epoch', 'validation_loss'],
)

plt.rcParams['figure.dpi'] = 200
plt.rcParams['figure.figsize'] = 4, 3

sns.lineplot(
    data=metrics,
    x='epoch',
    y='value',
    hue='variable',
    style='variable',
    linewidth=1.5,                      
    markersize=5
)

plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.xticks()
plt.yticks()
plt.grid(True, linestyle='--', alpha=0.5)  
sns.move_legend(
    plt.gca(), 
    title="",
    loc='best',
)

sns.despine()

# Latent Representations (SCANVI)

In [None]:
SCANVI_LATENT_KEY = "X_scANVI"

rdata.obsm[SCANVI_LATENT_KEY] = scanvi_model.get_latent_representation()

sc.pp.neighbors(
    rdata, 
    use_rep=SCANVI_LATENT_KEY,
)

sc.tl.leiden(
    rdata, 
    resolution=0.3,
    key_added='scanvi_clusters',
)

sc.tl.umap(
    rdata,
    min_dist=0.25,
    method='rapids',
)

plt.rcParams['figure.dpi'] = 200
plt.rcParams['figure.figsize'] = 5, 5

sc.pl.umap(
    rdata, 
    color=[
        'dataset',
        'standard_cell_type',
        'scanvi_clusters',
    ],
    ncols=1,
)

In [None]:
rdata.layers['SCANVI_counts'] = scanvi_model.get_normalized_expression(return_mean=False)
rdata

# Differential Expression

In [None]:
torch.cuda.empty_cache()
outpath = f"/scratch/indikar_root/indikar1/shared_data/sc_HSC/SCANVI/scANVI_deg_{n_genes}.csv"
deg = scanvi_model.differential_expression(
    rdata,
    groupby='cell_label',
    batch_correction=True,
    filter_outlier_cells=True,
)

print(f"{deg.shape=}")
deg = deg.reset_index()
deg.to_csv(outpath, index=False,)
deg.head()

# Save the model (SCANVI only)

In [None]:
fpath = "/scratch/indikar_root/indikar1/shared_data/sc_HSC/SCANVI/model/"
scanvi_model.save(
    fpath, 
    overwrite=True, 
    save_anndata=True,
    prefix='reference_'
) 
print('done')

# Benchmarking

In [None]:
from scib_metrics.benchmark import Benchmarker

torch.cuda.empty_cache()

bm = Benchmarker(
    rdata,
    batch_key="dataset",
    label_key="cell_label",
    embedding_obsm_keys=['X_pca', SCVI_LATENT_KEY, SCANVI_LATENT_KEY],
)

bm.benchmark()

bm.plot_results_table(min_max_scale=False)

In [None]:
fpath = f"/scratch/indikar_root/indikar1/shared_data/sc_HSC/SCANVI/benchmarks_{n_genes}.csv"
bmdf = bm.get_results(min_max_scale=False)
bmdf = bmdf.reset_index(drop=False,)
bmdf.to_csv(fpath, index=False,)
bmdf.head()

# Query mapping (SCANVI)

In [None]:
scvi.model.SCANVI.prepare_query_anndata(
    qdata, 
    scanvi_model,
)

scanvi_query = scvi.model.SCANVI.load_query_data(
    qdata, 
    scanvi_model,
)

print('Done!')

torch.cuda.empty_cache()

epochs = 15

SCANVI_PREDICTIONS_KEY = "predictions_scanvi"

scanvi_query.train(
    max_epochs=epochs, 
    plan_kwargs={"weight_decay": 0.01},
    
)

qdata.obsm[SCANVI_LATENT_KEY] = scanvi_query.get_latent_representation()
qdata.layers['SCANVI_counts'] = scanvi_query.get_normalized_expression(return_mean=False)
qdata.obs[SCANVI_PREDICTIONS_KEY] = scanvi_query.predict()

qdata

In [None]:
"""ADD SCVI LATENT SPACE AS WELL"""

scvi.model.SCVI.prepare_query_anndata(
    qdata, 
    model,
)

scvi_query = scvi.model.SCVI.load_query_data(
    qdata, 
    model,
)

epochs = 100

scvi_query.train(
    max_epochs=epochs, 
    plan_kwargs=plan_kwargs,
    early_stopping_patience=5,
)

qdata.obsm[SCVI_LATENT_KEY] = scvi_query.get_latent_representation()
qdata

# Store Query Model (SCANVI only)

In [None]:
fpath = "/scratch/indikar_root/indikar1/shared_data/sc_HSC/SCANVI/model/"
scanvi_query.save(
    fpath, 
    overwrite=True, 
    save_anndata=True,
    prefix='query_'
) 
print('done')

# Predicted Probability

In [None]:
pred_proba = scanvi_query.predict(soft=True)
pred_proba.head()

In [None]:
sns.histplot(
    data=pred_proba,
    x='HSC',
    log_scale=True,
    bins=31,
)

In [None]:
preds = scanvi_query.predict()
pd.value_counts(preds)

# Store predictions

In [None]:
fpath = f"/scratch/indikar_root/indikar1/shared_data/sc_HSC/SCANVI/ihsc_predictions_{n_genes}.csv"
df = pred_proba.copy()
df = df.astype(float)
df['prediction'] = df.idxmax(axis=1)
df = df.reset_index(drop=False, names='cell_id')

df.to_csv(fpath, index=False,)

df.head()

# Integrate

In [None]:
fdata = an.concat([rdata, qdata], label="batch")
fdata

In [None]:
sc.pp.neighbors(
    fdata, 
    use_rep=SCANVI_LATENT_KEY,
)

sc.tl.umap(
    fdata,
    min_dist=0.25,
    method='rapids',
)

plt.rcParams['figure.dpi'] = 200
plt.rcParams['figure.figsize'] = 5, 5

sc.pl.umap(
    fdata, 
    color=[
        'dataset',
        'standard_cell_type',
        'batch',
    ],
    ncols=1,
)

# Store ADATA

In [None]:
fpath = f"/scratch/indikar_root/indikar1/shared_data/sc_HSC/SCANVI/imputed_data_{n_genes}.h5ad"
fdata.write(fpath)
fdata