In [None]:
## notebook to combine the results from the integration rounds.  ### 

In [None]:
# --> Get the filepaths, load individual results, combine them, and save the combined results. <--
# Use label2label annotations! 

In [None]:
import os
from pathlib import Path

import numpy as np
import pandas as pd
import scanpy as sc
import anndata as ad


import matplotlib.pyplot as plt
import seaborn as sns
from spida.pl import plot_categorical, plot_continuous
plt.rcParams['axes.facecolor'] = 'white'

In [None]:
#parameters
EXPERIMENT = "PU" 
TRANSFER_COL = "Group"

In [None]:
adata_path = Path(f"/home/x-aklein2/projects/aklein/BICAN/data/aggregated/BICAN_BG_{EXPERIMENT}_comb_cellpose_SAM_filt.h5ad")
group_annot_path = Path(f"/home/x-aklein2/projects/aklein/BICAN/data/annotated/BICAN_BG_{EXPERIMENT}")
out_path = Path(f"/home/x-aklein2/projects/aklein/BICAN/data/annotated/BICAN_BG_{EXPERIMENT}/{EXPERIMENT}.h5ad")

In [None]:
adata = ad.read_h5ad(adata_path)
adata

In [None]:
df_list = []
for _file in group_annot_path.glob("*.h5ad"):
    if "integrator" in _file.name or "combined" in _file.name or "BICAN" not in _file.name:
        continue
    print(_file)
    _adata = ad.read_h5ad(_file)
    cols = []
    for col in _adata.obs.columns:
        if TRANSFER_COL in col or "unk" in col:
            cols.append(col)
    df_list.append(_adata.obs[cols])

In [None]:
_adata = ad.read_h5ad("/home/x-aklein2/projects/aklein/BICAN/data/annotated/BICAN_BG_PU/BICAN_BG_PU_unknown.h5ad")
_adata

In [None]:
df_temp = pd.concat(df_list, axis=0)
df_temp = df_temp.loc[df_temp.index.isin(adata.obs_names)]
df_temp = df_temp.loc[~df_temp.isna().all(axis=1)]
df_temp = df_temp[df_temp.columns[~df_temp.isna().all(axis=0)]]
df_temp

In [None]:
df_temp['unk_leiden'] = "unk_" + df_temp['unk_leiden'].astype(object)
df_temp['unk_leiden'].value_counts()

In [None]:
assert df_temp.index.duplicated().sum() == 0, "There are duplicated indices!"

In [None]:
df_temp['AIT_Group'] = df_temp['unk_leiden'].astype('object').fillna(df_temp['allcools_Group_filt']).astype('category')

In [None]:
adata.obs[df_temp.columns] = df_temp
adata.obs['AIT_Subclass'] = adata.obs['allcools_Subclass_filt'].copy()

In [None]:
adata.write_h5ad(out_path)

### Plots! 

In [None]:
experiments = adata.obs['experiment'].unique()
brain_regions = adata.obs['brain_region'].unique()
donors = adata.obs['donor'].unique()
replicates = adata.obs['replicate'].unique()
print(len(experiments), len(brain_regions), len(donors), len(replicates))

experiment_palette = {}
experiment_palette['CAB'] = '#F5867F'
experiment_palette['CAH'] = '#AB4642'
experiment_palette['CAT'] = '#430300'
experiment_palette['PU'] = '#F98F34'
experiment_palette['GP'] = '#6BBC46'
experiment_palette['GPe'] = '#007600'
experiment_palette['MGM1'] = '#FF2600'
experiment_palette['NAC'] = '#0C4E9B'
experiment_palette['STH'] = '#6B98C4'
experiment_palette['SUBTH'] = '#6B98C4'

donor_palette = {
    'UWA7648': '#D87C79',
    'UCI4723': '#7A4300',
    'UCI2424': '#D7A800',
    'UCI5224': '#AB4CAA'
}

replicate_palette = {
    "ucsd" : '#039BE5',
    "salk" : '#FFD54F'
}

In [None]:
adata.uns["brain_region_palette"] = experiment_palette
adata.uns["donor_palette"] = donor_palette
adata.uns["replicate_palette"] = replicate_palette

cols = []
for exp_order in adata.obs['brain_region'].cat.categories: 
    cols.append(experiment_palette[exp_order])
adata.uns['brain_region_colors'] = cols

cols = []
for exp_order in adata.obs['donor'].cat.categories: 
    cols.append(donor_palette[exp_order])
adata.uns['donor_colors'] = cols

cols = []
for exp_order in adata.obs['replicate'].cat.categories: 
    cols.append(replicate_palette[exp_order])
adata.uns['replicate_colors'] = cols

In [None]:
## Adding color schemes for the annotations: 
bg_color_palette_subclass = pd.read_excel('/anvil/projects/x-mcb130189/Wubin/BG/metadata/BG_color_palette.xlsx', sheet_name='Subclass', index_col=0)
bg_color_palette_subclass.head()
bg_color_palette_group = pd.read_excel('/anvil/projects/x-mcb130189/Wubin/BG/metadata/BG_color_palette.xlsx', sheet_name='Group', index_col=0)
bg_color_palette_group.head()

