In [None]:
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import seaborn as sns
import anndata 
import os
from scipy.cluster.hierarchy import dendrogram
import matplotlib
import mudata
matplotlib.rcParams['lines.linewidth'] = 0.5
from scipy.cluster.hierarchy import linkage, leaves_list
import quiche as qu
import matplotlib.cm as cm
from supplementary_plot_helpers import *
%reload_ext autoreload
%load_ext autoreload
%autoreload 2
%matplotlib inline

## Stanford cohort overview

In [None]:
phenotypic_markers = ['ECAD', 'CK17', 'CD45', 'CD3', 'CD4', 'CD8', 'FOXP3', 'CD20', 'CD56', 'CD14', 'CD68',
                    'CD163', 'CD11c', 'HLADR', 'ChyTr', 'Calprotectin', 'FAP', 'SMA', 'Vim', 'Fibronectin',
                    'Collagen1', 'CD31']

functional_markers = ['PDL1','Ki67','GLUT1','CD45RO','CD69', 'PD1','CD57','TBET', 'TCF1',
                        'CD45RB', 'TIM3','IDO', 'LAG3', 'CD38']

var_names = phenotypic_markers+functional_markers

cell_ordering = ['Cancer_1', 'Cancer_2', 'Cancer_3', 'CD4T', 'CD8T', 'Treg', 'T_Other', 'B', 
                 'NK', 'CD68_Mac', 'CD163_Mac', 'Mac_Other', 'Monocyte', 'APC','Mast', 'Neutrophil',
                 'CAF', 'Fibroblast', 'Smooth_Muscle', 'Endothelium']

sc.set_figure_params(dpi = 400, dpi_save = 400, fontsize = 14)

colors_dict_cells = {'APC': '#700548',
 'B': '#005377',
 'CAF': '#f2cc8f',
 'CD4T': '#ebb3a9',
 'CD8T': '#ff5666',
 'CD68_Mac': '#ffa52f',
 'CD163_Mac': '#788AA3',
 'Cancer_1': '#66cdaa',
 'Cancer_2': '#3d405b',
 'Cancer_3': '#b49ab8',
 'Endothelium': '#f78e69',
 'Fibroblast': '#2d9bd5',
 'Immune_Other': '#366962',
 'Mac_Other': '#c7d66d',
 'Mast': '#E36414',
 'Monocyte': '#CC6690',
 'NK': '#9ee2ff',
 'Neutrophil': '#4a7c59',
 'Other': '#FFBF69',
 'Smooth_Muscle': '#f5ebe0',
 'T_Other': '#901C14',
 'Treg': '#9e8576'}

colors_dict = {'myeloid':'#4DCCBD',
               'lymphoid':'#279AF1',
               'tumor':'#FF8484',
               'structural':'#F9DC5C'}

lineage_dict = {'APC':'myeloid',
 'B':'lymphoid',
 'CAF': 'structural',
 'CD4T': 'lymphoid',
 'CD8T': 'lymphoid',
 'CD68_Mac': 'myeloid',
 'CD163_Mac': 'myeloid',
 'Cancer_1': 'tumor',
 'Cancer_2': 'tumor',
 'Cancer_3': 'tumor',
 'Endothelium':'structural',
 'Fibroblast': 'structural',
 'Mac_Other': 'myeloid',
 'Mast':'myeloid',
 'Monocyte':'myeloid',
 'NK':'lymphoid',
 'Neutrophil':'myeloid',
 'Smooth_Muscle':'structural',
 'T_Other':'lymphoid',
 'Treg':'lymphoid'}

In [None]:
save_directory = os.path.join('publications', 'supplementary_figures', 'supplementary_figure12')
qu.pp.make_directory(save_directory)
adata = anndata.read_h5ad(os.path.join('data', 'Zenodo', 'stanford_preprocessed.h5ad'))

adata.X = qu.pp.standardize(adata.X)

### Supplementary Figure 12a

In [None]:
count_df = adata.obs[['Patient_ID', 'RECURRENCE_LABEL']].drop_duplicates()
label_counts = count_df['RECURRENCE_LABEL'].value_counts()

# Colors for the pie chart
colors = ['#66b3ff', '#ff9999']  # Custom color palette


