# STARmap PLUS Mouse CNS

- **Creator**: Adam Boxall (<ab70@sanger.ac.uk>)
- **Date of Creation:** 27.07.2024
- **Date of Last Modification:** 27.12.2024 (Sebastian Birk; <sebastian.birk@helmholtz-munich.de>)

- In order to run this notebook, a trained model needs to be stored under ```../artifacts/{dataset}/models/{model_label}/{load_timestamp}```
    - dataset: ```starmap_plus_mouse_cns```
    - model_label: ```reference```
    - load_timestamp: ```02022024_170500_1```

## 1. Setup

### 1.1 Import Libraries

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append("../../utils")

In [None]:
import os
import math
import pickle
import re
import warnings

import anndata as ad
import matplotlib.pyplot as plt
import nrrd
import numpy as np
import pandas as pd
import scanpy as sc
import skimage
import torch
from matplotlib import ticker
from matplotlib import pyplot as plt
from matplotlib.collections import PatchCollection
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from seaborn import color_palette, light_palette
from STalign import STalign

from nichecompass.models import NicheCompass
from nichecompass.utils import create_new_color_dict

from analysis_utils import plot_category_in_latent_and_physical_space

### 1.2 Define Parameters

In [None]:
dataset = "starmap_plus_mouse_cns"

#### 1.2.1 Generic Parameters

In [None]:
## Model
# AnnData keys
gp_names_key = "nichecompass_gp_names"
active_gp_names_key = "nichecompass_active_gp_names"

#### 1.2.2 Dataset-specific Parameters

In [None]:
load_timestamp = "02022024_170500_1"
model_label = "reference"
#cell_type_key = "cell_type"
spot_size = 50
#samples = [f"batch{i}" for i in range(1, 240)]
latent_leiden_resolution = 0.2
sample_key = "biosample_id"
    
latent_cluster_key = f"latent_leiden_{str(latent_leiden_resolution)}"

### 1.3 Run Notebook Setup

In [None]:
sc.set_figure_params(figsize=(6, 6))

In [None]:
# Ignore future warnings and user warnings
warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=UserWarning)
warnings.simplefilter(action="ignore", category=RuntimeWarning)

In [None]:
plt.rcParams['font.family'] = 'Helvetica'
plt.rcParams['font.size'] = 5

### 1.4 Configure Paths and Create Directories

In [None]:
# Define paths
figure_folder_path = f"../../artifacts/{dataset}/figures/{model_label}/{load_timestamp}"
model_folder_path = f"../../artifacts/{dataset}/models/{model_label}/{load_timestamp}"
result_folder_path = f"../../artifacts/{dataset}/results/{model_label}/{load_timestamp}"
srt_data_folder_path = "../../datasets/st_data" # spatially resolved transcriptomics data
srt_data_bronze_folder_path = f"{srt_data_folder_path}/bronze"

# Create required directories
os.makedirs(figure_folder_path, exist_ok=True)
os.makedirs(result_folder_path, exist_ok=True)

## 2. Model

In [None]:
#tmp
#adata = sc.read_h5ad('../../artifacts/starmap_plus_mouse_cns/models/reference/02022024_170500_1/starmap_plus_mouse_cns_reference.h5ad')
class model:
    def __init__(self, adata):
        self.adata = adata
        
model.adata = adata

In [None]:
# Load trained model
model = NicheCompass.load(dir_path=model_folder_path,
                          adata=None,
                          adata_file_name=f"{dataset}_{model_label}.h5ad",
                          gp_names_key=gp_names_key)

In [None]:
model.adata.uns[gp_names_key] = np.array([gp for gp in model.adata.uns[gp_names_key] if not "Add-on " in gp])
model.adata.uns[active_gp_names_key] = np.array([gp for gp in model.adata.uns[active_gp_names_key] if not "Add-on " in gp])
model.adata.uns[gp_names_key] = np.array([gp.replace(" ", "_") for gp in model.adata.uns[gp_names_key]])
model.adata.uns[active_gp_names_key] = np.array([gp.replace(" ", "_") for gp in model.adata.uns[active_gp_names_key]])

## 3. Analysis

### 3.1 Create Figures

In [None]:
n_genes = len(model.adata.var_names)
n_cells = len(model.adata.obs_names)
print(f"This model ran using {n_genes} genes and {n_cells} cells.")

In [None]:
model.adata = model.adata[model.adata.obs["nichecompass_latent_cluster"] != "unassigned"]

# Niche composition

Visualise the latent space for primary and sub niches.

In [None]:
umap_embedding = model.adata.obsm["X_umap"]

In [None]:
# set interactive color

inactive_color = "#f0f0f0"

# set color map for primary niches

niche_labels = model.adata.obs["nichecompass_latent_cluster"].astype(str).unique().tolist()

def natural_sort(l):
    convert = lambda text: int(text) if text.isdigit() else text.lower()
    alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ]
    return sorted(l, key = alphanum_key)

