# Mouse Brain Atlas

- **Creator**: Sebastian Birk (<sebastian.birk@helmholtz-munich.de>)
- **Date of Creation:** 22.01.2023
- **Date of Last Modification:** 11.01.2025 (Sebastian Birk; <sebastian.birk@helmholtz-munich.de>)

- In order to run this notebook, a trained model needs to be stored under ```../../../artifacts/{dataset}/models/{model_label}/{load_timestamp}```
    - dataset: ```mouse_brain_atlas```
    - model_label: ```reference```
    - load_timestamp: ```220824_000000_1```
- Run this notebook in the nichecompass-reproducibility environment, installable from ```('../../../envs/environment.yaml')```.

## 1. Setup

### 1.1 Import Libraries

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append("../../../utils")

In [None]:
import math
import os
import warnings

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
from matplotlib import ticker

from nichecompass.models import NicheCompass
from nichecompass.utils import create_new_color_dict

from analysis_utils import plot_category_in_latent_and_physical_space

### 1.2 Define Parameters

In [None]:
dataset = "mouse_brain_atlas"

#### 1.2.1 Generic Parameters

In [None]:
## Model
# AnnData keys
gp_names_key = "nichecompass_gp_names"
active_gp_names_key = "nichecompass_active_gp_names"

#### 1.2.2 Dataset-specific Parameters

In [None]:
load_timestamp = "220824_000000_1"
model_label = "reference"
cell_type_key = "cell_type"
spot_size = 50
samples = [f"batch{i}" for i in range(1, 240)]
latent_leiden_resolution = 0.2
sample_key = "batch"
    
latent_cluster_key = f"latent_leiden_{str(latent_leiden_resolution)}"

### 1.3 Run Notebook Setup

In [None]:
sc.set_figure_params(figsize=(6, 6))

In [None]:
# Ignore future warnings and user warnings
warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=UserWarning)
warnings.simplefilter(action="ignore", category=RuntimeWarning)

In [None]:
plt.rcParams['font.family'] = 'Helvetica'
plt.rcParams['font.size'] = 5

In [None]:
niche_color_map = {
    "0": "#66C5CC",
    "1": "#F6CF71",
    "2": "#F89C74",
    "3": "#DCB0F2",
    "4": "#87C55F",
    "5": "#9EB9F3",
    "6": "#FE88B1",
    "7": "#C9DB74",
    "8": "#8BE0A4",
    "9": "#B497E7",
    "10": "#D3B484",
    "11": "#B3B3B3",
    "12": "#276A8C",
    "13": "#DAB6C4",
    "14": "#9B4DCA",
    "15": "#9D88A2",
    "16": "#FF4D4D",
}

### 1.4 Configure Paths and Create Directories

In [None]:
# Define paths
figure_folder_path = f"../../../artifacts/{dataset}/figures/{model_label}/{load_timestamp}"
model_folder_path = f"../../../artifacts/{dataset}/models/{model_label}/{load_timestamp}"
result_folder_path = f"../../../artifacts/{dataset}/results/{model_label}/{load_timestamp}"

# Create required directories
os.makedirs(figure_folder_path, exist_ok=True)
os.makedirs(result_folder_path, exist_ok=True)

## 2. Data

In [None]:
adata = sc.read_h5ad(f"{model_folder_path}/anndata_umap_with_clusters.h5ad")

## 3. Analysis

### 3.1 Create Figures

In [None]:
# Preprocess data (niche selection and filtering)
niche_cell_counts = adata.obs["nichecompass_latent_cluster"].value_counts().to_dict()
retained_niches = [x for x, y in niche_cell_counts.items() if y > 100_000]
adata_filtered = adata[adata.obs["nichecompass_latent_cluster"].isin(retained_niches)]
print(f"Retaining {len(adata_filtered)} of {len(adata)} cells following filtering")
print(f"Retaining {len(retained_niches)} niches")

In [None]:
figure_folder_path

In [None]:
### Extended Data Fig. 9a ###
# Visualize subsample of embeddings
adata_filtered_subsample = sc.pp.subsample(
    adata_filtered, fraction=0.01, copy=True)

fig = sc.pl.umap(adata_filtered_subsample,
                 color="dataset",
                 title="NicheCompass GP embedding",
                 size=1, frameon=False,
                 return_fig=True)
plt.savefig(f"{figure_folder_path}/e9_a.svg")

