In [None]:
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import seaborn as sns
import anndata 
import os
from ark.utils.plot_utils import cohort_cluster_plot
from scipy.cluster.hierarchy import dendrogram
import matplotlib
import quiche as qu
import shutil
from supplementary_plot_helpers import *

%reload_ext autoreload
%load_ext autoreload
%autoreload 2
%matplotlib inline

## Spain TNBC overview

In [None]:
phenotypic_markers = ['ECAD', 'CK17', 'CD45', 'CD3', 'CD4', 'CD8', 'FOXP3', 'CD20', 'CD56', 'CD14', 'CD68',
                    'CD163', 'CD11c', 'HLADR', 'ChyTr', 'Calprotectin', 'FAP', 'SMA', 'Vim', 'Fibronectin',
                    'Collagen1', 'CD31']


functional_markers = ['PDL1','Ki67','GLUT1','CD45RO','CD69', 'PD1','CD57','TBET', 'TCF1',
                        'CD45RB', 'TIM3','IDO', 'LAG3', 'CD38']

var_names = phenotypic_markers+functional_markers

cell_ordering = ['Cancer_1', 'Cancer_2', 'Cancer_3', 'CD4T', 'CD8T', 'Treg', 'T_Other', 'B', 
                 'NK', 'CD68_Mac', 'CD163_Mac', 'Mac_Other', 'Monocyte', 'APC','Mast', 'Neutrophil',
                 'CAF', 'Fibroblast', 'Smooth_Muscle', 'Endothelium']

sc.set_figure_params(dpi = 400, dpi_save = 400, fontsize = 14)

colors_dict_cells = {'APC': '#700548',
 'B': '#005377',
 'CAF': '#f2cc8f',
 'CD4T': '#ebb3a9',
 'CD8T': '#ff5666',
 'CD68_Mac': '#ffa52f',
 'CD163_Mac': '#788AA3',
 'Cancer_1': '#66cdaa',
 'Cancer_2': '#3d405b',
 'Cancer_3': '#b49ab8',
 'Endothelium': '#f78e69',
 'Fibroblast': '#2d9bd5',
 'Immune_Other': '#366962',
 'Mac_Other': '#c7d66d',
 'Mast': '#E36414',
 'Monocyte': '#CC6690',
 'NK': '#9ee2ff',
 'Neutrophil': '#4a7c59',
 'Other': '#FFBF69',
 'Smooth_Muscle': '#f5ebe0',
 'T_Other': '#901C14',
 'Treg': '#9e8576'}

In [None]:
directory = os.path.join('data', 'tnbc_spain', 'adata')
save_directory = os.path.join('publications', 'figures', 'figure3')
qu.pp.make_directory(save_directory)

adata = anndata.read_h5ad(os.path.join('data', 'Zenodo', 'spain_preprocessed.h5ad'))
adata.X = qu.pp.standardize(adata.X)

## Figure 3c

In [None]:
fig = plt.figure(figsize=(10, 5.5), dpi = 400)
gs = GridSpec(1, 2, width_ratios=[1, 1], wspace=0, hspace=0.45, bottom=0.15)  #adjust this depending on how many phenotypic/functional markers you ahve
ax1 = fig.add_subplot(gs[0])
ax2 = fig.add_subplot(gs[1])
#phenotypic markers
mp1 = sc.pl.matrixplot(adata, 
                      var_names=phenotypic_markers, 
                      groupby='cell_cluster',
                      vmin=-1.5, vmax=2, cmap='vlag', 
                      categories_order=cell_ordering, 
                      ax=ax1, 
                      colorbar_title='avg. expression \n (z-score)', 
                      return_fig=True)

mp1.legend(show = False)
mp1.add_totals(size = 2, color = 'lightgrey', show = False).style(edge_color='black', cmap = 'vlag') #only have this here for edge colors, show = False so we don't see the totals on the first
ax1 = mp1.get_axes()['mainplot_ax']
y_ticks = ax1.get_yticks()

