In [None]:
import os
from pathlib import Path

import string
import math

import numpy as np
import pandas as pd
import anndata as ad
import spatialdata as sd
import spatialdata_plot as sdp
import shapely as shp
import geopandas as gpd

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.colors import Normalize
import matplotlib.patches as mpatches
import matplotlib.cm as cm
from mpl_toolkits.axes_grid1.inset_locator import inset_axes


import seaborn as sns
from adjustText import adjust_text  # pip install adjustText

In [None]:
plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['figure.figsize'] = (4, 4)
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = 'Arial'
plt.rcParams['font.size'] = 8
plt.rcParams['axes.facecolor'] = 'white'
    
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['savefig.transparent'] = True
plt.rcParams['savefig.bbox'] = 'tight'
plt.rcParams['savefig.pad_inches'] = 0.01

RASTERIZED = False

In [None]:
image_path = Path("/home/x-aklein2/projects/aklein/BICAN/BG/images/figures/ex_merfish")
image_path.mkdir(parents=True, exist_ok=True)

In [None]:
ad_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_CPS.h5ad"
adata = ad.read_h5ad(ad_path, backed='r')
adata
# adata_local = adata[adata.obs['dataset_id'] == "NAC_UCI5224_salk"].to_memory()
# adata_local

In [None]:
_experiment, _region = adata_local.obs[['experiment', 'region']].values[0]
_experiment, _region

In [None]:
zarr_path = Path(f"/home/x-aklein2/projects/aklein/BICAN/data/zarr_store/{_experiment}/{_region}")
sdata = sd.read_zarr(zarr_path)
sdata

In [None]:
image_key = f"default_{_experiment}_{_region}_z3"
cs = "pixel"
ch = "DAPI"
points_key = f"CPS_{_experiment}_{_region}_transcripts"
shapes_key = f"CPS_{_experiment}_{_region}_polygons"
tab_key1 = f"CPS_tablefilt"
tab_key2 = f"CPS_annot"

In [None]:
adata_prev = sdata[tab_key1]
adata_prev

In [None]:
adata_prev.uns['spatialdata_attrs']

In [None]:
adata_local.obs['EntityID'] = adata_local.obs['CELL_ID'].astype(adata_prev.obs['EntityID'].dtype)
adata_local.obs['cells_region']  = adata_prev.obs['cells_region'].values
adata_local.uns['spatialdata_attrs'] = adata_prev.uns['spatialdata_attrs']
adata_valid = sd.models.TableModel().validate(adata_local)
# adata_valid.uns['spatialdata_attrs'] = adata_prev.uns['spatialdata_attrs']
sdata[tab_key2] = adata_valid

In [None]:
image_channels = sd.models.get_channel_names(sdata[image_key])
image_scale_keys = list(sdata[image_key].keys())

max_int = (
    sdata[image_key][image_scale_keys[-1]]["image"]
    .max(["x", "y"])
    .compute()
    .to_dataframe()
    .to_dict()["image"]
)
min_int = (
    sdata[image_key][image_scale_keys[-1]]["image"]
    .min(["x", "y"])
    .compute()
    .to_dataframe()
    .to_dict()["image"]
)

In [None]:
sdata_sub = sdata.subset([image_key, points_key, shapes_key, tab_key2])

In [None]:
sdata_sub['points_sub'] = sdata_sub[points_key].sample(frac=0.01)

In [None]:
genes_to_plot = ['CNR1', 'CRYM']  # Replace with actual gene names

In [None]:
fts = sdata_sub[points_key].compute()
fts = fts.reset_index()
fts['gene'] = fts['gene'].astype("category")
sdata_sub[points_key] = sd.models.PointsModel.parse(fts)

In [None]:
sd.transformations.set_transformation(
    sdata_sub[points_key], 
    sd.transformations.get_transformation(sdata_sub[shapes_key], to_coordinate_system="global"),
    to_coordinate_system="global"
)