# Define a function to show both count and percentage
def autopct_format(pct, all_vals):
    total = sum(all_vals)
    val = int(round(pct * total / 100.0))  # Convert percentage to actual value
    return f'N = {val}\n({pct:.1f}%)'

# Create a pie chart
plt.figure(figsize=(8, 6))
plt.pie(
    label_counts,
    labels=label_counts.index,
    autopct=lambda pct: autopct_format(pct, label_counts),
    startangle=90,
    colors=colors,
    wedgeprops={'edgecolor': 'black', 'linewidth': 1.5},
    textprops={'fontsize': 20}
)
plt.tight_layout()  # Ensure everything fits without clipping
plt.savefig(os.path.join(save_directory, 'supplementary_figure12a.pdf'))

### Supplementary Figure 12b

In [None]:
cluster_counts = np.unique(adata.obs.fov, return_counts=True)[1]
plt.figure(figsize=(6, 4))
sns.set_style('ticks')
g = sns.histplot(data=cluster_counts, kde=True, color = '#A5A9B6')
g.tick_params(labelsize=16)
sns.despine()
plt.xlabel("Cells Per Image", fontsize = 16)
plt.ylabel("Count", fontsize = 16)
plt.tight_layout()
plt.savefig(os.path.join(save_directory, "supplementary_figure12b.pdf"), dpi=300)

### Supplementary Figure 12c

In [None]:
fig = plt.figure(figsize=(10, 5.5), dpi = 400)
gs = GridSpec(1, 2, width_ratios=[1, 1], wspace=0, hspace=0.45, bottom=0.15)  #adjust this depending on how many phenotypic/functional markers you ahve
ax1 = fig.add_subplot(gs[0])
ax2 = fig.add_subplot(gs[1])
#phenotypic markers
mp1 = sc.pl.matrixplot(adata[np.isin(adata.obs['cell_cluster'], cell_ordering)], 
                      var_names=phenotypic_markers, 
                      groupby='cell_cluster',
                      vmin=-1.5, vmax=2, cmap='vlag', 
                      categories_order=cell_ordering, 
                      ax=ax1, 
                      colorbar_title='avg. expression \n (z-score)', 
                      return_fig=True)

mp1.legend(show = False)
mp1.add_totals(size = 2, color = 'lightgrey', show = False).style(edge_color='black', cmap = 'vlag') #only have this here for edge colors, show = False so we don't see the totals on the first
ax1 = mp1.get_axes()['mainplot_ax']
y_ticks = ax1.get_yticks()

ax1.set_title('Phenotypic', fontsize = 14)
y_tick_labels = [tick.get_text() for tick in ax1.get_yticklabels()]
ax1.set_yticklabels(y_tick_labels, fontsize=12, ha='right', position=(-0.02, 0))

#cell type circles 
for y, cell_type in zip(y_ticks, y_tick_labels):
    if cell_type in colors_dict_cells:
        color = colors_dict_cells[cell_type]
        ax1.add_patch(plt.Circle((-0.5, y), 0.3, color=color, transform=ax1.transData, clip_on=False, zorder = 10))

#horizontal lines
line_positions = [3, 9, 16]  # Adjust according to your data
for pos in line_positions:
    ax1.hlines(y=pos, xmin=-0.5, xmax=len(phenotypic_markers), color='black', linewidth=1)

#functional markers
mp2 = sc.pl.matrixplot(adata[np.isin(adata.obs['cell_cluster'], cell_ordering)], 
                      var_names=functional_markers, 
                      groupby='cell_cluster',
                      vmin=-1.5, vmax=2, cmap='vlag', 
                      categories_order=cell_ordering, 
                      ax=ax2, 
                      colorbar_title='avg. expression \n (z-score)', 
                      return_fig=True)

mp2.add_totals(size = 2, color = 'lightgrey').style(edge_color='black', cmap = 'vlag')
ax2 = mp2.get_axes()['mainplot_ax']
ax2.set_title('Functional', fontsize = 14)
ax2.set_yticklabels(y_tick_labels, fontsize=12, ha='right', position=(-0.05, 0))

ax2.set_yticklabels([])
ax2.set_yticks([])