ax1.set_title('Phenotypic', fontsize = 14)
y_tick_labels = [tick.get_text() for tick in ax1.get_yticklabels()]
ax1.set_yticklabels(y_tick_labels, fontsize=12, ha='right', position=(-0.02, 0))

#cell type circles 
for y, cell_type in zip(y_ticks, y_tick_labels):
    if cell_type in colors_dict_cells:
        color = colors_dict_cells[cell_type]
        ax1.add_patch(plt.Circle((-0.5, y), 0.3, color=color, transform=ax1.transData, clip_on=False, zorder = 10))

#horizontal lines
line_positions = [3, 9, 16]  # Adjust according to your data
for pos in line_positions:
    ax1.hlines(y=pos, xmin=-0.5, xmax=len(phenotypic_markers), color='black', linewidth=1)

#functional markers
mp2 = sc.pl.matrixplot(adata, 
                      var_names=functional_markers, 
                      groupby='cell_cluster',
                      vmin=-1.5, vmax=2, cmap='vlag', 
                      categories_order=cell_ordering, 
                      ax=ax2, 
                      colorbar_title='avg. expression \n (z-score)', 
                      return_fig=True)

mp2.add_totals(size = 2, color = 'lightgrey').style(edge_color='black', cmap = 'vlag')
ax2 = mp2.get_axes()['mainplot_ax']
ax2.set_title('Functional', fontsize = 14)
ax2.set_yticklabels(y_tick_labels, fontsize=12, ha='right', position=(-0.05, 0))

ax2.set_yticklabels([])
ax2.set_yticks([])

#horizontal lines
for pos in line_positions:
    ax2.hlines(y=pos, xmin=-0.5, xmax=len(functional_markers), color='black', linewidth=1)

#ensure texts on barplots are visible and centered
ax3 = mp2.get_axes()['group_extra_ax']
for text in ax3.texts:
    text.set_horizontalalignment('left')
    text.set_x(text.get_position()[0] - 8) #xpos
    text.set_y(text.get_position()[1] - 1.5) #ypos

#lower the legend 
cbar_ax = mp2.get_axes()['color_legend_ax']
cbar_ax.set_position([cbar_ax.get_position().x0 - 0.06, cbar_ax.get_position().y0 - 0.12, 
                      cbar_ax.get_position().width, cbar_ax.get_position().height])


plt.savefig(os.path.join(save_directory, 'figure3c.pdf'), bbox_inches = 'tight')

## Figure 3d-f

#### Plot CPM overlay

In [None]:
data_dir = r'/Volumes/Shared/Noah Greenwald/TNBC_Cohorts/SPAIN/image_data/samples'
seg_dir = r'/Volumes/Shared/Noah Greenwald/TNBC_Cohorts/SPAIN/segmentation/samples/deepcell_output'

channel_to_rgb = np.array([
    [0.0, 1.0, 1.0],  # Cyan
    [1.0, 0.0, 1.0],  # Magenta
    [1.0, 1.0, 0.0],  # Yellow
    [1.0, 0.0, 0.0],  # Red
    [0.0, 0.0, 1.0],  # Blue
    [0.0, 1.0, 0.0]   # Green
])

fov_list = ['TMA32_R6C2', 'TMA34_R8C1', 'TMA44_R3C3']

cell_list = cell_ordering
df_cells = adata[np.isin(adata.obs['fov'], fov_list)].to_df()
df_cells['cell_cluster'] = adata[np.isin(adata.obs['fov'], fov_list)].obs['cell_cluster']
df_cells['label']= adata[np.isin(adata.obs['fov'], fov_list)].obs['label']
df_cells['fov'] = adata[np.isin(adata.obs['fov'], fov_list)].obs['fov']
df_cells = df_cells[np.isin(df_cells.cell_cluster, cell_list)]

colormap = pd.DataFrame({'cell_cluster': list(colors_dict_cells.keys()),
                         'color': list(colors_dict_cells.values())})

save_directory_ = os.path.join(save_directory, 'overlay', 'joint')
qu.pp.make_directory(save_directory_)

