In [None]:
import os
from pathlib import Path

import numpy as np
import pandas as pd
import anndata as ad

import matplotlib.pyplot as plt
import spatialdata as sd
import spatialdata_plot as sdp
plt.rcParams['figure.dpi'] = 150

In [None]:
# Print versions of important packages
print(f"Python: {os.sys.version}")
print(f"NumPy: {np.__version__}")
print(f"Pandas: {pd.__version__}")
print(f"Anndata: {ad.__version__}")
print(f"Spatialdata: {sd.__version__}")

In [None]:
adata = ad.read_h5ad("/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_PFV8_annotated_v5.h5ad")
adata = adata[adata.obs['dataset_id'] == "PU_UWA7648_salk"].copy()
adata

In [None]:
_experiment, _region = adata.obs[['experiment', 'region']].values[0]
_experiment, _region

In [None]:
zarr_path = Path(f"/home/x-aklein2/projects/aklein/BICAN/data/zarr_store/{_experiment}/{_region}")
sdata = sd.read_zarr(zarr_path)
sdata

In [None]:
image_key = f"default_{_experiment}_{_region}_z3"
cs = "pixel"
ch = "DAPI"
points_key = f"proseg_fv38_{_experiment}_{_region}_transcripts"
shapes_key = f"proseg_fv38_{_experiment}_{_region}_polygons"
tab_key1 = f"proseg_fv38_table_filt"
tab_key2 = f"proseg_fv38_annot"

In [None]:
adata_prev = sdata[tab_key1]
adata_prev

In [None]:
adata_prev.uns['spatialdata_attrs']

In [None]:
adata.obs['cell'] = adata.obs['cell'].astype(adata_prev.obs['cell'].dtype)
adata.obs['cells_region']  = adata_prev.obs['cells_region'].values
adata.uns['spatialdata_attrs'] = adata_prev.uns['spatialdata_attrs']
adata_valid = sd.models.TableModel().validate(adata)
# adata_valid.uns['spatialdata_attrs'] = adata_prev.uns['spatialdata_attrs']

In [None]:
sdata[tab_key2] = adata_valid

In [None]:
image_channels = sd.models.get_channel_names(sdata[image_key])
image_scale_keys = list(sdata[image_key].keys())

max_int = (
    sdata[image_key][image_scale_keys[-1]]["image"]
    .max(["x", "y"])
    .compute()
    .to_dataframe()
    .to_dict()["image"]
)
min_int = (
    sdata[image_key][image_scale_keys[-1]]["image"]
    .min(["x", "y"])
    .compute()
    .to_dataframe()
    .to_dict()["image"]
)

In [None]:
sdata_sub = sdata.subset([image_key, points_key, shapes_key, tab_key2])

In [None]:
sdata_sub['points_sub'] = sdata_sub[points_key].sample(frac=0.01)

In [None]:
genes_to_plot = ['CNR1', 'CRYM']  # Replace with actual gene names

In [None]:
fts = sdata_sub[points_key].compute()
fts = fts.reset_index()
fts['gene'] = fts['gene'].astype("category")
sdata_sub[points_key] = sd.models.PointsModel.parse(fts)

In [None]:
sd.transformations.set_transformation(
    sdata_sub[points_key], 
    sd.transformations.get_transformation(sdata_sub[shapes_key], to_coordinate_system="global"),
    to_coordinate_system="global"
)

sd.transformations.set_transformation(
    sdata_sub[points_key], 
    sd.transformations.get_transformation(sdata_sub[shapes_key], to_coordinate_system="pixel"),
    to_coordinate_system="pixel"
)

In [None]:
cs = "global"
fig, ax = plt.subplots()
norm = plt.Normalize(vmin=min_int[ch], vmax=max_int[ch] * 0.5)
ax.set_facecolor("black")
(
    sdata_sub.pl.render_images(image_key, channel="DAPI", scale=image_scale_keys[-1], norm=norm, cmap="gray")
    .pl.render_points(element=points_key, color="gene", groups="CNR1", palette=["orange"], size=0.5)
    .pl.show(coordinate_systems=cs, ax=ax, dpi=300)
)
ax.set_xticks([])
ax.set_yticks([])
ax.set_title(None)
plt.show()

