In [1]:
import os
import sys
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable, axes_size
import matplotlib.colors as mcolors
from matplotlib.transforms import Bbox
from matplotlib.colors import to_rgba
from matplotlib.colors import LinearSegmentedColormap
import seaborn as sns
import scvi
import scanpy as sc
import anndata as an
import scanpy.external as sce
import scipy
import scipy.sparse as sp
import time
import sklearn
import torch
from scipy.sparse import csr_matrix

from importlib import reload

# local imports
import utils as ut
import plotting as plt2

sc.settings.verbosity = 3 

In [2]:
print(torch.version.cuda) 

12.0


In [3]:
num_processors = os.cpu_count()
print(f"Number of processors: {num_processors}") 

Number of processors: 40


In [4]:
# Check CUDA availability
cuda_available = torch.cuda.is_available()
print(f"CUDA available: {cuda_available}")

if cuda_available:
    num_gpus = torch.cuda.device_count()
    print(f"Number of GPUs: {num_gpus}")

    for i in range(num_gpus):
        gpu_name = torch.cuda.get_device_name(i)
        print(f"GPU {i}: {gpu_name}")

        # Additional information (compute capability, memory)
        gpu_props = torch.cuda.get_device_properties(i)
        print(f"  Compute Capability: {gpu_props.major}.{gpu_props.minor}")
        print(f"  Total Memory: {gpu_props.total_memory / 1024**3:.2f} GB")
else:
    print("CUDA not available. Running on CPU.")

CUDA available: True
Number of GPUs: 1
GPU 0: Tesla V100-PCIE-16GB
  Compute Capability: 7.0
  Total Memory: 15.77 GB


In [5]:
# Check JAX
import jax
print(jax.devices())

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


[CpuDevice(id=0)]


# Load data

In [6]:
def load_annotations(fpath):
    """
    Loads annotations data from a CSV file, processes it, and returns the resulting DataFrame.

    Args:
        fpath (str): The file path to the CSV file containing annotations data.

    Returns:
        pd.DataFrame: The processed DataFrame with 'cell_id' as the index.
    """
    df = pd.read_csv(fpath)
    df['cell_id'] = df['obs_index'].astype(str) + "_" + df["dataset"]
    df = df.drop(columns='dataset')
    df = df.set_index('cell_id')
    return df

In [7]:
"""
DATA
"""
fpath = "/scratch/indikar_root/indikar1/shared_data/hematokytos/merged_anndata/merged_adata.h5ad"
adata = sc.read_h5ad(fpath)
sc.logging.print_memory_usage()
print(adata)

"""
ANNOTATIONS
"""
fpath = "/scratch/indikar_root/indikar1/shared_data/hematokytos/annotation/cell_types.csv"
df = load_annotations(fpath)
df = df[df.index.isin(adata.obs_names)]
print(f"{df.shape=}")

adata.obs = pd.concat([adata.obs, df], ignore_index=False, axis=1)
adata

Memory usage: current 8.75 GB, difference +8.75 GB
AnnData object with n_obs × n_vars = 171498 × 18867
    obs: 'n_genes', 'dataset', 'n_genes_by_counts', 'total_counts'
    var: 'n_cells_by_counts', 'mean_counts', 'pct_dropout_by_counts', 'total_counts', 'gene_id', 'token_id', 'gene_biotype', 'Chromosome', 'Start', 'End'
    layers: 'counts'


UnicodeDecodeError: 'utf-8' codec can't decode byte 0x89 in position 0: invalid start byte

# Annotations

In [None]:
# fpaths = [
#     "../resources/PanglaoDB_Augmented_2021.txt",
#     "../resources/Tabula_Sapiens.txt",
# ]

# marker_columns = []

# for fpath in fpaths:    
#     features = ut.load_pathway(fpath)

#     # extract columns
#     hsc_columns = [x for x in features.columns if "hemato" in x.lower()]
#     fib_columns = [x for x in features.columns if "fibroblast" in x.lower()]
#     fib_columns = [x for x in fib_columns if not "myofibr" in x.lower()]
    
#     columns = list(set(hsc_columns + fib_columns))

#     for col in columns:
#         gene_list = list(features[features[col].astype(bool)].index)
#         print(col, len(gene_list))
#         col_name = col.lower().replace("-", " ")
#         col_name = col_name.replace(" ", "_") + "_marker"
#         marker_columns.append(col_name)
#         adata.var[col_name] = adata.var.index.isin(gene_list)
    