niche_labels = natural_sort(niche_labels)
niche_colors = color_palette("husl", len(niche_labels))
niche_color_map = {key: value for key, value in zip(niche_labels, niche_colors)}

# set color map for sub niches

sub_niche_color_map = {}
for niche_label in niche_labels:
    niche_color = niche_color_map[niche_label]
    adata_sample = model.adata[model.adata.obs["nichecompass_latent_cluster"] == niche_label]
    sub_niche_labels = adata_sample.obs["nichecompass_latent_sub_cluster_label"].astype(str).unique().tolist()
    sub_niche_labels.sort()
    sub_niche_colors = light_palette(niche_color, len(sub_niche_labels) + 2, reverse=True)
    sub_niche_color_map_sample = {key: value for key, value in zip(sub_niche_labels[:len(sub_niche_labels)], sub_niche_colors)}
    sub_niche_color_map.update(sub_niche_color_map_sample)

In [None]:
fig, ax = plt.subplots()
ax.scatter(
    umap_embedding[:, 0],
    umap_embedding[:, 1],
    c=model.adata.obs["nichecompass_latent_cluster"].astype(str).map(niche_color_map),
    s=1
)
ax.grid(False)
ax.spines[['right', 'top']].set_visible(False)
ax.spines[['left', 'bottom']].set_linewidth(1)
ax.spines[['left', 'bottom']].set_color("black")
ax.set_xticks([])
ax.set_yticks([])

hfont = {'fontname':'Helvetica'}

ax.set_aspect('equal')
ax.margins(0.15)
plt.xlabel("UMAP 1", labelpad=7, **hfont)
plt.ylabel("UMAP 2", **hfont)

In [None]:
niche_cell_frequency = model.adata.obs["nichecompass_latent_cluster"].value_counts()
fig, ax = plt.subplots()
ax.bar(niche_cell_frequency.index, niche_cell_frequency, color=[niche_color_map[niche_label] for niche_label in niche_cell_frequency.index.tolist()], edgecolor="none")

ax.grid(which='major', axis='y', linestyle='--')
ax.grid(False, axis='x')
ax.spines[['right', 'top']].set_visible(False)
ax.spines[['left', 'bottom']].set_linewidth(1)
ax.spines[['left', 'bottom']].set_color("black")

ax.get_yaxis().set_major_formatter(ticker.FuncFormatter(lambda x, p: format(int(x), ',')))

hfont = {'fontname':'Helvetica'}
plt.xlabel("NicheCompass primary niche", labelpad=7, **hfont)
plt.ylabel("Number of cells", **hfont)

In [None]:
cluster_composition = model.adata.obs.groupby(["nichecompass_latent_cluster", "nichecompass_latent_sub_cluster_label"], observed=False).size().unstack()

categories = cluster_composition.axes[1].tolist()
clusters = cluster_composition.axes[0].tolist()

fig, ax = plt.subplots()
bottom = np.zeros(len(clusters))
for batch in categories:
    ax.bar(clusters, cluster_composition[batch], color=sub_niche_color_map[batch], label=batch, bottom=bottom)
    bottom += cluster_composition[batch]

ax.grid(which='major', axis='y', linestyle='--')
ax.grid(False, axis='x')
ax.spines[['right', 'top']].set_visible(False)
ax.spines[['left', 'bottom']].set_linewidth(1)
ax.spines[['left', 'bottom']].set_color("black")

ax.get_yaxis().set_major_formatter(ticker.FuncFormatter(lambda x, p: format(int(x), ',')))
hfont = {'fontname':'Helvetica'}
plt.xlabel("NicheCompass primary niche", labelpad=7, **hfont)
plt.ylabel("Number of cells", **hfont)


In [None]:
clusters = model.adata.obs["nichecompass_latent_cluster"].unique().tolist()
clusters = natural_sort(clusters)

fig, axs = plt.subplots(4, math.ceil(len(clusters)/4))

for cluster, ax in zip(clusters, axs.flat):
    adata_subset = model.adata[model.adata.obs["nichecompass_latent_cluster"] == cluster]
    umap_embedding_subset = adata_subset.obsm["X_umap"]
    
    ax.scatter(
        umap_embedding[:, 0],
        umap_embedding[:, 1],
        c=inactive_color,
        s=1
    )
    ax.scatter(
        umap_embedding_subset[:, 0],
        umap_embedding_subset[:, 1],
        c=adata_subset.obs["nichecompass_latent_sub_cluster_label"].astype(str).map(sub_niche_color_map),
        s=1
    )
    ax.grid(False)
    ax.spines[['right', 'top']].set_visible(False)
    ax.spines[['left', 'bottom']].set_linewidth(1)
    ax.spines[['left', 'bottom']].set_color("black")
    ax.set_xticks([])
    ax.set_yticks([])
    
    hfont = {'fontname':'Helvetica'}

    ax.set_title(f"Niche {cluster}")
    
    ax.set_aspect('equal')
    ax.margins(0.15)
    plt.xlabel("UMAP 1", labelpad=7, **hfont)
    plt.ylabel("UMAP 2", **hfont)
    