sd.transformations.set_transformation(
    sdata_sub[points_key], 
    sd.transformations.get_transformation(sdata_sub[shapes_key], to_coordinate_system="pixel"),
    to_coordinate_system="pixel"
)

In [None]:
cs = "global"
fig, ax = plt.subplots()
norm = plt.Normalize(vmin=min_int[ch], vmax=max_int[ch] * 0.5)
ax.set_facecolor("black")
(
    sdata_sub.pl.render_images(image_key, channel="DAPI", scale=image_scale_keys[-1], norm=norm, cmap="gray")
    .pl.render_points(element=points_key, color="gene", groups="BACH2", palette=["orange"], size=0.5)
    .pl.show(coordinate_systems=cs, ax=ax, dpi=300)
)
ax.set_xticks([])
ax.set_yticks([])
ax.set_title(None)
plt.show()

In [None]:
msn_groups = list(adata.uns['MSN_Groups_palette'].keys())
msn_palette = [adata.uns['MSN_Groups_palette'][group] for group in msn_groups]

In [None]:
msn_groups_sub = [
 'STRd D1 Matrix MSN',
 'STRd D2 Matrix MSN',
 'STRv D1 MSN',
 'STRv D2 MSN',
 'STRd D1 Striosome MSN',
 'STRd D2 Striosome MSN',
 'STR D1D2 Hybrid MSN',
 'STRd D2 StrioMat Hybrid MSN',]
msn_palette = [adata.uns['MSN_Groups_palette'][group] for group in msn_groups_sub]

In [None]:
shapes_key_filt = "shapes_filt"
sdata_sub[shapes_key_filt] = sdata_sub[shapes_key].loc[sdata_sub[shapes_key].index.isin(adata_valid.obs['EntityID'])]

In [None]:
## Set up scale bar function

def make_scale_bar(ax, x_coords, y_coords, microns_per_pixel=0.106, x_pct=100, y_pct=100):
    """
    Adds a scale bar to a matplotlib plot.

    Parameters:
    - ax: matplotlib Axes object
    - x_coords: list or array of x coordinates
    - y_coords: list or array of y coordinates
    - microns_per_pixel: conversion factor from pixels to microns
    - x_pct: percentile of x-coordinates to position the scale bar horizontally
    - y_pct: percentile of y-coordinates to position the scale bar vertically
    
    Example usage:
    - fig, ax = plt.subplots()
    - ax.scatter(x, y)
    - make_scale_bar(ax, x, y)

    """
    # Calculate x-axis range
    x_range = [min(x_coords), max(x_coords)]
    x_length = x_range[1] - x_range[0]
    x_length_um = x_length * microns_per_pixel

    # Target scale length ~1/6 of the x-axis
    target = x_length_um / 6

    # Compute order of magnitude
    order = 10 ** np.floor(np.log10(target))
    mantissa = target / order

    # Round mantissa to nearest 1, 2, or 5
    if mantissa < 1.5:
        nice_mantissa = 1
    elif mantissa < 3.5:
        nice_mantissa = 2
    elif mantissa < 7.5:
        nice_mantissa = 5
    else:
        nice_mantissa = 10

    # Final scale length in pixels
    scale_length_um = nice_mantissa * order
    scale_length_px = scale_length_um / microns_per_pixel

    # Format label
    if scale_length_um >= 1000:
        scale_label = f"{scale_length_um / 1000:.1f} mm"
    else:
        scale_label = f"{scale_length_um:.0f} Âµm"

    # Set coordinates for the scale bar
    x_start = np.percentile(x_coords, x_pct) - scale_length_px * 1.1
    x_end = np.percentile(x_coords, x_pct) - scale_length_px * 0.1
    y_pos = np.percentile(y_coords, y_pct) - scale_length_px * 0.1
    
    # Set up background for scale bar
    scale_bg = mpatches.Rectangle(
        (x_start - scale_length_px * 0.05, y_pos - scale_length_px * 0.05),
        width=(x_end - x_start) + scale_length_px * 0.1,
        height=scale_length_px * 0.4,
        color='white', alpha=0.8,
        zorder = 10
    )
    
    # Add background
    ax.add_patch(scale_bg)

    # Draw scale bar
    ax.plot([x_start, x_end], [y_pos + scale_length_px * 0.3, y_pos + scale_length_px * 0.3], color='black', linewidth=2, zorder=11)
    ax.text((x_start + x_end) / 2, y_pos + scale_length_px * 0.3, scale_label,
            color='black', ha='center', va='bottom', zorder=12)