cohort_cluster_plot(
    fovs=fov_list,
    seg_dir=seg_dir,
    save_dir=save_directory_,
    cell_data=df_cells,
    erode=True,
    fov_col='fov',
    label_col='label',
    cluster_col='cell_cluster',
    seg_suffix="_whole_cell.tiff",
    cmap=colormap,
    fig_file_type = 'pdf',
    display_fig=False,
)

#### Plot expression overlay

In [None]:
##tumor
channel_to_rgb = np.array([
    [0.0, 1.0, 1.0],  # Cyan
    [1.0, 0.0, 1.0],  # Magenta
    [1.0, 1.0, 0.0],  # Yellow
    [1.0, 0.0, 0.0],  # Red
    [0.0, 0.0, 1.0],  # Blue
    [0.0, 1.0, 0.0]   # Green
])

save_directory_ = os.path.join(save_directory, 'overlay', 'tumor')
qu.pp.make_directory(save_directory_)

qu.pl.plot_overlay(seg_dir, data_dir, 'TMA32_R6C2', ['ECAD', 'CK17', 'Vim'], ["H3K27me3", "H3K9ac"], channel_to_rgb, save_directory_, 'TMA32_R6C2_tumor_overlay')
qu.pl.plot_overlay(seg_dir, data_dir, 'TMA34_R8C1', ['ECAD', 'CK17', 'Vim'], ["H3K27me3", "H3K9ac"], channel_to_rgb, save_directory_, 'TMA34_R8C1_tumor_overlay')
qu.pl.plot_overlay(seg_dir, data_dir, 'TMA44_R3C3', ['ECAD', 'CK17', 'Vim'], ["H3K27me3", "H3K9ac"], channel_to_rgb, save_directory_, 'TMA44_R3C3_tumor_overlay')

In [None]:
## structural 
channel_to_rgb = np.array([
    [0.0, 1.0, 1.0],  # Cyan
    [0.0, 0.0, 1.0],  # Blue
    [1.0, 1.0, 0.0],  # Yellow
    [1.0, 0.0, 1.0],  # Magenta
    [1.0, 0.0, 0.0],  # Red
    [0.0, 1.0, 0.0],   # Green
])

save_directory_ = os.path.join(save_directory, 'overlay', 'structural')
qu.pp.make_directory(save_directory_)

qu.pl.plot_overlay(seg_dir, data_dir, 'TMA32_R6C2', ['ECAD', 'FAP', 'SMA', 'CD31'], ["H3K27me3", "H3K9ac"], channel_to_rgb, save_directory_, 'TMA32_R6C2_struc_overlay')
qu.pl.plot_overlay(seg_dir, data_dir, 'TMA34_R8C1', ['ECAD', 'FAP', 'SMA', 'CD31'], ["H3K27me3", "H3K9ac"], channel_to_rgb, save_directory_, 'TMA34_R8C1_struc_overlay')
qu.pl.plot_overlay(seg_dir, data_dir, 'TMA44_R3C3', ['ECAD', 'FAP', 'SMA', 'CD31'], ["H3K27me3", "H3K9ac"], channel_to_rgb, save_directory_, 'TMA44_R3C3_struc_overlay')

In [None]:
save_directory_ = os.path.join(save_directory, 'overlay', 'immune')
qu.pp.make_directory(save_directory_)

## immune
channel_to_rgb = np.array([
    [0.0, 1.0, 1.0],  # Cyan
    [1.0, 0.0, 0.0],  # Red
    [0.0, 0.0, 1.0],  # Blue
])

qu.pl.plot_overlay(seg_dir, data_dir, 'TMA32_R6C2', ['ECAD', 'CD8', 'CD4'], ["H3K27me3", "H3K9ac"], channel_to_rgb, save_directory_, 'TMA32_R6C2_immune_overlay')

channel_to_rgb = np.array([
    [0.0, 1.0, 1.0],  # Cyan
    [1.0, 0.0, 1.0],  # Magenta
    [0.0, 1.0, 0.0],   # Green
])
qu.pl.plot_overlay(seg_dir, data_dir, 'TMA34_R8C1', ['ECAD', 'CD45', 'CD68'], ["H3K27me3", "H3K9ac"], channel_to_rgb, save_directory_, 'TMA34_R8C1_immune_overlay')

