# Xenium Human Breast Cancer

- **Creator**: Sebastian Birk (<sebastian.birk@helmholtz-munich.de>)
- **Date of Creation:** 22.01.2023
- **Date of Last Modification:** 11.02.2025 (Sebastian Birk; <sebastian.birk@helmholtz-munich.de>)

- In order to run this notebook, a trained model needs to be stored under ```../../artifacts/{dataset}/models/{model_label}/{load_timestamp}```.
    - dataset: ```xenium_human_breast_cancer```
    - model_label: ```reference```
    - load_timestamp: ```26102023_153021```
- Run this notebook in the nichecompass-reproducibility environment, installable from ```('../../envs/environment.yaml')```.

## 1. Setup

### 1.1 Import Libraries

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append("../../utils")

In [None]:
import ast
import argparse
import gc
import os
import random
import shutil
import warnings
from datetime import datetime

import anndata as ad
import matplotlib
import matplotlib.pyplot as plt
import mlflow
import networkx as nx
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import scanpy as sc
import scipy.sparse as sp
import scipy.stats as stats
import seaborn as sns
import squidpy as sq
import torch
from matplotlib import gridspec
from matplotlib.pyplot import rc_context
from pywaffle import Waffle
from sklearn.preprocessing import MinMaxScaler

from nichecompass.models import NicheCompass
from nichecompass.utils import (add_gps_from_gp_dict_to_adata,
                                aggregate_obsp_matrix_per_cell_type,
                                create_cell_type_chord_plot_from_df,
                                create_new_color_dict,
                                compute_communication_gp_network,
                                generate_enriched_gp_info_plots,
                                visualize_communication_gp_network)

from analysis_utils import (add_cell_type_latent_cluster_emphasis,
                            add_sub_cell_type,
                            compute_cell_type_latent_clusters,
                            generate_gp_info_plots,
                            plot_category_in_latent_and_physical_space,
                            plot_cell_type_latent_clusters,
                            plot_latent,
                            plot_physical_latent_for_cell_types,
                            sankey,
                            store_top_gps_summary)

### 1.2 Define Parameters

In [None]:
dataset = "xenium_human_breast_cancer"

#### 1.2.1 Generic Parameters

In [None]:
## Model
# AnnData keys
gp_names_key = "nichecompass_gp_names"
active_gp_names_key = "nichecompass_active_gp_names"
latent_key = "nichecompass_latent"

#### 1.2.2 Dataset-specific Parameters

In [None]:
load_timestamp = "26102023_153021"
model_label = "reference"
latent_leiden_resolution = 0.2
condition_key = "batch"
sample_key = "batch"
spot_size = 30
cell_type_key = "cell_type"
    
latent_cluster_key = f"latent_leiden_{str(latent_leiden_resolution)}"

### 1.3 Run Notebook Setup

In [None]:
sc.set_figure_params(figsize=(6, 6))
sc.settings.set_figure_params(dpi=300)
sns.set_style("whitegrid", {'axes.grid' : False})

In [None]:
# Ignore future warnings and user warnings
warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=UserWarning)
warnings.simplefilter(action="ignore", category=RuntimeWarning)

In [None]:
plt.rcParams['font.family'] = 'Helvetica'
plt.rcParams['font.size'] = 5
plt.rcParams['text.usetex'] = False
plt.rcParams['svg.fonttype'] = 'none'

### 1.4 Configure Paths and Create Directories

In [None]:
# Define paths
figure_folder_path = f"/lustre/groups/aih/sebastian.birk/workspace/projects/nichecompass-reproducibility/artifacts/{dataset}/figures/{model_label}/{load_timestamp}"
model_folder_path = f"/lustre/groups/aih/sebastian.birk/workspace/projects/nichecompass-reproducibility/artifacts/{dataset}/models/{model_label}/{load_timestamp}"
result_folder_path = f"/lustre/groups/aih/sebastian.birk/workspace/projects/nichecompass-reproducibility/artifacts/{dataset}/results/{model_label}/{load_timestamp}"
srt_data_folder_path = "/lustre/groups/aih/sebastian.birk/workspace/projects/nichecompass-reproducibility/datasets/st_data" # spatially resolved transcriptomics data
srt_data_gold_folder_path = f"{srt_data_folder_path}/gold"

# Create required directories
os.makedirs(figure_folder_path, exist_ok=True)
os.makedirs(result_folder_path, exist_ok=True)

## 2. Model

In [None]:
# Load trained model
model = NicheCompass.load(dir_path=model_folder_path,
                          adata=None,
                          adata_file_name=f"{dataset}_{model_label}.h5ad",
                          gp_names_key=gp_names_key)