#horizontal lines
for pos in line_positions:
    ax2.hlines(y=pos, xmin=-0.5, xmax=len(functional_markers), color='black', linewidth=1)

#ensure texts on barplots are visible and centered
ax3 = mp2.get_axes()['group_extra_ax']
for text in ax3.texts:
    text.set_horizontalalignment('left')
    text.set_x(text.get_position()[0] - 8) #xpos
    text.set_y(text.get_position()[1] - 1.5) #ypos

#lower the legend 
cbar_ax = mp2.get_axes()['color_legend_ax']
cbar_ax.set_position([cbar_ax.get_position().x0 - 0.06, cbar_ax.get_position().y0 - 0.12, 
                      cbar_ax.get_position().width, cbar_ax.get_position().height])

plt.savefig(os.path.join(save_directory, 'supplementary_figure12c.pdf'), bbox_inches = 'tight')

### Supplementary Figure 12d

In [None]:
structural_cell_types = ['CAF', 'Fibroblast', 'Smooth_Muscle', 'Endothelium']
tumor_cell_types = ['Cancer_1', 'Cancer_2', 'Cancer_3']
immune_cell_types = ['CD4T', 'CD8T', 'Treg', 'T_Other', 'B', 'NK', 'CD68_Mac', 'CD163_Mac', 'Mac_Other', 'Monocyte', 'APC', 'Mast', 'Neutrophil']

cell_labels = ['Tumor', 'Structural', 'Immune']

color_relapse_0 = '#377eb8'  # Color for boxes before the dotted line
color_relapse_1 = '#e41a1c'  # Color for boxes after the dotted line
color_relapse_unknown = 'grey'
color_treatment_adj = '#FB9A23'#'#FFB563'
color_treatment_neoadj = '#90BE6D'
color_treatment_none = '#483C46'

df = adata.obs[['Patient_ID', 'cell_cluster', 'RECURRENCE_LABEL']].copy()

fig, axes = plt.subplots(nrows=3, ncols=1, figsize=(10, 8), sharex=True, gridspec_kw={'hspace': 0.1, 'wspace': 0.4, 'bottom':0.15}, dpi = 400)

cell_type_list = [tumor_cell_types, structural_cell_types, immune_cell_types]
sns.set_style('ticks')

all_cell_types = immune_cell_types + structural_cell_types + tumor_cell_types

# Compute proportions for all cell types
count_df = df[df['cell_cluster'].isin(all_cell_types)].groupby(['Patient_ID', 'cell_cluster']).size().unstack().loc[:, all_cell_types]
total_cell_count = df[df['cell_cluster'].isin(all_cell_types)].groupby(['Patient_ID'])['cell_cluster'].size()
prop_df = count_df.div(total_cell_count, axis=0).fillna(0)

# Perform hierarchical clustering on all cell types
Z = linkage(prop_df, method='average', metric='euclidean')
cluster_order = leaves_list(Z)
ordered_patients = [prop_df.index[i] for i in cluster_order]

# Plot dendrogram on top of the immune plot
ax_dendro = fig.add_axes([0.1, 0.8, 0.8, 0.04])
dendro = dendrogram(Z, labels=prop_df.index, ax=ax_dendro, orientation='top', no_labels=True, color_threshold = 0, above_threshold_color='k')
ax_dendro.set_xticks([])
ax_dendro.set_yticks([])
ax_dendro.spines['top'].set_visible(False)
ax_dendro.spines['right'].set_visible(False)
ax_dendro.spines['left'].set_visible(False)
ax_dendro.spines['bottom'].set_visible(False)
# Adjust the main plots to accommodate the dendrogram
axes[0].set_position([0.1, 0.58, 0.8, 0.22])  # Immune plot position
axes[1].set_position([0.1, 0.34, 0.8, 0.22])   # Structural plot position
axes[2].set_position([0.1, 0.1, 0.8, 0.22])   # Tumor plot position