fig.set_figheight(10)
fig.set_figwidth(10)

Show the number of cells in each primary niche.

In [None]:
sample_order = [
                'spinalcord',
                'well01OB',
                'well1_5',
                'well01brain',
                'well2_5',
                'well03',
                'well3_5',
                'well04',
                'well05',
                'well11',
                'well06',
                'well07',
                'well7_5',
                'well08',
                'well09',
                'well10',
                'well10_5'
]

cluster_composition = model.adata.obs.groupby(["label", "nichecompass_latent_cluster"], observed=False).size().unstack()
cluster_composition = cluster_composition.loc[sample_order]

cluster_composition_np = cluster_composition.to_numpy()

N = cluster_composition_np.shape[0]
M = cluster_composition_np.shape[1]
xlabels = cluster_composition.columns.tolist()
ylabels = cluster_composition.index.values.tolist()

categories = model.adata.obs["nichecompass_latent_cluster"].astype(str).fillna("Unknown").unique().tolist()
colours = color_palette("colorblind", len(categories))
colour_map = {key: value for key, value in zip(categories, colours)}

x, y = np.meshgrid(np.arange(M), np.arange(N))

area = cluster_composition_np
c = [niche_color_map[xlabels[i]] for i in x.flat]

fig, ax = plt.subplots()

radius = np.sqrt(area/np.pi)
R = radius/radius.max()/2

circles = [plt.Circle((j,i), radius=r, color=c) for r, j, i, c in zip(R.flat, x.flat, y.flat, c)]
collection = PatchCollection(circles, match_original=True)
ax.add_collection(collection)

ax.set(xticks=np.arange(M), yticks=np.arange(N), xticklabels=xlabels, yticklabels=ylabels)
ax.set_xticks(np.arange(M+1)-0.5, minor=True)
ax.set_yticks(np.arange(N+1)-0.5, minor=True)
ax.grid(False)

hfont = {'fontname':'Helvetica'}

plt.ylabel("Tissue Section", **hfont)
plt.xlabel("Niche", **hfont)

ax.set_aspect('equal')

ax.grid(False)
ax.spines[['left', 'bottom', "right", "top"]].set_linewidth(1)
ax.spines[['left', 'bottom', "right", "top"]].set_color("black")

ax.margins(0.03)

fig.set_figheight(20)
fig.set_figwidth(5)
fig.savefig(f'{figure_folder_path}/b2.svg')
plt.show()

In [None]:
sample_order = [
        'sagittal1',
        'sagittal2',
        'sagittal3'
]

cluster_composition = model.adata.obs.groupby(["label", "nichecompass_latent_cluster"], observed=False).size().unstack()
cluster_composition = cluster_composition.loc[sample_order]

cluster_composition_np = cluster_composition.to_numpy()

N = cluster_composition_np.shape[0]
M = cluster_composition_np.shape[1]
xlabels = cluster_composition.columns.tolist()
ylabels = cluster_composition.index.values.tolist()

categories = model.adata.obs["nichecompass_latent_cluster"].astype(str).fillna("Unknown").unique().tolist()
colours = color_palette("colorblind", len(categories))
colour_map = {key: value for key, value in zip(categories, colours)}

x, y = np.meshgrid(np.arange(M), np.arange(N))

area = cluster_composition_np
c = [niche_color_map[xlabels[i]] for i in x.flat]

fig, ax = plt.subplots()

radius = np.sqrt(area/np.pi)
R = radius/radius.max()/2

circles = [plt.Circle((j,i), radius=r, color=c) for r, j, i, c in zip(R.flat, x.flat, y.flat, c)]
collection = PatchCollection(circles, match_original=True)
ax.add_collection(collection)

ax.set(xticks=np.arange(M), yticks=np.arange(N), xticklabels=xlabels, yticklabels=ylabels)
ax.set_xticks(np.arange(M+1)-0.5, minor=True)
ax.set_yticks(np.arange(N+1)-0.5, minor=True)
ax.grid(False)

hfont = {'fontname':'Helvetica'}

plt.ylabel("Tissue Section", **hfont)
plt.xlabel("Niche", **hfont)

ax.set_aspect('equal')

ax.grid(False)
ax.spines[['left', 'bottom', "right", "top"]].set_linewidth(1)
ax.spines[['left', 'bottom', "right", "top"]].set_color("black")

ax.margins(0.03)

fig.set_figheight(20)
fig.set_figwidth(5)
fig.savefig(f'{figure_folder_path}/a2.svg')
plt.show()

Show the cell composition of each niche.

In [None]:
cluster_composition = model.adata.obs.groupby(["Main_molecular_cell_type", "nichecompass_latent_cluster"], observed=False).size().unstack()