In [None]:
model.adata.uns[gp_names_key] = np.array([gp for gp in model.adata.uns[gp_names_key] if not "Add-on " in gp])
model.adata.uns[active_gp_names_key] = np.array([gp for gp in model.adata.uns[active_gp_names_key] if not "Add-on " in gp])
model.adata.uns[gp_names_key] = np.array([gp.replace(" ", "_") for gp in model.adata.uns[gp_names_key]])
model.adata.uns[active_gp_names_key] = np.array([gp.replace(" ", "_") for gp in model.adata.uns[active_gp_names_key]])

In [None]:
# Dataset-specific metadata
model.adata.obs["batch"] = model.adata.obs["batch"].replace({"sample1": "Replicate 1", "sample2": "Replicate 2"})

trans_from=[['Epi_ABCC11+', 'Epi_FOXA1+', 'Epi_AGR3+', 'Epi_CENPF+', 'mgEpi_KRT14+', 'Epi_KRT14+'],['EC_CLEC14A+', 'EC_CAVIN2+'],['adipo_FB', 'GJB2+iKC-FB'],['EMT-Epi1_CEACAM6+', 'EMT-Epi2_CEACAM6+', 'EMT-Epi_SERPINA3+', 'EMT-Epi_KRT23+'],['DERL3+B', 'BANK1+B', 'B'],['eff_CD8+T1', 'eff_CD8+T2',],['tcm_CD4+T', 'CD161+FOXP3+T'],['NK/T'],['ADIPOQ+Mast'],['M2MØ', 'MMP12+miMØ'], ['DC1']]
trans_to = ['Epithelial', 'Endothelial', 'Fibroblast', 'EMT', 'B_cells', 'CD8+T', 'CD4+T', 'NK/T', 'Mast', 'MØ', 'DC']

model.adata.obs['cell_type'] = [str(i) for i in model.adata.obs['cell_states']]
for leiden,celltype in zip(trans_from, trans_to):
    for leiden_from in leiden:
        model.adata.obs['cell_type'][model.adata.obs['cell_type'] == leiden_from] = celltype

model.adata.obs[cell_type_key] = model.adata.obs[cell_type_key].replace("MØ", "Mɸ")
model.adata.obs["cell_states"] = model.adata.obs["cell_states"].replace("M2MØ", "M2Mɸ").replace("MMP12+miMØ", "MMP12+miMɸ")

In [None]:
# Check number of active gene programs
active_gps = model.get_active_gps()
print(f"Number of total gene programs: {len(model.adata.uns[gp_names_key])}.")
print(f"Number of active gene programs: {len(active_gps)}.")

gp_summary_df = model.get_gp_summary()
gp_summary_df[gp_summary_df["gp_active"] == True][20:]

samples = model.adata.obs[sample_key].unique().tolist()
model.add_active_gp_scores_to_obs()

In [None]:
# Compute Leiden clustering
sc.tl.leiden(adata=model.adata,
             resolution=latent_leiden_resolution,
             key_added=latent_cluster_key,
             neighbors_key=latent_key)

## 3. Analysis

### 3.1 Create Figures

In [None]:
### Fig. 4a ###
model.adata.obs['niche'] = model.adata.obs['latent_leiden_0.2'].copy()
model.adata.obs['niche'] = model.adata.obs['niche'].cat.rename_categories(['FB-Epi', 'CD4+T', 'EMT-Immune', 'Epi-Immune', 'FB-EMT', 'FB-Lymphoid', 'FB-Myeloid', 'FB-Endo', 'Mast-Stromal', 'EMT-Mɸ', 'EMT-Endo', 'Epi-Bcells', 'Stromal', 'Endo-Lymphoid'])

niche_colors = create_new_color_dict(
    adata=model.adata,
    cat_key='niche')

save_fig = True
file_path = f"{figure_folder_path}/res_{latent_leiden_resolution}_" \
            "niches_latent_physical_space.svg"

plot_category_in_latent_and_physical_space(
    adata=model.adata,
    plot_label="Niches",
    cat_key="niche",
    groups=None,
    sample_key=sample_key,
    samples=samples,
    cat_colors=niche_colors,
    size=(720000 / len(model.adata)),
    spot_size=spot_size,
    save_fig=save_fig,
    file_path=file_path)

In [None]:
### Fig. 4b ###
cell_type_colors = create_new_color_dict(
    adata=model.adata,
    color_palette=("cell_type_20"),
    cat_key=cell_type_key)

save_fig = True
file_path = f"{figure_folder_path}/" \
            "cell_types_latent_physical_space.svg"

plot_category_in_latent_and_physical_space(
    adata=model.adata,
    plot_label="Cell Types",
    cat_key=cell_type_key,
    groups=None,
    sample_key=sample_key,
    samples=samples,
    cat_colors=cell_type_colors,
    size=(720000 / len(model.adata)),
    spot_size=spot_size,
    save_fig=save_fig,
    file_path=file_path)

In [None]:
### Fig. 4c ###
condition_colors = create_new_color_dict(
    adata=model.adata,
    color_palette="batch",
    cat_key=condition_key)

