In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import mudata
import os
import quiche as qu
from sklearn.metrics import roc_curve, auc

In [None]:
save_directory = os.path.join('figures', 'supplementary_figures', 'supplementary_figure19')
qu.pp.make_directory(save_directory)

In [None]:
##load in data to save runtime
mdata = mudata.read_h5mu(os.path.join('/Users/jolene/Documents/Angelo_lab/quiche_outdated/quiche/data/tnbc_spain/mdata', 'mdata_spain_study_corrected.h5ad'))
mdata['quiche'].var[['logFC', 'SpatialFDR']] = mdata['quiche'].var[['logFC', 'SpatialFDR']].astype('float')

scores_df_spain = pd.DataFrame(mdata['quiche'].var.groupby('quiche_niche')['SpatialFDR'].median())
scores_df_spain.columns = ['pval']
scores_df_spain['logFC'] = mdata['quiche'].var.groupby('quiche_niche')['logFC'].mean()
scores_df_spain = scores_df_spain[scores_df_spain['pval'] < 0.05]
ids = list(set(scores_df_spain.index).intersection(set(list(mdata['quiche'].var['quiche_niche'].value_counts()[mdata['quiche'].var['quiche_niche'].value_counts() >= 5].index))))
scores_df_spain = scores_df_spain.loc[ids]
scores_df_spain = scores_df_spain[(scores_df_spain.logFC > 0.5) | (scores_df_spain.logFC < -0.5)]
niches_spain = list(scores_df_spain.index)

cov_count_df = qu.tl.compute_patient_proportion(mdata,
                                niches = niches_spain,
                                feature_key = 'quiche',
                                annot_key = 'quiche_niche',
                                patient_key = 'Patient_ID',
                                design_key = 'Relapse',
                                patient_niche_threshold = 5)

cov_count_df_neg = cov_count_df[cov_count_df['mean_logFC'] < 0]
cov_count_df_neg = cov_count_df_neg[cov_count_df_neg['patient_count'] >= 1]
cov_count_df_neg = cov_count_df_neg[cov_count_df_neg['Relapse'] == '0']

cov_count_df_pos = cov_count_df[cov_count_df['mean_logFC'] > 0]
cov_count_df_pos = cov_count_df_pos[cov_count_df_pos['patient_count'] >= 1]
cov_count_df_pos = cov_count_df_pos[cov_count_df_pos['Relapse'] == '1']

pos_niches = list(cov_count_df_pos.quiche_niche.values)
neg_niches = list(cov_count_df_neg.quiche_niche.values)

In [None]:
abundance_table = mdata['quiche'].var.groupby(['Patient_ID', 'quiche_niche']).size().unstack(fill_value=0)
abundance_table = abundance_table.loc[:, pos_niches + neg_niches]
niche_abundance = abundance_table.mean(axis=0)
niche_prevalence = (abundance_table > 0).mean(axis=0)

niche_stats = pd.DataFrame({'Abundance': niche_abundance, 'Prevalence': niche_prevalence})

abund_thresh = niche_stats['Abundance'].median()
prev_thresh = niche_stats['Prevalence'].median()

niche_stats['Abundance_Group'] = np.where(niche_stats['Abundance'] >= abund_thresh, 'HighAbund', 'LowAbund')
niche_stats['Prevalence_Group'] = np.where(niche_stats['Prevalence'] >= prev_thresh, 'HighPrev', 'LowPrev')

group_counts = niche_stats.groupby(['Abundance_Group', 'Prevalence_Group']).size().reset_index(name='Count')

plt.figure(figsize=(5,5), dpi = 600)
sns.set_context("paper", font_scale=1.5)
sns.scatterplot(data=niche_stats, x='Prevalence', y='Abundance', color = 'cornflowerblue')
plt.xlabel('Patient prevalence', fontsize=14)
plt.ylabel('Mean Niche Abundance', fontsize=14)
plt.grid(True, linestyle='--', alpha=0.6)
plt.tight_layout()
plt.savefig(os.path.join(save_directory, 'supplementary_figure19a.pdf'), bbox_inches = 'tight')


