In [None]:
import scanpy as sc
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import anndata
import squidpy as sq
import pandas as pd
from scipy.sparse.csgraph import connected_components
from collections import defaultdict
from math import ceil

In [None]:
adata1 = sc.read_h5ad("../nanostring/Prot_sopa/mgc/adata_mgc_14_annot.h5ad")
adata2 = sc.read_h5ad("../nanostring/Prot_sopa/protein_annot/14_annot.h5ad")
adata3 = sc.read_h5ad("../nanostring/Prot_sopa/mgc/adata_mgc_10_annot.h5ad")
adata4 = sc.read_h5ad("../nanostring/Prot_sopa/protein_annot/13_annot.h5ad")

In [None]:
adata = anndata.concat([adata1, adata2, adata3, adata4], join='outer', label='sample')

In [None]:
categories = adata.obs['annot_level1'].cat.categories
print(categories)

In [None]:
populations_to_remove = ["Artifact", "Autofluorence", "Autofluo/collagen", "RBC"]
adata = adata[~adata.obs.annot_level1.isin(populations_to_remove)]

In [None]:
NICHES_KEY = "global_niche"
CT0_KEY = "annot_level0"
CT1_KEY = "annot_level1"
ID_KEY = "ID"
PCR = "MGC"
NPCR = "non_MGC"

GROUP_KEY = "RCB"

GROUPS = {
    PCR: ['14H007030716H0216858_up','14H007030716H0216858_down',"10H064210813H0208914_up","10H064210813H0208914_down"],
    NPCR: ['Up_14H042050615H0377937',
            'Down_14H042050615H0377937',"Up_13H079750314H0227010","Down_13H079750314H0227010"],
}
TO_GROUP = {image_id: group for group, ids in GROUPS.items() for image_id in ids}

In [None]:
sq.gr.spatial_neighbors(adata, library_key=ID_KEY, coord_type="generic", radius=(0, 70), delaunay=True)

# Cell type proportions

In [None]:
df_proportions_ct0 = adata.obs.groupby(ID_KEY)[CT0_KEY].value_counts(normalize=True).unstack()
df_proportions_ct0.columns = [f"{pop} ratio " for pop in df_proportions_ct0.columns]
df_proportions_ct0

In [None]:
df_proportions_ct1 = adata.obs.groupby(ID_KEY)[CT1_KEY].value_counts(normalize=True).unstack()
df_proportions_ct1.columns = [f"{pop} ratio " for pop in df_proportions_ct1.columns]
df_proportions_ct1

## Niches enrichment

In [None]:
df_niche_infiltration0 = adata.obs.groupby([ID_KEY, NICHES_KEY])[CT0_KEY].value_counts(normalize=True).unstack(level=[1, 2])
df_niche_infiltration0.columns = [f"{pop} ratio in niche {niche}" for niche, pop in df_niche_infiltration0.columns]
df_niche_infiltration0

In [None]:
df_niche_infiltration1 = adata.obs.groupby([ID_KEY, NICHES_KEY])[CT1_KEY].value_counts(normalize=True).unstack(level=[1, 2])
df_niche_infiltration1.columns = [f"{pop} ratio in niche {niche}" for niche, pop in df_niche_infiltration1.columns]
df_niche_infiltration1

## Niches proportions

In [None]:
df_niches_proportion = adata.obs.groupby(ID_KEY)[NICHES_KEY].value_counts(normalize=True).unstack()
df_niches_proportion.columns = [f"{niche} niche ratio" for niche in df_niches_proportion.columns]
df_niches_proportion

## Squidpy statistics

In [None]:
def get_nhood_enrichment_df(ct_key):
    series = []

    pops = adata.obs[ct_key].cat.categories
    pop_pairs = [(pops[i], pops[j]) for i in range(len(pops)) for j in range(i+1, len(pops))]

    for image_id in adata.obs[ID_KEY].cat.categories:
        adata_sub = adata[adata.obs[ID_KEY] == image_id].copy()
        sq.gr.nhood_enrichment(adata_sub, cluster_key=ct_key)
        
        df_ = pd.DataFrame(adata_sub.uns[f'{ct_key}_nhood_enrichment']["zscore"], index=adata_sub.obs[ct_key].cat.categories, columns=adata_sub.obs[ct_key].cat.categories) # np.cbrt() ?
        s = pd.Series([df_.loc[p1, p2] for p1, p2 in pop_pairs], index=[f"Ngh enrighment: {p1} <-> {p2}" for p1, p2 in pop_pairs])    
        series.append(s)
        
    return pd.concat(series, axis=1, keys=adata.obs[ID_KEY].cat.categories).T

