In [None]:
import os
import sys
from rich import inspect, print as rprint
from pathlib import Path
import numpy as np
import pandas as pd
import anndata as ad
import scanpy as sc
import scanpy.external as sce
from scipy.sparse import issparse, csr_matrix
from sklearn.decomposition import TruncatedSVD

from ALLCools.clustering import *
from ALLCools.integration.seurat_class import SeuratIntegration
from spida.P.setup_adata import _calc_embeddings

import matplotlib.pyplot as plt
import seaborn as sns
from ALLCools.plot import *
from spida.pl import plot_categorical, plot_continuous
# Move Plot ALLCools functions outside of the cli
plt.rcParams['axes.facecolor'] = 'white'

from PyComplexHeatmap import *

from datetime import datetime 
current_datetime = datetime.now().strftime("%Y-%m-%d_%H:%M")

In [None]:
#parameters
EXP = None
REF_EXP = None

# ref_adata_path="/home/x-aklein2/projects/aklein/BICAN/data/reference/AIT/AIT_{REF_EXP}.h5ad"
# spatial_adata_path="/home/x-aklein2/projects/aklein/BICAN/BG/data/annotation/BICAN_BG_{EXP}/{EXP}.h5ad"


# Paths
integrator_path="/home/x-aklein2/projects/aklein/BICAN/BG/data/annotation/BICAN_BG_{EXP}/{EXP}_joint.h5ad"
anndata_store_path = "/home/x-aklein2/projects/aklein/BICAN/data/aggregated"
annotations_store_path = "/home/x-aklein2/projects/aklein/BICAN/data/annotated"
image_path = "/home/x-aklein2/projects/aklein/BICAN/BG/images/annotations"
cell_mapping_out_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/annotation/BICAN_BG_{EXP}/{EXP}_cell_mappings.tsv"

mc_subset_col="Subclass"
mc_subset_value=None
spatial_subset_col="c2c_allcools_label_Subclass"
spatial_subset_value=None

# umap_coord_base="base_round1_umap"
# tsne_coord_base="base_round1_tsne"

## Spatial related params
spatial_adata_path="/home/x-aklein2/projects/aklein/BICAN/BG/data/annotation/BICAN_BG_{EXP}/{EXP}.h5ad"
spatial_ann_path = None # .tsv file of annotations 
spatial_ann_cols=["Neighborhood", "Subclass"]
spatial_cell_type = "all_round1_leiden"
spatial_downsample=None
normalize_spatial=False
spatial_query=None
spatial_std_cutoff=None
spatial_tsne_key='base_round1_tsne'
spatial_umap_key='base_round1_umap'
spatial_batch_key="MERSCOPE"
harmony_batch_key=["replicate", "donor"]
added_joint_emb_key =  "subclass_"


## mC related Params
mc_adata_path = "/home/x-aklein2/projects/aklein/BICAN/data/reference/mC/BG.gene-CHN.h5ad"
mc_ann_path = "/home/x-aklein2/projects/aklein/BICAN/data/reference/mC/annotations.tsv"
mc_ann_cols = ["Class", "CellClass", "Subclass", "Group", "Region"]
mc_query = " CellClass != 'Nonneuron' "
mc_cell_type = "Group"
mc_std_cutoff = 0.01
mc_downsample = 2000
mc_palette = None
level=None
key_to_transfer="Group"
normalize_mc_per_cell=True
mc_batch_key="mC"

# General Factors: 
min_cell=20
topn=200
n_train_cell=100000
chunk_size = 50000
resolution = 4
cpu=16
ref_palette=None
spatial_palette=None
use_genes = None
k_weight_transfer = 20

outdir=None

In [None]:
ref_adata_path = mc_adata_path.format(REF_EXP=REF_EXP)
spatial_adata_path = spatial_adata_path.format(EXP=EXP)
integrator_path = integrator_path.format(EXP=EXP)
cell_mapping_out_path = cell_mapping_out_path.format(EXP=EXP)
Path(integrator_path).parent.mkdir(parents=True, exist_ok=True)
Path(cell_mapping_out_path).parent.mkdir(parents=True, exist_ok=True)
image_path = Path(image_path) / EXP
image_path.mkdir(parents=True, exist_ok=True)
ref_cell_type = mc_cell_type

In [None]:
# LABEL_COL = qry_subset_col
# TRANSFER_LEVEL = mc_cell_type
# added_joint_emb_key = "subclass_"

## Setup

### Spatial

In [None]:
spatial_adata = ad.read_h5ad(spatial_adata_path)
spatial_adata = spatial_adata[spatial_adata.obs[spatial_subset_col] == spatial_subset_value].copy()

# To dense
if issparse(spatial_adata.X):
    spatial_adata.X = spatial_adata.X.toarray()
# std cutoff
if spatial_std_cutoff is not None:
    std_filter = spatial_adata.X.std(axis=0) > spatial_std_cutoff
    spatial_adata._inplace_subset_var(std_filter)

spatial_adata.strings_to_categoricals()
spatial_genes = spatial_adata.var_names.tolist()

adata_cp = spatial_adata.copy()
sc.experimental.pp.highly_variable_genes(adata_cp, n_top_genes=200, subset=True, layer='counts')
# TODO: 
_calc_embeddings(
    adata_cp,
    layer="volume_norm",
    key_added=added_joint_emb_key,
    leiden_res=2.5,
    min_dist=0.25,
    knn=50,
    run_harmony=True, 
    batch_key=harmony_batch_key,
    harmony_nclust=20,
    max_iter_harmony=20,
)
# Copy Over results
for col in adata_cp.obs.columns: 
    if col not in spatial_adata.obs: 
        spatial_adata.obs[col] = adata_cp.obs[col].copy()