save_fig = True
file_path = f"{figure_folder_path}/" \
            "batches_latent_physical_space.svg"

plot_category_in_latent_and_physical_space(
    adata=model.adata,
    plot_label="Batches",
    cat_key=condition_key,
    groups=None,
    sample_key=sample_key,
    samples=samples,
    cat_colors=condition_colors,
    size=(720000 / len(model.adata)),
    spot_size=spot_size,
    save_fig=save_fig,
    file_path=file_path)

sc.tl.dendrogram(adata=model.adata,
                 use_rep="nichecompass_latent",
                 linkage_method="ward",
                 groupby="niche")

tmp = pd.crosstab(model.adata.obs["niche"],model.adata.obs['batch'], normalize='index')
tmp = tmp.reindex(model.adata.uns["dendrogram_niche"]["categories_ordered"][::])
ax = tmp.plot.barh(color=condition_colors, stacked=True, figsize=(3, 4)).legend(loc='center left', bbox_to_anchor=(1.0, 0.5))
plt.xlabel("Data Source Proportions")
plt.savefig(f"{figure_folder_path}/niche_batch_proportions.svg", bbox_inches='tight')

In [None]:
### Fig. 4e ###
model.adata.obs["Ptprc_ligand_receptor_target_gene_GP"] = -1 * model.adata.obs["Ptprc_ligand_receptor_target_gene_GP"]
model.adata.obs["Add-on_68_GP"] = -1 * model.adata.obs["Add-on_68_GP"]

gps = ["Ptprc_ligand_receptor_target_gene_GP",
       'Add-on_37_GP',
       'Add-on_86_GP',
       'Add-on_68_GP',
       'Add-on_51_GP',
       'Add-on_66_GP']

df = model.adata.obs[["niche"] + gps].groupby("niche").mean()
df = df.reindex(model.adata.uns["dendrogram_niche"]["categories_ordered"][::-1])

scaler = MinMaxScaler()
normalized_columns = scaler.fit_transform(df)
normalized_df = pd.DataFrame(normalized_columns, columns=df.columns)
normalized_df.index = df.index
normalized_df.columns = [col.split("_lig")[0]
                         .split("_met")[0]
                         .replace("_", " ")
                         .replace("9-cis-Retinoic", "9-cis-Ret.")
                         .replace("Add-on", "De novo")
                         .replace("GP", "") + " GP" for col in normalized_df.columns]

plt.figure(figsize=(6, 2))  # Set the figure size
ax = sns.heatmap(normalized_df.transpose(),
            cmap='viridis',
            annot=False,
            linewidths=0)
plt.xticks(rotation=45,
           ha="right"
          )
plt.savefig(f"{figure_folder_path}/enriched_gps_heatmap.svg",
            bbox_inches="tight")

model.adata.obs["Ptprc_ligand_receptor_target_gene_GP"] = -1 * model.adata.obs["Ptprc_ligand_receptor_target_gene_GP"]
model.adata.obs["Add-on_68_GP"] = -1 * model.adata.obs["Add-on_68_GP"]

In [None]:
### Fig. 4f ###
fig, axs = plt.subplots(nrows=4,
                        ncols=2,
                        figsize=(20, 15))
color_map = "RdGy_r"
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 1"],
    color="Add-on_37_GP",
    color_map=color_map,
    spot_size=spot_size,
    title=f"Add-on_37_GP",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[0, 0],
    show=False)    
color_map = "RdPu"
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 1"],
    color="KRT16",
    color_map=color_map,
    spot_size=spot_size,
    title=f"KRT16",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[0, 1],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 1"],
    color="KRT14",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title=f"KRT14",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[1, 0],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 1"],
    color="KRT5",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title="KRT5",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[1, 1],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 1"],
    color="KRT6B",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title="KRT6B",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[2, 0],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 1"],
    color="KRT15",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title="KRT15",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[2, 1],
    show=False)
plt.show()
fig.savefig(f"{figure_folder_path}/addon37_gp_genes_rep1.svg",
            bbox_inches="tight")

In [None]:
### Fig. 4g ###
fig, axs = plt.subplots(nrows=4,
                        ncols=2,
                        figsize=(20, 15))
color_map = "RdGy_r"
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 1"],
    color="Add-on_86_GP",
    color_map=color_map,
    spot_size=spot_size,
    title=f"Add-on_86_GP",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[0, 0],
    show=False)    
color_map = "RdPu"
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 1"],
    color="MLPH",
    color_map=color_map,
    spot_size=spot_size,
    title=f"MLPH",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[0, 1],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 1"],
    color="EPCAM",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title=f"EPCAM",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[1, 0],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 1"],
    color="FOXA1",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title="FOXA1",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[1, 1],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 1"],
    color="ELF3",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title="ELF3",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[2, 0],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 1"],
    color="KRT8",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title="KRT8",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[2, 1],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 1"],
    color="KRT7",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title="KRT7",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[3, 0],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 1"],
    color="ABCC11",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title="ABCC11",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[3, 1],
    show=False)