In [None]:
def get_nhood_enrichment_df(ct_key):
    series = []

    pops = adata.obs[ct_key].cat.categories
    pop_pairs = [(pops[i], pops[j]) for i in range(len(pops)) for j in range(i+1, len(pops))]

    for image_id in adata.obs[ID_KEY].cat.categories:
        adata_sub = adata[adata.obs[ID_KEY] == image_id].copy()
        sq.gr.nhood_enrichment(adata_sub, cluster_key=ct_key)
        
        df_ = pd.DataFrame(adata_sub.uns[f'{ct_key}_nhood_enrichment']["zscore"], index=adata_sub.obs[ct_key].cat.categories, columns=adata_sub.obs[ct_key].cat.categories) # np.cbrt() ?
        df_ = df_.reindex(index=pops, columns=pops)  # Fix: Reindex the dataframe to include missing categories
        s = pd.Series([df_.loc[p1, p2] for p1, p2 in pop_pairs], index=[f"Ngh enrighment: {p1} <-> {p2}" for p1, p2 in pop_pairs])    
        series.append(s)
        
    return pd.concat(series, axis=1, keys=adata.obs[ID_KEY].cat.categories).T

In [None]:
df_nhood_enrichment0 = get_nhood_enrichment_df(CT0_KEY)
df_nhood_enrichment1 = get_nhood_enrichment_df(CT1_KEY)

In [None]:
df_nhood_enrichment1

Centrality scores

In [None]:
def get_centrality_scores_df(ct_key):
    series = []

    for image_id in adata.obs[ID_KEY].cat.categories:
        adata_sub = adata[adata.obs[ID_KEY] == image_id].copy()
        sq.gr.centrality_scores(adata_sub, ct_key)
        s = adata_sub.uns[f"{ct_key}_centrality_scores"].unstack().copy()
        s.index = [f'{pop} {n.replace("_", " ")}' for n, pop in s.index]
        series.append(s)
        
    return pd.concat(series, axis=1, keys=adata.obs[ID_KEY].cat.categories).T

In [None]:
df_centrality_scores0 = get_centrality_scores_df(CT0_KEY)
df_centrality_scores1 = get_centrality_scores_df(CT1_KEY)

## Distances to niches

In [None]:
sns.displot(adata.obsp['spatial_distances'].data)

In [None]:
for niche_id in adata.obs[NICHES_KEY].unique():
    niche_nodes = np.where(adata.obs[NICHES_KEY] == niche_id)[0]
    
    distances = np.full(adata.n_obs, np.nan)
    current_distance = 0
    distances[niche_nodes] = current_distance

    visited = set(niche_nodes)
    queue = niche_nodes

    while len(queue):
        distances[queue] = current_distance
        
        neighbors = set(adata.obsp['spatial_connectivities'][queue].indices)
        queue = np.array(list(neighbors - visited))
        visited |= neighbors
        
        current_distance += 1
    
    adata.obs[f"distance_to_niche_{niche_id}"] = distances

In [None]:
v = adata.obs['distance_to_niche_Tumour'].values
sns.displot(v[v > 0])

In [None]:
series = []

for niche_id in adata.obs[NICHES_KEY].unique():
    key = f"distance_to_niche_{niche_id}"
    obs = adata.obs
    obs = obs[obs[key] > 0]
    s = obs.groupby(ID_KEY)[key].mean()
    series.append(s)
    
df_distances_niches = pd.concat(series, axis=1, keys=[f"Mean distance to niche: {niche}" for niche in adata.obs[NICHES_KEY].unique()])

In [None]:
df_distances_niches