cluster_composition_np = cluster_composition.to_numpy()

N = cluster_composition_np.shape[0]
M = cluster_composition_np.shape[1]
xlabels = cluster_composition.columns.tolist()
ylabels = cluster_composition.index.values.tolist()

categories = model.adata.obs["nichecompass_latent_cluster"].astype(str).fillna("Unknown").unique().tolist()
colours = color_palette("colorblind", len(categories))

x, y = np.meshgrid(np.arange(M), np.arange(N))

area = cluster_composition_np
c = [niche_color_map[xlabels[i]] for i in x.flat]

fig, ax = plt.subplots()

radius = np.sqrt(area/np.pi)
R = radius/radius.max()/2

circles = [plt.Circle((j,i), radius=r, color=c) for r, j, i, c in zip(R.flat, x.flat, y.flat, c)]
collection = PatchCollection(circles, match_original=True)
ax.add_collection(collection)

ax.set(xticks=np.arange(M), yticks=np.arange(N), xticklabels=xlabels, yticklabels=ylabels)
ax.set_xticks(np.arange(M+1)-0.5, minor=True)
ax.set_yticks(np.arange(N+1)-0.5, minor=True)
ax.grid(False)


hfont = {'fontname':'Helvetica'}

plt.ylabel("Cell type", **hfont)
plt.xlabel("Niche", **hfont)

ax.set_aspect('equal')

ax.grid(False)
ax.spines[['left', 'bottom', "right", "top"]].set_linewidth(1)
ax.spines[['left', 'bottom', "right", "top"]].set_color("black")

ax.margins(0.03)

fig.set_figheight(20)
fig.set_figwidth(5)
fig.savefig(f'{figure_folder_path}/c1.svg')
plt.show()

Generate a 3D plot based on manual tissue section alignment.

In [None]:
rotation = {
    "sagittal1": 270,
    "sagittal2": 90,
    "sagittal3": 90,
    "spinalcord": 0,
    "well01OB": 90,
    "well1_5": 0,
    "well01brain": 90,
    "well2_5": 180,
    "well03": 90,
    "well3_5": 180,
    "well04": 270,
    "well05": 270,
    "well06": 90,
    "well07": 270,
    "well7_5": 0,
    "well08": 270,
    "well09": 270,
    "well10": 270,
    "well10_5": 0,
    "well11": 90
}

scale_x = {
    "sagittal1": -1,
    "sagittal2": -1,
    "sagittal3": -1,
    "spinalcord": -1,
    "well01OB": -1,
    "well1_5": -1,
    "well01brain": -1,
    "well2_5": -1,
    "well03": 1,
    "well3_5": -1,
    "well04": -1,
    "well05": -1,
    "well06": -1,
    "well07": -1,
    "well7_5": -1,
    "well08": -1,
    "well09": -1,
    "well10": -1,
    "well10_5": -1,
    "well11": -1
}

scale_y = {
    "sagittal1": 1,
    "sagittal2": 1,
    "sagittal3": 1,
    "spinalcord": 1,
    "well01OB": 1,
    "well1_5": 1,
    "well01brain": -1,
    "well2_5": 1,
    "well03": -1,
    "well3_5": 1,
    "well04": 1,
    "well05": 1,
    "well06": 1,
    "well07": 1,
    "well7_5": 1,
    "well08": -1,
    "well09": 1,
    "well10": 1,
    "well10_5": -1,
    "well11": 1
}

labels = model.adata.obs["label"].tolist()
xys = [(item[0], item[1]) for item in model.adata.obsm["spatial"].tolist()]

xs = model.adata.obsm["spatial"].transpose()[0]
ys = model.adata.obsm["spatial"].transpose()[1]

def rotate_origin_only(xy, radians):
    """Only rotate a point around the origin (0, 0)."""
    x, y = xy
    xx = x * math.cos(radians) + y * math.sin(radians)
    yy = -x * math.sin(radians) + y * math.cos(radians)
    return xx, yy

rotated_xys = [rotate_origin_only(xy, math.radians(rotation[label])) for xy, label in zip(xys, labels)]
scaled_xys = [(xy[0] * scale_x[label], xy[1] * scale_y[label]) for xy, label in zip(rotated_xys, labels)]

unified_xys = np.array([[x, y] for x, y in scaled_xys])

center_translation_x = {label: (min(unified_xys[model.adata.obs.label == label].T.tolist()[0]) + max(unified_xys[model.adata.obs.label == label].T.tolist()[0]))/2 for label in model.adata.obs["label"].cat.categories.tolist()}

center_translation_y = {label: (min(unified_xys[model.adata.obs.label == label].T.tolist()[1]) + max(unified_xys[model.adata.obs.label == label].T.tolist()[1]))/2 for label in model.adata.obs["label"].cat.categories.tolist()}

