In [None]:
import argparse
import logging
import math
from pathlib import Path
from typing import List, Union
import numpy as np
import pandas as pd
import scanpy as sc
import scvi
from lightning.pytorch import seed_everything
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib import rcParams
import cell2location
from cell2location.utils.filtering import filter_genes
from cell2location.models import RegressionModel
from anndata import AnnData
from scipy.sparse import issparse
from scanpy.plotting._tools.scatterplots import _get_palette

In [None]:
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

In [None]:
seed = 42
scvi.settings.seed = seed
seed_everything(seed, workers=True)

In [None]:
reference_file = "./scData.h5ad"  # Replace with your file path
reference_output_dir = "./demo_reference_output"
adata_ref = sc.read_h5ad(reference_file)

In [None]:
selected = filter_genes(adata_ref, cell_count_cutoff=5, cell_percentage_cutoff2=0.05, nonz_mean_cutoff=1.25)
Path(reference_output_dir).mkdir(exist_ok=True, parents=True)
plt.savefig(Path(reference_output_dir).joinpath(f'{Path(reference_file).stem}_filter_genes.png'))
plt.close()        
adata_ref = adata_ref[:, selected].copy()

In [None]:
cell2location.models.RegressionModel.setup_anndata(adata=adata_ref, batch_key=None, labels_key="Broad_labels")
mod = RegressionModel(adata_ref)
mod.view_anndata_setup()
mod.train(max_epochs=250)
mod.plot_history(20)
plt.savefig(Path(reference_output_dir).joinpath(f'{Path(reference_file).stem}_train_history.png'))

In [None]:
adata_ref = mod.export_posterior(adata_ref, sample_kwargs={'num_samples': 1000, 'batch_size': 2500})
if 'means_per_cluster_mu_fg' in adata_ref.varm.keys():
    inf_aver = adata_ref.varm['means_per_cluster_mu_fg'][[f'means_per_cluster_mu_fg_{i}' for i in adata_ref.uns['mod']['factor_names']]].copy()
else:
    inf_aver = adata_ref.var[[f'means_per_cluster_mu_fg_{i}' for i in adata_ref.uns['mod']['factor_names']]].copy()
inf_aver.columns = adata_ref.uns['mod']['factor_names']
inf_aver.index.name = None
inf_aver.to_csv(Path(reference_output_dir).joinpath(f'{Path(reference_file).stem}_inf_aver.csv'), sep=',')
logger.info(f'Reference saved to {reference_output_dir}')

In [None]:
spatial_file ="./bin100.scanpy.h5ad"  # Replace with your file path
reference_csv = "demo_reference_output/scData_inf_aver.csv"  # From step 1
annotation_output_dir = "./demo_annotation_output"
adata = sc.read_h5ad(spatial_file)
adata.var['mt'] = adata.var_names.str.startswith(('mt-', 'MT-'))
adata = adata[:, ~adata.var['mt'].values]
adata

In [None]:
inf_aver = pd.read_csv(reference_csv, index_col=0)
inf_aver.index = inf_aver.index.astype(str)
intersect = np.intersect1d(adata.var_names, inf_aver.index)
adata = adata[:, intersect].copy()
inf_aver = inf_aver.loc[intersect, :].copy()

In [None]:
cell2location.models.Cell2location.setup_anndata(adata=adata)
mod2 = cell2location.models.Cell2location(adata, cell_state_df=inf_aver, 
                                           N_cells_per_location=25)
mod2.view_anndata_setup()
    
mod2.train(max_epochs=5000, batch_size=None, train_size=1)


In [None]:
adata = mod2.export_posterior(adata, sample_kwargs={'num_samples': 1000, 'batch_size': 10000})

In [None]:
adata.obsm['q05_cell_abundance_w_sf'][0:10]

In [None]:
factor_names = adata.uns['mod']['factor_names']
factor_names

In [None]:
adata.obs[factor_names] = adata.obsm['q05_cell_abundance_w_sf']
adata.obs

In [None]:
cellList = list(set(inf_aver.columns) & set(adata.obs.columns))
cellList
adata.obs['cell2loc_anno'] = adata.obs[cellList].idxmax(axis=1)
out_df = adata.obs[factor_names].copy()
out_df['annotation'] = adata.obs['cell2loc_anno']
out_df.to_csv(f'demo_reference_output/{Path(spatial_file).stem}_anno_cell2location.csv')

In [None]:
adata.write_h5ad(f'demo_reference_output/{Path(spatial_file).stem}_cell2location.h5ad', compression='gzip')

In [None]:

with mpl.rc_context({'axes.facecolor':  'black',
                     'figure.figsize': [4.5, 5]}):

    sc.pl.spatial(adata, cmap='magma',
                  color=cellList,
                  ncols=3, size=5,
                  vmax='p99',
                  #img_key='hires',
                  )