In [None]:
import os
import pandas as pd
import numpy as np
import quiche as qu
import matplotlib.pyplot as plt
import seaborn as sns
from supplementary_plot_helpers import *
import mudata
from scipy.stats import ranksums
from statsmodels.stats.multitest import multipletests
from ark.utils.plot_utils import cohort_cluster_plot
import anndata
sns.set_style('ticks')

%reload_ext autoreload
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
save_directory = os.path.join('publications', 'supplementary_figures', 'supplementary_figure21')
qu.pp.make_directory(save_directory)

In [None]:
adata_spain = anndata.read_h5ad(os.path.join('data', 'Zenodo', 'spain_preprocessed.h5ad'))

fiber_stats = pd.read_csv('/Volumes/Shared/Noah Greenwald/TNBC_Cohorts/SPAIN/intermediate_files/fiber_segmentation_processed_data/fiber_stats_table.csv')
fiber_stats_standardized = pd.DataFrame(qu.pp.standardize(fiber_stats.iloc[:, 1:]), index = fiber_stats['fov'], columns = fiber_stats.iloc[:, 1:].columns)
fiber_stats_standardized.drop(['avg_euler_number'], axis = 1, inplace = True)
relapse_dict = dict(zip(adata_spain.obs[['Patient_ID', 'Relapse']].drop_duplicates()['Patient_ID'], adata_spain.obs[['Patient_ID', 'Relapse']].drop_duplicates()['Relapse']))
fov_dict = dict(zip(adata_spain.obs[['fov', 'Patient_ID']].drop_duplicates()['fov'], adata_spain.obs[['fov', 'Patient_ID']].drop_duplicates()['Patient_ID']))

mdata_relapse = mudata.read_h5mu('/Users/jolene/Documents/Angelo_lab/quiche_outdated/quiche/data/tnbc_spain/mdata/mdata_spain_study_corrected.h5ad')
mdata_relapse['quiche'].var[['logFC', 'SpatialFDR']] = mdata_relapse['quiche'].var[['logFC', 'SpatialFDR']].astype('float')

scores_df_spain = pd.DataFrame(mdata_relapse['quiche'].var.groupby('quiche_niche')['SpatialFDR'].median())
scores_df_spain.columns = ['pval']
scores_df_spain['logFC'] = mdata_relapse['quiche'].var.groupby('quiche_niche')['logFC'].mean()
scores_df_spain = scores_df_spain[scores_df_spain['pval'] < 0.05]
ids = list(set(scores_df_spain.index).intersection(set(list(mdata_relapse['quiche'].var['quiche_niche'].value_counts()[mdata_relapse['quiche'].var['quiche_niche'].value_counts() >= 5].index))))
scores_df_spain = scores_df_spain.loc[ids]
scores_df_spain = scores_df_spain[(scores_df_spain.logFC > 0.5) | (scores_df_spain.logFC < -0.5)]
niches_spain = list(scores_df_spain.index)

cov_count_df = qu.tl.compute_patient_proportion(mdata_relapse,
                                niches = niches_spain,
                                feature_key = 'quiche',
                                annot_key = 'quiche_niche',
                                patient_key = 'Patient_ID',
                                design_key = 'Relapse',
                                patient_niche_threshold = 5)

cov_count_df_frequent = cov_count_df[cov_count_df['patient_count'] >= 3]

abundance_table = mdata_relapse['quiche'].var.groupby(['Patient_ID', 'quiche_niche']).size().unstack(fill_value=0)
abundance_table = abundance_table.loc[:, list(cov_count_df_frequent.sort_values(by = 'mean_logFC')['quiche_niche'].drop_duplicates())]
abundance_table.index = abundance_table.index.astype('float')

patient_dict = dict(zip(adata_spain.obs[['fov', 'Patient_ID']].drop_duplicates()['fov'], adata_spain.obs[['fov', 'Patient_ID']].drop_duplicates()['Patient_ID']))
fiber_stats['Patient_ID'] = fiber_stats.fov.map(patient_dict)
fiber_stats = fiber_stats[np.isin(fiber_stats['Patient_ID'], abundance_table.index.astype('float'))]
fiber_stats.drop(columns  = ['fov'], inplace = True)
fiber_stats = fiber_stats.groupby('Patient_ID').mean()