In [None]:
sc.pl.umap(adata_filtered_subsample,
           color="nichecompass_latent_cluster",
           size=1,
           palette=niche_color_map)

In [None]:
### Extended Data Fig. 9b ###
# Plot niche composition
freq_table = pd.crosstab(
    adata_filtered.obs["dataset"],
    adata_filtered.obs["nichecompass_latent_cluster"]
)
freq_table

fig, ax = plt.subplots()
ax = freq_table.transpose().plot(kind="bar", stacked=True, ylabel="Number of cells", xlabel="NicheCompass niche", ax=ax)

ax.grid(which='major', axis='y', linestyle='--')
ax.grid(False, axis='x')
ax.spines[['right', 'top']].set_visible(False)
ax.spines[['left', 'bottom']].set_linewidth(1)
ax.spines[['left', 'bottom']].set_color("black")

plt.xticks(rotation=0)

ax.get_yaxis().set_major_formatter(ticker.FuncFormatter(lambda x, p: format(int(x), ',')))

plt.savefig(f"{figure_folder_path}/e9_b.svg")

In [None]:
### Extended Data Fig. 9c ###
# Plot spatial distribution
merfish_section_label = "C57BL6J-1.083"
starmap_section_label = "well11"

fig, axs = plt.subplots(1, 2)

merfish_selected_section_adata = adata_filtered[
    adata_filtered.obs["section"] == merfish_section_label]
sc.pl.spatial(merfish_selected_section_adata,
              spot_size=20,
              title="MERFISH",
              color="nichecompass_latent_cluster",
              palette=niche_color_map,
              ax=axs[0],
              return_fig=False,
              show=False,
              frameon=False)
axs[0].legend().set_visible(False)

def rotate_origin_only(xy, radians):
    """Only rotate a point around the origin (0, 0)."""
    x, y = xy
    xx = x * math.cos(radians) + y * math.sin(radians)
    yy = -x * math.sin(radians) + y * math.cos(radians)

    return [xx, yy]

starmap_selected_section_adata = adata_filtered[
    adata_filtered.obs["section"] == starmap_section_label]
spatial_coordinates = starmap_selected_section_adata.obsm["spatial"].tolist()
rotated_spatial_coordinates = [rotate_origin_only(xy, math.pi/2) for xy in spatial_coordinates]
starmap_selected_section_adata.obsm["spatial"] = np.array(rotated_spatial_coordinates)
sc.pl.spatial(starmap_selected_section_adata,
              spot_size=0.12,
              title="STARmap PLUS",
              color="nichecompass_latent_cluster",
              palette=niche_color_map,
              ax=axs[1],
              return_fig=False,
              show=False,
              frameon=False)

legend_elements = [matplotlib.patches.Patch(facecolor=y, edgecolor=y, label=x) for x, y in niche_color_map.items()]

leg = axs[1].legend(handles=legend_elements,
                    loc="right",
                    bbox_to_anchor=(1.5, 0.5),
                    frameon=False)

plt.savefig(f"{figure_folder_path}/e9_c.svg")

In [None]:
### Extended Data Fig. 9d ###
# Visualize niches
color_map = {"True": "blue", "False": "lightgrey"}

for selected_nichecompass_latent_cluster in retained_niches:
    
    fig, axs = plt.subplots(1, 2)
    
    # plot the merfish cluster
    merfish_selected_section_adata.obs["is_cluster"] = merfish_selected_section_adata.obs["nichecompass_latent_cluster"] == selected_nichecompass_latent_cluster
    merfish_selected_section_adata.obs["is_cluster"] = merfish_selected_section_adata.obs["is_cluster"].astype("str")
    sc.pl.spatial(merfish_selected_section_adata,
                  spot_size=20,
                  return_fig=False,
                  title="MERFISH",
                  color="is_cluster",
                  show=False,
                  ax=axs[0],
                  palette=color_map,
                  frameon=False)
    
    # plot the starmap cluster
    starmap_selected_section_adata.obs["is_cluster"] = starmap_selected_section_adata.obs["nichecompass_latent_cluster"] == selected_nichecompass_latent_cluster
    starmap_selected_section_adata.obs["is_cluster"] = starmap_selected_section_adata.obs["is_cluster"].astype("str")
    sc.pl.spatial(starmap_selected_section_adata,
                  spot_size=0.12,
                  return_fig=False,
                  title="STARmap PLUS",
                  color="is_cluster",
                  show=False,
                  ax=axs[1],
                  palette=color_map, 
                  frameon=False)
    
    axs[0].legend().set_visible(False)
    axs[1].legend().set_visible(False)
    
    fig.suptitle(f"niche {selected_nichecompass_latent_cluster}")
    plt.savefig(f"{figure_folder_path}/e9_d.svg")   