In [None]:
std_multiplier = 1.0
min_niches = 5 
pos_niches = list(cov_count_df_pos.quiche_niche.values)
neg_niches = list(cov_count_df_neg.quiche_niche.values)


abundance_table = mdata['quiche'].var.groupby(['Patient_ID', 'quiche_niche']).size().unstack(fill_value=0)
abundance_table = abundance_table.loc[:, pos_niches + neg_niches]
relapse_df = mdata['expression'].obs[['Patient_ID', 'Relapse']].drop_duplicates()
relapse_df['Patient_ID'] = relapse_df['Patient_ID'].astype(str)
relapse_df['Relapse'] = relapse_df['Relapse'].astype(int)
niche_prevalence = (abundance_table > 0).mean(axis=0)
niche_prevalence = niche_prevalence[niche_prevalence >= 0.1]

pos_niches = list(set(pos_niches).intersection(set(niche_prevalence.index)))
neg_niches = list(set(neg_niches).intersection(set(niche_prevalence.index)))

niche_mean = abundance_table.mean(axis=0)
niche_std = abundance_table.std(axis=0)

#iteratively remove most prevalent niche
remaining_niches = list(niche_prevalence.sort_values(ascending = False).index)
remaining_pos_niches = pos_niches
remaining_neg_niches = neg_niches

removed_niches = set()
auc_results = []

selected_niches = remaining_niches 
thresholds = niche_mean[selected_niches] + std_multiplier * niche_std[selected_niches]
above_thresh = (abundance_table[selected_niches] >= thresholds).astype(int)

pos_counts = above_thresh[remaining_pos_niches].mean(axis=1)
neg_counts = above_thresh[remaining_neg_niches].mean(axis=1)

log_ratio = np.log1p(neg_counts + 1e-5) - np.log1p(pos_counts + 1e-5)
ratio_df = pd.DataFrame(log_ratio)
ratio_df = ratio_df.merge(relapse_df, on='Patient_ID')
ratio_df.columns = ['Patient_ID', 'log_ratio', 'Relapse']
ratio_df = ratio_df[ratio_df['log_ratio'] != 0]

if ratio_df['log_ratio'].nunique() > 1:
    fpr, tpr, _ = roc_curve(ratio_df['Relapse'], -ratio_df['log_ratio'])
    auc_val = auc(fpr, tpr)
    auc_results.append({
        'Niches_Removed': 0,
        'Remaining_Niches': len(selected_niches),
        'Prevalence': niche_prevalence[selected_niches[-1]],
        'AUC': auc_val
    })

for i in range(1, len(remaining_niches)):
    niche_to_remove = max(set(remaining_niches) - removed_niches, key=lambda n: niche_prevalence[n])
    removed_niches.add(niche_to_remove)
    remaining_niches.remove(niche_to_remove)
    remaining_pos_niches = [n for n in pos_niches if n in remaining_niches]
    remaining_neg_niches = [n for n in neg_niches if n in remaining_niches]
    selected_niches = remaining_niches
    thresholds = niche_mean[selected_niches] + std_multiplier * niche_std[selected_niches]
    above_thresh = (abundance_table[selected_niches] >= thresholds).astype(int)
    pos_counts = above_thresh[remaining_pos_niches].mean(axis=1)  #positive niche counts
    neg_counts = above_thresh[remaining_neg_niches].mean(axis=1)  #negative niche counts

    log_ratio = np.log1p(neg_counts + 1e-5) - np.log1p(pos_counts + 1e-5)
    ratio_df = pd.DataFrame(log_ratio)
    ratio_df = ratio_df.merge(relapse_df, on='Patient_ID')
    ratio_df.columns = ['Patient_ID', 'log_ratio', 'Relapse']
    ratio_df = ratio_df[ratio_df['log_ratio'] != 0]

    if ratio_df['log_ratio'].nunique() > 1:
        fpr, tpr, _ = roc_curve(ratio_df['Relapse'], -ratio_df['log_ratio'])
        auc_val = auc(fpr, tpr)
        auc_results.append({
            'Niches_Removed': i,
            'Remaining_Niches': len(selected_niches),
            'Prevalence': niche_prevalence[niche_to_remove],
            'AUC': auc_val
        })