for key, value in adata_cp.obsm.items(): 
    if key not in spatial_adata.obsm: 
        spatial_adata.obsm[key] = value.copy()

spatial_tsne_key=f'{added_joint_emb_key}tsne'
spatial_umap_key=f'{added_joint_emb_key}umap'
spatial_cell_type=f"{added_joint_emb_key}leiden"

sc.tl.rank_genes_groups(spatial_adata,groupby=spatial_cell_type,method="wilcoxon",use_raw=False)

spatial_adata

### Methylation 

In [None]:
# Read in ref
ref_adata = ad.read_h5ad(ref_adata_path, backed='r')
# Add the annotations from file 
keep_cells=ref_adata.obs.index.tolist()
if not mc_ann_path is None:
    if mc_ann_path.endswith('.h5ad'):
        obs=ad.read_h5ad(os.path.expanduser(mc_ann_path)).obs.copy()
    elif mc_ann_path.endswith('.csv') or mc_ann_path.endswith('.csv.gz'):
        obs=pd.read_csv(os.path.expanduser(mc_ann_path),index_col=0)
    else:
        obs=pd.read_csv(os.path.expanduser(mc_ann_path),index_col=0,sep='\t')
    # adata.obs=obs.reindex(index=adata.obs_names).copy()
    keep_cells=list(set(keep_cells) & set(obs.index.tolist()))
    if mc_ann_cols is None:
        mc_ann_cols=[]
    for col in obs.columns.tolist():
        if col not in mc_ann_cols+[mc_cell_type]:
            continue
        ref_adata.obs[col]=ref_adata.obs_names.map(obs[col].to_dict())

if not level is None and mc_cell_type=='cluster_id' and mc_cell_type in ref_adata.obs.columns.tolist():
    # too many clusters for cluster_id, use the first two level (such as  c5+c20+c0 -> c5+c20)
    ref_adata.obs[mc_cell_type]=ref_adata.obs['cluster_id'].apply(lambda x:"+".join(x.split('+')[:level]) if not pd.isna(x) else x)

# Downsample
ref_adata.obs[mc_cell_type]=ref_adata.obs[mc_cell_type].astype('category')
vc=ref_adata.obs[mc_cell_type].value_counts()
keep_cts=vc[vc>=min_cell].index.tolist()
cells=ref_adata.obs.loc[ref_adata.obs[mc_cell_type].isin(keep_cts)].index.tolist()
keep_cells=list(set(keep_cells) & set(cells))
if mc_downsample and ref_adata.n_obs > mc_downsample:
    if not mc_query is None:
        use_cells=ref_adata.obs.loc[ref_adata.obs[mc_cell_type].notna()].query(mc_query).groupby(mc_cell_type).apply(lambda x:x.sample(mc_downsample).index.tolist() if x.shape[0] > mc_downsample else x.index.tolist()).sum()
    else:
        use_cells=ref_adata.obs.loc[ref_adata.obs[mc_cell_type].notna()].groupby(mc_cell_type).apply(lambda x:x.sample(mc_downsample).index.tolist() if x.shape[0] > mc_downsample else x.index.tolist()).sum()
    keep_cells=list(set(keep_cells) & set(use_cells))
elif not mc_query is None:
    use_cells=ref_adata.obs.loc[ref_adata.obs[mc_cell_type].notna()].query(mc_query).index.tolist()
    keep_cells=list(set(keep_cells) & set(use_cells))
ref_adata=ref_adata[keep_cells,:].to_memory()
# final setup
ref_adata.obs[mc_cell_type]=ref_adata.obs[mc_cell_type].astype('category')
ref_adata.strings_to_categoricals()

# Get the genes from the spatial data
common_genes = list(set(spatial_genes) & set(ref_adata.var_names.tolist()))
ref_adata = ref_adata[:,common_genes].copy()

#Subset the mC data further: 
ref_adata = ref_adata[ref_adata.obs[mc_subset_col] == mc_subset_value].copy() if mc_subset_value is not None else ref_adata

# Get marker genes
sc.tl.rank_genes_groups(ref_adata,groupby=ref_cell_type,method="wilcoxon")

# If Normalize mc
if normalize_mc_per_cell:  # divide frac by prior mean (determined by alpha and beta) for each cell
    cols = ref_adata.obs.columns.tolist()
    if 'prior_mean' in cols:
        print("Normalizing cell level fraction by alpha and beta (prior_mean)")
        ref_adata.X = ref_adata.X / ref_adata.obs.prior_mean.values[:, None]
ref_adata

In [None]:
vc = ref_adata.obs[mc_cell_type].value_counts()
print(vc)

if len(vc[vc > 20].index) ==  1: 
    print(f"Only one group with > 20 cells found in {mc_subset_col}, skipping")
    annot = vc.index[0]

    key = mc_cell_type
    
    spatial_adata.obs[f'infer_{key}'] = annot
    spatial_adata.obs[f'infer_{key}_prob'] = 1.0
    spatial_adata.obs[f'infer_{key}_c2c'] = annot
    
    Path(integrator_path).parent.mkdir(parents=True, exist_ok=True)
    spatial_adata.write_h5ad(integrator_path)
    sys.exit(0)

