In [None]:
### Notebook to take the SPIDA outputs and add BICAN_BG relevant metadata ###
### Also do 2 rounds of clustering with Harmony integration ####

In [None]:
import os
import json
from pathlib import Path

import numpy as np
import pandas as pd
import scanpy as sc
import anndata as ad


import matplotlib.pyplot as plt
import seaborn as sns
from spida.pl import plot_categorical
plt.rcParams['axes.facecolor'] = 'white'

from datetime import datetime 
current_datetime = datetime.now().strftime("%Y-%m-%d_%H:%M")

In [None]:
#parameters
EXPERIMENT = "PU" 
prefix = "BICAN_BG"
suffix = "proseg_fv38_filt"
input_dir = f"/home/x-aklein2/projects/aklein/BICAN/data/aggregated"
output_dir = "/home/x-aklein2/projects/aklein/BICAN/BG/data/annotation"
image_path = "/home/x-aklein2/projects/aklein/BICAN/BG/images/annotations"
gene_rename_path = "/home/x-aklein2/projects/aklein/BICAN/data/reference/AIT/BG_gene_rename.json"

In [None]:
salk_path = Path(f"{input_dir}/{prefix}_{EXPERIMENT}_salk_{suffix}.h5ad")
ucsd_path = Path(f"{input_dir}/{prefix}_{EXPERIMENT}_ucsd_{suffix}.h5ad")
out_path = Path(f"{output_dir}/{prefix}_{EXPERIMENT}_{suffix}/{EXPERIMENT}.h5ad")
image_path = Path(image_path) / EXPERIMENT
image_path.mkdir(parents=True, exist_ok=True)

In [None]:
out_path.parent.mkdir(parents=True, exist_ok=True)
adata_salk = sc.read_h5ad(salk_path)
adata_ucsd = sc.read_h5ad(ucsd_path)

adata_salk

In [None]:
adata = adata_salk.concatenate(adata_ucsd, batch_key="dataset", batch_categories=["salk","ucsd"], index_unique=None)
adata.write_h5ad(out_path)
adata

In [None]:
for _df in adata.obsm.keys(): 
    if isinstance(adata.obsm[_df], pd.DataFrame): 
        print(_df, type(adata.obsm[_df]))
        adata.obsm[_df] = adata.obsm[_df].values
        print(_df, type(adata.obsm[_df]))
adata.obs.index = adata.obs['CELL_ID'].astype(str) + "." + adata.obs['dataset_id'].astype(str)

# Handling the gene renaming
if gene_rename_path is not None:
    gene_rename_map = pd.read_json(gene_rename_path, typ='series').to_dict()
    adata.var_names = adata.var_names.map(lambda x: gene_rename_map.get(x, x))

adata

In [None]:
adata.X = adata.layers['volume_norm'].copy()
sc.pp.log1p(adata)

### Add color palettes

In [None]:
def add_colors(adata, cat_col, palette):
    colors = []
    for _cat in adata.obs[cat_col].cat.categories: 
        try:
            if isinstance(palette, dict):
                color = palette[_cat]
            else:
                color = palette.loc[_cat, 'Hex']
        except KeyError:
            print(_cat)
            color = '#808080'
        colors.append(color)

    adata.uns[f'{cat_col}_colors'] = colors

In [None]:
experiments = adata.obs['experiment'].unique()
brain_regions = adata.obs['brain_region'].unique()
donors = adata.obs['donor'].unique()
replicates = adata.obs['replicate'].unique()
print(len(experiments), len(brain_regions), len(donors), len(replicates))

experiment_palette = {
    'CAB': '#F5867F',
    'CAH': '#AB4642',
    'CAT': '#430300',
    'PU': '#F98F34',
    'GP': '#6BBC46',
    'GPe': '#007600',
    'MGM1': '#FF2600',
    'NAC': '#0C4E9B',
    'STH': '#6B98C4',
    'SUBTH': '#6B98C4'
}

donor_palette = {
    'UWA7648': '#D87C79',
    'UCI4723': '#7A4300',
    'UCI2424': '#D7A800',
    'UCI5224': '#AB4CAA'
}

replicate_palette = {
    "ucsd" : '#039BE5',
    "salk" : '#FFD54F'
}

adata.uns["brain_region_palette"] = experiment_palette
adata.uns["donor_palette"] = donor_palette
adata.uns["replicate_palette"] = replicate_palette