In [None]:
cs = "pixel"
fig, ax = plt.subplots()
norm = plt.Normalize(vmin=min_int[ch], vmax=max_int[ch] * 0.5)
ax.set_facecolor("black")
(
    sdata_sub.pl.render_images(image_key, channel="DAPI", scale=image_scale_keys[-1], norm=norm, cmap="gray")
    .pl.render_points(element=points_key, color="gene", groups="CRYM", palette=["blue"], markersize=0.5)
    .pl.show(coordinate_systems=cs, ax=ax, dpi=300)
)
ax.set_xticks([])
ax.set_yticks([])
ax.set_title(None)
plt.show()

In [None]:
cs = "pixel"
fig, ax = plt.subplots()
norm = plt.Normalize(vmin=min_int[ch], vmax=max_int[ch] * 0.5)
ax.set_facecolor("black")
(
    sdata_sub.pl.render_images(image_key, channel="DAPI", scale=image_scale_keys[-1], norm=norm, cmap="gray")
    .pl.render_points(element=points_key, color="gene", groups="BCAS1", palette=["blue"], markersize=0.5)
    .pl.show(coordinate_systems=cs, ax=ax, dpi=300)
)
ax.set_xticks([])
ax.set_yticks([])
ax.set_title(None)
plt.show()

In [None]:
cs = "pixel"
fig, ax = plt.subplots()
norm = plt.Normalize(vmin=min_int[ch], vmax=max_int[ch] * 0.5)
ax.set_facecolor("black")
(
    sdata_sub.pl.render_images(image_key, channel="DAPI", scale=image_scale_keys[-1], norm=norm, cmap="gray")
    .pl.render_points(element=points_key, color="gene", groups="UGT8", palette=["blue"], markersize=0.5)
    .pl.show(coordinate_systems=cs, ax=ax, dpi=300)
)
ax.set_xticks([])
ax.set_yticks([])
ax.set_title(None)
plt.show()

## WM calling

In [None]:
import seaborn as sns
import geopandas as gpd
import math
from shapely import Polygon, Point, box
from sklearn.mixture import GaussianMixture

#### function

