In [1]:
import sys
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable, axes_size
from matplotlib.transforms import Bbox
import seaborn as sns
import scanpy as sc
import scipy
import sklearn
import torch
import scvi
from scvi.external import CellAssign
from sklearn.model_selection import train_test_split

from importlib import reload

# locals
import utils as ut
import plotting as plt2

In [None]:
import jax
print(jax.devices())

In [None]:
# Marker 

In [2]:
fpath = "../resources/PanglaoDB_Augmented_2021.txt"
pang = ut.load_pathway(fpath)  # Assuming 'ut.load_pathway' is a custom function to load the data
pang.head()

label,Acinar Cells,Adipocyte Progenitor Cells,Adipocytes,Adrenergic Neurons,Airway Epithelial Cells,Airway Goblet Cells,Airway Smooth Muscle Cells,Alpha Cells,Alveolar Macrophages,Anterior Pituitary Gland Cells,...,Transient Cells,Trichocytes,Trigeminal Neurons,Trophoblast Cells,Trophoblast Progenitor Cells,Trophoblast Stem Cells,Tuft Cells,Undefined Placental Cells,Urothelial Cells,Vascular Smooth Muscle Cells
GDF15,True,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
RARRES2,True,False,True,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
TM4SF4,True,False,False,False,False,False,False,True,False,False,...,False,False,False,False,False,False,False,False,False,False
CELA1,True,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
GCG,True,False,False,False,False,False,False,True,False,False,...,False,False,False,False,False,False,False,False,False,False


# Load Data

In [3]:
# fpath = "/scratch/indikar_root/indikar1/cstansbu/HSC/scanpy/hsc_landscape.anndata.h5ad"
fpath = "/scratch/indikar_root/indikar1/cstansbu/hematokytos/merged_anndata/merged_adata.h5ad"

adata = sc.read_h5ad(fpath)
adata.X = adata.layers['counts'].copy()
sc.logging.print_memory_usage()

lib_size = adata.X.sum(1)
adata.obs["size_factor"] = lib_size / np.mean(lib_size)

adata

Memory usage: current 8.51 GB, difference +8.51 GB


AnnData object with n_obs × n_vars = 171498 × 18867
    obs: 'n_genes', 'dataset', 'n_genes_by_counts', 'total_counts', 'size_factor'
    var: 'n_cells_by_counts', 'mean_counts', 'pct_dropout_by_counts', 'total_counts', 'gene_id', 'token_id', 'gene_biotype', 'Chromosome', 'Start', 'End'
    layers: 'counts'

In [4]:
adata.obs['dataset'].value_counts()

dataset
weng_young2_all        29289
tabula_sapiens         27313
weng_young1_all_t2     25317
pellin                 19180
weng_young1_all_t1     18499
weng_old2_BMMC_HSPC    17292
weng_old1_BMMC_HSPC    14790
iHSC                    8379
sc_fib                  7736
weng_young2_HSC         3703
Name: count, dtype: int64

# Add cell type annotations

In [5]:
fpath = "/scratch/indikar_root/indikar1/cstansbu/hematokytos/annotation/cell_types.csv"

def load_annotations(fpath):
    """
    Loads annotations data from a CSV file, processes it, and returns the resulting DataFrame.

    Args:
        fpath (str): The file path to the CSV file containing annotations data.

    Returns:
        pd.DataFrame: The processed DataFrame with 'cell_id' as the index.
    """
    df = pd.read_csv(fpath)
    df['cell_id'] = df['obs_index'].astype(str) + "_" + df["dataset"]
    df = df.drop(columns='dataset')
    df = df.set_index('cell_id')
    return df

df = load_annotations(fpath)

print(f"\nMerging annotations...")
adata.obs = pd.merge(
    adata.obs,
    df,
    how='left',
    left_index=True,
    right_index=True,
)

adata.obs.head()


Merging annotations...


