In [None]:
# parameters
ad_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_PFV8_annotated_v5.h5ad"
sd_store = "/home/x-aklein2/projects/aklein/BICAN/data/zarr_store"
output_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/regions/wm"

image_path = "/home/x-aklein2/projects/aklein/BICAN/BG/images/regions/wm"
save_figs = False

transfer_genes = ["BCAS1", "OPALIN", "MOBP", "PLEKHH1"]
hex_size = 50
hex_overlap = 0
gmm_cov_type = "full" # "diag", tied
gene_agreement_thr = 0.75
dsc_comp_min_size = 5
gmm_ncomp = 3

skip_existing = False



In [None]:
# imports
import os
from pathlib import Path
import itertools 
from tqdm import tqdm

import math
import numpy as np
import pandas as pd
import anndata as ad
import spatialdata as sd

import geopandas as gpd
from shapely import Polygon, Point, box
from sklearn.mixture import GaussianMixture
import libpysal as lps
import networkx as nx

import matplotlib.pyplot as plt
import seaborn as sns
import spatialdata_plot as sdp # type: ignore
plt.rcParams['figure.dpi'] = 150

from spida.utilities.tiling import create_hexagonal_grid

In [None]:
adata = ad.read_h5ad(ad_path)
adata

In [None]:
donors = adata.obs['donor'].unique().tolist()
replicates = adata.obs['replicate'].unique().tolist()
brain_regions = adata.obs['brain_region'].unique().tolist()
experiments = adata.obs['experiment'].unique().tolist()
skip = [("UWA7648", "CAT", "ucsd"), ("UWA7648", "CAT", "salk")]

In [None]:
Path(image_path).mkdir(parents=True, exist_ok=True)
Path(output_path).mkdir(parents=True, exist_ok=True)
if isinstance(transfer_genes, str):
    if transfer_genes.startswith('[') and transfer_genes.endswith(']'):
        transfer_genes = transfer_genes[1:-1].split(",")
        transfer_genes = [gene.strip().strip('"').strip("'") for gene in transfer_genes]
    else: 
        transfer_genes = [transfer_genes]
print(transfer_genes)