plt.show()
fig.savefig(f"{figure_folder_path}/addon86_gp_genes_rep1.svg",
            bbox_inches="tight")

In [None]:
### Fig. 4h ###
gp_summary = model.get_gp_summary()
for col in gp_summary.columns:
    gp_summary[col] = gp_summary[col].astype(str)
model.adata.uns["nichecompass_gp_summary"] = gp_summary

df = model.adata.uns['nichecompass_gp_summary']
df.head()

df_gp37 = df[df['gp_name'] == 'Add-on_37_GP']

gp37_genes = df_gp37['gp_source_genes'].values[0]

#  For simplicity and to make the figure easier to read, we will select the top 60 genes
genes_list = ast.literal_eval(gp37_genes)
genes = genes_list[:60]

# Extract gene weights
gp37_weights = df_gp37['gp_source_genes_weights'].values[0]

weights_list = ast.literal_eval(gp37_weights)
weights = weights_list[:60]

# Add gene enrichment information
# This can be done using your favourite Gene Enrichment Analysis tool or 
# Large Language Model. Here we used ToppFun, but ChatGPT concurred with 
# the programmes classification
classification = {
    'Cytoskeletal Proteins': {
        'Keratin Family': ['KRT16', 'KRT14', 'KRT5', 'KRT6B', 'KRT15', 'KRT23', 'KRT7'],
        'Tubulin Family': ['TUBB2B'],
        'Regulation of Actin': ['RAPGEF3'],
        'Myosins': ['MYH11']
    },
    'Cell Adhesion and Junctions': {
        'Desmosomes and Junctions': ['DSP', 'JUP', 'CLDN4', 'CEACAM6', 'CEACAM8'],
        'Tight Junctions and Adhesion': ['TACSTD2', 'AGR3', 'LYPD3']
    },
    'Signaling Receptors': {
        'Growth Factors': ['KIT', 'PDGFRA', 'PDGFRB', 'EGFR'],
        'Other': ['AVPR1A', 'OXTR', 'IL2RA']
    },
    'Gene Regulation': {
        'Forkhead Box Family': ['FOXC2', 'FOXP3'],
        'Transcription Factors': ['TCF7', 'KLF5'],
        'Open Reading Frames': ['C15orf48', 'C6orf132', 'C2orf42', 'C5orf46']
    },
    'Enzymes and Metabolic Proteins': {
        'Kinases and Phosphatases': ['MYLK', 'ERN1'],
        'Transporters': ['ABCC11'],
        'Peptidases and Hydrolases': ['SEC11C', 'USP53', 'PTRHD1', 'SH3YL1'],
        'Other': ['LGALSL']
    },
    'Immune System Genes': {
        'Chemokine Receptors': ['CXCR4', 'CX3CR1'],
        'Cell Surface Markers': ['CD14', 'GNLY', 'PIGR'],
        'Inflammatory Mediators': ['S100A14', 'EGFL7'],
        'Immune Regulation': ['TAC1']
    },
    'Secreted Proteins': {
        'Secreted Proteins': ['SERPINA3', 'SCGB2A1'],
        'ECM Proteins': ['CAV1']
    },
    'Cell Organization': {
        'Transport Proteins': ['NOSTRIN'],
        'Structural Proteins': ['SVIL', 'ANKRD29', 'GLIPR1']
    }
}

data = []
for category, subcategories in classification.items():
    for subcategory, genes_list in subcategories.items():
        for gene in genes_list:
            weight = weights[genes.index(gene)]
            data.append([category, subcategory, gene, weight])

df = pd.DataFrame(data, columns=['Category', 'Subcategory', 'Gene', 'Gene Weight'])

fig = px.sunburst(df,
                  path=['Category', 'Subcategory', 'Gene'],
                  color='Gene Weight',
                  color_continuous_scale='RdYlBu_r')

fig.update_layout(
    coloraxis_colorbar=dict(
        title="Gene Importance Weight",
        title_side='right',
        titlefont=dict(size=12),   
        tickfont=dict(size=10),   
        tickvals=[min(df['Gene Weight']), max(df['Gene Weight'])],
        len=0.7,                  
        thickness=15,             
        xpad=10,                  
        x=1.02                    
    ),
    width=1000,
    height=1000
)

fig.update_traces(
    textinfo="label",
    insidetextfont=dict(size=15)
)

fig.show()
fig.write_image(f'{figure_folder_path}/RdYlBu_r_add_on_37sunburst_plot.svg', scale=3)

In [None]:
### Fig. 4i ###
gp_summary = model.get_gp_summary()
for col in gp_summary.columns:
    gp_summary[col] = gp_summary[col].astype(str)