abundance_table.reset_index(drop=True, inplace=True)
fiber_stats.reset_index(drop=True, inplace=True)
fiber_stats = fiber_stats.dropna()
abundance_table = abundance_table.dropna()
abundance_table = abundance_table.loc[:, abundance_table.nunique() > 1]
fiber_stats = fiber_stats.loc[:, fiber_stats.nunique() > 1]

In [None]:
abundance_table = abundance_table.loc[:, abundance_table.nunique() > 1]
fiber_stats = fiber_stats.loc[:, fiber_stats.nunique() > 1]

correlation_matrix = np.zeros((abundance_table.shape[1], fiber_stats.shape[1]))
for i, df1_col in enumerate(list(abundance_table.columns)):
    for j, df2_col in enumerate(list(fiber_stats.columns)):
        corr, _ = spearmanr(abundance_table[df1_col], fiber_stats[df2_col], nan_policy='omit')
        correlation_matrix[i, j] = corr
correlation_df = pd.DataFrame(correlation_matrix, index=abundance_table.columns, columns=fiber_stats.columns)

In [None]:
stat_columns = [col for col in fiber_stats_standardized.columns if col != 'Relapse']
pvals = []
for stat in stat_columns:
    group1 = fiber_stats_standardized[fiber_stats_standardized['Relapse'] == 1][stat]
    group0 = fiber_stats_standardized[fiber_stats_standardized['Relapse'] == 0][stat]
    stat_test = ranksums(group1, group0)
    pvals.append(stat_test.pvalue)
_, pvals_corrected, _, _ = multipletests(pvals, method='fdr_bh')

ncols = 8
nrows = (len(stat_columns) + ncols - 1) // ncols
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(16, 3))
axes = axes.flatten()

for i, stat in enumerate(stat_columns):
    ax = axes[i]
    g = sns.boxplot(x='Relapse', y=stat, data=fiber_stats_standardized, ax=ax, whis=1.5, fliersize=0, boxprops=dict(alpha=0.6))
    g = sns.stripplot(x='Relapse', y=stat, data=fiber_stats_standardized, ax=ax, color='black', alpha=0.6, jitter=True)
    g.tick_params(labelsize = 12)
    ax.set_title(stat, fontsize = 12)
    fdr_label = f'FDR pvalue = {pvals_corrected[i]:.3f}'
    ax.text(0.5, 0.95, fdr_label, transform=ax.transAxes,
            ha='center', va='top', fontsize=10, bbox=dict(boxstyle="round", fc="white", ec="gray"))
    ax.set_xlabel('Relapse', fontsize = 12)
    ax.set_ylabel('Standardized Value', fontsize = 12)

# Hide unused subplots
for j in range(i + 1, len(axes)):
    axes[j].axis('off')

plt.tight_layout()
plt.savefig(os.path.join(save_directory, 'supplementary_figure21a.pdf'), bbox_inches = 'tight')


In [None]:
plt.figure(figsize=(8, 12))
sns.heatmap(correlation_df, annot=True, cmap='RdBu_r', fmt=".2f", linewidths=0.5,
            annot_kws={'size': 10}, vmin = -0.5, vmax = 0.5)


plt.ylabel('QUICHE niche neighborhoods', fontsize=14)
plt.xlabel('Collagen Fiber Metrics', fontsize=14)

# Rotate tick labels to make them readable
plt.xticks(rotation=45, ha='right', fontsize=12)  # Rotate x-axis labels
plt.yticks(rotation=0, fontsize=12)  # Rotate y-axis labels
plt.axhline(38, color = 'k')
# Adjust layout for better spacing
plt.tight_layout()

# Show the plot
plt.savefig(os.path.join(save_directory, 'supplementary_figure21b.pdf'))