Unnamed: 0,n_genes,dataset,n_genes_by_counts,total_counts,size_factor,obs_index,cell_type,standard_cell_type
PreBNK_AGTTGAAC-TTGCATAT_1_pellin,2637,pellin,2633,13711.738192,0.949161,PreBNK_AGTTGAAC-TTGCATAT_1,PreBNK,PreBNK
PreBNK_AATCCGGC-TGAAATGA_1_pellin,1144,pellin,1142,7140.276133,0.301382,PreBNK_AATCCGGC-TGAAATGA_1,PreBNK,PreBNK
PreBNK_CAAACATT-TCTGTGGT_1_pellin,877,pellin,873,5812.4095,0.19482,PreBNK_CAAACATT-TCTGTGGT_1,PreBNK,PreBNK
PreBNK_CGTGTACA-TTCCAGAC_1_pellin,2233,pellin,2232,11925.223895,0.844374,PreBNK_CGTGTACA-TTCCAGAC_1,PreBNK,PreBNK
PreBNK_CATGACGA-CTTACGGG_1_pellin,951,pellin,948,6039.884097,0.275699,PreBNK_CATGACGA-CTTACGGG_1,PreBNK,PreBNK


In [6]:
adata.obs['cell_type'].value_counts()

cell_type
FB                        26553
HSC                       11693
CD4                       11399
EryP                       9902
MPP                        9391
Refined.HSC                9096
iHSC                       8379
CD8                        7904
sc_fib                     7736
MEP                        7433
Mono                       6988
ProB                       6735
LinNegCD34PosCD164Pos      5992
GMP                        4882
LinNegCD34lowCD164high     4194
NK                         4189
MDP                        3989
MKP                        3980
CLP                        3640
B                          3329
LinNegCD34NegCD164high     3016
CMP                        2804
pDC                        1809
cDC                         885
LMPP                        805
Plasma                      670
PreBNK                      554
LinNegCD34NegCD164low       194
MLP                         123
Name: count, dtype: int64

In [7]:
adata.obs['standard_cell_type'].value_counts()

standard_cell_type
Fib               34289
HSC               20789
T_cell            19303
LinNeg            13396
B_cell            10734
EryP               9902
MPP                9391
iHSC               8379
MEP                7433
Mono               6988
GMP                4882
NK                 4189
MDP                3989
MKP                3980
CLP                3640
CMP                2804
Dendritic_cell     2694
LMPP                805
PreBNK              554
MLP                 123
Name: count, dtype: int64

In [8]:
adata.obs['cell_type'] = adata.obs['standard_cell_type']

# Get features

In [9]:
cell_types = [
    # 'B Cells',
    # 'B Cells Memory',
    # 'B Cells Naive',
    # 'Dendritic Cells',
    'Embryonic Stem Cells',
    'Endothelial Cells',
    'Endothelial Cells (Aorta)',
    'Endothelial Cells (Blood Brain Barrier)',
    'Epithelial Cells',
    'Erythroblasts',
    'Erythroid-like And Erythroid Precursor Cells',
    'Fibroblasts',
    'Hematopoietic Stem Cells',
    'Macrophages',
    'Megakaryocytes',
    'Monocytes',
    # 'Myeloid-derived Suppressor Cells',
    # 'T Cells',
    # 'T Cells Naive',
    # 'T Cytotoxic Cells',
    # 'T Follicular Helper Cells',
    # 'T Helper Cells',
    # 'T Memory Cells',
    # 'T Regulatory Cells',
]

marker_gene_mat = pang[cell_types].copy()
marker_gene_mat = marker_gene_mat[marker_gene_mat.index.isin(adata.var_names)]

print(f"{marker_gene_mat.shape=}")
marker_gene_mat = marker_gene_mat[marker_gene_mat.sum(axis=1) > 0] 
print(f"{marker_gene_mat.shape=}")
marker_gene_mat.head()

marker_gene_mat.shape=(6274, 12)
marker_gene_mat.shape=(1464, 12)