In [None]:
### Extended Data Fig. 9e ###
# Plot spatial distribution
merfish_section_label = "C57BL6J-3.015"
starmap_section_label = "sagittal1"

fig, axs = plt.subplots(1, 2)

merfish_selected_section_adata = adata_filtered[
    adata_filtered.obs["section"] == merfish_section_label]
sc.pl.spatial(merfish_selected_section_adata,
              spot_size=20,
              title="MERFISH",
              color="nichecompass_latent_cluster",
              palette=niche_color_map,
              ax=axs[0],
              return_fig=False,
              show=False,
              frameon=False)
axs[0].legend().set_visible(False)

def rotate_origin_only(xy, radians):
    """Only rotate a point around the origin (0, 0)."""
    x, y = xy
    xx = x * math.cos(radians) + y * math.sin(radians)
    yy = -x * math.sin(radians) + y * math.cos(radians)

    return [xx, yy]

starmap_selected_section_adata = adata_filtered[
    adata_filtered.obs["section"] == starmap_section_label]
spatial_coordinates = starmap_selected_section_adata.obsm["spatial"].tolist()
rotated_spatial_coordinates = [rotate_origin_only(xy, math.pi/2) for xy in spatial_coordinates]
starmap_selected_section_adata.obsm["spatial"] = np.array(rotated_spatial_coordinates)
sc.pl.spatial(starmap_selected_section_adata,
              spot_size=0.12,
              title="STARmap PLUS",
              color="nichecompass_latent_cluster",
              palette=niche_color_map,
              ax=axs[1],
              return_fig=False,
              show=False,
              frameon=False)

legend_elements = [matplotlib.patches.Patch(facecolor=y, edgecolor=y, label=x) for x, y in niche_color_map.items()]

leg = axs[1].legend(handles=legend_elements,
                    loc="right",
                    bbox_to_anchor=(1.5, 0.5),
                    frameon=False)

plt.savefig(f"{figure_folder_path}/e9_e1.svg")

In [None]:
### Extended Data Fig. 9e ###
# Plot spatial distribution
merfish_section_label = "C57BL6J-1.129"
starmap_section_label = "well10"

fig, axs = plt.subplots(1, 2)

merfish_selected_section_adata = adata_filtered[
    adata_filtered.obs["section"] == merfish_section_label]
sc.pl.spatial(merfish_selected_section_adata,
              spot_size=20,
              title="MERFISH",
              color="nichecompass_latent_cluster",
              palette=niche_color_map,
              ax=axs[0],
              return_fig=False,
              show=False,
              frameon=False)
axs[0].legend().set_visible(False)

def rotate_origin_only(xy, radians):
    """Only rotate a point around the origin (0, 0)."""
    x, y = xy
    xx = x * math.cos(radians) + y * math.sin(radians)
    yy = -x * math.sin(radians) + y * math.cos(radians)

    return [xx, yy]

starmap_selected_section_adata = adata_filtered[
    adata_filtered.obs["section"] == starmap_section_label]
spatial_coordinates = starmap_selected_section_adata.obsm["spatial"].tolist()
rotated_spatial_coordinates = [rotate_origin_only(xy, math.pi/2) for xy in spatial_coordinates]
starmap_selected_section_adata.obsm["spatial"] = np.array(rotated_spatial_coordinates)
sc.pl.spatial(starmap_selected_section_adata,
              spot_size=0.12,
              title="STARmap PLUS",
              color="nichecompass_latent_cluster",
              palette=niche_color_map,
              ax=axs[1],
              return_fig=False,
              show=False,
              frameon=False)

legend_elements = [matplotlib.patches.Patch(facecolor=y, edgecolor=y, label=x) for x, y in niche_color_map.items()]

leg = axs[1].legend(handles=legend_elements,
                    loc="right",
                    bbox_to_anchor=(1.5, 0.5),
                    frameon=False)

plt.savefig(f"{figure_folder_path}/e9_e2.svg")