auc_df = pd.DataFrame(auc_results)
sns.set_context("paper", font_scale=1.5)
plt.figure(figsize=(5,5), dpi=600)
plt.scatter(auc_df['Prevalence'], auc_df['AUC'], marker='o', linestyle='-', color='cornflowerblue')
plt.xlabel("Patient prevalence", fontsize=14)
plt.ylabel("Representative AUC", fontsize=14)
plt.grid(True, linestyle='--', alpha=0.6)
plt.tight_layout()
plt.savefig(os.path.join(save_directory 'supplementary_figure19b.pdf'))


In [None]:
std_multiplier = 1.0
min_niches = 5 
pos_niches = list(cov_count_df_pos.quiche_niche.values)
neg_niches = list(cov_count_df_neg.quiche_niche.values)

abundance_table = mdata['quiche'].var.groupby(['Patient_ID', 'quiche_niche']).size().unstack(fill_value=0)
abundance_table = abundance_table.loc[:, pos_niches + neg_niches]
relapse_df = mdata['expression'].obs[['Patient_ID', 'Relapse']].drop_duplicates()
relapse_df['Patient_ID'] = relapse_df['Patient_ID'].astype(str)
relapse_df['Relapse'] = relapse_df['Relapse'].astype(int)

niche_prevalence = (abundance_table > 0).mean(axis=0)  # Prevalence: how many patients have each niche

low_prevalence_niches = niche_prevalence[niche_prevalence < 0.1].index
medium_prevalence_niches = niche_prevalence[(niche_prevalence >= 0.1) & (niche_prevalence < 0.2)].index
high_prevalence_niches = niche_prevalence[niche_prevalence >= 0.2].index

pos_niches_low = list(set(pos_niches).intersection(set(low_prevalence_niches)))
neg_niches_low = list(set(neg_niches).intersection(set(low_prevalence_niches)))
pos_niches_medium = list(set(pos_niches).intersection(set(medium_prevalence_niches)))
neg_niches_medium = list(set(neg_niches).intersection(set(medium_prevalence_niches)))
pos_niches_high = list(set(pos_niches).intersection(set(high_prevalence_niches)))
neg_niches_high = list(set(neg_niches).intersection(set(high_prevalence_niches)))
niche_mean = abundance_table.mean(axis=0)
niche_std = abundance_table.std(axis=0)
auc_results_low = []
selected_niches_low = low_prevalence_niches
thresholds_low = niche_mean[selected_niches_low] + std_multiplier * niche_std[selected_niches_low]
above_thresh_low = (abundance_table[selected_niches_low] >= thresholds_low).astype(int)
pos_counts_low = above_thresh_low[pos_niches_low].mean(axis=1)
neg_counts_low = above_thresh_low[neg_niches_low].mean(axis=1)
log_ratio_low = np.log1p(neg_counts_low + 1e-5) - np.log1p(pos_counts_low + 1e-5)
ratio_df_low = pd.DataFrame(log_ratio_low)
ratio_df_low = ratio_df_low.merge(relapse_df, on='Patient_ID')
ratio_df_low.columns = ['Patient_ID', 'log_ratio', 'Relapse']
ratio_df_low = ratio_df_low[ratio_df_low['log_ratio'] != 0]
if ratio_df_low['log_ratio'].nunique() > 1:
    fpr, tpr, _ = roc_curve(ratio_df_low['Relapse'], -ratio_df_low['log_ratio'])
    auc_val_low = auc(fpr, tpr)
    auc_results_low.append({
        'Niches_Removed': 0,
        'Remaining_Niches': len(selected_niches_low),
        'Prevalence': niche_prevalence[selected_niches_low[-1]],
        'AUC': auc_val_low
    })