sorted_df_total = []
df_pivot_total = []
for i, (category, label) in enumerate(zip(cell_type_list, cell_labels)):
    ax = axes[i]
    ax.grid(False)

    #compute proportion
    count_df = df[np.isin(df['cell_cluster'], category)].groupby(['Patient_ID', 'cell_cluster']).size().unstack().loc[:, category]
    total_cell_count = df[np.isin(df['cell_cluster'], category)].groupby(['Patient_ID'])['cell_cluster'].size()
    prop_df = count_df.div(total_cell_count, axis=0)
    prop_df = pd.merge(prop_df.reset_index(), df[['Patient_ID', 'RECURRENCE_LABEL']].drop_duplicates(), on=['Patient_ID'])

    #sort by relapse status
    sorted_df = prop_df.melt(id_vars=['Patient_ID', 'RECURRENCE_LABEL']).sort_values(by='RECURRENCE_LABEL')
    sorted_df.set_index('Patient_ID', inplace = True)
    sorted_df = sorted_df.loc[ordered_patients]
    sorted_df.reset_index(inplace= True)
    #sorted_df = sorted_df[np.isin(sorted_df['Relapse'], [0, 1])] ##keep if you only want patients with relapse label
    df_pivot = sorted_df.pivot(index='Patient_ID', columns='variable', values='value').fillna(0)
    df_pivot = df_pivot.loc[ordered_patients]
    sorted_index = sorted_df[['Patient_ID', 'RECURRENCE_LABEL']].drop_duplicates().set_index('Patient_ID')

    color_list = [colors_dict_cells.get(col, '#333333') for col in df_pivot.columns]
    sorted_df_total.append(sorted_df)
    df_pivot_total.append(df_pivot)

    #stacked bar plot
    g = df_pivot.plot(kind='bar', stacked=True, color=color_list, ax=ax, width=1, edgecolor='white', linewidth=0.2, legend=True)
    g.tick_params(labelsize=8)
    g.legend_.remove()  # Remove the legend to place it later
    # ax.axvline(relapse_index - 0.5, color='k', lw=1, ls='--')  # Vertical dotted line
    ax.set_ylim(0, 1)
    ax.set_yticks([0, 0.5, 1])  # Set y-axis ticks
    ax.set_xticks([])
    ax.set_ylabel(f'{label}', fontsize=10)

    #add metadata
    n_boxes = df_pivot.shape[0]
    box_height = 0.08

    if i == 0:
        for j in range(n_boxes):
            box_y_position = 1.2
            relapse_status = sorted_index.iloc[j]['RECURRENCE_LABEL']
            if relapse_status == 'NEGATIVE':
                color = color_relapse_0
            elif relapse_status == 'POSITIVE':
                color = color_relapse_1
            else:
                color = color_relapse_unknown
            ax.add_patch(plt.Rectangle((j - 0.5, box_y_position), 1, box_height, facecolor=color, 
                                    transform=ax.get_xaxis_transform(), clip_on=False, edgecolor='white', linewidth=0.2, alpha = 0.7))
        ax.text(-1, box_y_position + (box_height / 2), 'Recurrence', fontsize=8, verticalalignment='center', horizontalalignment='right', transform=ax.get_xaxis_transform())


    handles, labels = ax.get_legend_handles_labels()
    ax.legend(handles, labels, loc='upper right', prop={'size': 6}, bbox_to_anchor=(1.08, 1.05))

axes[-1].set_xlabel(f'Patients ($N = {len(sorted_index)}$)', fontsize=10)
axes[-1].set_xticks([]) 
plt.savefig(os.path.join(save_directory, 'supplementary_figure12d.pdf'), bbox_inches = 'tight')

## niche detection 

In [None]:
## niche detection for Stanford 
phenotypic_markers = ['ECAD', 'CK17', 'CD45', 'CD3', 'CD4', 'CD8', 'FOXP3', 'CD20', 'CD56', 'CD14', 'CD68',
                    'CD163', 'CD11c', 'HLADR', 'ChyTr', 'Calprotectin', 'FAP', 'SMA', 'Vim', 'Fibronectin',
                    'Collagen1', 'CD31']


functional_markers = ['PDL1','Ki67','GLUT1','CD45RO','CD69', 'PD1','CD57','TBET', 'TCF1',
                        'CD45RB', 'TIM3','IDO', 'LAG3', 'CD38', 'HLADR']

var_names = phenotypic_markers+functional_markers