model.adata.uns["nichecompass_gp_summary"] = gp_summary

df = model.adata.uns['nichecompass_gp_summary']
df.head()

df_gp86 = df[df['gp_name'] == 'Add-on_86_GP']

gp86_genes = df_gp86['gp_source_genes'].values[0]

genes_list = ast.literal_eval(gp86_genes)
genes = genes_list[:60]

gp86_weights = df_gp86['gp_source_genes_weights'].values[0]

weights_list = ast.literal_eval(gp86_weights)
weights = weights_list[:60]

classification = {
    'Cytoskeletal Proteins': {
        'Keratin Family': ['KRT8', 'KRT7'],
        'Myosins': ['MYO5B'],
        'Actin-binding': ['CTTN']
    },
    'Cell Adhesion and Junctions': {
        'Cell Adhesion Molecules': ['EPCAM', 'TACSTD2', 'LYPD3', 'CEACAM8', 'CLDN4'],
        'Desmosomes and Junctions': ['DSP', 'JUP'],
        'ECM Proteins': ['DPT', 'FBLN1', 'FBLN1', 'DPT', 'MEDAG', 'STC1']
    },
    'Gene Regulation': {
        'Forkhead Box Family': ['FOXA1'],
        'Transcription Factors': ['ELF3', 'GATA3', 'KLF5', 'TFAP2A', 'ZEB2'],
        'Open Reading Frames': ['C6orf132']
    },
    'Metabolism': {
        'Fatty Acid Synthesis': ['FASN', 'SCD'],
        'Transporters': ['ABCC11', 'SLC5A6'],
        'Other Enzymes': ['USP53', 'SH3YL1', 'SMS', 'SQLE', 'NARS', 'SERHL2']
    },
    'Immune System Genes': {
        'Cytokines': ['IL7R'],
        'Cell Surface Markers': ['CD79A', 'RHOH', 'RTKN2'],
        'Inflammatory Mediators': ['S100A14'],
        'Immune Regulation': ['TRAF4']
    },
    'Signaling': {
        'Growth Factors': ['PDGFRA', 'ERBB2'],
        'Hormones': ['ESR1', 'AR'],
        'TNF': ['TRAF4'],
        'JAK-STAT': ['OCIAD2'],
        'Wnt/β-catenin': ['TCIM']
    },
    'Cell Cycle and Proliferation': {
        'Cell Cycle Regulators': ['CCND1', 'CENPF', 'TOP2A'],
        'Proliferation Markers': ['PCLAF', 'TPD52']
    }
}

# Create the DataFrame
data = []
for category, subcategories in classification.items():
    for subcategory, genes_list in subcategories.items():
        for gene in genes_list:
            weight = weights[genes.index(gene)]
            data.append([category, subcategory, gene, weight])

df = pd.DataFrame(data, columns=['Category', 'Subcategory', 'Gene', 'Gene Weight'])

fig = px.sunburst(df,
                  path=['Category', 'Subcategory', 'Gene'],
                  color='Gene Weight',
                  color_continuous_scale='RdYlBu_r')

fig.update_layout(
    coloraxis_colorbar=dict(
        title="Gene Importance Weight",
        title_side='right',
        titlefont=dict(size=12),   
        tickfont=dict(size=10),   
        tickvals=[min(df['Gene Weight']), max(df['Gene Weight'])],
        len=0.7,                  
        thickness=15,             
        xpad=10,                  
        x=1.02                    
    ),
    width=1000,
    height=1000
)

fig.update_traces(
    textinfo="label",
    insidetextfont=dict(size=15)
)

fig.show()
fig.write_image(f'{figure_folder_path}/RdYlBu_r_add_on_86sunburst_plot.svg', scale=3)

ChatGPT prompt:

"""To classify genes by categories and subcategories, we used ChatGPT 4o LLM. Here is the
prompt that we used: ‘Please classify the genes by their broad categories and subcategories
based on their biological functions and roles. Use two layers of classification: categories and
subcategories, for example, Immune system genes will be a category and Cytokines will be a
subcategory. Don't use Miscellaneous or Other classification. Use up to nine categories for the
first layers of classification. Output the results in the data frame: Category - subcategory - gene.
Here is the list of genes to classify:..."""

In [None]:
### Supplementary Fig. 25a ###
cell_state_colors = create_new_color_dict(
    adata=model.adata,
    skip_default_colors=50,
    cat_key="cell_states")

save_fig = True
file_path = f"{figure_folder_path}/" \
            "cell_states_latent_physical_space.svg"

plot_category_in_latent_and_physical_space(
    adata=model.adata,
    plot_label="Cell States",
    cat_key="cell_states",
    groups=None,
    sample_key=sample_key,
    samples=samples,
    cat_colors=cell_state_colors,
    size=(720000 / len(model.adata)),
    spot_size=spot_size,
    save_fig=save_fig,
    file_path=file_path)