In [None]:
def add_colors(adata, cat_col, palette):
    colors = []
    for _cat in adata.obs[cat_col].cat.categories: 
        try:
            color = palette.loc[_cat, 'Hex']
        except KeyError:
            print(_cat)
            color = '#808080'
        colors.append(color)

    adata.uns[f'{cat_col}_colors'] = colors

In [None]:
add_colors(adata, "AIT_Subclass", bg_color_palette_subclass)
add_colors(adata, "AIT_Group", bg_color_palette_group)

In [None]:
plot_categorical(adata, cluster_col="AIT_Subclass", coord_base="joint_umap", show=True, coding=True, text_anno=False)
plot_categorical(adata, cluster_col="AIT_Group", coord_base="joint_umap", show=True, coding=True, text_anno=False)
plot_categorical(adata, cluster_col="donor", coord_base="joint_umap", show=True, coding=True, text_anno=False)
plot_categorical(adata, cluster_col="replicate", coord_base="joint_umap", show=True, coding=True, text_anno=False)

In [None]:
### Write out ADATA object: 
adata.write_h5ad(out_path)

## ResolVI

In [None]:
# adata.obs[label_col]

In [None]:
label_col = f"AIT_{TRANSFER_COL}"
# label_col = f"AIT_Subclass"
adata.obs[label_col] = adata.obs[label_col].cat.remove_unused_categories()

In [None]:
import scvi

In [None]:
adata.obsm['X_spatial'] = adata.obsm['spatial'].copy()
scvi.external.RESOLVI.setup_anndata(adata, labels_key=label_col, layer="counts", batch_key="replicate")

In [None]:
supervised_resolvi = scvi.external.RESOLVI(adata, semisupervised=True)

In [None]:
supervised_resolvi.train(max_epochs=100)

In [None]:
adata.obsm['resolvi_celltypes'] = supervised_resolvi.predict(adata, num_samples=10, soft=True)
adata.obs['resolvi_predicted'] = adata.obsm['resolvi_celltypes'].idxmax(axis=1)
adata.obsm['X_resolVI'] = supervised_resolvi.get_latent_representation(adata)

In [None]:
plot_categorical(adata, cluster_col="resolvi_predicted", coord_base="X_resolVI", show=True, coding=True, text_anno=True)

In [None]:
samples_corr = supervised_resolvi.sample_posterior(
    model=supervised_resolvi.module.model_corrected,
    return_sites=["px_rate"],
    summary_fun={"post_sample_q50": np.median},
    num_samples=50,
    summary_frequency=30,
)
samples_corr = pd.DataFrame(samples_corr).T

In [None]:
samples = supervised_resolvi.sample_posterior(
    model=supervised_resolvi.module.model_residuals,
    return_sites=["mixture_proportions"],
    summary_fun={"post_sample_means": np.mean},
    num_samples=50,
    summary_frequency=100,
)
samples = pd.DataFrame(samples).T

In [None]:
adata.obs[["true_proportion", "diffusion_proportion", "background_proportion"]] = samples.loc[
    "post_sample_means", "mixture_proportions"
]

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
ax = axes[0]
plot_continuous(adata, color_by="true_proportion", coord_base="X_resolVI", show=False, ax=ax, cmap='viridis_r')
ax.set_title("True Proportion")
ax = axes[1]
plot_continuous(adata, color_by="diffusion_proportion", coord_base="X_resolVI", show=False, ax=ax, cmap='viridis_r')
ax.set_title("Diffusion Proportion")
ax = axes[2]
plot_continuous(adata, color_by="background_proportion", coord_base="X_resolVI", show=False, ax=ax, cmap='viridis_r')
ax.set_title("Background Proportion")

plt.show()

In [None]:
adata.layers["generated_expression"] = samples_corr.loc["post_sample_q50", "px_rate"]

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10, 5), dpi=200)
plot_categorical(adata, cluster_col="resolvi_predicted", coord_base="joint_umap", show=False, coding=True, text_anno=False, ax=ax[0])
plot_continuous(adata, color_by="OPALIN", coord_base="joint_umap", layer="counts", show=False, hue_portion=0.98, ax=ax[1])
plt.show()

In [None]:
adata_salk = adata[adata.obs['replicate'] == 'salk'].copy()

fig, ax = plt.subplots(1, 2, figsize=(10, 5), dpi=200)
plot_categorical(adata_salk, cluster_col="resolvi_predicted", coord_base="spatial", show=False, coding=True, text_anno=False, ax=ax[0])
plot_continuous(adata_salk, color_by="OPALIN", coord_base="spatial", layer="generated_expression", show=False, hue_portion=0.98, ax=ax[1])
plt.show()

In [None]:
adata.write_h5ad(out_path)