In [None]:
import os
import pandas as pd
import numpy as np
import quiche as qu
import matplotlib.pyplot as plt
import seaborn as sns
from supplementary_plot_helpers import *
import mudata
from scipy.stats import pointbiserialr
from statsmodels.stats.multitest import multipletests
from ark.utils.plot_utils import cohort_cluster_plot
import anndata


sns.set_style('ticks')

%reload_ext autoreload
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
save_directory = os.path.join('publications', 'supplementary_figures', 'supplementary_figure20')
qu.pp.make_directory(save_directory)

In [None]:
data_dir = r'/Volumes/Shared/Noah Greenwald/TNBC_Cohorts/SPAIN/image_data/samples'
seg_dir = r'/Volumes/Shared/Noah Greenwald/TNBC_Cohorts/SPAIN/segmentation/samples/deepcell_output'

adata_spain = anndata.read_h5ad(os.path.join('data', 'Zenodo', 'spain_preprocessed.h5ad'))

channel_to_rgb = np.array([
    [0.0, 1.0, 1.0],  # Cyan
    [1.0, 0.0, 1.0],  # Magenta
    [1.0, 1.0, 0.0],  # Yellow
    [1.0, 0.0, 0.0],  # Red
    [0.0, 0.0, 1.0],  # Blue
    [0.0, 1.0, 0.0]   # Green
])

fov_list = ['TMA32_R10C4', 'TMA32_R2C6', 'TMA32_R8C8', 'TMA32_R5C7']

cell_ordering = ['Cancer_1', 'Cancer_2', 'Cancer_3', 'CD4T', 'CD8T', 'Treg', 'T_Other', 'B', 
                 'NK', 'CD68_Mac', 'CD163_Mac', 'Mac_Other', 'Monocyte', 'APC','Mast', 'Neutrophil',
                 'CAF', 'Fibroblast', 'Smooth_Muscle', 'Endothelium']

sc.set_figure_params(dpi = 400, dpi_save = 400, fontsize = 14)

colors_dict_cells = {'APC': '#700548',
 'B': '#005377',
 'CAF': '#f2cc8f',
 'CD4T': '#ebb3a9',
 'CD8T': '#ff5666',
 'CD68_Mac': '#ffa52f',
 'CD163_Mac': '#788AA3',
 'Cancer_1': '#66cdaa',
 'Cancer_2': '#3d405b',
 'Cancer_3': '#b49ab8',
 'Endothelium': '#f78e69',
 'Fibroblast': '#2d9bd5',
 'Immune_Other': '#366962',
 'Mac_Other': '#c7d66d',
 'Mast': '#E36414',
 'Monocyte': '#CC6690',
 'NK': '#9ee2ff',
 'Neutrophil': '#4a7c59',
 'Other': '#FFBF69',
 'Smooth_Muscle': '#f5ebe0',
 'T_Other': '#901C14',
 'Treg': '#9e8576'} 

cell_list = cell_ordering
df_cells = adata_spain[np.isin(adata_spain.obs['fov'], fov_list)].to_df()
df_cells['cell_cluster'] = adata_spain[np.isin(adata_spain.obs['fov'], fov_list)].obs['cell_cluster']
df_cells['label']= adata_spain[np.isin(adata_spain.obs['fov'], fov_list)].obs['label']
df_cells['fov'] = adata_spain[np.isin(adata_spain.obs['fov'], fov_list)].obs['fov']
df_cells = df_cells[np.isin(df_cells.cell_cluster, cell_list)]

colormap = pd.DataFrame({'cell_cluster': list(colors_dict_cells.keys()),
                         'color': list(colors_dict_cells.values())})

save_directory_ = os.path.join('publications', 'figures', 'TLS', 'overlay', 'joint')
qu.pp.make_directory(save_directory_)

cohort_cluster_plot(
    fovs=fov_list,
    seg_dir=seg_dir,
    save_dir=save_directory_,
    cell_data=df_cells,
    erode=True,
    fov_col='fov',
    label_col='label',
    cluster_col='cell_cluster',
    seg_suffix="_whole_cell.tiff",
    cmap=colormap,
    fig_file_type = 'pdf',
    display_fig=False,
)

compartment_colormap = pd.DataFrame({'compartment_tls_tagg': ['cancer_core', 'cancer_border', 'stroma_core', 'stroma_border', 'tls', 'tagg'], 'color': ['blue', 'deepskyblue','#8E6E96', '#8E6E96', 'orange', 'crimson']})            