In [None]:
### Supplementary Fig. 25b ###
tmp = pd.crosstab(model.adata.obs["niche"], model.adata.obs[cell_type_key], normalize='index')
tmp = tmp.reindex(model.adata.uns["dendrogram_niche"]["categories_ordered"][::])
ax = tmp.plot.barh(color=cell_type_colors, stacked=True).legend(loc='center left', bbox_to_anchor=(1.0, 0.5))
plt.xlabel("Cell Type Proportions")
plt.savefig(f"{figure_folder_path}/niche_cell_type_proportions.svg", bbox_inches='tight')

In [None]:
### Supplementary Fig. 25c ###
adata_sub = model.adata[model.adata.obs['niche'].isin(['EMT-Endo', 'EMT-Mɸ'])]
marker_genes = ['KRT16', 'KRT14', 'KRT5', 'KRT6B', 'KRT15', 'MLPH', 'EPCAM', 'FOXA1', 'ELF3', 'KRT8', 'KRT7', 'ABCC11']
ax_dict = sc.pl.dotplot(adata_sub, marker_genes, 'niche', swap_axes=True, dendrogram=True, standard_scale='var', cmap='magma', show=False)

for ax in ax_dict.values():
    for label in ax.get_xticklabels():
        label.set_rotation(45)
        label.set_ha('right')

plt.savefig(f"{figure_folder_path}/addon_gps_gene_expr_dotplot.svg",
    bbox_inches="tight")
plt.show()

save_fig = True
file_path = f"{figure_folder_path}/res_{latent_leiden_resolution}_" \
            "niches_filtered_latent_physical_space.svg"

plot_category_in_latent_and_physical_space(
    adata=model.adata,
    plot_label="Niches",
    cat_key="niche",
    groups=['EMT-Endo', 'EMT-Mɸ'],
    sample_key=sample_key,
    samples=samples,
    cat_colors=niche_colors,
    size=(720000 / len(model.adata)),
    spot_size=spot_size,
    save_fig=save_fig,
    file_path=file_path)

In [None]:
### Supplementary Fig. 25d ###
adata_subset = model.adata[model.adata.obs['niche'].isin(['Epi-Immune', 'FB-EMT', 'EMT-Immune', 'EMT-Mɸ', 'EMT-Endo'])]

df = adata_subset.obs.groupby(['niche', 'cell_type']).size().reset_index(name = 'counts')

df['proportions'] = df.groupby('niche')['counts'].transform(lambda x: x / x.sum() * 100)
df['waffle_counts'] = (df['proportions'] * 10).astype(int)

for group in df['niche'].unique():
    temp_df = df[df['niche'] == group]

    data = dict(zip(temp_df['cell_type'], temp_df['waffle_counts']))
    colors = [cell_type_colors[cell_type] for cell_type in temp_df['cell_type']]
    fig = plt.figure(
        FigureClass = Waffle, 
        rows = 10, 
        values = data, 
        title = {'label': f'Niche {group}', 'loc': 'left', 'fontsize': 14},
        labels = [f"{k} ({v}%)" for k, v in zip(temp_df['cell_type'], temp_df['proportions'].round(2))],
        #legend = {'loc': 'lower left', 'bbox_to_anchor': (0, -0.4), 'ncol': len(data), 'framealpha': 0},
        legend = {'loc': 'lower left', 'bbox_to_anchor': (0, -0.4), 'ncol': len(data), 'framealpha': 0, 'fontsize': 16},
        figsize = (40, 4),
        colors = colors
    )
    plt.savefig(f"{figure_folder_path}/{group}_cell_type_proportions_waffle.svg", bbox_inches='tight')

In [None]:
### Supplementary Fig. 26a ###
model.adata.obs["Ptprc_ligand_receptor_target_gene_GP_sign_corrected"] = -1 * model.adata.obs["Ptprc_ligand_receptor_target_gene_GP"]
model.adata.obs["Add-on_68_GP_sign_corrected"] = -1 * model.adata.obs["Add-on_68_GP"]

fig, axs = plt.subplots(nrows=4,
                        ncols=2,
                        figsize=(20, 15))
color_map = "RdGy_r"
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 1"],
    color="Ptprc_ligand_receptor_target_gene_GP_sign_corrected",
    color_map=color_map,
    spot_size=spot_size,
    title=f"Ptprc_ligand_receptor_target_gene_GP",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[0, 0],
    show=False)    
color_map = "RdPu"
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 1"],
    color="PTPRC",
    color_map=color_map,
    spot_size=spot_size,
    title=f"PTPRC",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[0, 1],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 1"],
    color="MRC1",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title=f"MRC1",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[1, 0],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 1"],
    color="CD4",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title="CD4",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[1, 1],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 1"],
    color="CD247",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title="CD247",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[2, 0],
    show=False)