auc_results_medium = []
selected_niches_medium = medium_prevalence_niches
thresholds_medium = niche_mean[selected_niches_medium] + std_multiplier * niche_std[selected_niches_medium]
above_thresh_medium = (abundance_table[selected_niches_medium] >= thresholds_medium).astype(int)
pos_counts_medium = above_thresh_medium[pos_niches_medium].mean(axis=1)
neg_counts_medium = above_thresh_medium[neg_niches_medium].mean(axis=1)
log_ratio_medium = np.log1p(neg_counts_medium + 1e-5) - np.log1p(pos_counts_medium + 1e-5)
ratio_df_medium = pd.DataFrame(log_ratio_medium)
ratio_df_medium = ratio_df_medium.merge(relapse_df, on='Patient_ID')
ratio_df_medium.columns = ['Patient_ID', 'log_ratio', 'Relapse']
ratio_df_medium = ratio_df_medium[ratio_df_medium['log_ratio'] != 0]
if ratio_df_medium['log_ratio'].nunique() > 1:
    fpr, tpr, _ = roc_curve(ratio_df_medium['Relapse'], -ratio_df_medium['log_ratio'])
    auc_val_medium = auc(fpr, tpr)
    auc_results_medium.append({
        'Niches_Removed': 0,
        'Remaining_Niches': len(selected_niches_medium),
        'Prevalence': niche_prevalence[selected_niches_medium[-1]],
        'AUC': auc_val_medium
    })

auc_results_high = []
selected_niches_high = high_prevalence_niches
thresholds_high = niche_mean[selected_niches_high] + std_multiplier * niche_std[selected_niches_high]
above_thresh_high = (abundance_table[selected_niches_high] >= thresholds_high).astype(int)
pos_counts_high = above_thresh_high[pos_niches_high].mean(axis=1)
neg_counts_high = above_thresh_high[neg_niches_high].mean(axis=1)
log_ratio_high = np.log1p(neg_counts_high + 1e-5) - np.log1p(pos_counts_high + 1e-5)
ratio_df_high = pd.DataFrame(log_ratio_high)
ratio_df_high = ratio_df_high.merge(relapse_df, on='Patient_ID')
ratio_df_high.columns = ['Patient_ID', 'log_ratio', 'Relapse']
ratio_df_high = ratio_df_high[ratio_df_high['log_ratio'] != 0]
if ratio_df_high['log_ratio'].nunique() > 1:
    fpr, tpr, _ = roc_curve(ratio_df_high['Relapse'], -ratio_df_high['log_ratio'])
    auc_val_high = auc(fpr, tpr)
    auc_results_high.append({
        'Niches_Removed': 0,
        'Remaining_Niches': len(selected_niches_high),
        'Prevalence': niche_prevalence[selected_niches_high[-1]],
        'AUC': auc_val_high
    })

auc_results = auc_results_low + auc_results_medium + auc_results_high
auc_df = pd.DataFrame(auc_results)
auc_df['Prevalence_Category'] = pd.Categorical(auc_df['Prevalence'].apply(lambda x: 'Low (< 10%)' if x < 0.10 else ('Medium (10% - 20%)' if x < 0.20 else 'High (>= 20%)')),
                                                categories=['Low (< 10%)', 'Medium (10% - 20%)', 'High (>= 20%)'],
                                                ordered=True)
plt.figure(figsize=(5,5), dpi=600)
for label, group in auc_df.groupby('Prevalence_Category'):
    plt.scatter([label] * len(group), group['AUC'], label=label, marker='x', s=30)

plt.xlabel("Prevalence", fontsize=14)
plt.ylabel("Representative AUC", fontsize=14)
plt.ylim(0.7, 0.85)
plt.tight_layout()
plt.legend(title="Prevalence Category")
plt.savefig(os.path.join(save_directory 'supplementary_figure19c.pdf'))