## View

In [None]:
from spida.pl import plot_categorical
plt.rcParams['axes.facecolor'] = 'white'

In [None]:
spatial_colors=None
if spatial_tsne_key is not None:
    with plt.rc_context({"figure.figsize": (4, 4), "figure.dpi": (100)}):
        plt.figure()
        plot_categorical(spatial_adata, cluster_col=spatial_cell_type, coord_base=spatial_tsne_key, show=False)
        # plt.savefig(os.path.join("figures","MajorType.tsne.summary.pdf"),bbox_inches='tight',dpi=300)
        plt.show()
if spatial_umap_key is not None:
    with plt.rc_context({"figure.figsize": (4, 4), "figure.dpi": (100)}):
        plt.figure()
        plot_categorical(spatial_adata, cluster_col=spatial_cell_type, coord_base=spatial_umap_key, show=False)
        # plt.savefig(os.path.join("figures","MajorType.umap.summary.pdf"),bbox_inches='tight',dpi=300)
        plt.show()
    spatial_colors={cluster:color for cluster,color in zip(spatial_adata.obs[spatial_cell_type].cat.categories.tolist(),spatial_adata.uns[f'{spatial_cell_type}_colors'])}
spatial_adata.obs[spatial_cell_type].value_counts()

In [None]:
with plt.rc_context({"figure.figsize": (4, 4), "figure.dpi": (100)}):
    plt.figure()
    sc.pl.umap(ref_adata,color=[ref_cell_type],cmap='jet',
           ncols=2,wspace=0.25,show=False,vmin='p5',vmax='p95')
    plt.show()
ref_colors={cluster:color for cluster,color in zip(ref_adata.obs[ref_cell_type].cat.categories.tolist(),ref_adata.uns[f'{ref_cell_type}_colors'])}
ref_adata.obs[ref_cell_type].value_counts()

## Feature Selection

In [None]:
print(ref_adata.var_names.isin(spatial_adata.var_names).sum(),spatial_adata.var_names.isin(ref_adata.var_names).sum())
ref_adata._inplace_subset_var(ref_adata.var_names.isin(spatial_adata.var_names))

### mC

In [None]:
if use_genes is None:
    # use top 200 DEGs and DMGs for each cell type group and use the union of two gene sets as features for integration
    markers = sc.get.rank_genes_groups_df(ref_adata, group=ref_adata.obs[ref_cell_type].unique())
    markers=markers.loc[~ markers.names.isna()]
    markers=markers.loc[(~ markers.logfoldchanges.isna()) & (markers.scores < 0) & (markers.pvals < 0.05)]
    markers.sort_values('logfoldchanges',ascending=True,inplace=True)
    ref_features=np.unique(markers.groupby('group').apply(lambda x:x.head(topn).names.tolist()).sum())
    print(len(ref_features),ref_features[:10])
    # markers

### Spatial

In [None]:

if use_genes is None:
    markers = sc.get.rank_genes_groups_df(spatial_adata, group=spatial_adata.obs[spatial_cell_type].unique())
    markers=markers.loc[~ markers.names.isna()]
    markers=markers.loc[(~ markers.logfoldchanges.isna()) & (markers.scores > 0) & (markers.pvals < 0.05)]
    markers.sort_values('logfoldchanges',ascending=False,inplace=True)
    spatial_features=np.unique(markers.groupby('group').apply(lambda x:x.head(topn).names.tolist()).sum())
    print(len(spatial_features),spatial_features[:10])
    # markers

### combine

In [None]:
if use_genes is None:
    selected_features=np.unique(np.concatenate((ref_features, spatial_features)))
else:
    if isinstance(use_genes,list):
        selected_features=use_genes
    else:
        print(use_genes)
        selected_features=pd.read_csv(os.path.expanduser(use_genes),index_col=0,sep='\t',header=None).index.tolist()
print(len(selected_features),selected_features[:10])

In [None]:
ref_adata._inplace_subset_var(ref_adata.var_names.isin(selected_features))
spatial_adata._inplace_subset_var(spatial_adata.var_names.isin(selected_features))

In [None]:
ref_adata, spatial_adata

## Filter + Normalization

In [None]:
if not mc_std_cutoff is None:
    std_filter = ref_adata.X.std(axis=0) > mc_std_cutoff
    ref_adata._inplace_subset_var(std_filter)
ref_adata

# Normalize mc adata
# log mC fraction and scale features
log_scale(ref_adata, with_mean=True) ## after log: 0-0.69, after scale: z score (columns)
# reverse mC fraction so its positively corr with RNA
ref_adata.X *= -1
ref_adata

In [None]:
use_vars=list(set(ref_adata.var_names) & set(spatial_adata.var_names))
ref_adata=ref_adata[:,use_vars]
spatial_adata=spatial_adata[:,use_vars]

In [None]:
ref_adata

In [None]:
spatial_adata

## Merge + PCA

In [None]:
ref_adata=ref_adata
query_adata=spatial_adata
batch_categories=[mc_batch_key, spatial_batch_key]

In [None]:
adata = ref_adata.concatenate(
    query_adata, batch_categories=batch_categories, batch_key="Modality", index_unique=None
)
print(adata.obs.Modality.value_counts())
adata

In [None]:
np.random.seed(0)