In [None]:
def create_hexagonal_grid_redblobgames(bounds, hex_size, overlap=0.0):
    """
    Create a hexagonal grid covering the given bounds using Red Blob Games standard geometry.
    
    For pointy-top hexagons:
    - Horizontal spacing = sqrt(3) * size
    - Vertical spacing = 3/2 * size
    
    Parameters:
    -----------
    bounds : tuple
        (minx, miny, maxx, maxy) bounding box to cover
    hex_size : float
        Radius of hexagon (distance from center to vertex)
    overlap : float, default 0.0
        Overlap parameter:
        - 0.0: No overlap (disjoint hexagons)
        - 0.5: 50% overlap 
        - 1.0: Complete overlap (same position)
        - Negative values create gaps
    
    Returns:
    --------
    geopandas.GeoDataFrame
        Grid of hexagonal polygons
    """
    minx, miny, maxx, maxy = bounds
    
    # Red Blob Games standard spacing for pointy-top hexagons
    horizontal_spacing = hex_size * np.sqrt(3)  # sqrt(3) * size
    vertical_spacing = hex_size * 1.5           # 3/2 * size
    
    # Apply overlap: spacing = base_spacing * (1 - overlap)
    # overlap=0.0 -> spacing = base_spacing (no overlap)
    # overlap=0.5 -> spacing = 0.5 * base_spacing (50% overlap)
    # overlap=1.0 -> spacing = 0 (complete overlap)
    actual_horizontal_spacing = horizontal_spacing * (1 - overlap)
    actual_vertical_spacing = vertical_spacing * (1 - overlap)
    
    # Calculate grid dimensions
    width = maxx - minx
    height = maxy - miny
    
    # Number of hexagons needed (with some buffer)
    cols = int(np.ceil(width / actual_horizontal_spacing)) + 2
    rows = int(np.ceil(height / actual_vertical_spacing)) + 2
    
    hexagons = []
    
    for row in range(rows):
        for col in range(cols):
            # Calculate center position
            # Even rows: no horizontal offset
            # Odd rows: offset by half the horizontal spacing
            if row % 2 == 0:
                x = minx + col * actual_horizontal_spacing
            else:
                x = minx + (col + 0.5) * actual_horizontal_spacing
            
            y = miny + row * actual_vertical_spacing
            
            # Create hexagon geometry (pointy-top orientation)
            angles = np.linspace(0, 2 * np.pi, 7)  # 7 points to close the polygon
            # For pointy-top: first vertex at 30° (π/6 radians)
            angles = angles + np.pi/6
            
            hex_x = x + hex_size * np.cos(angles)
            hex_y = y + hex_size * np.sin(angles)
            
            hex_coords = list(zip(hex_x, hex_y))
            hexagon = Polygon(hex_coords)
            
            hexagons.append({
                'geometry': hexagon,
                'row': row,
                'col': col,
                'center_x': x,
                'center_y': y
            })
    
    # Create GeoDataFrame
    gdf = gpd.GeoDataFrame(hexagons)
    
    # Filter to only hexagons that intersect with bounds
    bounds_poly = box(minx, miny, maxx, maxy)
    gdf = gdf[gdf.geometry.intersects(bounds_poly)].copy()
    gdf.reset_index(drop=True, inplace=True)
    
    return gdf

#### continue

In [None]:
gdf = gpd.GeoDataFrame(fts, geometry=gpd.points_from_xy(fts['x'], fts['y'])) 
gdf.head()

In [None]:
transfer_genes = ['MOBP', 'BCAS1', 'OPALIN']
for _gene in transfer_genes: 
    print(_gene, _gene in fts['gene'].cat.categories)
    gdf[_gene] = (gdf['gene'] == _gene).astype(int)

In [None]:
total_bounds = gdf.total_bounds  # (minx, miny, maxx, maxy)
grid = create_hexagonal_grid_redblobgames(total_bounds, 30, overlap=0)

In [None]:
grid['hex_id'] = grid.index.astype(str)
grid = grid.set_index("hex_id")
joint_grid = gpd.sjoin(grid, gdf, how="inner", predicate="contains")

In [None]:
grid['cell_count'] = joint_grid.groupby('hex_id').size()
for _gene in transfer_genes:
    grid[f'{_gene}_count'] = joint_grid.groupby('hex_id')[_gene].sum()

In [None]:
grid.head()

#### Attempting to also aggregate protein stain information into this
Not working so far, need to have a food way to aggregate the xarray information in the geopandas hexes. 

from: https://notebooksharing.space/view/c6c1f3a7d0c260724115eaa2bf78f3738b275f7f633c1558639e7bbd75b31456#displayOptions=

In [None]:
import xarray as xr
# import rioxarray as rxr
# import rasterio

scale = "scale3"
channel = "MBP"

xarr = sdata.transform_element_to_coordinate_system(image_key, target_coordinate_system="pixel")
# xarr
available_scales = list(xarr.keys())
available_scales

image = xarr[scale]
data_var_keys = list(image.data_vars)
image = image[data_var_keys[0]]
image = image.sel(c=channel).squeeze().to_dataset()

points = image.stack(point=['x', 'y'])
print(points)


def bounds_to_poly(x, y): 
    return Polygon([
        (x[0], y[0]),
        (x[0], y[1]),
        (x[1], y[1]),
        (x[1], y[0])
    ])

# boxes = xr.apply_ufunc(
#     bounds_to_poly, 
#     points.x,
#     points.y,
#     input_core_dims=[['point'], ['point']],
#     output_dtypes=[object],
#     vectorize=True
# )
# boxes