centered_xys = np.array([[xy[0] - center_translation_x[label], xy[1] - center_translation_y[label]] for xy, label in zip(unified_xys.tolist(), labels)])

In [None]:
included_labels = model.adata.obs["label"].cat.categories.tolist()[4:]
included_idx = model.adata.obs.label.isin(included_labels).tolist()

categories = model.adata.obs["nichecompass_latent_sub_cluster_label"].astype(str).fillna("Unknown").unique().tolist()
colours = color_palette("colorblind", len(categories))
colour_map = {key: value for key, value in zip(categories, colours)}

z_spacing = 8000
z_order = ['sagittal1',
           'sagittal2',
           'sagittal3',
           'spinalcord',
           'well01OB',
           'well1_5',
           'well01brain',
           'well2_5',
           'well03',
           'well3_5',
           'well04',
           'well05',
           'well11',
           'well06',
           'well07',
           'well7_5',
           'well08',
           'well09',
           'well10',
           'well10_5']
zs = np.array([z_order.index(label) * z_spacing for label in labels])

xs = np.array(centered_xys.T.tolist()[0])
ys = np.array(centered_xys.T.tolist()[1])

fig = plt.figure()
ax = fig.add_subplot(projection='3d')
fig.set_figheight(5)
fig.set_figwidth(30)

color = model.adata.obs["nichecompass_latent_sub_cluster_label"].astype(str).fillna("Unknown").map(sub_niche_color_map)
ax.scatter(xs[included_idx], zs[included_idx], ys[included_idx], s=0.2, facecolors=color[included_idx], edgecolor=None, linewidth=0, alpha=1)

ax.set_aspect('equal')
ax.view_init(20, -20)

ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])

fig.tight_layout()
fig.savefig(f'{figure_folder_path}/b1.png', dpi=300)
plt.show()

In [None]:
included_labels = model.adata.obs["label"].cat.categories.tolist()[:3]
included_idx = model.adata.obs.label.isin(included_labels).tolist()

categories = model.adata.obs["nichecompass_latent_sub_cluster_label"].astype(str).fillna("Unknown").unique().tolist()
colours = color_palette("colorblind", len(categories))
colour_map = {key: value for key, value in zip(categories, colours)}

z_spacing = 100000
z_order = ['sagittal3',
           'sagittal2',
           'sagittal1',
           'spinalcord',
           'well01OB',
           'well1_5',
           'well01brain',
           'well2_5',
           'well03',
           'well3_5',
           'well04',
           'well05',
           'well11',
           'well06',
           'well07',
           'well7_5',
           'well08',
           'well09',
           'well10',
           'well10_5']
zs = np.array([z_order.index(label) * z_spacing for label in labels])

xs = np.array(centered_xys.T.tolist()[0])
ys = np.array(centered_xys.T.tolist()[1])

fig = plt.figure()
ax = fig.add_subplot(projection='3d')
fig.set_figheight(5)
fig.set_figwidth(30)

color = model.adata.obs["nichecompass_latent_sub_cluster_label"].astype(str).fillna("Unknown").map(sub_niche_color_map)
ax.scatter(zs[included_idx], xs[included_idx], ys[included_idx], s=0.2, facecolors=color[included_idx], edgecolor=None, linewidth=0, alpha=1)

ax.set_aspect('equal')
ax.view_init(20, -20)

ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])

fig.tight_layout()
fig.savefig(f'{figure_folder_path}/a1.png', dpi=300)
plt.show()

### 3.1.2 Alignment with Allen Mouse Brain Reference Atlas

In [None]:
adata_well11 = model.adata[model.adata.obs["label"] == "well11"]

In [None]:
coordinates = pd.DataFrame(model.adata.obsm["spatial"][model.adata.obs["label"] == "well11"], columns=["x", "y"])

In [None]:
cell_index = adata_well11.obs_names.tolist()
with open("alignment_cell_index.pkl", "wb") as file:
    pickle.dump(cell_index, file)

In [None]:
#scale coordinates to
#the y size of a mid coronal section in allen reference atlas is 7200 um
#this corresponds to the x direction of the input
scale_factor = 7200 / (max(coordinates["x"]) - min(coordinates["x"]))
scale_factor

In [None]:
coordinates["x"] = coordinates["x"] * 0.2
coordinates["y"] = coordinates["y"] * 0.2

In [None]:
coordinates

In [None]:
url = 'http://api.brain-map.org/api/v2/data/query.csv?criteria=model::Structure,rma::criteria,[ontology_id$eq1],rma::options[order$eq%27structures.graph_order%27][num_rows$eqall]'
ontology_name,namesdict = STalign.download_aba_ontology(url, 'allen_ontology.csv') #url for adult mouse