channel_to_rgb = np.array([
    [0.0, 1.0, 1.0],  # Cyan
    [1.0, 0.0, 1.0],  # Magenta
    [1.0, 0.0, 0.0],  # Red
])
qu.pl.plot_overlay(seg_dir, data_dir, 'TMA44_R3C3', ['ECAD', 'CD45', 'Calprotectin'], ["H3K27me3", "H3K9ac"], channel_to_rgb, save_directory_, 'TMA44_R3C3_immune_overlay')

## Figure 3g

In [None]:
matplotlib.rcParams['lines.linewidth'] = 0.5
structural_cell_types = ['CAF', 'Fibroblast', 'Smooth_Muscle', 'Endothelium']
tumor_cell_types = ['Cancer_1', 'Cancer_2', 'Cancer_3']
immune_cell_types = ['CD4T', 'CD8T', 'Treg', 'T_Other', 'B', 'NK', 'CD68_Mac', 'CD163_Mac', 'Mac_Other', 'Monocyte', 'APC', 'Mast', 'Neutrophil']

all_cell_types = immune_cell_types + structural_cell_types + tumor_cell_types

cell_labels = ['Tumor', 'Structural', 'Immune']

color_relapse_0 = '#377eb8'  # Color for boxes before the dotted line
color_relapse_1 = '#e41a1c'  # Color for boxes after the dotted line
color_relapse_unknown = 'grey'
color_treatment_adj = '#FB9A23'#'#FFB563'
color_treatment_neoadj = '#90BE6D'
color_treatment_none = '#483C46'

df = adata.obs[['Patient_ID', 'cell_cluster', 'Relapse', 'Grade_grouped', 'Treatment_type']].copy()

fig, axes = plt.subplots(nrows=3, ncols=1, figsize=(14, 8), sharex=True, gridspec_kw={'hspace': 0.1, 'wspace': 0.4, 'bottom':0.15}, dpi = 400)

cell_type_list = [tumor_cell_types, structural_cell_types, immune_cell_types]

sns.set_style('ticks')

from scipy.cluster.hierarchy import linkage, leaves_list

# Compute proportions for all cell types
count_df = df[df['cell_cluster'].isin(all_cell_types)].groupby(['Patient_ID', 'cell_cluster']).size().unstack().loc[:, all_cell_types]
total_cell_count = df[df['cell_cluster'].isin(all_cell_types)].groupby(['Patient_ID'])['cell_cluster'].size()
prop_df = count_df.div(total_cell_count, axis=0).fillna(0)

# Perform hierarchical clustering on all cell types
Z = linkage(prop_df, method='average', metric='euclidean')
cluster_order = leaves_list(Z)
ordered_patients = [prop_df.index[i] for i in cluster_order]
# Plot dendrogram on top of the plots
ax_dendro = fig.add_axes([0.1, 0.8, 0.8, 0.03])
dendro = dendrogram(Z, labels=prop_df.index, ax=ax_dendro, orientation='top', no_labels=True, color_threshold=0, above_threshold_color='k')

ax_dendro.set_xticks([])
ax_dendro.set_yticks([])
ax_dendro.spines['top'].set_visible(False)
ax_dendro.spines['right'].set_visible(False)
ax_dendro.spines['left'].set_visible(False)
ax_dendro.spines['bottom'].set_visible(False)
# Adjust the main plots to accommodate the dendrogram
axes[0].set_position([0.1, 0.58, 0.8, 0.22])  # Immune plot position
axes[1].set_position([0.1, 0.34, 0.8, 0.22])   # Structural plot position
axes[2].set_position([0.1, 0.1, 0.8, 0.22])   # Tumor plot position