# select ref cells to fit the model
train_cell = np.zeros(ref_adata.shape[0], dtype=bool)
if ref_adata.shape[0] > n_train_cell:
    train_cell[
        np.random.choice(np.arange(ref_adata.shape[0]), n_train_cell, False)
    ] = True
else:
    train_cell[:] = True

ref_adata.obs["Train"] = train_cell.copy()

In [None]:
## Run PCA on Merged Adata
n_feature = ref_adata.shape[1]
if n_feature >= 100:
    model = TruncatedSVD(n_components=100, random_state=0, algorithm='randomized')
else:
    model = TruncatedSVD(n_components=n_feature - 1, random_state=0, algorithm='randomized')

# use selected train cells to fit
model.fit(ref_adata.X[ref_adata.obs["Train"].values])
sel_dim = model.singular_values_ != 0
print(sel_dim.sum(), 'non-zero singular values')

In [None]:
fig, ax = plt.subplots()
ax.plot(model.explained_variance_ratio_)

In [None]:
# transform all other data
chunks = []
for chunk_start in range(0, adata.shape[0], chunk_size):
    chunks.append(
        model.transform(adata.X[chunk_start : (chunk_start + chunk_size)])
    )

# add NNZ PCs to adata
adata.obsm["X_pca"] = np.concatenate(chunks, axis=0)[:, sel_dim]

# remove low variance PCs
n_pcs = significant_pc_test(adata, p_cutoff=0.05, obsm="X_pca")

# scale PC by singular values
adata.obsm["X_pca"] /= model.singular_values_[sel_dim][:n_pcs]

In [None]:
for col in adata.obs.select_dtypes(['category', 'object']).columns:
    adata.obs[col] = adata.obs[col].astype(str)
for i, m in enumerate(batch_categories):
    adata1 = adata[adata.obs["Modality"] == m]
    cols=adata1.obs.columns.tolist()
    for col in cols:
        if adata1.obs[col].notna().sum() == 0:
            adata1.obs.drop(col,axis=1,inplace=True) #otherwise, would get error:Can't implicitly convert non-string objects to strings for all nan columns
    adata1.strings_to_categoricals()
    if i == 0: 
        ref_adata = adata1.copy()
    elif i == 1:
        spatial_adata = adata1.copy()
    # adata1.write_h5ad(f"{m.lower()}_pca.h5ad")

# Seurat Integration 

In [None]:
query_adata = spatial_adata
adata_list = [ref_adata, query_adata]
for adata in adata_list:
    print(adata.shape)

In [None]:
cells = sum([a.shape[0] for a in adata_list])
features = adata_list[0].shape[1]

adata_merge = ad.AnnData(
    X=csr_matrix((cells, features), dtype=np.float32),
    obs=pd.concat([a.obs for a in adata_list]),
    var=adata_list[0].var,
)
print(adata_list[0].obsm["X_pca"].shape,adata_list[1].obsm["X_pca"].shape)

In [None]:
n_pc = adata_list[0].obsm["X_pca"].shape[1]
if n_pc < 10:
    n_cca_components = n_pc
else:
    n_cca_components = max(n_pc - 10, 10)

print("CCA Components", n_cca_components)
min_sample = adata_merge.obs["Modality"].value_counts().min()
print("Smaller Sample Size", min_sample)

In [None]:
integrator = SeuratIntegration()

In [None]:
# take ~2.5-3h for 300K mC + 4M 10X RNA
anchor = integrator.find_anchor(
    adata_list,
    k_local=None,
    key_local="X_pca",
    k_anchor=5,
    key_anchor="X",
    dim_red="cca",
    max_cc_cells=100000,
    k_score=30,
    k_filter=min(200, min_sample),
    scale1=False,
    scale2=False,
    n_components=n_cca_components,
    n_features=100,
    alignments=None,#[[[0], [1]]],
)

In [None]:
corrected = integrator.integrate(
    key_correct="X_pca",
    row_normalize=True,
    k_weight=50,
    sd=1,
    alignments=None, #[[[0], [1]]],
)

adata_merge.obsm["X_pca_integrate"] = np.concatenate(corrected)
# f=open("integrator.pkl",'wb')
# pickle.dump(integrator,f,True)
# f.close()

In [None]:
if key_to_transfer is None:
    key_to_transfer=[ref_cell_type]
if isinstance(key_to_transfer, str): 
    key_to_transfer = [key_to_transfer]
key_to_transfer

In [None]:
label_transfer=integrator.label_transfer(
    ref=[0],qry=[1],categorical_key=key_to_transfer,continuous_key=None,
    key_dist='X_pca', k_weight=k_weight_transfer
)
adata_merge.uns['label_transfer']=label_transfer
for key in label_transfer:
    labels = label_transfer[key].columns.tolist()
    adata_merge.obs[f'infer_{key}'] = adata_merge.obs.index.to_series().map(
        label_transfer[key].apply(lambda x: labels[np.argmax(x)], axis=1).to_dict())
    adata_merge.obs[f'infer_{key}_prob'] = adata_merge.obs.index.to_series().map(
        label_transfer[key].apply(lambda x: x[np.argmax(x)], axis=1).to_dict())
    adata_merge.obs[key] = adata_merge.obs.apply(
        lambda x: x[key] if not pd.isna(x[key]) else x[f'infer_{key}'], axis=1)

### Harmony on embeddings