In [None]:
imageurl = 'http://download.alleninstitute.org/informatics-archive/current-release/mouse_ccf/ara_nissl/ara_nissl_50.nrrd'
labelurl = 'http://download.alleninstitute.org/informatics-archive/current-release/mouse_ccf/annotation/ccf_2017/annotation_50.nrrd'
imagefile, labelfile = STalign.download_aba_image_labels(imageurl, labelurl, 'aba_nissl.nrrd', 'aba_annotation.nrrd')

In [None]:
dx=15
blur = 1
#Rasterize Image
X_,Y_,W = STalign.rasterize(coordinates["x"],coordinates["y"],dx=dx, blur = blur,draw=False)

In [None]:
#Plot unrasterized/rasterized images
fig,ax = plt.subplots(1,2)
ax[0].scatter(coordinates["x"],coordinates["y"],s=0.5,alpha=0.25)
ax[0].invert_yaxis()
ax[0].set_title('List of cells')
ax[0].set_aspect('equal')

W = W[0]
extent = (X_[0],X_[-1],Y_[0],Y_[-1])
ax[1].imshow(W,  origin='lower')
ax[1].invert_yaxis()
ax[1].set_title('Rasterized')

# save figure
#fig.canvas.draw()
#fig.savefig(outname[:-4]+'_image.png')

In [None]:
#find slice
#peruse through images in atlas
# Loading the atlas
slice = 140

vol,hdr = nrrd.read(imagefile)
A = vol
vol,hdr = nrrd.read(labelfile)
L = vol

dxA = np.diag(hdr['space directions'])
nxA = A.shape
xA = [np.arange(n)*d - (n-1)*d/2.0 for n,d in zip(nxA,dxA)]
XA = np.meshgrid(*xA,indexing='ij')

fig,ax = plt.subplots(1,2)
extentA = STalign.extent_from_x(xA[1:])
ax[0].imshow(A[slice],extent=extentA)
ax[0].set_title('Atlas Slice')

ax[1].imshow(W,extent=extentA)
ax[1].set_title('Target Image')
fig.canvas.draw()

In [None]:
from scipy.ndimage import rotate

theta_deg = 90

fig,ax = plt.subplots(1,2)
extentA = STalign.extent_from_x(xA[1:])
ax[0].imshow(rotate(A[slice], angle=theta_deg),extent=extentA)
ax[0].set_title('Atlas Slice')

ax[1].imshow(W,extent=extentA)
ax[1].set_title('Target Image')
fig.canvas.draw()

In [None]:
points_atlas = np.array([[3000,-300]])
points_target = np.array([[0,-250]])
Li,Ti = STalign.L_T_from_points(points_atlas,points_target)

In [None]:
xJ = [Y_,X_]
J = W[None]/np.mean(np.abs(W))
xI = xA
I = A[None] / np.mean(np.abs(A),keepdims=True)
I = np.concatenate((I,(I-np.mean(I))**2))

In [None]:
sigmaA = 2 #standard deviation of artifact intensities
sigmaB = 2 #standard deviation of background intensities
sigmaM = 2 #standard deviation of matching tissue intenities
muA = torch.tensor([3,3,3],device='cpu') #average of artifact intensities
muB = torch.tensor([0,0,0],device='cpu') #average of background intensities

In [None]:
fig,ax = plt.subplots()
ax.hist(J.ravel())
plt.xlabel('Intensity')
plt.ylabel('Number of Pixels')
plt.title('Intensity Histogram of Target Image')

In [None]:
# initialize variables
scale_x = 0.9 #default = 0.9
scale_y = 0.9 #default = 0.9
scale_z = 0.9 #default = 0.9
theta0 = (np.pi/180)*theta_deg

# get an initial guess
if 'Ti' in locals():
    T = np.array([-xI[0][slice],np.mean(xJ[0])-(Ti[0]*scale_y),np.mean(xJ[1])-(Ti[1]*scale_x)])
else:
    T = np.array([-xI[0][slice],np.mean(xJ[0]),np.mean(xJ[1])])

scale_atlas = np.array([[scale_z,0,0],
                        [0,scale_x,0],
                        [0,0,scale_y]])
L = np.array([[1.0,0.0,0.0],
              [0.0,np.cos(theta0),-np.sin(theta0)],
              [0.0,np.sin(theta0),np.cos(theta0)]])
L = np.matmul(L,scale_atlas)#np.identity(3)

In [None]:
%%time

# run LDDMM
# specify device (default device for STalign.LDDMM is cpu)
if torch.cuda.is_available():
    device = 'cuda:0'
else:
    device = 'cpu'

#returns mat = affine transform, v = velocity, xv = pixel locations of velocity points
transform = STalign.LDDMM_3D_to_slice(
    xI,I,xJ,J,
    T=T,L=L,
    nt=4,niter=10000,
    device='cpu',
    sigmaA = sigmaA, #standard deviation of artifact intensities
    sigmaB = sigmaB, #standard deviation of background intensities
    sigmaM = sigmaM, #standard deviation of matching tissue intenities
    muA = muA, #average of artifact intensities
    muB = muB #average of background intensities
)