cell_ordering = ['Cancer_1', 'Cancer_2', 'Cancer_3', 'CD4T', 'CD8T', 'Treg', 'T_Other', 'B', 
                 'NK', 'CD68_Mac', 'CD163_Mac', 'Mac_Other', 'Monocyte', 'APC','Mast', 'Neutrophil',
                 'CAF', 'Fibroblast', 'Smooth_Muscle', 'Endothelium']
sc.set_figure_params(dpi = 400, dpi_save = 400, fontsize = 14)

colors_dict_cells = {'APC': '#700548',
 'B': '#005377',
 'CAF': '#f2cc8f',
 'CD4T': '#ebb3a9',
 'CD8T': '#ff5666',
 'CD68_Mac': '#ffa52f',
 'CD163_Mac': '#788AA3',
 'Cancer_1': '#66cdaa',
 'Cancer_2': '#3d405b',
 'Cancer_3': '#b49ab8',
 'Endothelium': '#f78e69',
 'Fibroblast': '#2d9bd5',
 'Immune_Other': '#366962',
 'Mac_Other': '#c7d66d',
 'Mast': '#E36414',
 'Monocyte': '#CC6690',
 'NK': '#9ee2ff',
 'Neutrophil': '#4a7c59',
 'Other': '#FFBF69',
 'Smooth_Muscle': '#f5ebe0',
 'T_Other': '#901C14',
 'Treg': '#9e8576'}

colors_dict = {'myeloid':'#4DCCBD',
               'lymphoid':'#279AF1',
               'tumor':'#FF8484',
               'structural':'#F9DC5C'}

lineage_dict = {'APC':'myeloid',
 'B':'lymphoid',
 'CAF': 'structural',
 'CD4T': 'lymphoid',
 'CD8T': 'lymphoid',
 'CD68_Mac': 'myeloid',
 'CD163_Mac': 'myeloid',
 'Cancer_1': 'tumor',
 'Cancer_2': 'tumor',
 'Cancer_3': 'tumor',
 'Endothelium':'structural',
 'Fibroblast': 'structural',
 'Mac_Other': 'myeloid',
 'Mast':'myeloid',
 'Monocyte':'myeloid',
 'NK':'lymphoid',
 'Neutrophil':'myeloid',
 'Smooth_Muscle':'structural',
 'T_Other':'lymphoid',
 'Treg':'lymphoid'}

sc.set_figure_params(dpi = 400, dpi_save = 400, fontsize = 14)
adata = anndata.read_h5ad(os.path.join('data', 'Zenodo', 'Stanford_preprocessed.h5ad'))

sketch_size = 1000
adata  = qu.pp.filter_fovs(adata, 'Patient_ID', sketch_size)

In [None]:
# design = '~RECURRENCE_LABEL'
# model_contrasts = 'RECURRENCE_LABELPOSITIVE-RECURRENCE_LABELNEGATIVE'
# mdata, sig_niches = qu.tl.run_quiche(adata, radius = 200, labels_key = 'cell_cluster', spatial_key = 'spatial',
#                                     fov_key = 'fov', patient_key = 'Patient_ID', n_neighbors = 30, merge = False, test_key='Patient_ID', sketch_key='Patient_ID',
#                                     delaunay = False, min_cells = 3, k_sim = 100, design = design, khop = None, label_scheme='normal',
#                                     model_contrasts = model_contrasts, sketch_size = sketch_size, nlargest = 3, annotation_key = 'quiche_niche', n_jobs = 8)
# mdata['quiche'].var = mdata['quiche'].var.astype('str')
# mdata.write_h5mu(os.path.join('data', 'tnbc_stanford', 'mdata', 'mdata_stanford.h5ad'))