In [None]:
sce.pp.harmony_integrate(adata_merge,
                     key='Modality',basis='X_pca_integrate',
                     nclust=50,max_iter_harmony=30) #X_pca_harmony

### tSNE

In [None]:
# 1 hour to run 4M + 3K cell
tsne(adata_merge, obsm='X_pca_harmony',n_jobs=cpu)

### UMAP

In [None]:
# 15 min to run 4M + 3K cell
sc.pp.neighbors(adata_merge, use_rep='X_pca_harmony')

In [None]:
# 2 hours to run 4M + 3K cell
print(resolution)
sc.tl.leiden(adata_merge, resolution=resolution)

In [None]:
# 2 hours to run 4M + 3K cell
# 4 hours to run 4M + 3K cell if using spectral init, the init step is very slow
min_dist = max(0.1, 1 - adata_merge.shape[0] / 60000)
try:
    sc.tl.paga(adata_merge, groups=ref_cell_type) #leiden
    sc.pl.paga(adata_merge, plot=False)
    sc.tl.umap(adata_merge, min_dist=min_dist, init_pos='paga')
except Exception:
    print('Init with PAGA failed, use default spectral init')
    sc.tl.umap(adata_merge, min_dist=min_dist)

In [None]:
adata_merge.write_h5ad(integrator_path)

# Plots

In [None]:
adata_merge = ad.read_h5ad(integrator_path)
adata_merge

In [None]:
#TODO: Change the modality to be passed in parameters
ref_vc=adata_merge.obs.query(f"Modality=='{mc_batch_key}'").groupby('leiden')[ref_cell_type].value_counts(normalize=True).sort_values(ascending=False).reset_index()
spatial_vc=adata_merge.obs.query(f"Modality=='{spatial_batch_key}'").groupby('leiden')[spatial_cell_type].value_counts(normalize=True).sort_values(ascending=False).reset_index()

# mc_vc=mc_vc.loc[mc_vc.proportion > 0.5]
# rna_vc=rna_vc.loc[rna_vc.proportion > 0.5]

ref_vc.drop_duplicates('leiden',keep='first',inplace=True)
spatial_vc.drop_duplicates('leiden',keep='first',inplace=True)
ref_vc.rename(columns={'proportion':f'{ref_cell_type}_proportion'},inplace=True)
spatial_vc.rename(columns={'proportion':f'{spatial_cell_type}_proportion'},inplace=True)
df_map=pd.concat([ref_vc.set_index('leiden'),spatial_vc.set_index('leiden')],axis=1)
df_map[f'{ref_cell_type}_cell_count']=df_map.index.to_series().map(adata_merge.obs.query(f"Modality=='{mc_batch_key}'").groupby('leiden')[ref_cell_type].count()).astype(int)
df_map[f'{spatial_cell_type}_cell_count']=df_map.index.to_series().map(adata_merge.obs.query(f"Modality=='{spatial_batch_key}'").groupby('leiden')[spatial_cell_type].count()).astype(int)
df_map.to_csv(cell_mapping_out_path,sep='\t')
df_map.query(f"`{ref_cell_type}_cell_count` > 100 & `{spatial_cell_type}_cell_count` > 100")

In [None]:
print(adata_merge.obs['Modality'].unique())
hue_cols=[ref_cell_type,spatial_cell_type]
modalities=[mc_batch_key,spatial_batch_key]

In [None]:
if ref_palette is not None:
    if os.path.isfile(os.path.expanduser(ref_palette)) and os.path.exists(os.path.expanduser(ref_palette)): #excel
        D=pd.read_excel(os.path.expanduser(ref_palette),
                    sheet_name=None, index_col=0)
        ref_colors=D[ref_cell_type].Hex.to_dict()
    else: #column from adata.obs
        obs=ref_adata.obs.copy()
        ref_colors=obs.reset_index().loc[:,[ref_cell_type,ref_palette]].drop_duplicates().dropna().set_index(ref_cell_type)[ref_palette].to_dict()
adata_merge.uns[f'{ref_cell_type}_colors']=[ref_colors.get(k,'grey') for k in adata_merge.obs[ref_cell_type].cat.categories.tolist()]

In [None]:
if not spatial_palette is None:
    if os.path.isfile(os.path.expanduser(spatial_palette)) and os.path.exists(os.path.expanduser(spatial_palette)): #excel
        D=pd.read_excel(os.path.expanduser(spatial_palette),
                    sheet_name=None, index_col=0)
        spatial_colors=D[spatial_cell_type].Hex.to_dict()
    else: #column from adata.obs
        obs=spatial_adata.obs.copy()
        spatial_colors=obs.reset_index().loc[:,[spatial_cell_type,spatial_palette]].drop_duplicates().dropna().set_index(spatial_cell_type)[spatial_palette].to_dict()
if spatial_colors is None:
    sc.pl.umap(adata_merge,color=[spatial_cell_type],cmap='jet',
               ncols=2,wspace=0.25,show=False,vmin='p5',vmax='p95')
    atac_colors={cluster:color for cluster,color in zip(adata_merge.obs[spatial_cell_type].cat.categories.tolist(),adata_merge.uns[f'{spatial_cell_type}_colors'])}
adata_merge.uns[f'{spatial_cell_type}_colors']=[spatial_colors.get(k,'grey') for k in adata_merge.obs[spatial_cell_type].cat.categories.tolist()]