In [None]:
A = transform['A']
v = transform['v']
xv = transform['xv']
Xs = transform['Xs']

In [None]:
df = STalign.analyze3Dalign(labelfile,  xv,v,A, xJ, dx, scale_x=scale_x, scale_y=scale_y,x=coordinates["x"],y=coordinates["y"], X_=X_, Y_=Y_, namesdict=namesdict,device='cpu')

In [None]:
It = torch.tensor(I,device='cpu',dtype=torch.float64)
AI = STalign.interp3D(xI,It,Xs.permute(3,0,1,2),padding_mode="border")
Ishow_source = ((AI-torch.amin(AI,(1,2,3))[...,None,None])/(torch.amax(AI,(1,2,3))-torch.amin(AI,(1,2,3)))[...,None,None,None]).permute(1,2,3,0).clone().detach().cpu()
Jt = torch.tensor(J,device='cpu',dtype=torch.float64)
Ishow_target = Jt.permute(1,2,0).cpu()/torch.max(Jt).item()

import matplotlib as mpl
fig,ax = plt.subplots(1,3, figsize=(15,5))
ax0 = ax[0].imshow(Ishow_target, cmap = mpl.cm.Blues,alpha=0.9)
ax[0].set_title('MERFISH Slice')
ax1 = ax[1].imshow(Ishow_source[0,:,:,0], cmap = mpl.cm.Reds,alpha=0.2)
ax[1].set_title('z=0 slice of Aligned 3D Allen Brain Atlas')
ax2 = ax[2].imshow(Ishow_target, cmap = mpl.cm.Blues,alpha=0.9)
ax2 = ax[2].imshow(Ishow_source[0,:,:,0], cmap = mpl.cm.Reds,alpha=0.3)
ax[2].set_title('Overlayed')

plt.show()


In [None]:
verts, faces, normals, values = skimage.measure.marching_cubes(vol>0,0.8,spacing = dxA)
verts = verts + np.array([x[0] for x in xA])
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection='3d')
mesh = Poly3DCollection(verts[faces])
#mesh.set_edgecolor('k')
mesh.set_facecolor('r')
mesh.set_alpha(0.2)
ax.add_collection3d(mesh)
ax.set_xlim(-8000, 8000)  # a = 6 (times two for 2nd ellipsoid)
ax.set_ylim(-8000, 8000)  # b = 10
ax.set_zlim(-8000, 8000)  # c = 16
x = df['coord0']
y = df['coord1']
z = df['coord2']
#ax.grid(True)
#ax.set_xticks([])
#ax.set_yticks([])
#ax.set_zticks([])
#pos1 = ax.get_position()
#pos = [pos1.x0 +0.3, pos1.y0+0.3, pos1.width/2, pos1.height/2]
#ax.set_position(pos)
ax.scatter3D(x,y,z, s= 0.1)

#ax.view_init(-240, 90)
#ax.view_init(-90, 120)

In [None]:
STalign.plot_brain_regions(df)

In [None]:
brain_regions = ['CA1']
STalign.plot_subset_brain_regions(df, brain_regions)

In [None]:
brain_regions = ['SSp-bfd5', 'SSp-bfd6a', 'SSp-bfd6b', 'SSp-bfd4', 'SSp-bfd2/3', 'SSp-bfd1']
STalign.plot_subset_brain_regions(df, brain_regions)

In [None]:
df.to_csv(f"{result_folder_path}/spatial_8327576_well11_aligned_to_atlas.csv")

Import aligned model output for the sample 'well11'.

In [None]:
well11_aligned_coordinates = pd.read_csv(f"{result_folder_path}/spatial_8327576_well11_aligned_to_atlas.csv")
with open("alignment_cell_index.pkl", "rb") as file:
    cell_index = pickle.load(file)
allen_ontology = pd.read_csv("allen_ontology.csv")

well11_aligned_coordinates["cell_index"] = cell_index
well11_aligned_coordinates = well11_aligned_coordinates.set_index("cell_index")

well11_adata = model.adata[(model.adata.obs["label"] == "well11") & (model.adata.obs["nichecompass_latent_cluster"] != "unassigned")]

well11_aligned_coordinates = well11_aligned_coordinates.loc[well11_adata.obs_names]

well11_adata.obs["struct_id"] = well11_aligned_coordinates["struct_id"].tolist()
well11_adata.obs["acronym"] = well11_aligned_coordinates["acronym"].tolist()


Show spatial distribution for the isocortex only (feature 315).

In [None]:
allen_ontology_isocortex = allen_ontology[allen_ontology["structure_id_path"].str.startswith("/997/8/567/688/695/315/")]

ontology_order_isocortex = allen_ontology_isocortex["id"].tolist()

cluster_composition = well11_adata.obs.groupby(["nichecompass_latent_sub_cluster_label", "struct_id"], observed=False).size().unstack()
cluster_composition = cluster_composition[[x for x in ontology_order_isocortex if x in cluster_composition.columns]]