sorted_df_total = []
df_pivot_total = []
for i, (category, label) in enumerate(zip(cell_type_list, cell_labels)):
    ax = axes[i]
    ax.grid(False)
    count_df = df.groupby(['Patient_ID', 'cell_cluster']).size().unstack().loc[:, category]
    total_cell_count = df[np.isin(df['cell_cluster'], category)].groupby(['Patient_ID'])['cell_cluster'].size()
    prop_df = count_df.div(total_cell_count, axis=0)
    prop_df = pd.merge(prop_df.reset_index(), df[['Patient_ID', 'Relapse', 'Grade_grouped', 'Treatment_type']].drop_duplicates(), on=['Patient_ID'])

    #sort by relapse status
    sorted_df = prop_df.melt(id_vars=['Patient_ID', 'Relapse', 'Grade_grouped', 'Treatment_type']).sort_values(by='Relapse')
    sorted_df.set_index('Patient_ID', inplace = True)
    sorted_df = sorted_df.loc[ordered_patients]
    sorted_df.reset_index(inplace= True)
    #sorted_df = sorted_df[np.isin(sorted_df['Relapse'], [0, 1])] ##keep if you only want patients with relapse label
    df_pivot = sorted_df.pivot(index='Patient_ID', columns='variable', values='value').fillna(0)
    df_pivot = df_pivot.loc[ordered_patients]
    sorted_index = sorted_df[['Patient_ID', 'Relapse', 'Grade_grouped', 'Treatment_type']].drop_duplicates().set_index('Patient_ID')

    color_list = [colors_dict_cells.get(col, '#333333') for col in df_pivot.columns]
    sorted_df_total.append(sorted_df)
    df_pivot_total.append(df_pivot)

    #stacked bar plot
    g = df_pivot.plot(kind='bar', stacked=True, color=color_list, ax=ax, width=1, edgecolor='white', linewidth=0.2, legend=True)
    g.tick_params(labelsize=8)
    g.legend_.remove()
    # ax.axvline(relapse_index - 0.5, color='k', lw=1, ls='--')  # Vertical dotted line
    ax.set_ylim(0, 1)
    ax.set_yticks([0, 0.5, 1])  # Set y-axis ticks
    ax.set_xticks([])
    ax.set_ylabel(f'{label}', fontsize=10)

    #add metadata
    n_boxes = df_pivot.shape[0]
    box_height = 0.08

    if i == 0:
        for j in range(n_boxes):
            box_y_position = 1.3
            relapse_status = sorted_index.iloc[j]['Relapse']
            if relapse_status == 0:
                color = color_relapse_0
            elif relapse_status == 1:
                color = color_relapse_1
            else:
                color = color_relapse_unknown
            ax.add_patch(plt.Rectangle((j - 0.5, box_y_position), 1, box_height, facecolor=color, 
                                    transform=ax.get_xaxis_transform(), clip_on=False, edgecolor='white', linewidth=0.2, alpha = 0.7))
        ax.text(-1, box_y_position + (box_height / 2), 'Relapse', fontsize=8, verticalalignment='center', horizontalalignment='right', transform=ax.get_xaxis_transform())

        for j in range(n_boxes):
            box_y_position = 1.2
            treatment_status = sorted_index.iloc[j]['Treatment_type']
            if treatment_status == 'adjuvant':
                color = color_treatment_adj
            elif treatment_status == 'neoadjuvant':
                color = color_treatment_neoadj
            elif treatment_status == 'no treatment':
                color = color_treatment_none
            else:
                print('missing treatment')
            ax.add_patch(plt.Rectangle((j - 0.5, box_y_position), 1, box_height, facecolor=color, 
                                    transform=ax.get_xaxis_transform(), clip_on=False, edgecolor='white', linewidth=0.2, alpha = 0.7))
        ax.text(-1, box_y_position + (box_height / 2), 'Treatment', fontsize=8, verticalalignment='center', horizontalalignment='right', transform=ax.get_xaxis_transform())


    handles, labels = ax.get_legend_handles_labels()
    ax.legend(handles, labels, loc='upper right', prop={'size': 6}, bbox_to_anchor=(1.08, 1.05))