In [None]:
adata_dict={}
for modality,hue_col in zip(modalities,hue_cols):
    adata=adata_merge[adata_merge.obs['Modality'] == modality]
    with plt.rc_context({"figure.figsize": (4, 4), "figure.dpi": (100)}):    
        sc.pl.tsne(adata, color=[hue_col], 
                   wspace=0.8)
        sc.pl.umap(adata, color=[hue_col], 
                   wspace=0.8)
    adata_dict[modality]=adata

In [None]:
for coord_base in ['tsne', 'umap']:
    fig, axes = plt.subplots(nrows=1,ncols=2,figsize=(8, 3.5),dpi=300,constrained_layout=True)
    for ax,col,modality in zip(axes,hue_cols,modalities):
        adata=adata_dict[modality]
        colors={cluster:color for cluster,color in zip(adata.obs[col].cat.categories.tolist(),adata.uns[f'{col}_colors'])}
        ncol=2 if len(colors)>40 else 1
        categorical_scatter(data=adata_merge[adata_merge.obs['Modality'] != modality],
                            coord_base=coord_base,
                            max_points=None,hue=None,
                            scatter_kws=dict(color='lightgrey'),ax=ax)
        categorical_scatter(data=adata,
                            ax=ax,coord_base=coord_base,
                            hue=col,
                            palette=colors,
                            max_points=None,
                            # dodge_text=True,dodge_kws={
                            #          "force_points": (15, 1),
                            #           'autoalign':'y'},
                            show_legend=True,
                            legend_kws=dict(
                                ncol=ncol,loc='upper left',bbox_to_anchor=(1,1),
                                borderpad=0.4, # pad between marker (text) and border
                                labelspacing=0.2, #The vertical space between the legend entries, in font-size units
                                handleheight=0.5, #The height of the legend handles, in font-size units.
                                handletextpad=0.2, # The pad between the legend handle (marker) and text, in font-size units.
                                borderaxespad=0.2, # The pad between the Axes and legend border, in font-size units
                                columnspacing=0.2, #The spacing between columns, in font-size units
                                fontsize=5.5,title_fontsize=6
                                )
                           )
        ax.set_title(modality,fontsize=14)
    plt.subplots_adjust(hspace=0,wspace=0,left=0,right=0)
    fig.tight_layout()
    fig.savefig(f"{image_path}/seurat.integrated.{coord_base}_with_legend.pdf",dpi=300,bbox_inches='tight') #
    plt.show()

In [None]:
with plt.rc_context({"figure.figsize": (4, 4), "figure.dpi": (100)}):    
        sc.pl.tsne(adata_merge, color=['leiden'], 
                   wspace=0.8)
        sc.pl.umap(adata_merge, color=['leiden'], 
                   wspace=0.8)

In [None]:
for coord_base in ['tsne', 'umap']:
    #fig, axes = plt.subplots(nrows=1,ncols=3,figsize=(12, 4),dpi=300,constrained_layout=True)
    col="leiden"
    fig, ax = plt.subplots(figsize=(5,4), dpi=300)
    colors={cluster:color for cluster,color in zip(adata_merge.obs[col].cat.categories.tolist(),adata_merge.uns[f'{col}_colors'])}
    categorical_scatter(data=adata_merge,
                        ax=ax,coord_base=coord_base,
                        hue=col,text_anno=adata.obs[col].cat.codes.map(str),#text_anno=col,
                        palette=colors,
                        max_points=None,
                        text_kws=dict(color='white',fontweight="bold",fontsize=4,
                                     bbox=dict(facecolor=colors,boxstyle='circle', #ellipse, round
                                               edgecolor='white',fill=True,linewidth=0.5,alpha=0.75)),
                        luminance=0.48,
                        # dodge_text=True,dodge_kws={
                        #          "force_points": (15, 1),
                        #           'autoalign':'y'},
                        show_legend=True,#legend_kws=dict(ncol=1)
                       )
    ax.set_title(col,fontsize=14)
    fig.tight_layout()
    fig.savefig(f"{image_path}/seurat.integrated.leiden.{coord_base}.pdf",dpi=300,bbox_inches='tight') #
    plt.show()

In [None]:
for coord_base in ['tsne', 'umap']:
    #fig, axes = plt.subplots(nrows=1,ncols=3,figsize=(12, 4),dpi=300,constrained_layout=True)
    for col,modality in zip(hue_cols,modalities):
        fig, ax = plt.subplots(figsize=(4.5, 4), dpi=300)
        adata=adata_dict[modality]
        colors={cluster:color for cluster,color in zip(adata.obs[col].cat.categories.tolist(),adata.uns[f'{col}_colors'])}
        ncol=2 if len(colors)>20 else 1
        categorical_scatter(data=adata_merge[adata_merge.obs['Modality'] != modality],
                            coord_base=coord_base,
                            max_points=None,hue=None,
                            scatter_kws=dict(color='lightgrey'),ax=ax)
        categorical_scatter(data=adata,
                            ax=ax,coord_base=coord_base,
                            hue=col,text_anno=col,
                            palette=colors,
                            max_points=None,
                            text_kws=dict(color='black',fontweight="bold",fontsize=4,#fontstretch='ultra-condensed',
                                     bbox=dict(facecolor=colors,boxstyle='round', #ellipse, round
                                               edgecolor='white',fill=True,linewidth=0.5,alpha=0.75)),
                            luminance=0.48,
                            # dodge_text=True,dodge_kws={
                            #          "force_points": (15, 1),
                            #           'autoalign':'y'},
                            show_legend=False,legend_kws=dict(ncol=ncol,fontsize=5.5))
        ax.set_title(modality,fontsize=14)
        fig.tight_layout()
        fig.savefig(f"{image_path}/seurat.integrated.{modality}.{coord_base}_without_legend.pdf",dpi=300,bbox_inches='tight') #
        plt.show()