In [None]:
## load to save on runtime
mdata = mudata.read_h5mu(os.path.join('data', 'tnbc_stanford', 'mdata', 'mdata_stanford.h5ad'))
mdata['quiche'].var['SpatialFDR'] = mdata['quiche'].var['SpatialFDR'].astype('float')
mdata['quiche'].var['logFC'] = mdata['quiche'].var['logFC'].astype('float')
scores_df_stanford = pd.DataFrame(mdata['quiche'].var.groupby('quiche_niche')['SpatialFDR'].median())
scores_df_stanford.columns = ['pval']
scores_df_stanford['logFC'] = mdata['quiche'].var.groupby('quiche_niche')['logFC'].mean()
scores_df_stanford = scores_df_stanford[scores_df_stanford['pval'] < 0.05]
ids = list(set(scores_df_stanford.index).intersection(set(list(mdata['quiche'].var['quiche_niche'].value_counts()[mdata['quiche'].var['quiche_niche'].value_counts() >= 5].index))))
scores_df_stanford = scores_df_stanford.loc[ids]
scores_df_stanford = scores_df_stanford[(scores_df_stanford.logFC > 0.5) | (scores_df_stanford.logFC < -0.5)]
niches_stanford = list(scores_df_stanford.index)

cov_count_df = qu.tl.compute_patient_proportion(mdata,
                                niches = niches_stanford,
                                feature_key = 'quiche',
                                annot_key = 'quiche_niche',
                                patient_key = 'Patient_ID',
                                design_key = 'RECURRENCE_LABEL',
                                patient_niche_threshold = 5)

cov_count_df_frequent = cov_count_df[cov_count_df['patient_count'] >= 3]

### Supplementary Figure 12e

In [None]:
qu.pl.beeswarm_prev(mdata,
    feature_key="quiche",
    alpha = 0.05,
    xlim_prev=[-0.3, 0.3],
    niches=cov_count_df_frequent.quiche_niche,
    figsize=(6, 6),
    annot_key='quiche_niche',
    design_key='RECURRENCE_LABEL',
    patient_key='Patient_ID',
    xlim=[-5.5,5.5],
    fontsize=10,
    colors_dict={'NEGATIVE': '#377eb8', 'POSITIVE': '#e41a1c'},
    save_directory=save_directory,
    filename_save=f'supplementary_figure12e')

### Supplementary Figure 12f-g

In [None]:
cov_count_df_neg = cov_count_df[cov_count_df['mean_logFC'] < 0]
cov_count_df_neg = cov_count_df_neg[cov_count_df_neg['patient_count'] >= 1]
cov_count_df_neg = cov_count_df_neg[cov_count_df_neg['RECURRENCE_LABEL'] == 'NEGATIVE']

cov_count_df_pos = cov_count_df[cov_count_df['mean_logFC'] > 0]
cov_count_df_pos = cov_count_df_pos[cov_count_df_pos['patient_count'] >= 1]
cov_count_df_pos = cov_count_df_pos[cov_count_df_pos['RECURRENCE_LABEL'] == 'POSITIVE']

G1 = qu.tl.compute_niche_network(cov_count_df = cov_count_df_neg, colors_dict = colors_dict, lineage_dict=lineage_dict, annot_key = 'quiche_niche') 

qu.pl.plot_niche_network_donut(G=G1, figsize=(5.5, 5.5), node_order=cell_ordering, buffer=1.5, weightscale = 0.3, edge_color='#1D265E',
                         centrality_measure = 'eigenvector',colors_dict=colors_dict, curvature=0.2,save_directory=save_directory, filename_save=f'supplementary_figure12f',
                         min_node_size = 20, max_node_size = 850, lineage_dict=lineage_dict, donut_radius_inner = 1.15, donut_radius_outer = 1.25,
                         vmin = 0, vmax = 10,edge_cmap = cm.bone_r, edge_label = 'Patients')

G2 = qu.tl.compute_niche_network(cov_count_df = cov_count_df_pos, colors_dict = colors_dict, lineage_dict=lineage_dict, annot_key = 'quiche_niche') 

qu.pl.plot_niche_network_donut(G=G2, figsize=(5.5, 5.5), node_order=cell_ordering, buffer=1.5, weightscale = 0.3, edge_color='#1D265E',
                         centrality_measure = 'eigenvector',colors_dict=colors_dict, curvature=0.2, font_size=10,save_directory=save_directory, filename_save=f'supplementary_figure12g',
                         min_node_size = 20, max_node_size = 850, lineage_dict=lineage_dict, donut_radius_inner = 1.15, donut_radius_outer = 1.25,
                         vmin = 0, vmax = 10,edge_cmap = cm.bone_r, edge_label = 'Patients')