fov_list = ['TMA32_R10C4', 'TMA32_R2C6', 'TMA32_R8C8', 'TMA32_R5C7']
save_directory_ = os.path.join('publications', 'figures', 'TLS', 'overlay', 'compartment')
qu.pp.make_directory(save_directory_)
qu.pl.cohort_cluster_plot(fovs=fov_list,
                        save_dir = os.path.join(save_directory_),
                        cell_data=adata_spain.obs.loc[:, ['fov', 'compartment_tls_tagg', 'label']],
                        erode=True,
                        seg_dir = seg_dir,
                        fov_col= 'fov',
                        label_col='label',
                        cluster_col='compartment_tls_tagg',
                        seg_suffix="_whole_cell.tiff",
                        unassigned_color=np.array([0, 0, 0, 1]),
                        cmap=compartment_colormap,
                        display_fig=False)

In [None]:
directory = '/Volumes/Shared/Noah Greenwald/TNBC_Cohorts/SPAIN/intermediate_files/mask_dir'

cell_table_clusters = pd.read_csv(os.path.join(directory, 'individual_masks-no_tagg_tls', 'cell_annotation_mask.csv'))
cell_table_clusters_tls = pd.read_csv(os.path.join(directory, 'individual_masks', 'cell_annotation_mask.csv'))

merged_df = pd.merge(adata_spain.obs[['fov', 'label']], cell_table_clusters_tls, on =['fov', 'label'])
adata_spain.obs['compartment_tls_tagg'] = merged_df['mask_name'].values

merged_df = pd.merge(adata_spain.obs[['fov', 'label']], cell_table_clusters, on =['fov', 'label'])
adata_spain.obs['compartment'] = merged_df['mask_name'].values
adata_spain.obs['Patientcompartment'] = adata_spain.obs['Patient_ID'].astype('str') + adata_spain.obs['compartment'].astype('str')

adata_spain.obs['TLS'] = '0'
adata_spain.obs['TLS'][adata_spain.obs['compartment_tls_tagg'] == 'tls'] = '1'

adata_spain.obs['tagg'] = '0'
adata_spain.obs['tagg'][adata_spain.obs['compartment_tls_tagg'] == 'tagg'] = '1'

In [None]:
patient_tls = adata_spain.obs.loc[:, ['Patient_ID', 'TLS']].groupby('Patient_ID')['TLS'].apply(lambda x: (x.astype('int') > 0).any())
patient_tls = pd.DataFrame(patient_tls).astype('int')

total_samples = len(patient_tls)
#number of patients with TLS
patients_with_tls = (patient_tls > 0).sum()
patient_tls_pos = patient_tls[patient_tls['TLS'] == 1]
#proportion of patients with TLS
proportion_with_tls = patients_with_tls / total_samples
print(patients_with_tls, proportion_with_tls)

patient_tagg = adata_spain.obs.loc[:, ['Patient_ID', 'tagg']].groupby('Patient_ID')['tagg'].apply(lambda x: (x.astype('int') > 0).any())
patient_tagg = pd.DataFrame(patient_tagg).astype('int')

total_samples = len(patient_tagg)
#number of patients with TLS
patients_with_tagg = (patient_tagg > 0).sum()
#proportion of patients with TLS
proportion_with_tagg = patients_with_tagg / total_samples
patient_tagg_pos = patient_tagg[patient_tagg['tagg'] == 1]
print(patients_with_tagg, proportion_with_tagg)

In [None]:
patient_tls_dict = dict(zip(patient_tls.index, patient_tls['TLS']))
patient_tagg_dict = dict(zip(patient_tagg.index, patient_tagg['tagg']))
adata_spain.obs['TLS_status'] = adata_spain.obs['Patient_ID'].map(patient_tls_dict).astype('str')
adata_spain.obs['Tagg_status'] = adata_spain.obs['Patient_ID'].map(patient_tagg_dict).astype('str')

In [None]:
adata_spain.obs['Relapse'] = adata_spain.obs['Relapse'].astype('int').astype('str')
adata_spain.obs['Study'] = adata_spain.obs['Study'].map(dict(zip(adata_spain.obs['Study'].cat.categories,['A', 'B', 'C', 'D', 'E'])))

sketch_size = 1000
plt.figure(figsize = (4,4))
adata_spain.obs.groupby('Patient_ID').size().hist(bins = 50)
plt.axvline(sketch_size, color = 'k', ls = '--', lw = 1)
adata_spain  = qu.pp.filter_fovs(adata_spain, 'Patient_ID', sketch_size)

In [None]:
design = '~Study+TLS_status+Tagg_status+Relapse'
model_contrasts = 'Relapse1'
mdata, sig_niches = qu.tl.run_quiche(adata_spain, radius = 200, labels_key = 'cell_cluster', spatial_key = 'spatial',
                                    fov_key = 'fov', patient_key = 'Patient_ID', n_neighbors = 30, merge = False, test_key='Patient_ID', sketch_key='Patient_ID',
                                    delaunay = False, min_cells = 3, k_sim = 100, design = design, khop = None, label_scheme='normal',
                                    model_contrasts = model_contrasts, sketch_size = sketch_size, nlargest = 3, annotation_key = 'quiche_niche', n_jobs = 8)