# boxes_df = gpd.GeoDataFrame(
#     data={"geometry" : boxes.values, "x" : points.x.values, "y" : points.y.values},
#     # index=boxes.indexes['point'],
# )
# boxes_df.shape

# boxes_df.geometry[0]
# boxes_df.geometry.area.sum()
# boxes_df.sample(100000).plot()

#### continue

In [None]:
ncols = len(transfer_genes)
fig, axes = plt.subplots(1, ncols, figsize=(5*ncols, 5))
for i, _gene in enumerate(transfer_genes):
    ax=axes[i]
    ax.set_title(f'{_gene}')
    # toplot = flipy(grid.geometry)
    grid.plot(ax=ax, column=f"{_gene}_count", cmap="YlOrRd", edgecolor='k', alpha=0.9, linewidth=0.2, vmax=grid[f"{_gene}_count"].max()*0.75).axis('off')
plt.show()

In [None]:
fig, axes = plt.subplots(1, ncols, figsize=(5*ncols, 5))
for i, _gene in enumerate(transfer_genes):
    sns.histplot(grid[f'{_gene}_count'], bins=30, ax=axes[i])
    axes[i].set_title(f'{_gene} Count Distribution')
plt.show()

In [None]:
fig, axes = plt.subplots(1, ncols, figsize=(5*ncols, 5))
for i, _gene in enumerate(transfer_genes):
    print(f"Fitting GMM for {_gene}...")
    df_gene = grid[[_gene + '_count']].dropna().copy()
    gmm = GaussianMixture(n_components=2, random_state=0, covariance_type="full").fit(df_gene[[_gene + '_count']].values)
    gene_prediction = gmm.predict(grid[[_gene + '_count']].dropna().values)
    df_gene['predict'] = gene_prediction
    pred_val = df_gene.groupby("predict").mean().idxmax(axis=0)[0]
    df_gene['predict'] = (df_gene['predict'] == pred_val).astype(int)
    grid.loc[df_gene.index, _gene + '_predict'] = df_gene['predict']
    
    # for plotting
    sns.histplot(ax=axes[i], data=df_gene, x=_gene + '_count', bins=30, hue="predict", palette="viridis", edgecolor='k')
    axes[i].set_title(f'{_gene} GMM Prediction')
plt.show()

In [None]:
ncols = len(transfer_genes)
fig, axes = plt.subplots(1, ncols, figsize=(5*ncols, 5))
for i, _gene in enumerate(transfer_genes):
    ax=axes[i]
    ax.set_title(f'{_gene}')
    grid.plot(ax=ax, column=f"{_gene}_predict", cmap="YlOrRd", edgecolor='k', alpha=0.7, linewidth=0.5)
plt.show()

In [None]:
hex_ids = {}
for _gene in transfer_genes: 
    hexes = set(grid[grid[_gene + "_predict"] == 1].index)
    for _h in hexes: 
        hex_ids[_h] = 1 if _h not in hex_ids else hex_ids[_h] + 1

In [None]:
df_hids = pd.DataFrame.from_dict(hex_ids, orient="index")
df_hids = df_hids / len(transfer_genes)
chosen_cells = df_hids[df_hids[0] >= 0.75].index
len(chosen_cells)

In [None]:
chosen_cells = grid.loc[list(chosen_cells)]
chosen_cells.shape

In [None]:
fig, ax = plt.subplots()
grid.plot(ax=ax, color='lightgrey', edgecolor='k', alpha=0.5)
chosen_cells.plot(ax=ax, color='red', edgecolor='k', alpha=0.8)

In [None]:
import libpysal as lps
import networkx as nx

In [None]:
chosen_cells = chosen_cells.reset_index()
W = lps.weights.Queen.from_dataframe(chosen_cells)
G = W.to_networkx()
connected_components = list(nx.connected_components(G))
disconnected_comp = [comp for comp in connected_components]

In [None]:
W.neighbors