In [None]:
def get_distance_niche(ct_key):
    dfs_ = []

    for niche_id in adata.obs[NICHES_KEY].unique():
        key = f"distance_to_niche_{niche_id}"
        obs = adata.obs
        obs = obs[obs[key] > 0]
        df_ = obs.groupby([ID_KEY, ct_key])[key].mean().unstack()
        df_.columns = [f"{c} mean distance to niche: {niche_id}" for c in df_.columns]
        dfs_.append(df_)
        
    return pd.concat(dfs_, axis=1)

In [None]:
df_distances_niches0 = get_distance_niche(CT0_KEY)
df_distances_niches1 = get_distance_niche(CT1_KEY)

# Global analysis

In [None]:
dataframes_dict = {
    "Proportions ct0": df_proportions_ct0,
    "Proportions ct1": df_proportions_ct1,
    "Niche Infiltration ct0": df_niche_infiltration0,
    "Niche Infiltration ct1": df_niche_infiltration1,
    "Niches proportion": df_niches_proportion,
    "Nnhood enrichment ct0": df_nhood_enrichment0,
    "Nnhood enrichment ct1": df_nhood_enrichment1,
    "Centrality scores ct0": df_centrality_scores0,
    "Centrality scores ct1": df_centrality_scores1,
    "Distances niches": df_distances_niches,
    "Distances niches ct0": df_distances_niches0,
    "Distances niches ct1": df_distances_niches1,
}

df = pd.concat(dataframes_dict.values(), axis=1, keys=dataframes_dict.keys())
df = df.loc[:, ~df.columns.get_level_values(1).duplicated()].copy()
dfs = {group: df.loc[ids] for group, ids in GROUPS.items()}

df.index = pd.MultiIndex.from_arrays([df.index, df.index.map(TO_GROUP)], names=[df.index.name, GROUP_KEY])

In [None]:
filtered_res = df[df.index.get_level_values(GROUP_KEY).isin([PCR, NPCR])]
filtered_res


In [None]:
palette = {
    NPCR: "#F77189",
    PCR: "#36ADA4",
}

def plot_biomarker(df, label, figsize=(3, 4), level=1, ax=None):
    if isinstance(label, int):
        label = df.columns[label][1]
        
    x = df.index.get_level_values(GROUP_KEY)
    y = df.loc[:, df.columns.get_level_values(level) == label].squeeze()
    
    if ax is None:
        plt.figure(figsize=figsize)
    
    sns.boxplot(x=x, y=y, width=.2, palette=palette, ax=ax)
    
    # Check if label is NPCR and set color of dots accordingly
    if label == NPCR:
        sns.stripplot(x=x, y=y, size=8, color=palette[PCR], linewidth=0, ax=ax)
    else:
        sns.stripplot(x=x, y=y, size=8, color=".3", linewidth=0, ax=ax)

    sns.despine(offset=10, trim=True, ax=ax)
    if ax:
        ax.set_xlabel("")
        ax.set_ylabel(label)
    else:
        plt.xlabel("")
        plt.ylabel(label)
    
def plot_many(df, labels, figsize=(10, 10), level=1, ncols=4):
    fig, axs = plt.subplots(nrows=ceil(len(labels) / ncols), ncols=ncols, figsize=figsize)

    flat_axs = axs.reshape(-1)
    for i, label in enumerate(labels):
        plot_biomarker(df, label, level=level, ax=flat_axs[i])

In [None]:
plot_biomarker(filtered_res, 1)

### Stats

In [None]:
from scipy import stats

In [None]:
group1, group2 = GROUPS.keys()

In [None]:
df.to_csv("stats_proportions_global_MGC_overall.csv")

In [None]:
def run_ttest(i):
    res = stats.ttest_ind(dfs[group1].values[:, i], dfs[group2].values[:, i], equal_var=False)
    
    return pd.Series([res.statistic, res.pvalue], index=["statistic", "pvalue"])

In [None]:
run_ttest(0)

In [None]:
df_stat = pd.concat([run_ttest(i) for i in range(df.shape[1])], axis=1, keys=df.columns.get_level_values(1)).T.sort_values("pvalue")

In [None]:
plot_many(filtered_res, best_markers, figsize=(20, 100))