mdata['quiche'].var = mdata['quiche'].var.astype('str')
# mdata.write_h5mu(os.path.join('data', 'mdata_spain_study_TLS_corrected.h5ad'))

In [None]:
mdata = mudata.read_h5mu(os.path.join('data','mdata_spain_study_TLS_corrected.h5ad'))
mdata['quiche'].var[['logFC', 'SpatialFDR']] = mdata['quiche'].var[['logFC', 'SpatialFDR']].astype('float')

scores_df_spain = pd.DataFrame(mdata['quiche'].var.groupby('quiche_niche')['SpatialFDR'].median())
scores_df_spain.columns = ['pval']
scores_df_spain['logFC'] = mdata['quiche'].var.groupby('quiche_niche')['logFC'].mean()
scores_df_spain = scores_df_spain[scores_df_spain['pval'] < 0.05]
ids = list(set(scores_df_spain.index).intersection(set(list(mdata['quiche'].var['quiche_niche'].value_counts()[mdata['quiche'].var['quiche_niche'].value_counts() >= 5].index))))
scores_df_spain = scores_df_spain.loc[ids]
scores_df_spain = scores_df_spain[(scores_df_spain.logFC > 0.5) | (scores_df_spain.logFC < -0.5)]
niches_spain = list(scores_df_spain.index)

cov_count_df = qu.tl.compute_patient_proportion(mdata,
                                niches = niches_spain,
                                feature_key = 'quiche',
                                annot_key = 'quiche_niche',
                                patient_key = 'Patient_ID',
                                design_key = 'Relapse',
                                patient_niche_threshold = 5)

cov_count_df_frequent = cov_count_df[cov_count_df['patient_count'] >= 3]

In [None]:
niche_counts = mdata['quiche'].var[['Patient_ID', 'quiche_niche']].groupby(['Patient_ID', 'quiche_niche']).size().reset_index(name='count')
niche_counts = niche_counts[np.isin(niche_counts['quiche_niche'], list(cov_count_df_frequent.quiche_niche.values))]
niche_counts['TLS_status'] = niche_counts['Patient_ID'].astype('float').map(patient_tls_dict)
list_niches = list(niche_counts['quiche_niche'].unique())

In [None]:
sns.set_style('ticks')
qu.pl.beeswarm_prev(mdata,
    feature_key="quiche",
    alpha = 0.05,
    xlim_prev=[-0.3, 0.3],
    niches=cov_count_df_frequent.quiche_niche,
    figsize=(6, 12),
    annot_key='quiche_niche',
    design_key='Relapse',
    patient_key='Patient_ID',
    xlim=[-3,3],
    fontsize=10,
    colors_dict={'0': '#377eb8', '1': '#e41a1c'},
    save_directory=save_directory,
    filename_save=f'supplementary_figure20c.pdf')

In [None]:
correlations = []
pvals = []
for niche in list_niches:
    df_subset = niche_counts[niche_counts['quiche_niche'] == niche].copy()
    if len(df_subset) < 2:
        correlations.append(None)
        pvals.append(1.0)
        continue
    corr, pval = pointbiserialr(df_subset['TLS_status'].values, df_subset['count'].values)
    correlations.append(corr)
    pvals.append(pval)
corrected_pvals = multipletests(pvals, method='fdr_bh')[1]

niches_to_plot = len(list_niches)
rows = 8
cols = (niches_to_plot + rows - 1) // rows
fig, axes = plt.subplots(rows, cols, figsize=(cols * 3, rows * 3), sharey=False)
axes = axes.flatten()
for i, niche in enumerate(list_niches):
    ax = axes[i]
    df_subset = niche_counts[niche_counts['quiche_niche'] == niche].copy()
    
    if df_subset.empty:
        ax.axis('off')
        continue
    
    sns.stripplot(data=df_subset, x='TLS_status', y='count', ax=ax, jitter=0.2, size=4, palette=['#4c72b0', '#55a868'], dodge=False)
    
    ax.text(0.05, 0.95,
            f"r = {correlations[i]:.2f}\nFDR pval = {corrected_pvals[i]:.2e}",
            transform=ax.transAxes,
            verticalalignment='top',
            fontsize=10,
            bbox=dict(boxstyle="round,pad=0.3", facecolor="white", edgecolor="gray", alpha=0.8))

    ax.set_title(niche, fontsize=13)
    ax.set_xlabel("TLS Status")
    ax.set_ylabel("Niche count")
    ax.tick_params(size=10)

for j in range(i + 1, len(axes)):
    axes[j].axis('off')

plt.tight_layout()
plt.savefig(os.path.join(save_directory, 'supplementary_figure20b.pdf'), bbox_inches = 'tight')