add_colors(adata, "brain_region", experiment_palette)
add_colors(adata, "donor", donor_palette)
add_colors(adata, "replicate", replicate_palette)

In [None]:
## Adding color schemes for the annotations - m3c (change to my custom colors): 
bg_color_palette_neighborhood = pd.read_excel('/anvil/projects/x-mcb130189/Wubin/BG/metadata/BG_color_palette.xlsx', sheet_name='Neighborhood', index_col=0)
bg_color_palette_class = pd.read_excel('/anvil/projects/x-mcb130189/Wubin/BG/metadata/BG_color_palette.xlsx', sheet_name='Class', index_col=0)
bg_color_palette_subclass = pd.read_excel('/anvil/projects/x-mcb130189/Wubin/BG/metadata/BG_color_palette.xlsx', sheet_name='Subclass', index_col=0)
bg_color_palette_group = pd.read_excel('/anvil/projects/x-mcb130189/Wubin/BG/metadata/BG_color_palette.xlsx', sheet_name='Group', index_col=0)
adata.uns['m3c_neighborhood_palette'] = bg_color_palette_neighborhood['Hex'].to_dict()
adata.uns['m3c_class_palette'] = bg_color_palette_class['Hex'].to_dict()
adata.uns['m3c_subclass_palette'] = bg_color_palette_subclass['Hex'].to_dict()
adata.uns['m3c_group_palette'] = bg_color_palette_group['Hex'].to_dict()

# add_colors(adata, "AIT_Subclass", bg_color_palette_subclass)


In [None]:
# add_colors(adata, "AIT_Group", bg_color_palette_group)

In [None]:
### Adding color scheme for the annotations - AIT
ref = ad.read_h5ad("/anvil/projects/x-mcb130189/Wubin/BICAN/adata/HMBA_v2/Human_HMBA_basalganglia_AIT_pre-print.h5ad", backed='r')
ref

In [None]:
%%capture
ait_color_palette_neighb = ref.obs[['Neighborhood', 'color_hex_neighborhood']].groupby('Neighborhood').first().to_dict()['color_hex_neighborhood']
ait_color_palette_class = ref.obs[['Class', 'color_hex_class']].groupby('Class').first().to_dict()['color_hex_class']
ait_color_palette_subclass = ref.obs[['Subclass', 'color_hex_subclass']].groupby('Subclass').first().to_dict()['color_hex_subclass']
ait_color_palette_group = ref.obs[['Group', 'color_hex_group']].groupby('Group').first().to_dict()['color_hex_group']
adata.uns['AIT_neighborhood_palette'] = ait_color_palette_neighb
adata.uns['AIT_class_palette'] = ait_color_palette_class
adata.uns['AIT_subclass_palette'] = ait_color_palette_subclass
adata.uns['AIT_group_palette'] = ait_color_palette_group

## Calculate Embeddings

In [None]:
from spida.P.setup_adata import multi_round_clustering

In [None]:
multi_round_clustering(adata,
                       layer="volume_norm",
                       key_added="base_",
                       num_rounds=2,
                       leiden_res=[0.75, 0.5],
                       min_dist=0.25,
                       knn=50,
                       min_group_size=50,
                       run_harmony=True, 
                       batch_key=["replicate", "donor"],
                       harmony_nclust=20,
                       max_iter_harmony=20,
                    )
adata

In [None]:
fig, axes = plt.subplots(1,2, figsize=(8,4), dpi=200)
plot_categorical(adata, cluster_col="base_round1_leiden", coord_base="X_base_round1_umap", show=False, ax=axes[0])
plot_categorical(adata, cluster_col="base_round2_leiden", coord_base="X_base_round1_umap", show=False, ax=axes[1])
plt.savefig(image_path / f"{current_datetime}_00_base_umap_leiden.png", dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
fig, axes = plt.subplots(1,2, figsize=(8,4), dpi=200)
plot_categorical(adata, cluster_col="donor", coord_base="X_base_round1_umap", show=False, ax=axes[0])
plot_categorical(adata, cluster_col="replicate", coord_base="X_base_round1_umap", show=False, ax=axes[1])
plt.savefig(image_path / f"{current_datetime}_00_base_umap_donor_replicate.png", dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
#TODO: Write out ADATA
adata.write_h5ad(out_path)
adata