clusters = cluster_composition.index[cluster_composition.sum(axis=1) > 250].tolist()

well11_aligned_coordinates_roi = well11_aligned_coordinates[well11_aligned_coordinates["struct_id"].isin(ontology_order_isocortex)]

categories = well11_aligned_coordinates_roi["struct_id"].astype(str).fillna("Unknown").unique().tolist()
colours = color_palette(["#b4b4b4", "#c0c0c0", "#cdcdcd", "#dadada"], len(categories))
colour_map = {key: value for key, value in zip(categories, colours)}

fig, axs = plt.subplots(math.ceil(len(clusters)/4), 4, sharex=True, sharey=True)

for ax, cluster in zip(axs.flat, clusters):

    roi = well11_aligned_coordinates[((well11_adata.obs["nichecompass_latent_sub_cluster_label"] == cluster) & (well11_adata.obs["struct_id"].isin(ontology_order_isocortex))).tolist()]
    
    roi_adata = well11_adata[((well11_adata.obs["nichecompass_latent_sub_cluster_label"] == cluster) & (well11_adata.obs["struct_id"].isin(ontology_order_isocortex))).tolist()]
    
    primary_cluster = roi_adata.obs["nichecompass_latent_cluster"].unique().tolist()[0]

    ax.scatter(
        [x * -1 for x in well11_aligned_coordinates_roi["y"].tolist()],
        [x * -1 for x in well11_aligned_coordinates_roi["x"].tolist()],
        c=well11_aligned_coordinates_roi["struct_id"].astype(str).fillna("Unknown").map(colour_map),
        s=0.05,
        rasterized=True
    )

    ax.scatter(
        [x * -1 for x in roi["y"].tolist()],
        [x * -1 for x in roi["x"].tolist()],
        c=niche_color_map[primary_cluster],
        s=1,
        rasterized=True
    )

    ax.set_title(cluster)
    ax.set_aspect('equal')
    ax.axis('off')

fig.tight_layout()

fig.set_figheight(13)
fig.set_figwidth(10)
fig.savefig(f'{figure_folder_path}/e1.svg', dpi=300)
plt.show()

Show the section side-by-side with the reference annotation.

In [None]:
categories = well11_aligned_coordinates["struct_id"].astype(str).fillna("Unknown").unique().tolist()
colours = color_palette(["#b4b4b4", "#c0c0c0", "#cdcdcd", "#dadada"], len(categories))
colour_map = {key: value for key, value in zip(categories, colours)}

categories = well11_adata.obs["nichecompass_latent_cluster"].astype(str).fillna("Unknown").unique().tolist()

fig, axs = plt.subplots(1, 2, sharey=True, sharex=True)
fig.set_figheight(6)
fig.set_figwidth(6)

axs[0].scatter(
    [x * -1 for x in well11_aligned_coordinates["y"].tolist()],
    [x * -1 for x in well11_aligned_coordinates["x"].tolist()],
    c=well11_adata.obs["nichecompass_latent_cluster"].astype(str).fillna("Unknown").map(niche_color_map),
    s=0.01,
    rasterized=True
)

axs[1].scatter(
    [x * -1 for x in well11_aligned_coordinates["y"].tolist()],
    [x * -1 for x in well11_aligned_coordinates["x"].tolist()],
    c=well11_aligned_coordinates["struct_id"].astype(str).fillna("Unknown").map(colour_map),
    s=1,
    rasterized=True
)

hfont = {'fontname':'Helvetica'}

axs[0].set_title("Niches", **hfont)
axs[0].set_aspect('equal')
axs[0].axis('off')

axs[1].set_title("Reference", **hfont)
axs[1].set_aspect('equal')
axs[1].axis('off')

legend_handles = [plt.plot([], marker="o", ls="", color=colour)[0] for colour in niche_color_map.values()]
axs[0].legend(legend_handles,
              niche_color_map.keys(),
              bbox_to_anchor=(0, 1, 1, 0),
              loc='lower left',
              ncols=5,
              borderaxespad=5,
              title="Niches",
              frameon=False)

fig.set_figheight(6)
fig.set_figwidth(6)
fig.savefig(f'{figure_folder_path}/d1.svg', bbox_inches="tight", dpi=300)
plt.show()

### 3.2 Save Results

In [None]:
# Log normalize counts for cellxgene server
model.adata.layers['counts'] = model.adata.X
sc.pp.normalize_total(model.adata, target_sum=1e4)
sc.pp.log1p(model.adata)

# Store gp summary in adata
gp_summary = model.get_gp_summary()
for col in gp_summary.columns:
    gp_summary[col] = gp_summary[col].astype(str)
model.adata.uns["nichecompass_gp_summary"] = gp_summary

model.adata.write(f"{result_folder_path}/{dataset}_analysis.h5ad")