plt.show()
fig.savefig(f"{figure_folder_path}/ptprc_gp_genes_rep1.svg",
            bbox_inches="tight")

fig, axs = plt.subplots(nrows=4,
                        ncols=2,
                        figsize=(20, 15))
color_map = "RdGy_r"
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 1"],
    color="Add-on_68_GP_sign_corrected",
    color_map=color_map,
    spot_size=spot_size,
    title=f"Add-on_68_GP",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[0, 0],
    show=False)    
color_map = "RdPu"
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 1"],
    color="MMP1",
    color_map=color_map,
    spot_size=spot_size,
    title=f"MMP1",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[0, 1],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 1"],
    color="GJB2",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title=f"GJB2",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[1, 0],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 1"],
    color="C15orf48",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title="C15orf48",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[1, 1],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 1"],
    color="C5orf46",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title="C5orf46",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[2, 0],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 1"],
    color="FOXC2",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title="FOXC2",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[2, 1],
    show=False)
plt.show()
fig.savefig(f"{figure_folder_path}/addon68_gp_genes_rep1.svg",
            bbox_inches="tight")

fig, axs = plt.subplots(nrows=4,
                        ncols=2,
                        figsize=(20, 15))
color_map = "RdGy_r"
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 1"],
    color="Add-on_51_GP",
    color_map=color_map,
    spot_size=spot_size,
    title=f"Add-on_51_GP",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[0, 0],
    show=False)    
color_map = "RdPu"
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 1"],
    color="OPRPN",
    color_map=color_map,
    spot_size=spot_size,
    title=f"OPRPN",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[0, 1],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 1"],
    color="APOC1",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title=f"APOC1",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[1, 0],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 1"],
    color="LDHB",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title="LDHB",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[1, 1],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 1"],
    color="PTN",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title="PTN",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[2, 0],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 1"],
    color="KLF5",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title="KLF5",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[2, 1],
    show=False)
plt.show()
fig.savefig(f"{figure_folder_path}/addon51_gp_genes_rep1.svg",
            bbox_inches="tight")

fig, axs = plt.subplots(nrows=4,
                        ncols=2,
                        figsize=(20, 15))
color_map = "RdGy_r"
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 1"],
    color="Add-on_66_GP",
    color_map=color_map,
    spot_size=spot_size,
    title=f"Add-on_66_GP",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[0, 0],
    show=False)    
color_map = "RdPu"
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 1"],
    color="ALDH1A3",
    color_map=color_map,
    spot_size=spot_size,
    title=f"ALDH1A3",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[0, 1],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 1"],
    color="KRT23",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title=f"KRT23",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[1, 0],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 1"],
    color="KIT",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title="KIT",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[1, 1],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 1"],
    color="KRT15",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title="KRT15",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[2, 0],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 1"],
    color="PIGR",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title="PIGR",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[2, 1],
    show=False)
plt.show()
fig.savefig(f"{figure_folder_path}/addon66_gp_genes_rep1.svg",
            bbox_inches="tight")

In [None]:
### Supplementary Fig. 26b ###
fig, axs = plt.subplots(nrows=4,
                        ncols=2,
                        figsize=(20, 15))
color_map = "RdGy_r"
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 2"],
    color="Ptprc_ligand_receptor_target_gene_GP",
    color_map=color_map,
    spot_size=spot_size,
    title=f"Ptprc_ligand_receptor_target_gene_GP",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[0, 0],
    show=False)    
color_map = "RdPu"
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 2"],
    color="PTPRC",
    color_map=color_map,
    spot_size=spot_size,
    title=f"PTPRC",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[0, 1],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 2"],
    color="MRC1",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title=f"MRC1",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[1, 0],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 2"],
    color="CD4",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title="CD4",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[1, 1],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 2"],
    color="CD247",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title="CD247",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[2, 0],
    show=False)
plt.show()
fig.savefig(f"{figure_folder_path}/ptprc_gp_genes_rep2.svg",
            bbox_inches="tight")

fig, axs = plt.subplots(nrows=4,
                        ncols=2,
                        figsize=(20, 15))
color_map = "RdGy_r"
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 2"],
    color="Add-on_68_GP_sign_corrected",
    color_map=color_map,
    spot_size=spot_size,
    title=f"Add-on_68_GP",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[0, 0],
    show=False)    
color_map = "RdPu"
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 2"],
    color="MMP1",
    color_map=color_map,
    spot_size=spot_size,
    title=f"MMP1",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[0, 1],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 2"],
    color="GJB2",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title=f"GJB2",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[1, 0],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 2"],
    color="C15orf48",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title="C15orf48",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[1, 1],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 2"],
    color="C5orf46",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title="C5orf46",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[2, 0],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 2"],
    color="FOXC2",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title="FOXC2",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[2, 1],
    show=False)