# adata.var['is_marker'] = adata.var[marker_columns].any(axis=1)
# print()
# adata

# Randomly subsample the tabula sapiens fibroblast data

In [None]:
# sample_size = 1e4 # the numbner of fibroblast signatures to keep

# mask = (adata.obs['dataset'] == 'tabula_sapiens') & (adata.obs['standard_cell_type'] == 'Fib')

# print(f"Filtering for Fibroblasts from tabula_sapiens dataset...")
# fibdf = adata.obs[mask].copy()
# print(f"Found {len(fibdf)} fibroblasts.")

# print(f"Sampling {sample_size} fibroblasts...")
# keep = fibdf.sample(int(sample_size))
# keep = keep.index
# print("Sampling complete.")

# fib_to_drop = fibdf[~fibdf.index.isin(keep)].index
# all_keep = ~adata.obs.index.isin(fib_to_drop)

# # Print the lengths
# print(f"Total fibroblasts: {len(adata.obs.index)}")
# print(f"Fibroblasts to keep: {len(keep)}")
# print(f"Fibroblasts to drop: {len(fib_to_drop)}") 

# adata = adata[all_keep, :].copy()  
# adata

# Cell type filtering

In [None]:
cell_types = [
    # 'PreBNK', - only preent in single batch
    # 'LinNeg', - only preent in single batch
    'HSC',
    'CMP',
    'MEP',
    'MPP',
    'GMP',
    # 'MLP', - only preent in single batch
    'EryP',
    'MDP',
    'MKP',
    # 'Mono',
    'CLP',
    # 'T_cell',
    # 'B_cell',
    # 'NK',
    'LMPP',
    # 'Dendritic_cell',
    'Fib',
    'iHSC',
]

adata = adata[adata.obs['standard_cell_type'].notna(), :].copy()
adata = adata[adata.obs['standard_cell_type'].isin(cell_types), :].copy()

adata.obs['standard_cell_type'].value_counts()

# Preprocessing

In [None]:
sc.pp.filter_cells(adata, min_genes=1000)
sc.pp.filter_genes(adata, min_cells=250)

adata

In [None]:
# Normalizing to median total counts
target_sum = 1e4
sc.pp.normalize_total(adata, target_sum=target_sum)
sc.pp.log1p(adata)
adata.layers["log_norm"] = csr_matrix(adata.X.astype('float32').copy())

adata

# Make reference and query data

In [None]:
reference_data = [
    'tabula_sapiens',
    'weng_young2_all',
    'sc_fib',
    'weng_young1_all_t2',
    'weng_young1_all_t1',
]

query_data = [
    'iHSC',
]

sample_size = None

# define data sets
rdata = adata[adata.obs['dataset'].isin(reference_data), :].copy()
if not sample_size is None:
    sc.pp.subsample(rdata, n_obs=sample_size)

print(rdata)
print()
qdata = adata[adata.obs['dataset'].isin(query_data), :].copy()
print(qdata)

In [None]:
rdata.obs['cell_label'] = rdata.obs['standard_cell_type'].apply(lambda x: str(x).strip())
rdata.obs['cell_label'].value_counts()

In [None]:
qdata.obs['cell_label'] = 'Unknown'
qdata.obs['cell_label'].value_counts()

# Store the Data

In [None]:
output_path = "/scratch/indikar_root/indikar1/shared_data/sc_HSC/SCANVI/full_data.h5ad"
fdata = an.concat([rdata, qdata], label="batch")
fdata.write(fpath)
fdata

# Feature Selection

In [None]:
n_genes = 5000

rdata.raw = rdata
sc.pp.highly_variable_genes(
    rdata, 
    n_top_genes=n_genes,
    batch_key="dataset", 
    layer='log_norm',
)

# actually subset the data
rdata = rdata[:, rdata.var['highly_variable']].copy()

rdata

In [None]:
rdata.X = rdata.X.toarray()
rdata

# Reference Mapping (scANVI)

In [None]:
scvi.model.SCVI.setup_anndata(
    rdata, 
    batch_key="dataset",
    layer="counts",
    labels_key='cell_label',
)

In [None]:
torch.cuda.empty_cache()

epochs = 400

model = scvi.model.SCVI(
    rdata,
    use_layer_norm="both",
    use_batch_norm="none",
    n_latent=12,
    encode_covariates=True,
    dropout_rate=0.3,
    n_layers=2,
)