label,Embryonic Stem Cells,Endothelial Cells,Endothelial Cells (Aorta),Endothelial Cells (Blood Brain Barrier),Epithelial Cells,Erythroblasts,Erythroid-like And Erythroid Precursor Cells,Fibroblasts,Hematopoietic Stem Cells,Macrophages,Megakaryocytes,Monocytes
GDF15,False,False,False,False,False,False,False,False,False,True,False,False
RARRES2,False,False,False,False,False,False,False,True,False,False,False,False
CELA1,False,False,False,False,False,False,False,True,False,False,False,False
PRR15L,False,False,False,False,True,False,False,False,False,False,False,False
DEFB1,False,False,False,False,True,False,False,False,False,False,False,False


# Define the trainning data

In [10]:
fraction = 0.25

exclude_datasets = [
    'iHSC',
    'sc_fib',
]

train_types= [
    'Fib',
    'HSC',
    'T_cell',
    'LinNeg',
    'B_cell',
    'EryP',
    'MPP',
    'MEP',
    'Mono',
    'GMP',
    'NK',
    'MDP',
    'MKP',
    'CLP',
    'CMP',
    'Dendritic_cell',
    'LMPP',
    'PreBNK',
    'MLP',
]

mask = (~adata.obs['dataset'].isin(exclude_datasets)) & (adata.obs['cell_type'].isin(train_types))
bdata = adata[mask, marker_gene_mat.index].copy()

sc.pp.subsample(bdata, fraction=fraction)
sc.pp.filter_cells(bdata, min_genes=2)

# Split the data into training and testing sets
train_idx, test_idx = train_test_split(bdata.obs.index, test_size=0.2, random_state=42)

# Create AnnData objects for training and testing
train_data = bdata[train_idx, :].copy()
test_data = bdata[test_idx, :].copy()

print(train_data)
print()
print(test_data)

AnnData object with n_obs × n_vars = 30429 × 1464
    obs: 'n_genes', 'dataset', 'n_genes_by_counts', 'total_counts', 'size_factor', 'obs_index', 'cell_type', 'standard_cell_type'
    var: 'n_cells_by_counts', 'mean_counts', 'pct_dropout_by_counts', 'total_counts', 'gene_id', 'token_id', 'gene_biotype', 'Chromosome', 'Start', 'End'
    layers: 'counts'

AnnData object with n_obs × n_vars = 7608 × 1464
    obs: 'n_genes', 'dataset', 'n_genes_by_counts', 'total_counts', 'size_factor', 'obs_index', 'cell_type', 'standard_cell_type'
    var: 'n_cells_by_counts', 'mean_counts', 'pct_dropout_by_counts', 'total_counts', 'gene_id', 'token_id', 'gene_biotype', 'Chromosome', 'Start', 'End'
    layers: 'counts'


In [11]:
train_data.obs['cell_type'].value_counts()

cell_type
Fib               5318
HSC               4165
T_cell            3868
LinNeg            2647
B_cell            2167
EryP              2041
MPP               1914
MEP               1459
Mono              1356
GMP                964
MKP                829
NK                 818
MDP                790
CLP                681
CMP                564
Dendritic_cell     550
LMPP               147
PreBNK             128
MLP                 23
Name: count, dtype: int64

# Model config and Trainning

In [12]:
torch.cuda.empty_cache()
    
scvi.external.CellAssign.setup_anndata(
    train_data,
    size_factor_key="size_factor",
    # batch_key='dataset',
)

model = CellAssign(
    train_data, 
    marker_gene_mat.astype(int),
)

print("Done and ready for trainning!")

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


Done and ready for trainning!


In [13]:
model.train(
    max_epochs=10,
    lr=0.01,
    batch_size=16,
)

/home/cstansbu/miniconda3/envs/scanpy/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/cstansbu/miniconda3/envs/scanpy/lib/python3.12 ...
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/cstansbu/miniconda3/envs/scanpy/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/cstansbu/miniconda3/envs/scanpy/lib/python3.12 ...
/home/cstansbu/miniconda3/envs/scanpy/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers wh