plt.show()
fig.savefig(f"{figure_folder_path}/addon68_gp_genes_rep2.svg",
            bbox_inches="tight")

fig, axs = plt.subplots(nrows=4,
                        ncols=2,
                        figsize=(20, 15))
color_map = "RdGy_r"
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 2"],
    color="Add-on_51_GP",
    color_map=color_map,
    spot_size=spot_size,
    title=f"Add-on_51_GP",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[0, 0],
    show=False)    
color_map = "RdPu"
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 2"],
    color="OPRPN",
    color_map=color_map,
    spot_size=spot_size,
    title=f"OPRPN",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[0, 1],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 2"],
    color="APOC1",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title=f"APOC1",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[1, 0],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 2"],
    color="LDHB",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title="LDHB",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[1, 1],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 2"],
    color="PTN",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title="PTN",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[2, 0],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 2"],
    color="KLF5",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title="KLF5",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[2, 1],
    show=False)
plt.show()
fig.savefig(f"{figure_folder_path}/addon51_gp_genes_rep2.svg",
            bbox_inches="tight")

fig, axs = plt.subplots(nrows=4,
                        ncols=2,
                        figsize=(20, 15))
color_map = "RdGy_r"
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 2"],
    color="Add-on_66_GP",
    color_map=color_map,
    spot_size=spot_size,
    title=f"Add-on_66_GP",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[0, 0],
    show=False)    
color_map = "RdPu"
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 2"],
    color="ALDH1A3",
    color_map=color_map,
    spot_size=spot_size,
    title=f"ALDH1A3",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[0, 1],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 2"],
    color="KRT23",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title=f"KRT23",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[1, 0],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 2"],
    color="KIT",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title="KIT",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[1, 1],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 2"],
    color="KRT15",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title="KRT15",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[2, 0],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 2"],
    color="PIGR",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title="PIGR",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[2, 1],
    show=False)
plt.show()
fig.savefig(f"{figure_folder_path}/addon66_gp_genes_rep2.svg",
            bbox_inches="tight")

In [None]:
### Supplementary Fig. 26c ###
fig, axs = plt.subplots(nrows=4,
                        ncols=2,
                        figsize=(20, 15))
color_map = "RdGy_r"
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 2"],
    color="Add-on_37_GP",
    color_map=color_map,
    spot_size=spot_size,
    title=f"Add-on_37_GP",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[0, 0],
    show=False)    
color_map = "RdPu"
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 2"],
    color="KRT16",
    color_map=color_map,
    spot_size=spot_size,
    title=f"KRT16",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[0, 1],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 2"],
    color="KRT14",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title=f"KRT14",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[1, 0],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 2"],
    color="KRT5",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title="KRT5",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[1, 1],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 2"],
    color="KRT6B",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title="KRT6B",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[2, 0],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 2"],
    color="KRT15",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title="KRT15",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[2, 1],
    show=False)
plt.show()
fig.savefig(f"{figure_folder_path}/addon37_gp_genes_rep2.svg",
            bbox_inches="tight")

fig, axs = plt.subplots(nrows=4,
                        ncols=2,
                        figsize=(20, 15))
color_map = "RdGy_r"
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 2"],
    color="Add-on_86_GP",
    color_map=color_map,
    spot_size=spot_size,
    title=f"Add-on_86_GP",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[0, 0],
    show=False)    
color_map = "RdPu"
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 2"],
    color="MLPH",
    color_map=color_map,
    spot_size=spot_size,
    title=f"MLPH",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[0, 1],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 2"],
    color="EPCAM",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title=f"EPCAM",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[1, 0],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 2"],
    color="FOXA1",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title="FOXA1",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[1, 1],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 2"],
    color="ELF3",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title="ELF3",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[2, 0],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 2"],
    color="KRT8",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title="KRT8",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[2, 1],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 2"],
    color="KRT7",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title="KRT7",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[3, 0],
    show=False)
sc.pl.spatial(
    adata=model.adata[model.adata.obs["batch"] == "Replicate 2"],
    color="ABCC11",
    use_raw=False,
    color_map=color_map,
    spot_size=spot_size,
    title="ABCC11",
    legend_loc=None,
    colorbar_loc="bottom",
    ax=axs[3, 1],
    show=False)
plt.show()
fig.savefig(f"{figure_folder_path}/addon86_gp_genes_rep2.svg",
            bbox_inches="tight")

### 3.2 Save Results

In [None]:
# Log normalize counts for cellxgene server
sc.pp.normalize_total(model.adata, target_sum=1e4)
sc.pp.log1p(model.adata)

# Store gp summary in adata
gp_summary = model.get_gp_summary()
for col in gp_summary.columns:
    gp_summary[col] = gp_summary[col].astype(str)
model.adata.uns["nichecompass_gp_summary"] = gp_summary

model.adata.write(f"{result_folder_path}/{dataset}_analysis.h5ad")