In [None]:
# Creating the bounding box for the crop
box = shp.box(31000, 30000, 34000, 33000)
box_gdf = gpd.GeoDataFrame(geometry=[box], index=['box1'])
sdata_sub['bounding_box'] = sd.models.ShapesModel.parse(box_gdf)
sd.transformations.set_transformation(sdata_sub['bounding_box'], sd.transformations.get_transformation(sdata_sub[shapes_key], to_coordinate_system="pixel").inverse(), to_coordinate_system="global")
sd.transformations.set_transformation(sdata_sub['bounding_box'], sd.transformations.get_transformation(sdata_sub[shapes_key], to_coordinate_system="global").inverse(), to_coordinate_system="pixel")

In [None]:
# msn_groups = ['STRd D1 Striosome MSN', 'STRd D2 Striosome MSN', 'STRd D1 Matrix MSN', 'STRd D2 Matrix MSN', 
#              'STRd D1/D2 Hybrid MSN', 'STRv D1 MSN', 'STRv D2 MSN']
fig, ax = plt.subplots()

# Generate the Spatialdata plot
ax.set_facecolor("white")
(
    sdata_sub.pl.render_images(image_key, channel="DAPI", scale=image_scale_keys[-1], norm=norm, cmap="gray")
    .pl.render_shapes(element=shapes_key, color='Group', groups=msn_groups_sub, palette=msn_palette, fill_alpha=0.9, linewidth=0.1, outline=True, scale=4)
    .pl.render_shapes(element="bounding_box", fill_alpha=0, outline_alpha=1, outline_width=1, outline_color='red')
    .pl.show(coordinate_systems='pixel', ax=ax, colorbar=False)
)
ax.set_xticks([])
ax.set_yticks([])
ax.set_title(None)

# Remove the Legend
# ax.get_legend().remove()

# Add the scale bar
cell_meta = sdata_sub[tab_key2].obs
make_scale_bar(ax, cell_meta["CENTER_X"], cell_meta["CENTER_Y"], microns_per_pixel=1)

# Save Image 
fig.savefig(image_path / "NAC_UCI5224_MSNs.png", dpi=300, bbox_inches='tight', pad_inches=0)
fig.savefig(image_path / "NAC_UCI5224_MSNs.pdf", dpi=300, bbox_inches='tight', pad_inches=0)
plt.show()


In [None]:
xmin = box[0]
ymin = box[1]
xmax = box[0] + box[2]
ymax = box[1] + box[3]
sdata_sub_crop = sdata_sub.query.bounding_box(
    axes=["x", "y"],
    min_coordinate=[xmin, ymin],
    max_coordinate=[xmax, ymax],
    target_coordinate_system=cs,
)

In [None]:
subclasses = sdata_sub_crop[tab_key2].obs['Subclass'].unique().tolist()
subclass_palette = [adata.uns['Subclass_palette'][subclass] for subclass in subclasses]

norm = plt.Normalize(vmin=min_int[ch], vmax=max_int[ch] * 0.5)

fig, ax = plt.subplots()