Epoch 1/10:   0%|          | 0/10 [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...
Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x152256344380>>
Traceback (most recent call last):
  File "/home/cstansbu/miniconda3/envs/scanpy/lib/python3.12/site-packages/ipykernel/ipkernel.py", line 770, in _clean_thread_parent_frames
KeyboardInterrupt: 


KeyboardInterrupt: 

# Training Eval

In [None]:
metrics = pd.concat(
    model.history.values(), 
    ignore_index=False,
    axis=1
)

metrics = metrics.reset_index()
metrics.head()

In [None]:
# Melt the DataFrame
df_melted = pd.melt(
    metrics,
    id_vars='epoch',
    value_vars=['train_loss_epoch', 'validation_loss'],
    var_name='Loss Type',  # More descriptive name
    value_name='Loss'      # More descriptive name
)

plt.rcParams['figure.dpi'] = 200
plt.rcParams['figure.figsize'] = 3.5, 3.5

sns.lineplot(
    data=df_melted,
    x='epoch',
    y='Loss',
    hue='Loss Type',
    palette=['#1f77b4', '#ff7f0e'],  # Use distinct colors
    linewidth=1.5,                  # Make lines thicker
    marker='o',                   # Add markers for data points
    markersize=5
)

plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.xticks()
plt.yticks()
plt.legend(title='Loss Type') 
plt.grid(True, linestyle='--', alpha=0.5)  # Add a subtle grid

sns.despine()

plt.tight_layout()  # Adjust layout to prevent clipping
plt.show()

# Prediction

In [None]:
torch.cuda.empty_cache()
predictions = model.predict()
predictions.head()

In [None]:
# train_data.obs["scvi-tools predictions"] = predictions.idxmax(axis=1).values

# # celltype is the original CellAssign prediction
# sc.pl.umap(
#     train_data, 
#     color=["standard_cell_type", "scvi-tools predictions"], 
#     frameon=False, 
#     ncols=1,
# )

In [None]:
train_data.obs["predictions"] = predictions.idxmax(axis=1).values
df = train_data.obs
confusion_matrix = pd.crosstab(
    df["predictions"],
    df["cell_type"],
    rownames=["Predictions"],
    colnames=["Cell Type"],
)
confusion_matrix /= confusion_matrix.sum(1).ravel().reshape(-1, 1)

sns.heatmap(
    confusion_matrix,
    square=True,
    cbar_kws=dict(shrink=0.4, aspect=12),
    annot=True,
    lw='1',
)

plt.title('Training Data')

In [None]:
break

# Testing Data Validation 

In [None]:
def predict_other(adata, model):
    """Predict soft cell type assignment probability for each cell."""
    adata = model._validate_anndata(adata)
    scdl = model._make_data_loader(adata=adata)
    
    predictions = []
    for tensors in scdl:
        generative_inputs = model.module._get_generative_input(tensors, None)
        outputs = model.module.generative(**generative_inputs)
        gamma = outputs["gamma"]
        predictions += [gamma.cpu()]
    prob = pd.DataFrame(
        torch.cat(predictions).detach().numpy(), 
        columns=model.cell_type_markers.columns
    )
    prob['prediction'] = prob.idxmax(axis=1).values
    return prob


torch.cuda.empty_cache()
test_pred = predict_other(test_data, model)
test_pred['cell_type'] = test_data.obs['cell_type'].values

test_pred

In [None]:
confusion_matrix = pd.crosstab(
    test_pred["prediction"],
    test_pred["cell_type"],
    rownames=["Predictions"],
    colnames=["Cell Type"],
)
confusion_matrix /= confusion_matrix.sum(1).ravel().reshape(-1, 1)

sns.heatmap(
    confusion_matrix,
    square=True,
    cbar_kws=dict(shrink=0.4, aspect=12),
    annot=True,
    lw='1',
)

plt.title('True Hold Out Data')

# Our data

In [None]:
our_data = adata[adata.obs['dataset'].isin(exclude_datasets), marker_gene_mat.index].copy()
our_data

In [None]:
pred = predict_other(our_data, model)
pred['cell_type'] = our_data.obs['cell_type'].values
pred['dataset'] = our_data.obs['dataset'].values

pred

In [None]:
pred[['dataset', 'prediction']].value_counts()