start_time = time.time()  # Record the start time

model.train(
    max_epochs=epochs,
    accelerator='gpu',
    devices='auto',
    enable_model_summary=True,
    early_stopping=True,
    batch_size=5000,
)

end_time = time.time()  # Record the end time
total_time = end_time - start_time  # Calculate total execution time

print(f"Training completed in {total_time:.2f} seconds")

# Optional: More detailed timing information
minutes = int(total_time // 60)
seconds = int(total_time % 60)
print(f"Training time: {minutes} minutes {seconds} seconds")

In [None]:
metrics = pd.concat(
    model.history.values(), 
    ignore_index=False,
    axis=1
)

metrics = metrics.reset_index()

plt.rcParams['figure.dpi'] = 200
plt.rcParams['figure.figsize'] = 3.5, 3.5

sns.lineplot(
    data=metrics,
    x='epoch',
    y='train_loss_epoch',
    linewidth=1.5,                        
    markersize=5
)

plt.xlabel('Epoch')
plt.ylabel('Training Loss')
plt.xticks()
plt.yticks()
plt.grid(True, linestyle='--', alpha=0.5)  # Add a subtle grid

sns.despine()

plt.tight_layout()  # Adjust layout to prevent clipping
plt.show()

# SCVI Latent Space

In [None]:
SCVI_LATENT_KEY = "X_scVI"
rdata.obsm[SCVI_LATENT_KEY] = model.get_latent_representation()

sc.pp.neighbors(
    rdata, 
    use_rep=SCVI_LATENT_KEY,
)

sc.tl.leiden(
    rdata, 
    resolution=0.3,
    key_added='scvi_clusters',
)

sc.tl.umap(
    rdata,
    min_dist=0.25,
    method='rapids',
)

plt.rcParams['figure.dpi'] = 200
plt.rcParams['figure.figsize'] = 5, 5

sc.pl.umap(
    rdata, 
    color=[
        'dataset',
        'standard_cell_type',
        'scvi_clusters',
    ],
    ncols=1,
)

In [None]:
# break

# SCANVI Model

In [None]:
torch.cuda.empty_cache()

# train the scanvi model
scanvi_model = scvi.model.SCANVI.from_scvi_model(
    model, 
    unlabeled_category="Unknown",
)

epochs = 400

start_time = time.time()  # Record the start time

scanvi_model.train(
    max_epochs=epochs,
    accelerator='gpu',
    devices='auto',
    enable_model_summary=True,
    early_stopping=True,
    batch_size=5000,
)

end_time = time.time()  # Record the end time
total_time = end_time - start_time  # Calculate total execution time

print(f"Training completed in {total_time:.2f} seconds")

# Optional: More detailed timing information
minutes = int(total_time // 60)
seconds = int(total_time % 60)
print(f"Training time: {minutes} minutes {seconds} seconds")

In [None]:
metrics = pd.concat(
    scanvi_model.history.values(), 
    ignore_index=False,
    axis=1
)
metrics = metrics.reset_index()

plt.rcParams['figure.dpi'] = 200
plt.rcParams['figure.figsize'] = 3.5, 3.5

sns.lineplot(
    data=metrics,
    x='epoch',
    y='train_loss_epoch',
    linewidth=1.5,                       
    markersize=5
)

plt.xlabel('Epoch')
plt.ylabel('Training Loss')
plt.xticks()
plt.yticks()
plt.grid(True, linestyle='--', alpha=0.5)  # Add a subtle grid

sns.despine()

plt.tight_layout()  # Adjust layout to prevent clipping
plt.show()

# Latent Representations (SCANVI)

In [None]:
SCANVI_LATENT_KEY = "X_scANVI"

rdata.obsm[SCANVI_LATENT_KEY] = scanvi_model.get_latent_representation()

sc.pp.neighbors(
    rdata, 
    use_rep=SCANVI_LATENT_KEY,
)

sc.tl.leiden(
    rdata, 
    resolution=0.3,
    key_added='scanvi_clusters',
)

sc.tl.umap(
    rdata,
    min_dist=0.25,
    method='rapids',
)

plt.rcParams['figure.dpi'] = 200
plt.rcParams['figure.figsize'] = 5, 5

sc.pl.umap(
    rdata, 
    color=[
        'dataset',
        'standard_cell_type',
        'scanvi_clusters',
    ],
    ncols=1,
)

In [None]:
rdata.layers['SCANVI_counts'] = scanvi_model.get_normalized_expression(return_mean=False)
rdata

# Save the model (SCANVI only)

In [None]:
fpath = "/scratch/indikar_root/indikar1/shared_data/sc_HSC/SCANVI/model/"
scanvi_model.save(
    fpath, 
    overwrite=True, 
    save_anndata=True,
    prefix='reference_'
) 
print('done')

# Benchmarking

In [None]:
from scib_metrics.benchmark import Benchmarker

torch.cuda.empty_cache()

bm = Benchmarker(
    rdata,
    batch_key="dataset",
    label_key="cell_label",
    embedding_obsm_keys=['X_pca', SCVI_LATENT_KEY, SCANVI_LATENT_KEY],
)

bm.benchmark()

bm.plot_results_table(min_max_scale=False)

In [None]:
fpath = "/scratch/indikar_root/indikar1/shared_data/sc_HSC/SCANVI/benchmarks.csv"
bmdf = bm.get_results(min_max_scale=False)
bmdf = bmdf.reset_index(drop=False,)
bmdf.to_csv(fpath, index=False,)
bmdf.head()

# Query mapping (SCANVI)

In [None]:
scvi.model.SCANVI.prepare_query_anndata(
    qdata, 
    scanvi_model,
)

scanvi_query = scvi.model.SCANVI.load_query_data(
    qdata, 
    scanvi_model,
)

print('Done!')

torch.cuda.empty_cache()

epochs = 15

SCANVI_PREDICTIONS_KEY = "predictions_scanvi"

scanvi_query.train(
    max_epochs=epochs, 
    plan_kwargs={"weight_decay": 0.01},
)

qdata.obsm[SCANVI_LATENT_KEY] = scanvi_query.get_latent_representation()
qdata.layers['SCANVI_counts'] = scanvi_query.get_normalized_expression(return_mean=False)
qdata.obs[SCANVI_PREDICTIONS_KEY] = scanvi_query.predict()

qdata

In [None]:
"""ADD SCVI LATENT SPACE AS WELL"""

scvi.model.SCVI.prepare_query_anndata(
    qdata, 
    model,
)

scvi_query = scvi.model.SCVI.load_query_data(
    qdata, 
    model,
)

epochs = 15

scvi_query.train(
    max_epochs=epochs, 
    plan_kwargs={"weight_decay": 0.01},
)

qdata.obsm[SCVI_LATENT_KEY] = scvi_query.get_latent_representation()
qdata

# Store Query Model (SCANVI only)

In [None]:
fpath = "/scratch/indikar_root/indikar1/shared_data/sc_HSC/SCANVI/model/"
scanvi_query.save(
    fpath, 
    overwrite=True, 
    save_anndata=True,
    prefix='query_'
) 
print('done')

# Predicted Probability

In [None]:
pred_proba = scanvi_query.predict(soft=True)
pred_proba.head()

In [None]:
sns.histplot(
    data=pred_proba,
    x='HSC',
    log_scale=True,
    bins=31,
)

In [None]:
preds = scanvi_query.predict()
pd.value_counts(preds)

# Store predictions

In [None]:
fpath = "/scratch/indikar_root/indikar1/shared_data/sc_HSC/SCANVI/ihsc_predictions.csv"
df = pred_proba.copy()
df = df.astype(float)
df['prediction'] = df.idxmax(axis=1)
df = df.reset_index(drop=False, names='cell_id')

df.to_csv(fpath, index=False,)

df.head()

# Integrate

In [None]:
fdata = an.concat([rdata, qdata], label="batch")
fdata

In [None]:
sc.pp.neighbors(
    fdata, 
    use_rep=SCANVI_LATENT_KEY,
)

sc.tl.umap(
    fdata,
    min_dist=0.25,
    method='rapids',
)

plt.rcParams['figure.dpi'] = 200
plt.rcParams['figure.figsize'] = 5, 5

sc.pl.umap(
    fdata, 
    color=[
        'dataset',
        'standard_cell_type',
        'batch',
    ],
    ncols=1,
)

# Store ADATA

In [None]:
fpath = "/scratch/indikar_root/indikar1/shared_data/sc_HSC/SCANVI/imputed_data.h5ad"
fdata.write(fpath)
fdata