# Make the Spatialdata plot
ax.set_facecolor("white")
(
    sdata_sub_crop.pl.render_images(image_key, channel="DAPI", scale=image_scale_keys[-1], norm=norm, cmap="gray")
    # .pl.render_shapes(element=shapes_key, color="Subclass", groups=subclasses, palette=subclass_palette, fill_alpha=1, outline_width=0.5, outline_color="white", outline=True, scale=1.1)
    .pl.render_points(element=points_key, color="gene", groups=["DRD1", "DRD2"], palette=["cyan", 'magenta'], size=0.2)
    .pl.render_shapes(element=shapes_key, fill_alpha=0, outline_width=0.5, outline_color="white", outline=True, scale=1.1)
    .pl.show(coordinate_systems='pixel', ax=ax, colorbar=False, frameon=False)
)
# Remove axes and legend
ax.set_xticks([])
ax.set_yticks([])
ax.set_title(None)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['left'].set_visible(False)
# ax.get_legend().remove()

# make the scale bar
cell_meta = sdata_sub_crop[tab_key2].obs
make_scale_bar(ax, cell_meta["CENTER_X"], cell_meta["CENTER_Y"], y_pct=94, x_pct=95, microns_per_pixel=1)

# save the figure
fig.savefig(image_path / "NAC_UCI5224_MSNs_crop_DRD1_DRD2.png", dpi=300, bbox_inches='tight', pad_inches=0)
fig.savefig(image_path / "NAC_UCI5224_MSNs_crop_DRD1_DRD2.pdf", dpi=300, bbox_inches='tight', pad_inches=0)

plt.show()

In [None]:
# msn_groups = ['STRd D1 Striosome MSN', 'STRd D2 Striosome MSN', 'STRd D1 Matrix MSN', 'STRd D2 Matrix MSN', 
#              'STRd D1/D2 Hybrid MSN', 'STRv D1 MSN', 'STRv D2 MSN']
fig, ax = plt.subplots()
ax.set_facecolor("black")
(
    sdata_sub.pl.render_shapes(element=shapes_key, color='grey', fill_alpha=0.75, outline=True, scale=1)
    .pl.render_shapes(element=shapes_key, color='Group', groups=msn_groups_sub, palette=msn_palette, fill_alpha=1, outline=True, scale=5)
    .pl.show(coordinate_systems='pixel', ax=ax)
)
ax.set_xticks([])
ax.set_yticks([])
ax.set_title(None)
# fig.savefig("/home/x-aklein2/projects/aklein/BICAN/BG/images/figures/sample_merfish/NAC_UCI5224_MSNs.png", dpi=300, bbox_inches='tight', pad_inches=0)
plt.show()

## Using Geopandas

In [None]:
sdata = sd.read_zarr("/home/x-aklein2/projects/aklein/BICAN/data/zarr_store/202506151211_BICAN-4x1-PU-01_VMSC31910/region_UCI5224")
sdata

In [None]:
image_key = "default_202506151211_BICAN-4x1-PU-01_VMSC31910_region_UCI5224_z3"
points_key = "CPS_202506151211_BICAN-4x1-PU-01_VMSC31910_region_UCI5224_transcripts"
shapes_key = "CPS_202506151211_BICAN-4x1-PU-01_VMSC31910_region_UCI5224_polygons"
table_key = "CPS_table"

In [None]:
max_int = (
        sdata[IMAGE_KEY][image_scale_keys[-1]]["image"]
        .max(["x", "y"])
        .compute()
        .to_dataframe()
        .to_dict()["image"]
    )
    min_int = (
        sdata[IMAGE_KEY][image_scale_keys[-1]]["image"]
        .min(["x", "y"])
        .compute()
        .to_dataframe()
        .to_dict()["image"]
    )

    ncols = 4
    nrows = int(len(image_channels) / ncols) + (len(image_channels) % ncols > 0)
    fig, axes = plt.subplots(nrows, ncols, figsize=(20, nrows * 5), dpi=200)
    axes = axes.flatten()

    for i, channel in enumerate(image_channels):
        norm = Normalize(vmin=min_int[channel], vmax=max_int[channel] * 0.5)
        sdata.pl.render_images(
            IMAGE_KEY, channel=channel, cmap="grey", norm=norm
        ).pl.show(ax=axes[i], title=channel, coordinate_systems=cs, colorbar=False)

    plt.tight_layout()

    if pdf_file:
        pdf_file.savefig(fig, bbox_inches="tight")
    else:
        fig.show()