In [None]:
chosen_cells['comp'] = -1
geoms = []
for i, disc in enumerate(disconnected_comp): 
    print(len(disc))
    if len(disc) < 5:
        continue
    chosen_cells.loc[list(disc), "comp"] = i
    temp = chosen_cells[chosen_cells.index.isin(disc)]
    geom = temp.union_all().convex_hull
    # geom = temp.union_all().concave_hull(ratio=0.05)
    geoms.append(geom)

In [None]:
gdf_geoms = gpd.GeoDataFrame(geometry=geoms)
union = gdf_geoms.union_all()
gdf_geoms = gpd.GeoDataFrame(geometry=[union])
fig, ax = plt.subplots()
gdf_geoms.plot(ax=ax, edgecolor='k', color='none').axis("off");
plt.show()

In [None]:
# fig, axes = plt.subplots(1, 3, figsize=(14,8))
# ax=axes[0]
# ax.set_title('Tz Count')
# grid.plot(ax=ax, column="cell_count", cmap="YlOrRd", edgecolor='k', alpha=0.7, linewidth=0.5)
# ax=axes[1]
# ax.set_title('Oligodendrocytes - MOBP')
# grid.plot(ax=ax, column="MOBP_count", cmap="YlOrRd", edgecolor='k', alpha=0.7, linewidth=0.5)
# ax=axes[2]
# ax.set_title('Oligodendrocytes - BCAS1')
# grid.plot(ax=ax, column="BCAS1_count", cmap="YlOrRd", edgecolor='k', alpha=0.7, linewidth=0.5)
# plt.tight_layout()
# plt.show()

In [None]:
# fig, axes = plt.subplots(1, 3, figsize=(14,8))
# ax=axes[0]
# ax.set_title('Tz Count')
# grid.plot(ax=ax, column="cell_count", cmap="YlOrRd", edgecolor='k', alpha=0.7, linewidth=0.5)
# ax=axes[1]
# ax.set_title('CRYM')
# grid.plot(ax=ax, column="CRYM_count", cmap="YlOrRd", edgecolor='k', alpha=0.7, linewidth=0.5)
# ax=axes[2]
# ax.set_title('DRD1')
# grid.plot(ax=ax, column="DRD1_count", cmap="YlOrRd", edgecolor='k', alpha=0.7, linewidth=0.5)
# plt.tight_layout()
# plt.show()

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10,5))
sns.histplot(grid['MOBP_count'], bins=30, ax=ax[0])
sns.histplot(grid['BCAS1_count'], bins=30, ax=ax[1])
plt.show()

In [None]:
cov_types = ['full', 'tied', 'diag', 'spherical']
for _cov in cov_types: 
    print(_cov)
    df_mobp = grid[['MOBP_count']].dropna().copy()
    gmm = GaussianMixture(n_components=3, random_state=0, covariance_type=_cov).fit(df_mobp[['MOBP_count']].values)
    mobp_prediction = gmm.predict(grid[['MOBP_count']].dropna().values)
    df_mobp['predict'] = mobp_prediction

    df_bcas1 = grid[['BCAS1_count']].dropna().copy()
    gmm = GaussianMixture(n_components=3, random_state=0, covariance_type=_cov).fit(df_bcas1[['BCAS1_count']].values)
    bcas1_prediction = gmm.predict(grid[['BCAS1_count']].dropna().values)
    df_bcas1['predict'] = bcas1_prediction

    fig, axes = plt.subplots(1, 2, figsize=(10,5))
    sns.histplot(ax=axes[0], data=df_mobp, x='MOBP_count', bins=30, hue="predict", palette="viridis", edgecolor='k')
    sns.histplot(ax=axes[1], data=df_bcas1, x='BCAS1_count', bins=30, hue="predict", palette="viridis", edgecolor='k')
    plt.suptitle(_cov)
    plt.show()

    grid['BCAS1_preict'] = np.nan
    grid.loc[df_bcas1.index, 'BCAS1_predict'] = df_bcas1['predict']
    grid['MOBP_predict'] = np.nan
    grid.loc[df_mobp.index, 'MOBP_predict'] = df_mobp['predict']

    fig, axes = plt.subplots(1, 2, figsize=(10,8))
    ax=axes[0]
    ax.set_title('Oligodendrocytes - MOBP Predict')
    grid.plot(ax=ax, column="MOBP_predict", cmap="YlOrRd", edgecolor='k', alpha=0.7, linewidth=0.5)
    ax=axes[1]
    ax.set_title('Oligodendrocytes - BCAS1 Predict')
    grid.plot(ax=ax, column="BCAS1_predict", cmap="YlOrRd", edgecolor='k', alpha=0.7, linewidth=0.5)
    plt.suptitle(_cov)
    plt.tight_layout()
    plt.show()

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10,5))
sns.histplot(ax=axes[0], data=df_mobp, x='MOBP_count', bins=30, hue="predict", palette="viridis", edgecolor='k')
sns.histplot(ax=axes[1], data=df_bcas1, x='BCAS1_count', bins=30, hue="predict", palette="viridis", edgecolor='k')
plt.show()