axes[-1].set_xlabel(f'Patients ($N = {len(sorted_index)}$)', fontsize=10)
axes[-1].set_xticks([]) 
plt.savefig(os.path.join(save_directory, 'figure3g.pdf'), bbox_inches = 'tight')

## Figure 3h

In [None]:
adata_spain = anndata.read_h5ad('data/Zenodo/spain_preprocessed.h5ad')
relapse_dict = dict(zip(adata_spain.obs[['Patient_ID', 'Relapse']].drop_duplicates()['Patient_ID'], adata_spain.obs[['Patient_ID', 'Relapse']].drop_duplicates()['Relapse']))

adata_stanford = anndata.read_h5ad('data/Zenodo/Stanford_preprocessed.h5ad')
recurrence_dict = dict(zip(adata_stanford.obs[['Patient_ID', 'RECURRENCE_LABEL']].drop_duplicates()['Patient_ID'], adata_stanford.obs[['Patient_ID', 'RECURRENCE_LABEL']].drop_duplicates()['RECURRENCE_LABEL']))

adata_ntpublic = anndata.read_h5ad('data/Zenodo/nt_preprocessed.h5ad')
pcr_dict = dict(zip(adata_ntpublic.obs[['Patient_ID', 'pCR']].drop_duplicates()['Patient_ID'], adata_ntpublic.obs[['Patient_ID', 'pCR']].drop_duplicates()['pCR']))

In [None]:
logFC_spain = compute_logFC(adata_spain, 'Patient_ID', 'cell_cluster', 'Relapse', relapse_dict, 1.0, 0.0)
logFC_ntpublic = compute_logFC(adata_ntpublic, 'Patient_ID', 'cell_cluster', 'pCR', pcr_dict, 'RD', 'pCR')
logFC_stanford = compute_logFC(adata_stanford, 'Patient_ID', 'cell_cluster', 'RECURRENCE_LABEL', recurrence_dict, 'POSITIVE', 'NEGATIVE')

In [None]:
# Merge dataframes on cell_cluster
merged_df = pd.merge(pd.DataFrame(logFC_spain, columns=['Spain']), pd.DataFrame(logFC_stanford, columns=['Stanford']), on='cell_cluster')

# Merge the existing merged_df with the new cohort data
merged_df = pd.merge(merged_df, pd.DataFrame(logFC_ntpublic, columns = ['NT']), on='cell_cluster', how = 'outer')

# Sort by average logFC for better visual alignment
merged_df['avg_logFC'] = merged_df.mean(axis=1, skipna=True)
merged_df = merged_df.sort_values(by='avg_logFC', ascending=False)

# Create the plot using Matplotlib
plt.figure(figsize=(3, 5))
sns.set_style('ticks')

jitter_strength = 0.02  # Adjust the strength of the jitter

for col, color in zip(merged_df.columns[:-1], ['#94C7AB', '#B589BD', '#F4A261']):
    jittered_x_positions = merged_df[col] + np.random.uniform(-jitter_strength, jitter_strength, size=len(merged_df))
    plt.scatter(jittered_x_positions, merged_df.index, color=color, s=40, edgecolor='k', zorder=5, label=col, linewidths=0.4)

for i, row in merged_df.iterrows():
    valid_x = row[:-1].dropna()  # Drop NaN values from the row
    plt.plot(valid_x, [i]*len(valid_x), '-', color='grey', linewidth=1, zorder=4)

sns.set_style('ticks')
# Customize the plot
plt.axvline(0, color='black', linestyle='--', linewidth=0.8, zorder=3)
plt.xlabel('Log2(FC) Abundance', fontsize = 10)
plt.ylabel('Cell Types', fontsize = 10)
plt.yticks(range(len(merged_df)), merged_df.index)
plt.xlim(-1.5, 1.5)
plt.gca().margins(y=0.02)  # Remove vertical margins
plt.tick_params(labelsize=8)
# Add a legend
plt.legend(prop={'size':8})
plt.savefig(os.path.join(save_directory, 'figure3h.pdf'), bbox_inches = 'tight')