In [None]:
cs = "pixel"
channel = "DAPI"
image_channels = sd.models.get_channel_names(sdata[image_key])
image_scale_keys = list(sdata[image_key].keys())

max_int = (
    sdata[image_key][image_scale_keys[-1]]["image"]
    .max(["x", "y"])
    .compute()
    .to_dataframe()
    .to_dict()["image"]
)
min_int = (
    sdata[image_key][image_scale_keys[-1]]["image"]
    .min(["x", "y"])
    .compute()
    .to_dataframe()
    .to_dict()["image"]
)
norm = Normalize(vmin=min_int[channel], vmax=max_int[channel] * 0.5)

In [None]:
sdata.pl.render_images(
    image_key, channel=channel, cmap="grey", norm=norm
    ).pl.render_shapes(
        shapes_key,
        color="volume",
        shape_type="polygon",
        edge_color="red",
        face_color="none",
        linewidth=2,
        coordinate_systems='pixel', colorbar=False, title=channel
    )

In [None]:
sdata[table_key]

In [None]:
adata_local = adata[(adata.obs['experiment'] == sdata[table_key].obs['experiment'].unique()[0]) & 
                    (adata.obs['region'] == sdata[table_key].obs['region'].unique()[0]) & 
                    (adata.obs['donor'] == sdata[table_key].obs['donor'].unique()[0])
                    ].to_memory()
adata_local.obs['EntityID'] = adata_local.obs['CELL_ID'].astype(str)

In [None]:
# sdata[table_key].uns

In [None]:
# adata_local.uns['spatialdata_attrs'] = {'instance_key': 'EntityID',
#   'region': 'CPS_202506151211_BICAN-4x1-PU-01_VMSC31910_region_UCI5224_polygons',
#   'region_key': 'cells_region'}
# adata_local.obs['EntityID'] = adata_local.obs['CELL_ID'].astype(int)

In [None]:
# shapes = sdata[shapes_key].copy()
# new_shapes = shapes.loc[adata_local.obs['EntityID'].values]
# new_shapes

In [None]:
# ttk = 'annot_table'
# sdata[ttk] = adata_local

In [None]:
shapes = sdata[shapes_key].copy()
shapes = shapes.loc[adata_local.obs['EntityID'].values]
shapes['Subclass'] = shapes.index.map(adata_local.obs.set_index('EntityID')['Subclass'].to_dict()).fillna("unknown")
shapes['Group'] = shapes.index.map(adata_local.obs.set_index('EntityID')['Group'].to_dict()).fillna("unknown")
shapes['MSN_Groups'] = shapes.index.map(adata_local.obs.set_index('EntityID')['MSN_Groups'].to_dict()).fillna("unknown")

shapes['Subclass_colors'] = shapes['Subclass'].map(adata.uns['Subclass_palette'])
shapes['Group_colors'] = shapes['Group'].map(adata.uns['Group_palette'])
shapes['MSN_Groups_colors'] = shapes['MSN_Groups'].map(adata.uns['MSN_Groups_palette'])

In [None]:
sc_fact = 3
shapes_pl = shapes.copy()
shapes_pl.geometry = shapes_pl.geometry.scale(sc_fact, sc_fact)

fig, ax = plt.subplots(figsize=(6,6), dpi=200)
shapes_pl.plot(column='Subclass', ax=ax, legend=False, color=shapes['Subclass_colors'], edgecolor='black', linewidth=0.05).axis("off");
plt.show()