In [None]:
grid['BCAS1_preict'] = np.nan
grid.loc[df_bcas1.index, 'BCAS1_predict'] = df_bcas1['predict']
grid['MOBP_predict'] = np.nan
grid.loc[df_mobp.index, 'MOBP_predict'] = df_mobp['predict']

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10,8))
ax=axes[0]
ax.set_title('Oligodendrocytes - MOBP Predict')
grid.plot(ax=ax, column="MOBP_predict", cmap="YlOrRd", edgecolor='k', alpha=0.7, linewidth=0.5)
ax=axes[1]
ax.set_title('Oligodendrocytes - BCAS1 Predict')
grid.plot(ax=ax, column="BCAS1_predict", cmap="YlOrRd", edgecolor='k', alpha=0.7, linewidth=0.5)
plt.tight_layout()
plt.show()

## Plot For Tati

In [None]:
# adata_valid.obs['cell'] = adata_valid.obs['cell'].astype(int)
# adata.uns.keys()

In [None]:
msn_groups = list(adata.uns['MSN_Groups_palette'].keys())
msn_palette = [adata.uns['MSN_Groups_palette'][group] for group in msn_groups]

In [None]:
# msn_groups = ['STRd D1 Striosome MSN', 'STRd D2 Striosome MSN', 'STRd D1 Matrix MSN', 'STRd D2 Matrix MSN', 
#              'STRd D1/D2 Hybrid MSN', 'STRv D1 MSN', 'STRv D2 MSN']
fig, ax = plt.subplots()
ax.set_facecolor("black")
(
    sdata_sub.pl.render_shapes(element=shapes_key, color='grey', fill_alpha=0.75, outline=True)
    .pl.render_shapes(element=shapes_key, color='Group', groups=msn_groups, palette=msn_palette, fill_alpha=1, outline=True, scale=2)
    .pl.show(coordinate_systems='pixel', ax=ax)
)
ax.set_xticks([])
ax.set_yticks([])
ax.set_title(None)
fig.savefig("/home/x-aklein2/projects/aklein/BICAN/BG/images/figures/sample_merfish/NAC_UCI5224_MSNs.png", dpi=300, bbox_inches='tight', pad_inches=0)
plt.show()

In [None]:
image_channels = sd.models.get_channel_names(sdata[KEYS[IMAGE_KEY]])
image_scale_keys = list(sdata[KEYS[IMAGE_KEY]].keys())

max_int = (
    sdata[IMAGE_KEY][image_scale_keys[-1]]["image"]
    .max(["x", "y"])
    .compute()
    .to_dataframe()
    .to_dict()["image"]
)
min_int = (
    sdata[IMAGE_KEY][image_scale_keys[-1]]["image"]
    .min(["x", "y"])
    .compute()
    .to_dataframe()
    .to_dict()["image"]
)