In [None]:
for coord_base in ['tsne', 'umap']:
    #fig, axes = plt.subplots(nrows=1,ncols=3,figsize=(12, 4),dpi=300,constrained_layout=True)
    for col,modality in zip(hue_cols,modalities):
        fig, ax = plt.subplots(figsize=(5,4), dpi=300)
        adata=adata_dict[modality]
        ct2code=adata.obs.assign(code=adata.obs[col].cat.codes).loc[:,[col,'code']].drop_duplicates().set_index(col).code.to_dict()
        id_colors={str(ct2code[cluster]):color for cluster,color in zip(adata.obs[col].cat.categories.tolist(),adata.uns[f'{col}_colors'])}
        colors={f"{ct2code[cluster]}: {cluster}":color for cluster,color in zip(adata.obs[col].cat.categories.tolist(),adata.uns[f'{col}_colors'])}
        ncol=2 if len(colors)>40 else 1
        adata.obs[f'code_{col}']=adata.obs[col].apply(lambda x:f"{ct2code[x]}: {x}")
        categorical_scatter(data=adata_merge[adata_merge.obs['Modality'] != modality],
                            coord_base=coord_base,
                            max_points=None,hue=None,
                            scatter_kws=dict(color='lightgrey'),ax=ax)
        categorical_scatter(data=adata,
                            ax=ax,coord_base=coord_base,
                            hue=f'code_{col}',text_anno=adata.obs[col].cat.codes.map(str),#text_anno=col,
                            palette=colors,
                            max_points=None,#text_anno_palette=id_colors,
                            #text_anno=mc_data.obs[mc_col].replace(mask_dict_mc),
                            text_kws=dict(color='white',fontweight="bold",fontsize=4,
                                     bbox=dict(facecolor=id_colors,boxstyle='round', #ellipse, round,circle
                                               edgecolor='white',fill=True,linewidth=0.5,alpha=0.75)),
                            luminance=0.48,
                            # dodge_text=True,dodge_kws={
                            #          "force_points": (15, 1),
                            #           'autoalign':'y'},
                            show_legend=True,
                            legend_kws=dict(
                                ncol=ncol,loc='upper left',bbox_to_anchor=(1,1),
                                borderpad=0.4, # pad between marker (text) and border
                                labelspacing=0.2, #The vertical space between the legend entries, in font-size units
                                handleheight=0.5, #The height of the legend handles, in font-size units.
                                handletextpad=0.2, # The pad between the legend handle (marker) and text, in font-size units.
                                borderaxespad=0.2, # The pad between the Axes and legend border, in font-size units
                                columnspacing=0.2, #The spacing between columns, in font-size units
                                fontsize=5.5,title_fontsize=6
                                )
                           )
        ax.set_title(modality,fontsize=14)
        fig.tight_layout()
        fig.savefig(f"{image_path}/seurat.integrated.{modality}.{coord_base}_with_code.pdf",dpi=300,bbox_inches='tight') #
        plt.show()

In [None]:
use_adata=adata_merge[adata_merge.obs['Modality'] == spatial_batch_key].to_memory()
if spatial_cell_type==ref_cell_type:
    new_spatial_cell_type=f'SPATIAL.{spatial_cell_type}'
    # use_adata.obs.rename(columns={rna_cell_type:new_rna_cell_type},inplace=True)
    spatial_cell_type=new_spatial_cell_type
key=ref_cell_type
df=use_adata.obs.loc[:,[spatial_cell_type,f'infer_{key}',f'infer_{key}_prob']]
df=df.loc[df[f'infer_{key}_prob']>=0.5]
D=df[f'infer_{key}'].to_dict()
use_adata.obs[key]=use_adata.obs_names.map(D)
use_obs=df.index.tolist()
with plt.rc_context({"figure.figsize": (4, 4), "figure.dpi": (100)}):
    sc.pl.tsne(use_adata[use_obs,:], color=[f'infer_{key}',f'infer_{key}_prob'],
               wspace=0.8)
    sc.pl.umap(use_adata[use_obs,:], color=[f'infer_{key}',f'infer_{key}_prob'],
               wspace=0.8)

In [None]:
df=use_adata.obs.loc[:,[spatial_cell_type,f'infer_{key}']].value_counts().reset_index()
D=df.groupby(spatial_cell_type)['count'].sum()
df['N']=df[spatial_cell_type].map(D).astype(int)
df['fraction']=df['count'] / df.N
data=df.pivot(index=spatial_cell_type,columns=f'infer_{key}',values='fraction')
data

In [None]:
df_rows=data.index.to_series().to_frame()
cols=data.columns.tolist()
max_idx=np.argmax(data.fillna(0).values,axis=1)
df_rows['GROUP']=[cols[i] for i in max_idx]
use_rows=[]
for col in data.columns.tolist():
    df1=df_rows.loc[df_rows.GROUP==col]
    if df1.shape[0]==0:
        continue
    use_rows.extend(df1[spatial_cell_type].unique().tolist())