In [None]:
## Iterating for all elements
pbar = tqdm(itertools.product(donors, brain_regions, replicates))
for _i in pbar:
    if _i in skip:
        # print(f"Skipping {_i}")
        continue
    _donor, _brain_region, _replicate, = _i
    _experiment, _region = adata.obs.loc[(adata.obs['donor'] == _donor) & 
                           (adata.obs['brain_region'] == _brain_region) & 
                           (adata.obs['replicate'] == _replicate), ['experiment', 'region']].values[0]
    pbar.set_description(f"Processing {_i} ({_experiment}, {_region})")
    out_path_wm = f"{output_path}/{_donor}_{_brain_region}_{_replicate}_wm_regions.gpkg"
    if skip_existing and Path(out_path_wm).exists():
        print(f"Regions already exist for {_i}, skipping...")
        continue

    zarr_path = f"{sd_store}/{_experiment}/{_region}"
    sdata = sd.read_zarr(zarr_path)
    # print(sdata)

    cs = "pixel"
    ch = "DAPI"
    image_key = f"default_{_experiment}_{_region}_z3"
    points_key = f"proseg_fv38_{_experiment}_{_region}_transcripts"
    shapes_key = f"proseg_fv38_{_experiment}_{_region}_polygons"
    tab_key1 = "proseg_fv38_table_filt"
    tab_key2 = "proseg_fv38_annot"

    fts = sdata[points_key].compute()
    fts = fts.reset_index()
    fts['gene'] = fts['gene'].astype("category")

    gdf = gpd.GeoDataFrame(fts, geometry=gpd.points_from_xy(fts['x'], fts['y'])) 
    # print(gdf.head())

    for _gene in transfer_genes: 
        print(_gene, _gene in fts['gene'].cat.categories)
        gdf[_gene] = (gdf['gene'] == _gene).astype(int)

    total_bounds = gdf.total_bounds  # (minx, miny, maxx, maxy)
    grid = create_hexagonal_grid(total_bounds, hex_size, overlap=hex_overlap)

    grid['hex_id'] = grid.index.astype(str)
    grid = grid.set_index("hex_id")
    joint_grid = gpd.sjoin(grid, gdf, how="inner", predicate="contains")

    grid['cell_count'] = joint_grid.groupby('hex_id').size()
    for _gene in transfer_genes:
        grid[f'{_gene}_count'] = joint_grid.groupby('hex_id')[_gene].sum()

    ncols = len(transfer_genes)
    fig, axes = plt.subplots(1, ncols, figsize=(5*ncols, 5))
    for i, _gene in enumerate(transfer_genes):
        ax=axes[i] if ncols > 1 else axes
        ax.set_title(f'{_gene}')
        grid.plot(ax=ax, column=f"{_gene}_count", cmap="YlOrRd", edgecolor='k', alpha=0.7, linewidth=0.5)
    plt.show()
    plt.close()

    fig, axes = plt.subplots(1, ncols, figsize=(5*ncols, 5))
    for i, _gene in enumerate(transfer_genes):
        print(f"Fitting GMM for {_gene}...")
        df_gene = grid[[_gene + '_count']].dropna().copy()
        gmm = GaussianMixture(n_components=gmm_ncomp, random_state=0, covariance_type=gmm_cov_type).fit(df_gene[[_gene + '_count']].values)
        gene_prediction = gmm.predict(grid[[_gene + '_count']].dropna().values)
        df_gene['predict'] = gene_prediction
        pred_val = df_gene.groupby("predict").mean().idxmax(axis=0)[0]
        df_gene['predict'] = (df_gene['predict'] == pred_val).astype(int)
        grid.loc[df_gene.index, _gene + '_predict'] = df_gene['predict']
        
        # for plotting
        ax = axes[i] if ncols > 1 else axes
        sns.histplot(ax=ax, data=df_gene, x=_gene + '_count', bins=30, hue="predict", palette="viridis", edgecolor='k')
        ax.set_title(f'{_gene} GMM Prediction')
    plt.show()
    plt.close()

    ncols = len(transfer_genes)
    fig, axes = plt.subplots(1, ncols, figsize=(5*ncols, 5))
    for i, _gene in enumerate(transfer_genes):
        ax=axes[i] if ncols > 1 else axes
        ax.set_title(f'{_gene}')
        grid.plot(ax=ax, column=f"{_gene}_predict", cmap="YlOrRd", edgecolor='k', alpha=0.7, linewidth=0.5)
    if save_figs: 
        plt.savefig(f"{image_path}/{_donor}_{_brain_region}_{_replicate}_wm_gene_predictions.png", bbox_inches='tight')
    else: 
        plt.show()
        plt.close()

    hex_ids = {}
    for _gene in transfer_genes: 
        hexes = set(grid[grid[_gene + "_predict"] == 1].index)
        for _h in hexes: 
            hex_ids[_h] = 1 if _h not in hex_ids else hex_ids[_h] + 1

    df_hids = pd.DataFrame.from_dict(hex_ids, orient="index")
    df_hids = df_hids / len(transfer_genes)
    chosen_cells = df_hids[df_hids[0] >= gene_agreement_thr].index
    chosen_cells = grid.loc[list(chosen_cells)]
    print(f"Chosen Hexes: {len(chosen_cells)}")

    fig, ax = plt.subplots()
    grid.plot(ax=ax, color='lightgrey', edgecolor='k', alpha=0.5)
    chosen_cells.plot(ax=ax, color='red', edgecolor='k', alpha=0.8)

    chosen_cells = chosen_cells.reset_index()
    W = lps.weights.Queen.from_dataframe(chosen_cells)
    G = W.to_networkx()
    connected_components = list(nx.connected_components(G))
    disconnected_comp = [comp for comp in connected_components]

    chosen_cells['comp'] = -1
    geoms = []
    for i, disc in enumerate(disconnected_comp): 
        print(len(disc))
        if len(disc) < dsc_comp_min_size:
            continue
        chosen_cells.loc[list(disc), "comp"] = i
        temp = chosen_cells[chosen_cells.index.isin(disc)]
        geom = temp.union_all().convex_hull
        geoms.append(geom)

    gdf_geoms = gpd.GeoDataFrame(geometry=geoms)
    union = gdf_geoms.union_all()
    gdf_geoms = gpd.GeoDataFrame(geometry=[union])
    
    # Plot
    fig, ax = plt.subplots()
    gdf_geoms.plot(ax=ax, edgecolor='k', color='none').axis("off");
    plt.show()
    plt.close()

    # Save Geoms: 
    gdf_geoms = gdf_geoms.explode()
    gdf_geoms.to_file(out_path_wm, driver="GPKG")

    geoms_key = "wm_regions"
    sdata[geoms_key] = sd.models.ShapesModel().parse(gdf_geoms)
    sd.transformations.set_transformation(
        sdata[geoms_key],
        sd.transformations.get_transformation(sdata[points_key], to_coordinate_system="pixel"),
        to_coordinate_system="pixel"
    )

    sd.transformations.set_transformation(
        sdata[geoms_key],
        sd.transformations.get_transformation(sdata[points_key], to_coordinate_system="global"),
        to_coordinate_system="global"
    )

    # final Plots
    try: 
        fig, ax = plt.subplots(figsize=(5,5))
        (
            sdata.pl.render_images(image_key, channel=ch, cmap="gray")
            .pl.render_shapes(geoms_key, color="none", outline_color="red", outline_width=2, outline_alpha=1, fill_alpha=0.5)
            .pl.show(ax=ax, coordinate_systems=cs)
        )
        ax.set_title(f"{_donor} {_brain_region} {_replicate} - WM Regions")
        if save_figs:
            plt.savefig(f"{image_path}/{_donor}_{_brain_region}_{_replicate}_wm_regions.png", bbox_inches='tight')
        plt.close()
    except ValueError as e:
        print(f"Could not plot final figure for {_i}: {e}")