df_rows=df_rows.loc[use_rows]
ct2code=use_adata.obs.assign(code=use_adata.obs[spatial_cell_type].cat.codes).loc[:,[spatial_cell_type,'code']].drop_duplicates().set_index(spatial_cell_type).code.to_dict()
df_rows['Label']=df_rows[spatial_cell_type].apply(lambda x:f"{ct2code[x]}: {x}")
df_rows

In [None]:
print(df_rows.groupby('GROUP').apply(lambda x:x.Label.apply(lambda x:int(x.split(':')[0])).tolist()).to_dict())
print(df_rows.groupby('GROUP').apply(lambda x:x.Label.apply(lambda x:x.split(':')[1].strip()).tolist()).to_dict())

In [None]:
from PyComplexHeatmap import *
row_ha=HeatmapAnnotation(
    label=anno_label(df_rows.Label,colors='black',relpos=(0,0.5)),
    axis=0,orientation='right',
)

plt.figure(figsize=(24,12))
ClusterMapPlotter(
    data.loc[df_rows.index.tolist()],row_cluster=False,col_cluster=False,cmap='Reds',
    right_annotation=row_ha,row_split=df_rows['GROUP'],row_split_gap=0.5,
    row_split_order=df_rows['GROUP'].unique().tolist(),
    show_rownames=False,show_colnames=True,yticklabels=True,xticklabels=True,
    xticklabels_kws=dict(labelrotation=-60,labelcolor='blue',labelsize=10),
    yticklabels_kws=dict(labelcolor='red',labelsize=10),
    annot=True,fmt='.2g',linewidth=0.05,linecolor='gold',linestyle='-:',
    label='fraction',legend_kws=dict(extend='both',extendfrac=0.1),
    xlabel=ref_cell_type,ylabel=spatial_cell_type,
    xlabel_kws=dict(color='blue',fontsize=14,labelpad=5),xlabel_side='top',
    ylabel_kws=dict(color='red',fontsize=14,labelpad=5), #increace labelpad manually using labelpad (points)
    # xlabel_bbox_kws=dict(facecolor='green'),
    # ylabel_bbox_kws=dict(facecolor='chocolate',edgecolor='red'),|
    # standard_scale=0,
)
plt.savefig(f"{image_path}/confusion_matrix.pdf",dpi=300,bbox_inches='tight') #
plt.show()

In [None]:
ref_col = mc_cell_type
cluster_col = "leiden"

use_adata = adata_merge[adata_merge.obs['Modality'] == mc_batch_key].copy()
vc = use_adata.obs.loc[:, [cluster_col, ref_col]].value_counts().reset_index()
D = vc.groupby(cluster_col)['count'].sum()
vc['N']=vc[cluster_col].map(D).astype(int)
vc['fraction']=vc['count']/vc['N']
data = vc.pivot(index=cluster_col, columns=ref_col, values='fraction')
data.head()

df_rows=data.index.to_series().to_frame()
cols=data.columns.tolist()
max_idx=np.argmax(data.fillna(0).values,axis=1)
df_rows["GROUP"]=[cols[i] for i in max_idx]
use_rows=[]
for col in data.columns.tolist(): 
    df1=df_rows.loc[df_rows['GROUP']==col]
    if df1.shape[0]==0:
        continue
    use_rows.extend(df1[cluster_col].unique().tolist())
df_rows=df_rows.loc[use_rows]
ct2code=use_adata.obs.assign(code=use_adata.obs[cluster_col].cat.codes).loc[:,[cluster_col,'code']].drop_duplicates().set_index(cluster_col).code.to_dict()
# df_rows['Label']=df_rows[cluster_col].apply(lambda x: f"{ct2code[x]}: {x}")
ret = []
for x in df_rows[cluster_col].tolist():
    ret.extend([f"{ct2code[x]}: {x}"])
df_rows['Label']=ret
df_rows.head()

# Plot
row_ha=HeatmapAnnotation(
    label=anno_label(df_rows.Label,colors='black',relpos=(0,0.5)),
    axis=0,orientation='right',
)

plt.figure(figsize=(24,12))
ClusterMapPlotter(
    data.loc[df_rows.index.tolist()],row_cluster=False,col_cluster=False,cmap='Reds',
    right_annotation=row_ha,row_split=df_rows['GROUP'],row_split_gap=0.5,
    row_split_order=df_rows['GROUP'].unique().tolist(),
    show_rownames=False,show_colnames=True,yticklabels=True,xticklabels=True,
    xticklabels_kws=dict(labelrotation=-60,labelcolor='blue',labelsize=10),
    yticklabels_kws=dict(labelcolor='red',labelsize=10),
    annot=True,fmt='.2g',linewidth=0.05,linecolor='gold',linestyle='-:',
    label='fraction',legend_kws=dict(extend='both',extendfrac=0.1),
    xlabel=ref_col,ylabel=cluster_col,
    xlabel_kws=dict(color='blue',fontsize=14,labelpad=5),xlabel_side='top',
    ylabel_kws=dict(color='red',fontsize=14,labelpad=5), #increace labelpad manually using labelpad (points)
    # xlabel_bbox_kws=dict(facecolor='green'),
    # ylabel_bbox_kws=dict(facecolor='chocolate',edgecolor='red'),|
    # standard_scale=0,
)
plt.savefig(f"{image_path}/integrated_confusion_matrix.pdf",dpi=300,bbox_inches='tight')
plt.show()
plt.close()