# Imports

In [None]:
from mpl_toolkits.axes_grid1 import ImageGrid
from matplotlib.colors       import LinearSegmentedColormap
from matplotlib.colors       import TwoSlopeNorm
from skimage.transform       import resize
from plottify                import autosize
from sklearn                 import metrics
from PIL                     import Image
from adjustText              import adjust_text
from scipy.cluster           import hierarchy
from sklearn.metrics         import f1_score, roc_curve
import statsmodels.api       as sm
import matplotlib.pyplot     as plt
import numpy                 as np
import seaborn               as sns
import pandas                as pd
import scanpy                as sc
import matplotlib
import anndata
import random
import fastcluster
import copy
import umap
import h5py
import sys
import os

# Variables for data selections

In [70]:
# Workspace path.
main_path = '/media/adalberto/Disk2/PhD_Workspace'
sys.path.append(main_path)
from models.clustering.cox_proportional_hazard_regression_leiden_clusters import *
from models.evaluation.folds import load_existing_split
from models.visualization.attention_maps import *
from models.clustering.data_processing import *
from data_manipulation.data import Data

In [None]:
# Image dataset variables.
dataset            = 'TCGAFFPE_LUADLUSC_5x_60pc'
additional_dataset = 'NYUFFPE_survival_5x_60pc'


############# Lungsubtype
meta_field     = 'luad'
matching_field = 'slides'
resolution     = 2.0
fold_number    = 4
groupby        = 'leiden_%s' % resolution
meta_folder    = 'lungsubtype_nn250'
folds_pickle   = '%s/utilities/files/LUADLUSC/lungsubtype_Institutions.pkl' % main_path

# Institutions.
inst_csv   = '%s/utilities/files/TCGA/TCGA_Institutions.csv' % main_path
inst_frame = pd.read_csv(inst_csv)
inst_frame = inst_frame[inst_frame['Study Name'].isin(['Lung adenocarcinoma', 'Lung squamous cell carcinoma'])]

# Representations.
h5_complete_path = '%s/results/BarlowTwins_3/TCGAFFPE_LUADLUSC_5x_60pc_250K/h224_w224_n3_zdim128_filtered/hdf5_TCGAFFPE_LUADLUSC_5x_60pc_he_complete_lungsubtype_survival_filtered.h5' % main_path
h5_additional_path = '%s/results/BarlowTwins_3/TCGAFFPE_LUADLUSC_5x_60pc_250K/h224_w224_n3_zdim128_filtered/NYUFFPE_LUADLUSC_5x_60pc/h224_w224_n3_zdim128/hdf5_NYUFFPE_LUADLUSC_5x_60pc_he_combined_filtered.h5' % main_path

# File name and directories.
file_name = h5_complete_path.split('/hdf5_')[1].split('.h5')[0] + '_%s__fold%s' % (groupby.replace('.', 'p'), fold_number)
if h5_additional_path is not None: file_additional = h5_additional_path.split('/hdf5_')[1].split('.h5')[0] + '_%s__fold%s' % (groupby.replace('.', 'p'), fold_number)

# Setup folder.
main_cluster_path = h5_complete_path.split('hdf5_')[0]
main_cluster_path = os.path.join(main_cluster_path, meta_folder)
adatas_path       = os.path.join(main_cluster_path, 'adatas')
figures_path      = os.path.join(main_cluster_path, 'figures')
os.makedirs(figures_path, exist_ok=True)

### Images

In [None]:
data = Data(dataset=dataset, marker='he', patch_h=224, patch_w=224, n_channels=3, batch_size=64, project_path=main_path, load=True)
img_dicts = dict()
img_dicts['train'] = data.training.images
img_dicts['valid'] = data.validation.images
img_dicts['test'] = data.test.images

additional_data = Data(dataset=additional_dataset, marker='he', patch_h=224, patch_w=224, n_channels=3, batch_size=64, project_path=main_path, load=True)
additional_img_dicts = dict()
additional_img_dicts['train'] = additional_data.training.images

# Paper Figure - Latent Space and Cluster Network - LUAD vs LUSC

In [None]:
adata_train, h5ad_path = read_h5ad_reference(h5_complete_path, meta_folder, groupby, fold_number)

In [None]:
# Check if we have run PAGA to get the cluster network
done = False
if os.path.isfile(h5ad_path.replace('.h5ad', '_paga.h5ad')):
    done=True
    print('Reading PAGA H5AD')
    adata_train = anndata.read_h5ad(h5ad_path.replace('.h5ad', '_paga.h5ad'))
else:
    print('Running PAGA')
    sc.tl.paga(adata_train, groups=groupby, neighbors_key='nn_leiden')

In [None]:
# HPC network visualization
layout           = 'fa'  # ‘fa’, ‘fr’, ‘rt’, ‘rt_circular’, ‘drl’, ‘eq_tree’
random_state     = 0
threshold        = 0.29

# Figure related
node_size_scale  = 25
node_size_power  = 0.5
edge_width_scale = .05
fontsize    = 10
fontoutline = 2
meta_field   = 'luad'

if not done:
    fig = plt.figure(figsize=(100,10))
    ax  = fig.add_subplot(1, 3, 1)
    sc.pl.paga(adata_train, layout=layout, random_state=random_state, color=meta_field, threshold=threshold, node_size_scale=node_size_scale, node_size_power=node_size_power,
            edge_width_scale=edge_width_scale, fontsize=fontsize, fontoutline=fontoutline, frameon=False, show=False, ax=ax)
    plt.show()

In [None]:
# Run UMAP based on cluster network if we have not already done so
if not done:
    sc.tl.umap(adata_train, init_pos="paga", neighbors_key='nn_leiden')
    adata_train.write(h5ad_path.replace('.h5ad', '_paga.h5ad'))

In [None]:
# Representations and Cluster Network.
def show_umap_leiden(adata, meta_field, layout, random_state, threshold, node_size_scale, node_size_power, edge_width_scale, directory, file_name,
                     fontsize=10, fontoutline=2, marker_size=2, ax_size=16, l_size=12, l_t_size=14, l_box_w=1, l_markerscale=1, palette='tab20', figsize=(30,10),
                     leiden_name=False):
    from matplotlib.lines import Line2D

    leiden_clusters = np.unique(adata.obs[groupby].astype(int))
    colors = sns.color_palette(palette, len(leiden_clusters))

    fig = plt.figure(figsize=figsize)
    ax  = fig.add_subplot(1, 3, 1)


    ax = sc.pl.umap(adata, ax=ax, color=meta_field, size=marker_size, show=False, frameon=False, na_color='black')
    if meta_field == 'luad':
        legend_c = ax.legend(loc='best', markerscale=l_markerscale, title='Lung Type', prop={'size': l_size})
        legend_c.get_title().set_fontsize(l_t_size)
        legend_c.get_frame().set_linewidth(l_box_w)
        legend_c.get_texts()[0].set_text('LUSC')
        legend_c.get_texts()[1].set_text('LUAD')
    ax.set_title('Tile Vector\nRepresentations', fontsize=ax_size, fontweight='bold')

    ax  = fig.add_subplot(1, 3, 2)
    sc.pl.umap(adata, ax=ax, color=groupby, size=marker_size, show=False, legend_loc='on data', legend_fontsize=fontsize, legend_fontoutline=fontoutline, frameon=False, palette=colors)
    if leiden_name:
        ax.set_title('Leiden Clusters', fontsize=ax_size, fontweight='bold')
    else:
        ax.set_title('Histomorphological Phenotype\nClusters', fontsize=ax_size, fontweight='bold')

    adjust_text(ax.texts)

    ax  = fig.add_subplot(1, 3, 3)
    names_lines  = ['LUSC', 'LUAD']
    sc.pl.paga(adata, layout=layout, random_state=random_state, color=meta_field, threshold=threshold, node_size_scale=node_size_scale, node_size_power=node_size_power, edge_width_scale=edge_width_scale, fontsize=fontsize, fontoutline=fontoutline, frameon=False, show=False, ax=ax)
    if meta_field == 'luad':
        legend = ax.legend(legend_c.legendHandles, names_lines, title='Lung Type', loc='upper left', prop={'size': l_size})
        legend.get_title().set_fontsize(l_t_size)
        legend.get_frame().set_linewidth(l_box_w)
    if leiden_name:
        ax.set_title('Leiden Cluster Network', fontsize=ax_size, fontweight='bold')
    else:
        ax.set_title('Histomorphological Phenotype\nCluster Network', fontsize=ax_size, fontweight='bold')
    plt.tight_layout()
    plt.savefig(os.path.join(directory,file_name))
    plt.show()

def plot_umaps(data_df, x, y, hue, scatter_size, palette, figsize, fontsize_labels, fontsize_legend, l_box_w):
    fig   = plt.figure(figsize=figsize)
    ax    = fig.add_subplot(1, 1, 1)
    plot = sns.scatterplot(data=data_df, x=x, y=y, hue=hue, s=scatter_size, ax=ax, palette=palette)

    h,l = plot.get_legend_handles_labels()
    plot.legend_.remove()
    legend = fig.legend(h,l, ncol=2, bbox_to_anchor=(1.17, 0.9), title=r'$\bf{TCGA\ Institution\ Code}$', prop={'weight':'bold'})

    legend.get_title().set_fontsize(fontsize_legend)
    legend.get_frame().set_linewidth(l_box_w)

    ax.set_xlabel('UMAP Dim. 0', fontsize=fontsize_labels, fontweight='bold')
    ax.set_ylabel('UMAP Dim. 1', fontsize=fontsize_labels, fontweight='bold')

    for tick in ax.xaxis.get_major_ticks():
        tick.label1.set_fontsize(fontsize_labels)
        tick.label1.set_fontweight('bold')
    for tick in ax.yaxis.get_major_ticks():
        tick.label1.set_fontsize(fontsize_labels)
        tick.label1.set_fontweight('bold')
    for axis in ['top','bottom','left','right']:
        ax.spines[axis].set_linewidth(4)
    plt.show()


In [None]:
# Graph visualization related
layout           = 'fa'  # ‘fa’, ‘fr’, ‘rt’, ‘rt_circular’, ‘drl’, ‘eq_tree’
random_state     = 0
threshold        = 0.29

# Figure related
node_size_scale  = 25
node_size_power  = 0.5
edge_width_scale = .05

meta_field = 'luad'

sns.set_theme(style='white')
show_umap_leiden(adata_train, meta_field, layout, random_state, threshold, node_size_scale, node_size_power, edge_width_scale, directory=figures_path,
                 file_name=file_name + '_clusternetwork_all_anno.jpg', fontsize=25, fontoutline=10, marker_size=5, ax_size=62, l_size=50, l_t_size=55, l_box_w=4,
                 l_markerscale=6, palette='tab20', figsize=(50,20))

meta_field = None
show_umap_leiden(adata_train, meta_field, layout, random_state, threshold, node_size_scale, node_size_power, edge_width_scale, directory=figures_path,
                 file_name=file_name + '_clusternetwork_all_anno.jpg', fontsize=25, fontoutline=10, marker_size=5, ax_size=62, l_size=50, l_t_size=55, l_box_w=4,
                 l_markerscale=6, palette='tab20', figsize=(50,20))
show_umap_leiden(adata_train, meta_field, layout, random_state, threshold, node_size_scale, node_size_power, edge_width_scale, directory=figures_path,
                 file_name=file_name + '_clusternetwork_all_anno.jpg', fontsize=25, fontoutline=10, marker_size=5, ax_size=62, l_size=50, l_t_size=55, l_box_w=4,
                 l_markerscale=6, palette='tab20', figsize=(50,20), leiden_name=True)

# Paper Figure - Clustermap Slides vs Clusters

In [None]:
frames = build_cohort_representations(meta_folder, meta_field, matching_field, groupby, fold_number, folds_pickle, h5_complete_path, h5_additional_path, 'percent', 100)
complete_df, additional_complete_df, frame_clusters, frame_samples, features = frames

In [None]:
def clustermap_representations(features, complete_df, frame_clusters, method_slides, metric_slides, figsize, fontsize_labels, fontsize_ticks, dendrogram_ratio):
    slide_rep_df = complete_df.iloc[1:].copy(deep=True)

    # Row and Columns colors
    row_lut = dict(zip(np.unique(slide_rep_df[meta_field]), ['blue', 'orange']))
    row_colors = pd.Series(slide_rep_df[meta_field].map(row_lut), name='LUSC/LUAD\nWSI\n')

    purity_color_map = LinearSegmentedColormap.from_list('cluster_purity', ['blue','orange'])
    purities = [purity if flag else 100-purity for purity, flag in zip(frame_clusters['Subtype Purity(%)'], frame_clusters[meta_field])]
    col_colors = pd.Series([matplotlib.colors.to_hex(purity_color_map(perc/100)) for perc in purities], name='HPC\nLUSC/LUAD\nPurity\n')

    g = sns.clustermap(slide_rep_df[features].astype(float)*100, vmin=0, vmax=100, row_colors=row_colors, col_colors=col_colors, col_linkage=None, row_linkage=None, method=method_slides, metric=metric_slides, cmap='rocket_r', figsize=figsize, dendrogram_ratio=dendrogram_ratio, tree_kws=dict(linewidths=3.0))

    # X ticks and labels
    g.ax_heatmap.set_xticklabels(g.ax_heatmap.get_xmajorticklabels(), fontsize=fontsize_ticks)
    g.ax_heatmap.set_xlabel('HPC', fontsize=fontsize_labels)
    # Y ticks and labels
    g.ax_heatmap.set_ylabel('Whole Slide Image (WSI)',   fontsize=fontsize_labels)
    g.ax_heatmap.set_yticks([])
    # Row color labels
    g.ax_row_colors.tick_params(axis='both', length=0)
    g.ax_row_colors.tick_params(axis='x', which='major', labelsize=fontsize_labels)
    # Column color labels
    g.ax_col_colors.tick_params(axis='both', length=0)
    g.ax_col_colors.tick_params(axis='y', which='major', labelsize=fontsize_labels)
    g.ax_cbar.tick_params(labelsize=fontsize_ticks)

    [label.set_fontweight('bold') for label in g.ax_col_colors.get_yticklabels()]
    [label.set_fontweight('bold') for label in g.ax_row_colors.get_xticklabels()]
    [label.set_fontweight('bold') for label in g.ax_cbar.get_yticklabels()]

    for sel_ax in [g.ax_heatmap]:
        for ticks in [sel_ax.xaxis.get_major_ticks(), sel_ax.yaxis.get_major_ticks()]:
            for tick in ticks:
                tick.label1.set_fontsize(fontsize_ticks)
                tick.label1.set_fontweight('bold')

    g.ax_heatmap.set_xlabel('Histomorphological Phenotype Cluster (HPC)', fontsize=fontsize_labels, fontweight='bold')
    g.ax_heatmap.set_ylabel('Whole Slide Image (WSI)',   fontsize=fontsize_labels, fontweight='bold')

    plt.show()

method_slides = 'ward'
metric_slides = 'correlation'
sns.set_theme(style='white')
clustermap_representations(features, complete_df, frame_clusters, method_slides, metric_slides, figsize=(35,25), fontsize_labels=50, fontsize_ticks=27, dendrogram_ratio=(0.1,0.2))


# Paper Figure - Slide representations

In [None]:
all_data = build_cohort_representations(meta_folder, meta_field, matching_field, groupby, fold_number, folds_pickle, h5_complete_path, None,
                                        type_composition='percent', min_tiles=100, use_conn=False, use_ratio=False, top_variance_feat=0, reduction=2)

complete_df, additional_complete_df, frame_clusters, frame_samples, features = all_data

fontsize_title = 25
fontsize_y     = 20
fontsize       = 13

slides_heatmap = ['TCGA-33-4532-01Z-00-DX2', 'TCGA-33-4532-01Z-00-DX3'] + complete_df.slides.values.tolist()[130:139]
heatmap_df = complete_df[complete_df.slides.isin(slides_heatmap)][['slides']+features]
heatmap_df = heatmap_df.set_index('slides')
heatmap_df = heatmap_df*100

sns.set_theme(style='white')
fig, ax_dict = plt.subplots(11,1, figsize=(20,10))
flag = True
for i, slide in enumerate(heatmap_df.index):
    if i==10:
        flag = False
    elif i==9:
        ax = sns.heatmap(heatmap_df.loc[slide].values.astype(int).reshape(1,-1), vmin=0, vmax=100, linecolor='black', annot=False, linewidths=1, cmap='rocket_r', ax=ax_dict[i], cbar_ax=ax_dict[i+1], cbar_kws={"orientation": "horizontal"})
    else:
        ax = sns.heatmap(heatmap_df.loc[slide].values.astype(int).reshape(1,-1), vmin=0, vmax=100, linecolor='black', annot=False, linewidths=1, cmap='rocket_r', ax=ax_dict[i], cbar=False)

    if flag:
        if i==0:
            ax_dict[i].set_title('WSI HPC % Contribution', fontsize=fontsize_title, fontweight='bold', y=1.5)
            ax_dict[i].xaxis.set_tick_params(labeltop='on', labelbottom=False)
            ax_dict[i].tick_params('both', length=0, width=0, which='major')
            labels = [item.get_text() for item in ax_dict[i].get_xticklabels()]
            ax_dict[i].set_xticklabels(labels, fontsize=fontsize, fontweight='bold')
            slide += ' '
        else:
            ax_dict[i].set_xticks([])
        ax_dict[i].set_yticklabels([slide], fontsize=fontsize_y, fontweight='bold', rotation=0, ha='right')

    for tick in ax_dict[i].xaxis.get_major_ticks():
        tick.label1.set_fontsize(fontsize)
        tick.label1.set_fontweight('bold')
    for tick in ax_dict[i].yaxis.get_major_ticks():
        tick.label1.set_fontsize(fontsize_y)
        tick.label1.set_fontweight('bold')
    for axis in ['top','bottom','left','right']:
        ax_dict[i].spines[axis].set_linewidth(4)

plt.tight_layout()
plt.show()

# Paper Figure - UMAP WSI vector representations

In [None]:
frames = build_cohort_representations(meta_folder, meta_field, matching_field, groupby, fold_number, folds_pickle, h5_complete_path, h5_additional_path, 'clr', 100)
complete_df, additional_complete_df, frame_clusters, frame_samples, features = frames

labels = complete_df.to_numpy()[1:,-1]
data   = complete_df.to_numpy()[1:,2:-1]

labels_add = additional_complete_df.to_numpy()[1:,-1]
data_add   = additional_complete_df.to_numpy()[1:,2:-1]

columns = [col for col in complete_df.columns if col != 'luad' and col != 'samples' and col != 'slides']

labels = complete_df.to_numpy()[1:,-1]
data   = complete_df.to_numpy()[1:,2:-1]
df     = pd.DataFrame(data, columns=columns)
df['Lung Type'] = labels
df['Cohort']       = 'TCGA'

labels_add = additional_complete_df.to_numpy()[1:,-1]
data_add   = additional_complete_df.to_numpy()[1:,1:-1]
df_add     = pd.DataFrame(data_add, columns=columns)
df_add['Lung Type'] = labels_add
df_add['Cohort']       = 'NYU'

df_all = pd.concat([df, df_add], axis=0)


In [None]:
scatter_size    = 300

figsize         = (50,20)
fontsize_labels = 60
fontsize_legend = 60
l_markerscale   = 10
l_box_w         = 3
lw              = 5

min_dist     = 0.0
n_components = 2
n_neighbors  = 25
metric       = 'euclidean'

print(metric, n_neighbors)
# UMAP
fit = umap.UMAP(n_neighbors=n_neighbors, min_dist=min_dist, n_components=n_components, metric=metric)
u   = fit.fit_transform(df_all[columns])
df_all['UMAP Dim. 0'] = u[:, 0]
df_all['UMAP Dim. 1'] = u[:, 1]

fig   = plt.figure(figsize=figsize)
ax    = fig.add_subplot(1, 2, 1)

# Scatter plot.
sns.scatterplot(data=df_all[df_all.Cohort=='TCGA'], x='UMAP Dim. 0', y='UMAP Dim. 1', hue='Lung Type', s=scatter_size, ax=ax)
ax.set_xlabel('UMAP Dim. 0', fontsize=fontsize_labels)
ax.set_ylabel('UMAP Dim. 1', fontsize=fontsize_labels)
ax.set_title('TCGA Cohort',  fontsize=fontsize_labels, fontweight='bold')
ax.tick_params(axis='both', which='major', labelsize=fontsize_labels-25)
legend = ax.legend(loc='lower right', title='Lung Type', markerscale=l_markerscale, prop={'size': fontsize_legend-5})
legend.get_title().set_fontsize(fontsize_legend)
legend.get_texts()[0].set_text('LUSC')
legend.get_texts()[1].set_text('LUAD')
legend.get_frame().set_linewidth(l_box_w)

ax    = fig.add_subplot(1, 2, 2)
sns.scatterplot(data=df_all[df_all.Cohort=='TCGA'], x='UMAP Dim. 0', y='UMAP Dim. 1', color='grey',    s=scatter_size/4, ax=ax)
sns.scatterplot(data=df_all[df_all.Cohort=='NYU'],  x='UMAP Dim. 0', y='UMAP Dim. 1', hue='Lung Type', s=scatter_size,   ax=ax)
ax.set_xlabel('UMAP Dim. 0', fontsize=fontsize_labels)
ax.set_ylabel('UMAP Dim. 1', fontsize=fontsize_labels)
ax.set_title('NYU Cohort',  fontsize=fontsize_labels, fontweight='bold')
ax.tick_params(axis='both', which='major', labelsize=fontsize_labels-25)
legend = ax.legend(loc='lower right', title='Lung Type', markerscale=l_markerscale, prop={'size': fontsize_legend-5})
legend.get_title().set_fontsize(fontsize_legend)
legend.get_texts()[0].set_text('LUSC')
legend.get_texts()[1].set_text('LUAD')
legend.get_frame().set_linewidth(l_box_w)

plt.tight_layout()
plt.show()

In [None]:
scatter_size    = 2000

figsize         = (20,20)
fontsize_labels = 60
fontsize_legend = 60
l_markerscale   = 10
l_box_w         = 3
lw              = 5

min_dist     = 0.0
n_components = 2
n_neighbors  = 25
metric       = 'euclidean'

print(metric, n_neighbors)
# UMAP
fit = umap.UMAP(n_neighbors=n_neighbors, min_dist=min_dist, n_components=n_components, metric=metric)
u   = fit.fit_transform(df_all[columns])
df_all['UMAP Dim. 0'] = u[:, 0]
df_all['UMAP Dim. 1'] = u[:, 1]

fig   = plt.figure(figsize=figsize)
ax    = fig.add_subplot(1, 1, 1)

# Scatter plot.
sns.scatterplot(data=df_all, x='UMAP Dim. 0', y='UMAP Dim. 1', hue='Lung Type', style='Cohort', markers={'TCGA':'v', 'NYU':'s'}, s=scatter_size, ax=ax)
ax.set_xlabel('UMAP Dim. 0', fontsize=fontsize_labels)
ax.set_ylabel('UMAP Dim. 1', fontsize=fontsize_labels)
ax.set_title('Whole Slide Image\nVector Representations',  fontsize=fontsize_labels, fontweight='bold')
ax.tick_params(axis='both', which='major', labelsize=fontsize_labels)
legend = ax.legend(loc='upper left', markerscale=l_markerscale, prop={'size': fontsize_legend-5}, ncol=2)
legend.get_texts()[1].set_text('LUSC')
legend.get_texts()[2].set_text('LUAD')
legend.get_texts()[0].set_size(fontsize_legend)
legend.get_texts()[3].set_size(fontsize_legend)
legend.get_frame().set_linewidth(l_box_w)

for tick in ax.xaxis.get_major_ticks():
    tick.label1.set_fontsize(fontsize_labels)
    tick.label1.set_fontweight('bold')
for tick in ax.yaxis.get_major_ticks():
    tick.label1.set_fontsize(fontsize_labels)
    tick.label1.set_fontweight('bold')

ax.set_xlabel('UMAP Dim. 0', fontsize=fontsize_labels, fontweight='bold')
ax.set_ylabel('UMAP Dim. 1', fontsize=fontsize_labels, fontweight='bold')
for axis in ['top','bottom','left','right']:
    ax.spines[axis].set_linewidth(4)

plt.tight_layout()
plt.show()


# Paper Figure - ROC curve - Institutions & Folds

In [None]:
# Label & LR penalty.
label = 1
alpha = 10.0

# Institutions.
replace_dict = {'Ontario Institute for Cancer Research (OICR)':'Ontario Institute for Cancer Research',
                'Ontario Institute for Cancer Research (OICR)/Ottawa':'Ontario Institute for Cancer Research',
                'St Joseph\'s Medical Center (MD)': 'St. Joseph\'s Medical Center (MD)',
                'Fox Chase':'Fox Chase Cancer Center'}

inst_csv   = '%s/utilities/files/TCGA/TCGA_Institutions.csv' % main_path
inst_frame = pd.read_csv(inst_csv)
inst_frame = inst_frame[inst_frame['Study Name'].isin(['Lung adenocarcinoma', 'Lung squamous cell carcinoma'])][['TSS Code', 'Source Site']]
inst_frame['Source Site'] = inst_frame['Source Site'].replace(replace_dict)
inst_frame['TSS Code']    = inst_frame['TSS Code'].replace({'1':'01','2':'02','3':'03','4':'04','5':'05', '6':'06','7':'07','8':'08','9':'09'})

folds_pickle = '/media/adalberto/Disk2/PhD_Workspace/utilities/files/LUADLUSC/lungsubtype_Institutions.pkl'
h5_complete_path = '/media/adalberto/Disk2/PhD_Workspace/results/BarlowTwins_3/TCGAFFPE_LUADLUSC_5x_60pc_250K/h224_w224_n3_zdim128_filtered/hdf5_TCGAFFPE_LUADLUSC_5x_60pc_he_complete_lungsubtype_survival_filtered.h5'
h5_additional_path = '/media/adalberto/Disk2/PhD_Workspace/results/BarlowTwins_3/TCGAFFPE_LUADLUSC_5x_60pc_250K/h224_w224_n3_zdim128_filtered/NYUFFPE_LUADLUSC_5x_60pc/h224_w224_n3_zdim128/hdf5_NYUFFPE_LUADLUSC_5x_60pc_he_combined_filtered.h5'

In [None]:
# Get folds from existing split.
folds = load_existing_split(folds_pickle)

# Path for alpha Logistic Regression results.
main_cluster_path = h5_complete_path.split('hdf5_')[0]
main_cluster_path = os.path.join(main_cluster_path, meta_folder)
adatas_path       = os.path.join(main_cluster_path, 'adatas')

data_res_folds = dict()
data_res_folds[resolution] = dict()
for i, fold in enumerate(folds):
    # Read CSV files for train, validation, test, and additional sets.
    dataframes, complete_df, leiden_clusters = read_csvs(adatas_path, matching_field, groupby, i, fold, h5_complete_path, h5_additional_path, additional_as_fold=False, force_fold=fold_number)
    train_df, valid_df, test_df, additional_df = dataframes

    # Check clusters and diversity within.
    frame_clusters, frame_samples = create_frames(train_df, groupby, meta_field, diversity_key=matching_field, reduction=2)

    # Create representations per sample: cluster % of total sample.
    data, data_df, features = prepare_data_classes(dataframes, matching_field, meta_field, groupby, leiden_clusters, 'clr', 100, use_conn=False, use_ratio=False, top_variance_feat=0)

    # Insert institutions.
    data_dfs = list()
    for dataframe in data_df:
        sample_slide = dataframe['slides'].values[0]
        if dataframe is not None and 'TCGA' in str(sample_slide):
            dataframe.insert(0, 'TSS Code', dataframe['slides'].apply(lambda x: x.split('-')[1]))
            dataframe = pd.merge(dataframe, inst_frame, on='TSS Code', how='left')
        data_dfs.append(dataframe)

    # Include features that are not the regular leiden clusters.
    frame_clusters = include_features_frame_clusters(frame_clusters, leiden_clusters, features, groupby)

    # Store representations.
    data_res_folds[resolution][i] = {'data':data, 'data_df':data_dfs, 'complete_df':complete_df, 'features':features, 'frame_clusters':frame_clusters, 'leiden_clusters':leiden_clusters}

    # Information.
    print('\t\tFold', i, 'Features:', len(features), 'Clusters:', len(leiden_clusters))


In [None]:
folds_roc = dict()
folds_roc['test'] = dict()
if h5_additional_path is not None:
    folds_roc['additional'] = dict()
for i, fold in enumerate(folds):
    # Load data for classification.
    data            = data_res_folds[resolution][i]['data']
    data_df         = data_res_folds[resolution][i]['data_df']
    features        = data_res_folds[resolution][i]['features']
    frame_clusters  = data_res_folds[resolution][i]['frame_clusters']
    leiden_clusters = data_res_folds[resolution][i]['leiden_clusters']

    train,    valid,    test,    additional    = data
    train_df, valid_df, test_df, additional_df = data_df
    train_data, train_labels = train

    # One-vs-rest for Logistic Regression.
    model = sm.Logit(endog=train_labels[:,label], exog=train_data).fit_regularized(method='l1', alpha=alpha, disp=0)

    train, valid, test, additional = data
    train_data, train_labels = train
    valid_data, valid_labels = valid
    test_data,  test_labels  = test
    if additional is not None:
        additional_data, additional_labels = additional

    # Predictions.
    train_pred = model.predict(exog=train_data)
    valid_pred = model.predict(exog=valid_data)
    test_pred  = model.predict(exog=test_data)
    train_df['predictions'] = train_pred
    valid_df['predictions'] = valid_pred
    test_df['predictions']  = test_pred

    if additional is not None:
        additional_pred = model.predict(exog=additional_data)
        additional_df['predictions']  = additional_pred
    data_res_folds[resolution][i]['data_df'] = [train_df, valid_df, test_df, additional_df]

    folds_roc['test'][i] = dict()
    fpr, tpr, thresholds = roc_curve(list(test_labels[:,label]), list(test_pred))
    f1_score_            = f1_score(list(test_labels[:,label]), list(test_pred>0.5), average='weighted')
    folds_roc['test'][i]['fpr'] = fpr
    folds_roc['test'][i]['tpr'] = tpr
    folds_roc['test'][i]['f1_score'] = f1_score_

    if additional is not None:
        folds_roc['additional'][i] = dict()
        fpr, tpr, thresholds = roc_curve(list(additional_labels[:,label]), list(additional_pred))
        f1_score_            = f1_score(list(additional_labels[:,label]), list(additional_pred>0.5), average='weighted')
        folds_roc['additional'][i]['fpr'] = fpr
        folds_roc['additional'][i]['tpr'] = tpr
        folds_roc['additional'][i]['f1_score'] = f1_score_

In [None]:
if 'TSS Code' not in complete_df.columns:
    complete_df.insert(0, 'TSS Code', complete_df['slides'].apply(lambda x: x.split('-')[1]))
    complete_df = pd.merge(complete_df, inst_frame, on='TSS Code', how='left')
complete_df['Source Site'] = complete_df['Source Site'].replace({'Mary Bird Perkins Cancer Center - Our Lady of the Lake':'Mary Bird Perkins Cancer Center', 'Thoraxklinik at University Hospital Heidelberg':'University Hospital Heidelberg'})
a, frame_samples = cluster_diversity(complete_df, frame_clusters, groupby, diversity_key='Source Site')
frame_samples    = frame_samples[[groupby, 'Source Site', 'Purity (%)', 'Counts']]
frame_samples['Purity (%)'] = frame_samples['Purity (%)']/100
frame_samples

## Paper Figure - ROCAUC/F1-Score Folds

In [None]:
def plot_auc(ax, fold, title, lw, fontsize_labels, fontsize_legend, l_box_w, f1_score=False):
    aucs = list()
    for i in range(5):
        if f1_score:
            metric = fold[i]['f1_score']
            label = " Fold %s F1-Score = %0.3f" % (i, metric)
            fig_title = '%s\nF1-Score ' % title
        else:
            metric = auc(fold[i]['fpr'], fold[i]['tpr'])
            label = " Fold %s AUC = %0.3f" % (i, metric)
            fig_title = '%s\nAUC ' % title
        aucs.append(metric)
        ax.plot(fold[i]['fpr'], fold[i]['tpr'], lw=lw, label=label)

    mean, minus, plus = mean_confidence_interval(aucs, confidence=0.95)
    legend = ax.legend(loc='lower right', title='Mean (CI): %s (%s-%s)' % ( np.round(mean, 3), np.round(minus, 3), np.round(plus, 3)), prop={'size': fontsize_legend-4, 'weight':'bold'})
    legend.get_title().set_fontsize(fontsize_legend)
    legend.get_title().set_fontweight('bold')
    # set the linewidth of each legend object
    for line in legend.get_lines():
        line.set_linewidth(lw)
    legend.get_frame().set_linewidth(l_box_w)

    for tick in ax.xaxis.get_major_ticks():
        tick.label1.set_fontsize(fontsize_labels)
        tick.label1.set_fontweight('bold')
    for tick in ax.yaxis.get_major_ticks():
        tick.label1.set_fontsize(fontsize_labels)
        tick.label1.set_fontweight('bold')

    ax.set_title(fig_title,  fontsize=fontsize_labels*1.2, fontweight='bold')
    ax.set_ylabel('True Positive Rate',  fontsize=fontsize_labels, fontweight='bold')
    ax.set_xlabel('False Positive Rate', fontsize=fontsize_labels, fontweight='bold')
    for axis in ['top','bottom','left','right']:
        ax.spines[axis].set_linewidth(4)


figsize    = (20,20)
fontsize_labels = 60
fontsize_legend = 60
l_box_w         = 3
lw              = 5

for flag in [True, False]:
    fig   = plt.figure(figsize=figsize)
    ax    = fig.add_subplot(1, 1, 1)
    title = 'TCGA Cohort'
    plot_auc(ax, folds_roc['test'], title, lw, fontsize_labels, fontsize_legend, l_box_w, f1_score=flag)
    plt.show()

    fig   = plt.figure(figsize=figsize)
    ax    = fig.add_subplot(1, 1, 1)
    title = 'NYU Cohort'
    plot_auc(ax, folds_roc['additional'], title, lw, fontsize_labels, fontsize_legend, l_box_w, f1_score=flag)
    plt.tight_layout()
    plt.show()


In [None]:
naming_replacements = {'Mary Bird Perkins Cancer Center - Our Lady of the Lake':'Mary Bird Perkins Cancer Center', 'Thoraxklinik at University Hospital Heidelberg':'University Hospital Heidelberg'}

data = list()
institutions_roc = dict()
for i, fold in enumerate(folds):
    train_df, valid_df, test_df, additional_df = data_res_folds[resolution][i]['data_df']
    test_df['Source Site'] = test_df['Source Site'].replace(naming_replacements)
    for institution in np.unique(test_df['Source Site']):
        test_inst_df = test_df[test_df['Source Site']==institution].copy(deep=True)

        test_labels = test_inst_df[meta_field].values.tolist()
        test_pred   = test_inst_df['predictions'].values.tolist()

        samples = len(test_labels)

        fpr = None
        tpr = None
        thresholds = None
        roc_auc = None

        if len(np.unique(test_labels)) != 1:
            fpr, tpr, thresholds = metrics.roc_curve(test_labels, test_pred)
            roc_auc = auc(fpr, tpr)

        institutions_roc[institution] = [fpr, tpr, thresholds, roc_auc, samples]
        data.append((institution, roc_auc, samples, i))

data = pd.DataFrame(data, columns=['Institution', 'AUC', 'Sample Size', 'Fold'])
data = data.sort_values(by='Sample Size', ascending=False)

## Paper Figure - Insitutions per cluster

In [None]:
work_df = complete_df.copy(deep=True)
work_df['Weight'] = 1
site_distribution           = work_df[['Weight', 'Source Site']].groupby('Source Site').count()
site_distribution['Weight'] = site_distribution['Weight']/site_distribution['Weight'].sum()
site_distribution['Group'] = 'All Institutions'
site_distribution = site_distribution.reset_index()
site_distribution = site_distribution.sort_values(by='Source Site', ascending=False)


figsize = (50, 20)
fontsize_labels = 38
fontsize_ticks  = 33
rotation        = 45

subsampled = sorted(np.random.choice(leiden_clusters, size=12, replace=False))

plotted = 0
while plotted < len(subsampled):
    f, (ax1, ax2, ax3, ax4, ax5) = plt.subplots(1, 5, sharey=True, sharex=True, figsize=figsize)

    frame_samples = frame_samples.sort_values(by='Source Site')
    for i, ax in  enumerate((ax2, ax3, ax4, ax5)):
        if plotted >= len(leiden_clusters.tolist()):
            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)
            ax.spines['bottom'].set_visible(False)
            ax.spines['left'].set_visible(False)
            ax.yaxis.label.set_visible(False)
            ax.xaxis.label.set_visible(False)
            ax.get_xaxis().set_ticks([])
        else:
            ratio = complete_df[complete_df[groupby]==subsampled[plotted]].shape[0]/complete_df.shape[0]*100
            ax.set_title('HPC %s\n%s%s of entire\npopulation' % (subsampled[plotted],np.round(ratio,1), '%'),  fontsize=fontsize_labels*1.2, fontweight='bold')
            work_samples_df = frame_samples[frame_samples[groupby]==subsampled[plotted]]
            sns.barplot(y='Source Site', x='Purity (%)', data=work_samples_df, ax=ax, palette='tab20')
            ax.yaxis.label.set_visible(False)
            plotted += 1

    for ax in (ax1, ax2, ax3, ax4, ax5):
        ax.set_xlim([0.0, frame_samples['Purity (%)'].max()+0.05])
        for tick in ax.xaxis.get_major_ticks():
            tick.label1.set_fontsize(fontsize_ticks)
            tick.label1.set_fontweight('bold')
            tick.label1.set_rotation(rotation)
        for tick in ax1.yaxis.get_major_ticks():
            tick.label1.set_fontsize(fontsize_ticks)
            tick.label1.set_fontweight('bold')

        ax.set_ylabel('Institution',            fontsize=fontsize_labels*1.1, fontweight='bold')
        ax.set_xlabel('Institution percentage', fontsize=fontsize_labels*1.1, fontweight='bold')
        for axis in ['top','bottom','left','right']:
            ax.spines[axis].set_linewidth(4)

    sns.barplot(y='Source Site', x='Weight', data=site_distribution, ax=ax1, palette='tab20')
    ax1.set_title('Entire\npopulation',      fontsize=fontsize_labels*1.2, fontweight='bold')
    ax1.set_ylabel('Institution',            fontsize=fontsize_labels*1.1, fontweight='bold')
    ax1.set_xlabel('Institution percentage', fontsize=fontsize_labels*1.1, fontweight='bold')

    plt.tight_layout()
    plt.show()


## Paper Figure - ROC Insitutions

In [None]:
naming_replacements = {'Mary Bird Perkins Cancer Center - Our Lady of the Lake':'Mary Bird Perkins Cancer Center', 'Thoraxklinik at University Hospital Heidelberg':'University Hospital Heidelberg'}

data = list()
data_f = list()
institutions_roc = dict()
for i, fold in enumerate(folds):
    train_df, valid_df, test_df, additional_df = data_res_folds[resolution][i]['data_df']
    test_df['Source Site'] = test_df['Source Site'].replace(naming_replacements)
    for institution in np.unique(test_df['Source Site']):
        test_inst_df = test_df[test_df['Source Site']==institution].copy(deep=True)

        test_labels = test_inst_df[meta_field].values.tolist()
        test_pred   = test_inst_df['predictions'].values.tolist()

        samples = len(test_labels)
        values, counts = np.unique(test_inst_df[meta_field], return_counts=True)
        luad_samples = 0
        lusc_samples = 0
        for j, value in enumerate(values):
            if value == 0:
                lusc_samples = counts[j]
            else:
                luad_samples = counts[j]
        
        fpr = None
        tpr = None
        thresholds = None
        roc_auc = None

        if len(np.unique(test_labels)) != 1:
            fpr, tpr, thresholds = roc_curve(test_labels, test_pred)
            roc_auc = auc(fpr, tpr)

        institutions_roc[institution] = [fpr, tpr, thresholds, roc_auc, samples]
        data.append((institution, roc_auc, samples, lusc_samples, luad_samples, i))
        data_f.append((institution, roc_auc, lusc_samples, 'LUSC', i))
        data_f.append((institution, roc_auc, luad_samples, 'LUAD', i))
        
data = pd.DataFrame(data, columns=['Institution', 'AUC', 'Sample Size', 'Sample Size LUSC', 'Sample Size LUAD', 'Fold'])
data = data.sort_values(by='Sample Size', ascending=False)

data_f = pd.DataFrame(data_f, columns=['Institution', 'AUC', 'Sample Size', 'Subtype', 'Fold'])
data_f = data_f.sort_values(by='AUC', ascending=False)

In [None]:
data['Institution'] = data['Institution'].replace(naming_replacements)
data = data.sort_values(by='AUC', ascending=False)
# data = data[~data['AUC'].isna()]
figsize = (70, 27)
fontsize_labels = 45

f, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, sharey=True, figsize=figsize)
sns.barplot(y='Institution', x='AUC', data=data, ax=ax1, palette='tab20')
ax1.xaxis.set_ticks(np.arange(0, 1.05, 0.05))
ax1.set_xlim([0.0, 1.005])
ax1.axvline(0.93, linestyle='--', color='black')
for tick in ax1.xaxis.get_major_ticks():
    tick.label1.set_fontsize(fontsize_labels)
    tick.label1.set_fontweight('bold')
    tick.label1.set_rotation(90)
for tick in ax1.yaxis.get_major_ticks():
    tick.label1.set_fontsize(fontsize_labels)
    tick.label1.set_fontweight('bold')

ax1.set_ylabel('Institution',  fontsize=fontsize_labels*1.1, fontweight='bold')
ax1.set_xlabel('AUC', fontsize=fontsize_labels*1.1, fontweight='bold')
for axis in ['top','bottom','left','right']:
    ax1.spines[axis].set_linewidth(4)


sns.barplot(y='Institution', x='Sample Size', data=data, ax=ax2, palette='tab20')
ax2.set_xlabel('Sample Size', fontsize=fontsize_labels*1.1, fontweight='bold')
xlim = ax2.get_xlim()

sns.barplot(y='Institution', x='Sample Size LUAD', data=data, ax=ax3, palette='tab20')
ax3.set_xlabel('Sample Size LUAD', fontsize=fontsize_labels*1.1, fontweight='bold')
ax3.set_xlim(xlim)

sns.barplot(y='Institution', x='Sample Size LUSC', data=data, ax=ax4, palette='tab20')
ax4.set_xlabel('Sample Size LUSC', fontsize=fontsize_labels*1.1, fontweight='bold')
ax4.set_xlim(xlim)

for ax_rem in [ax2, ax3, ax4]:
    for tick in ax_rem.xaxis.get_major_ticks():
        tick.label1.set_fontsize(fontsize_labels)
        tick.label1.set_fontweight('bold')
    for tick in ax_rem.yaxis.get_major_ticks():
        tick.label1.set_fontsize(fontsize_labels)
        tick.label1.set_fontweight('bold')

    ax_rem.yaxis.label.set_visible(False)
    for axis in ['top','bottom','left','right']:
        ax_rem.spines[axis].set_linewidth(4)

plt.tight_layout()
plt.show()

In [None]:
fontsize_labels = 45
f, ax = plt.subplots(1, 1, sharey=True, figsize=(70, 27))
sns.barplot(x='Institution', y='Sample Size', hue='Subtype', ax=ax, data=data_f, palette=['royalblue', 'darkorange'])

for tick in ax1.xaxis.get_major_ticks():
    tick.label1.set_fontsize(fontsize_labels)
    tick.label1.set_fontweight('bold')
    tick.label1.set_rotation(90)
for tick in ax1.yaxis.get_major_ticks():
    tick.label1.set_fontsize(fontsize_labels)
    tick.label1.set_fontweight('bold')

ax.set_xlabel('Institution',  fontsize=fontsize_labels*1.1, fontweight='bold')

for axis in ['top','bottom','left','right']:
    ax.spines[axis].set_linewidth(4)

for tick in ax.xaxis.get_major_ticks():
    tick.label1.set_fontsize(fontsize_labels)
    tick.label1.set_fontweight('bold')
    tick.label1.set_rotation(90)
for tick in ax.yaxis.get_major_ticks():
    tick.label1.set_fontsize(fontsize_labels)
    tick.label1.set_fontweight('bold')

ax.yaxis.label.set_visible(False)
for axis in ['top','bottom','left','right']:
    ax.spines[axis].set_linewidth(4)

ax.legend(prop={'size': fontsize_labels-2, 'weight':'bold'})

ax.set_ylabel('Sample Size',  fontsize=fontsize_labels*1.1, fontweight='bold')
plt.tight_layout()
plt.show()

In [None]:
figsize    = (40,20)
fontsize_labels = 40
fontsize_legend = 32
l_box_w         = 3
lw              = 5

fig   = plt.figure(figsize=figsize)
ax    = fig.add_subplot(1, 1, 1)

data = data.sort_values(by='Sample Size', ascending=False)
for institution in data['Institution'].values:

    fpr, tpr, thresholds, roc_auc, samples = institutions_roc[institution]
    if roc_auc is None: continue
    ax.plot(fpr, tpr, lw=lw, label="%s (%s) = %0.2f" % (institution, samples, roc_auc))


legend = ax.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., title=r'$\bf{Per\ Institution\ AUC\ and\ Sample\ Size}$', prop={'size': fontsize_legend-4, 'weight':'bold'})
legend.get_title().set_fontsize(fontsize_legend)
# set the linewidth of each legend object
for line in legend.get_lines():
    line.set_linewidth(lw)
legend.get_frame().set_linewidth(l_box_w)

for tick in ax.xaxis.get_major_ticks():
    tick.label1.set_fontsize(fontsize_labels)
    tick.label1.set_fontweight('bold')
for tick in ax.yaxis.get_major_ticks():
    tick.label1.set_fontsize(fontsize_labels)
    tick.label1.set_fontweight('bold')

ax.set_title('TCGA Cohort',  fontsize=fontsize_labels*1.2, fontweight='bold')
ax.set_ylabel('True Positive Rate',  fontsize=fontsize_labels, fontweight='bold')
ax.set_xlabel('False Positive Rate', fontsize=fontsize_labels, fontweight='bold')
for axis in ['top','bottom','left','right']:
    ax.spines[axis].set_linewidth(4)

plt.tight_layout()
plt.show()


# Paper Figure - AUC across resolutions

In [None]:
# Path to LR results CSV.
lr_dir   = '%s/results/BarlowTwins_3/TCGAFFPE_LUADLUSC_5x_60pc_250K/h224_w224_n3_zdim128_filtered/lungsubtype_nn250_fold4/alpha_5p0_mintiles_100' % main_path
csv_path = '%s/luad_auc_results_mintiles_100.csv' % lr_dir
results_df = pd.read_csv(csv_path)
resolutions_u = results_df['Leiden Resolution'].unique()
res_hpc = dict()
for res in resolutions_u:
    res_ = res.split('_')[1]
    folds = list()
    for i in range(5):
        # Read ROCAUC for each resolution.
        csv_path = '%s/%s_fold%s_clusters.csv' % (lr_dir, res.replace('.','p'), i)
        fold_df = pd.read_csv(csv_path)
        folds.append(len(fold_df[res].unique()))
    res_hpc[res] = int(np.mean(folds))
    
all_data = list()
for row in results_df.iterrows():
    resolution, fold, tcga_train, tcga_valid, tcga_test, nyu_additional = row[1].values
    n_hpc = res_hpc[str(resolution)]
    resolution = resolution.split('leiden_')[1]
    all_data.append((resolution, n_hpc, fold, 'TCGA Train', tcga_train))
    all_data.append((resolution, n_hpc, fold, 'TCGA Validation', tcga_valid))
    all_data.append((resolution, n_hpc, fold, 'TCGA Test', tcga_test))
    all_data.append((resolution, n_hpc, fold, 'NYU Cohort', nyu_additional))
    
all_data = pd.DataFrame(all_data, columns=['Resolution', 'Number HPCs', 'Fold', 'Set', 'ROC AUC'])
all_data

In [None]:
fontsize_labels = 14
lw = 3
l_box_w = 3

for x_label in [ 'Number HPCs', 'Resolution']:
    sns.set_theme(style='darkgrid')
    fig, ax = plt.subplots(figsize=(20, 7), nrows=1, ncols=1)
    sns.pointplot(x=x_label, hue='Set', y='ROC AUC', data=all_data, ax=ax, dodge=.3, join=False, capsize=.00, markers='o', errorbar=('ci', 95))
    ax.set_ylim([0.85, 1.05])
    ax.set_title('LUAD vs LUSC\nROC AUC', fontweight='bold', fontsize=18)
    ax.legend(loc='upper left')
    start, end = ax.get_ylim()
    ax.yaxis.set_ticks(np.arange(start, end, 0.05))
    for tick in ax.xaxis.get_major_ticks():
        tick.label1.set_fontsize(fontsize_labels)
        tick.label1.set_fontweight('bold')
    for tick in ax.yaxis.get_major_ticks():
        tick.label1.set_fontsize(fontsize_labels)
        tick.label1.set_fontweight('bold')

    for axis in ['top','bottom','left','right']:
        ax.spines[axis].set_linewidth(4)

    ax.set_ylabel('ROC AUC', fontweight='bold', size=fontsize_labels+2)
    if x_label == 'Number HPCs':
        ax.set_xlabel('Number of HPCs', fontweight='bold', size=fontsize_labels+2)
    else:
        ax.set_xlabel('Leiden Resolution Parameter', fontweight='bold', size=fontsize_labels+2)

    legend = ax.legend_
    for line in legend.get_lines():
        line.set_linewidth(lw)
    legend.get_frame().set_linewidth(l_box_w)
    for i in range(len(legend.get_texts())):
        legend.get_texts()[i].set_fontweight('bold')
        legend.get_texts()[i].set_fontsize(fontsize_labels)

plt.show()

# Paper Figure - Forest plot

In [None]:
# Path to CSV where we have the LR coef. 
coeff_csv = '%s/results/BarlowTwins_3/TCGAFFPE_LUADLUSC_5x_60pc_250K/h224_w224_n3_zdim128_filtered/lungsubtype_nn250_fold4/alpha_10p0_mintiles_100/leiden_2p0_fold4_clusters.csv' % main_path
coeff_frame = pd.read_csv(coeff_csv)

# Use dominant subtype instead of purity. The CSV contains the HPC assignations on the train set.
train_csv = '%s/results/BarlowTwins_3/TCGAFFPE_LUADLUSC_5x_60pc_250K/h224_w224_n3_zdim128_filtered/lungsubtype_nn250_fold4/adatas/TCGAFFPE_LUADLUSC_5x_60pc_he_complete_lungsubtype_survival_filtered_leiden_2p0__fold4.csv' % main_path
train_frame = pd.read_csv(train_csv)
x,y = 'leiden_2.0', 'luad'
frame_clusters = train_frame.groupby(x)[y].value_counts(normalize=True).mul(100).rename('Subtype Purity(%)').reset_index()
frame_clusters = frame_clusters[frame_clusters['Subtype Purity(%)']>50.0]
frame_clusters = frame_clusters.replace({'luad':{0:'LUSC', 1:'LUAD'}})
coeff_frame['Subtype'] = frame_clusters[y].values
coeff_frame

In [None]:
from matplotlib.font_manager import FontProperties
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.ticker as mticker

# Forest Plot for Logistic Regression coefficients.
class EffectMeasurePlot_LR:
    """Used to generate effect measure plots. effectmeasure plot accepts four list type objects.
    effectmeasure_plot is initialized with the associated names for each line, the point estimate,
    the lower confidence limit, and the upper confidence limit.
    Plots will resemble the following form:
        _____________________________________________      Measure     % CI
        |                                           |
    1   |        --------o-------                   |       x        n, 2n
        |                                           |
    2   |                   ----o----               |       w        m, 2m
        |                                           |
        |___________________________________________|
        #           #           #           #
    The following functions (and their purposes) live within effectmeasure_plot
    labels(**kwargs)
        Used to change the labels in the plot, as well as the center and scale. Inputs are
        keyword arguments
        KEYWORDS:
            -effectmeasure  + changes the effect measure label
            -conf_int       + changes the confidence interval label
            -scale          + changes the scale to either log or linear
            -center         + changes the reference line for the center
    colors(**kwargs)
        Used to change the color of points and lines. Also can change the shape of points.
        Valid colors and shapes for matplotlib are required. Inputs are keyword arguments
        KEYWORDS:
            -errorbarcolor  + changes the error bar colors
            -linecolor      + changes the color of the reference line
            -pointcolor     + changes the color of the points
            -pointshape     + changes the shape of points
    plot(t_adjuster=0.01,decimal=3,size=3)
        Generates the effect measure plot of the input lists according to the pre-specified
        colors, shapes, and labels of the class object
        Arguments:
            -t_adjuster     + used to refine alignment of the table with the line graphs.
                              When generate plots, trial and error for this value are usually
                              necessary
            -decimal        + number of decimal places to display in the table
            -size           + size of the plot to generate
    Example)
    >>>lab = ['One','Two'] #generating lists of data to plot
    >>>emm = [1.01,1.31]
    >>>lcl = ['0.90',1.01]
    >>>ucl = [1.11,1.53]
    >>>
    >>>x = zepid.graphics.effectmeasure_plot(lab,emm,lcl,ucl) #initializing effectmeasure_plot with the above lists
    >>>x.labels(effectmeasure='RR') #changing the table label to 'RR'
    >>>x.colors(pointcolor='r') #changing the point colors to red
    >>>x.plot(t_adjuster=0.13) #generating the effect measure plot
    """

    def __init__(self, label, effect_measure, lcl, ucl, pvalues, subtypes, purities, counts, mean_tp, max_tp, perc_pat, center=0):
        """Initializes effectmeasure_plot with desired data to plot. All lists should be the same
        length. If a blank space is desired in the plot, add an empty character object (' ') to
        each list at the desired point.
        Inputs:
        label
            -list of labels to use for y-axis
        effect_measure
            -list of numbers for point estimates to plot. If point estimate has trailing zeroes,
             input as a character object rather than a float
        lcl
            -list of numbers for upper confidence limits to plot. If point estimate has trailing
             zeroes, input as a character object rather than a float
        ucl
            -list of numbers for upper confidence limits to plot. If point estimate has
             trailing zeroes, input as a character object rather than a float
        """
        self.df = pd.DataFrame()
        self.df['study'] = label
        self.df['OR']    = effect_measure
        self.df['LCL']   = lcl
        self.df['UCL']   = ucl
        self.df['P']     = pvalues
        self.df['S']     = subtypes
        self.df['Pu']    = purities
        self.df['C']     = counts
        self.df['M']     = mean_tp
        self.df['Ma']    = max_tp
        self.df['Pp']    = perc_pat
        self.df['OR2']   = self.df['OR'].astype(str).astype(float)
        if (all(isinstance(item, float) for item in lcl)) & (all(isinstance(item, float) for item in effect_measure)):
            self.df['LCL_dif'] = self.df['OR'] - self.df['LCL']
        else:
            self.df['LCL_dif'] = (pd.to_numeric(self.df['OR'])) - (pd.to_numeric(self.df['LCL']))
        if (all(isinstance(item, float) for item in ucl)) & (all(isinstance(item, float) for item in effect_measure)):
            self.df['UCL_dif'] = self.df['UCL'] - self.df['OR']
        else:
            self.df['UCL_dif'] = (pd.to_numeric(self.df['UCL'])) - (pd.to_numeric(self.df['OR']))
        self.em       = 'OR'
        self.ci       = '95% CI'
        self.p        = 'P-Value'
        self.subtype  = 'Predominant\nLung Type'
        self.purity   = 'Purity\n%'
        self.counts   = 'Tile Counts'
        self.mean_tp  = 'Mean Tiles\nPer Pat.'
        self.max_tp   = 'Max Tiles\nPer Pat.'
        self.perc_pat = 'Patients\n%'
        self.scale    = 'linear'
        self.center   = center
        self.errc     = 'dimgrey'
        self.shape    = 'o'
        self.pc       = 'k'
        self.linec    = 'gray'

    def labels(self, **kwargs):
        """Function to change the labels of the outputted table. Additionally, the scale and reference
        value can be changed.
        Accepts the following keyword arguments:
        effectmeasure
            -changes the effect measure label
        conf_int
            -changes the confidence interval label
        scale
            -changes the scale to either log or linear
        center
            -changes the reference line for the center
        """
        if 'effectmeasure' in kwargs:
            self.em = kwargs['effectmeasure']
        if 'ci' in kwargs:
            self.ci = kwargs['conf_int']
        if 'scale' in kwargs:
            self.scale = kwargs['scale']
        if 'center' in kwargs:
            self.center = kwargs['center']

    def colors(self, **kwargs):
        """Function to change colors and shapes.
        Accepts the following keyword arguments:
        errorbarcolor
            -changes the error bar colors
        linecolor
            -changes the color of the reference line
        pointcolor
            -changes the color of the points
        pointshape
            -changes the shape of points
        """
        if 'errorbarcolor' in kwargs:
            self.errc = kwargs['errorbarcolor']
        if 'pointshape' in kwargs:
            self.shape = kwargs['pointshape']
        if 'linecolor' in kwargs:
            self.linec = kwargs['linecolor']
        if 'pointcolor' in kwargs:
            self.pc = kwargs['pointcolor']

    def plot(self, bbox, figsize=(3, 3), t_adjuster=0.01, decimal=3, size=3, max_value=None, min_value=None, fontsize=12, p_th=0.05):
        """Generates the matplotlib effect measure plot with the default or specified attributes.
        The following variables can be used to further fine-tune the effect measure plot
        t_adjuster
            -used to refine alignment of the table with the line graphs. When generate plots, trial
             and error for this value are usually necessary. I haven't come up with an algorithm to
             determine this yet...
        decimal
            -number of decimal places to display in the table
        size
            -size of the plot to generate
        max_value
            -maximum value of x-axis scale. Default is None, which automatically determines max value
        min_value
            -minimum value of x-axis scale. Default is None, which automatically determines min value
        """
        tval = []
        ytick = []
        for i in range(len(self.df)):
            if (np.isnan(self.df['OR2'][i]) == False):
                if ((isinstance(self.df['OR'][i], float)) & (isinstance(self.df['LCL'][i], float)) & (isinstance(self.df['UCL'][i], float))):
                    list_val = [round(self.df['OR2'][i], decimal), ('(' + str(round(self.df['LCL'][i], decimal)) + ', ' + str(round(self.df['UCL'][i], decimal)) + ')'), str(self.df['P'][i]),
                                self.df['S'][i], self.df['Pu'][i], self.df['C'][i], self.df['M'][i], self.df['Ma'][i], self.df['Pp'][i]]
                    tval.append(list_val)
                else:
                    list_val = [self.df['OR'][i], ('(' + str(self.df['LCL'][i]) + ', ' + str(self.df['UCL'][i]) + ')'), self.df['P'][i], self.df['S'][i], self.df['Pu'][i], self.df['C'][i],
                                self.df['M'][i], self.df['Ma'][i], self.df['Pp'][i]]
                    tval.append()
                ytick.append(i)
            else:
                tval.append([' ', ' ', ' ', ' '])
                ytick.append(i)
        if max_value is None:
            if pd.to_numeric(self.df['UCL']).max() < 1:
                maxi = round(((pd.to_numeric(self.df['UCL'])).max() + 0.05),
                             2)  # setting x-axis maximum for UCL less than 1
            if (pd.to_numeric(self.df['UCL']).max() < 9) and (pd.to_numeric(self.df['UCL']).max() >= 1):
                maxi = round(((pd.to_numeric(self.df['UCL'])).max() + 1),
                             0)  # setting x-axis maximum for UCL less than 10
            if pd.to_numeric(self.df['UCL']).max() > 9:
                maxi = round(((pd.to_numeric(self.df['UCL'])).max() + 10),
                             0)  # setting x-axis maximum for UCL less than 100
        else:
            maxi = max_value
        if min_value is None:
            if pd.to_numeric(self.df['LCL']).min() > 0:
                mini = round(((pd.to_numeric(self.df['LCL'])).min() - 0.1), 1)  # setting x-axis minimum
            if pd.to_numeric(self.df['LCL']).min() < 0:
                mini = round(((pd.to_numeric(self.df['LCL'])).min() - 0.05), 2)  # setting x-axis minimum
        else:
            mini = min_value
        plt.figure(figsize=figsize)  # blank figure
        gspec = gridspec.GridSpec(1, 6)  # sets up grid
        plot = plt.subplot(gspec[0, 0:4])  # plot of data
        tabl = plt.subplot(gspec[0, 4:])  # table of OR & CI
        plot.set_ylim(-1, (len(self.df)))  # spacing out y-axis properly
        if self.scale == 'log':
            try:
                plot.set_xscale('log')
            except:
                raise ValueError('For the log scale, all values must be positive')
        plot.axvline(self.center, color=self.linec, zorder=1)
        plot.errorbar(self.df.OR2, self.df.index, xerr=[self.df.LCL_dif, self.df.UCL_dif], marker='None', zorder=2, ecolor=self.errc, elinewidth=size*0.3, linewidth=0)
        plot.scatter(self.df.OR2, self.df.index, c=self.pc, s=(size * 25), marker=self.shape, zorder=3, edgecolors='None')
        plot.xaxis.set_ticks_position('bottom')
        plot.yaxis.set_ticks_position('left')
        plot.get_xaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())
        plot.get_xaxis().set_minor_formatter(matplotlib.ticker.NullFormatter())
        plot.set_yticks(ytick, fontsize=fontsize)
        plot.set_xlim([mini, maxi])
        plot.set_xticks([mini, self.center, maxi], fontsize=fontsize)
        plot.set_xticklabels([mini, self.center, maxi], fontsize=fontsize, fontweight='bold')
        plot.set_yticklabels(self.df.study, fontsize=fontsize, fontweight='bold')
        plot.yaxis.set_ticks_position('none')
        plot.invert_yaxis()  # invert y-axis to align values properly with table
        # tb = tabl.table(cellText=tval, cellLoc='center', loc='right', colLabels=[self.em, self.ci, self.p, self.subtype, self.purity, self.counts, self.mean_tp, self.max_tp, self.perc_pat], bbox=[0, t_adjuster, 4.5, 1])
        tb = tabl.table(cellText=tval, cellLoc='center', loc='right', colLabels=[self.em, self.ci, self.p, self.subtype, self.purity, self.counts, self.mean_tp, self.max_tp, self.perc_pat], bbox=bbox)
        tabl.axis('off')
        tb.auto_set_font_size(False)
        tb.set_fontsize(fontsize)
        for (row, col), cell in tb.get_celld().items():
            c_pvalue = self.df['P'].values[row-1]
            if c_pvalue < p_th and row !=0:
                cell.set_text_props(fontproperties=FontProperties(size=fontsize))
            else:
                cell.set_text_props(fontproperties=FontProperties(weight='light', size=fontsize))
            if (row == 0):
                cell.set_text_props(fontproperties=FontProperties(weight='bold', size=fontsize))
                cell.set_height(.015)
            cell.set_linewidth(0)
        tb.auto_set_column_width(col=list(range(len([self.em, self.ci, self.p, self.subtype, self.purity, self.counts, self.mean_tp, self.max_tp, self.perc_pat]))))
        return plot


In [None]:
sns.set_theme(style='white')

p_th        = 0.05
label       = 1

frame_label = coeff_frame.sort_values(by='coef_%s'%label)
frame_label = frame_label[frame_label['coef_%s'%label]!=0]

groupby   = [value for value in frame_label.columns if 'leiden' in value][0]
labs      = frame_label[groupby].values.tolist()
measure   = np.round(frame_label['coef_%s'%label],2).values.tolist()
lower     = np.round(frame_label['[0.025_%s'%label],2).values.tolist()
upper     = np.round(frame_label['0.975]_%s'%label],2).values.tolist()
pvalues   = np.round(frame_label['P>|z|_%s'%label],3).values.tolist()
subtype   = frame_label['Subtype'].values.tolist()
purity    = frame_label['Subtype Purity(%)'].values.astype(int).tolist()
counts    = frame_label['Subtype Counts'].values.tolist()
mean_tp   = frame_label['mean_tile_sample'].values.astype(int).tolist()
max_tp    = np.round(frame_label['max_tile_sample'].values*100,1).tolist()
perc_pat  = np.round(frame_label['percent_sample'].values*100,1).tolist()
max_value = max(abs(max(upper)), abs(min(lower)))


figsize    = (25,27)
t_adjuster = 0.015
t_adjuster = 0.016
decimal    = 3
size       = 10
fontsize   = 35

bbox = [0, t_adjuster, 4.5, 1.03]

p = EffectMeasurePlot_LR(label=labs, effect_measure=measure, lcl=lower, ucl=upper, pvalues=pvalues, subtypes=subtype, purities=purity, counts=counts, mean_tp=mean_tp, max_tp=max_tp, perc_pat=perc_pat)
p.labels(effectmeasure='Log Odds\nRatio')
p.colors(pointshape="o")
ax=p.plot(figsize=figsize, bbox=bbox, t_adjuster=t_adjuster, max_value=max_value, min_value=-max_value, fontsize=fontsize, p_th=p_th, size=size)
plt.suptitle("HPC\n \n ",x=0.1,y=0.89, fontsize=fontsize, fontweight='bold')
ax.set_xlabel("Favors LUSC               Favors LUAD", fontsize=fontsize, x=0.5, fontweight='bold')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(True)
ax.spines['left'].set_visible(False)
plt.show()


# Paper Figure - HPC samples

In [None]:
# Workspace path.
main_path = '/media/adalberto/Disk2/PhD_Workspace'

dataset            = 'TCGAFFPE_LUADLUSC_5x_60pc'
additional_dataset = 'NYUFFPE_survival_5x_60pc'

data = Data(dataset=dataset, marker='he', patch_h=224, patch_w=224, n_channels=3, batch_size=64, project_path=main_path, load=True)
img_dicts = dict()
img_dicts['train'] = data.training.images
img_dicts['valid'] = data.validation.images
img_dicts['test'] = data.test.images

additional_data = Data(dataset=additional_dataset, marker='he', patch_h=224, patch_w=224, n_channels=3, batch_size=64, project_path=main_path, load=True)
additional_img_dicts = dict()
additional_img_dicts['train'] = additional_data.training.images

In [None]:
folds = load_existing_split(folds_pickle)
dataframes, complete_df, leiden_clusters = read_csvs(adatas_path, matching_field, groupby, fold_number, folds[fold_number], h5_complete_path, h5_additional_path)
train_df, valid_df, test_df, additional_df = dataframes


In [None]:
def cluster_set_images(review_clusters, frame, data_dicts, groupby, batches=1, ncols=20, nrows=4, annotated=False, figures_path=None):

    if figures_path is not None:
        figures_path = os.path.join(figures_path, 'hpc_tile_samples')
        if not os.path.isdir(figures_path):
            os.makedirs(figures_path)

    for cluster_id in review_clusters:
        indexes       = frame[(frame[groupby]==cluster_id)]['indexes'].values.tolist()
        original_sets = frame[(frame[groupby]==cluster_id)]['original_set'].values.tolist()
        combined      = list(zip(indexes, original_sets))
        random.shuffle(combined)
        combined_plot = sorted(combined[:100*batches])

        csv_information = list()
        images_cluster = list()
        for index, original_set in combined_plot:
            images_cluster.append(data_dicts[original_set][int(index)]/255.)
            entry_dict = frame[(frame.indexes==index)&(frame.original_set==original_set)].to_dict('index')
            for key in entry_dict:
                csv_information.append(entry_dict[key])

        for batch in range(batches):
            fig, axs = plt.subplots(ncols=ncols, nrows=nrows)
            fig.set_figheight(8)
            fig.set_figwidth(8*(ncols/4)*0.8)
            if annotated:
                fig.suptitle('HPC %s - TCGA' % (cluster_id), ha='center', fontweight='bold', fontsize=65)
            else:
                fig.suptitle('HPC %s' % (cluster_id), ha='center', fontweight='bold', fontsize=65)
            gs = axs[0, -4].get_gridspec()
            # remove the underlying axes
            for i in range(ncols-4,ncols):
                for ax in axs[0:, i]:
                    ax.remove()
            axbig = fig.add_subplot(gs[0:, -4:])
            axbig.set_xticks([])
            axbig.set_yticks([])
            axbig.set_yticks([])
            axes_list = list(axs.flatten())
            axes_list.append(axbig)
            for ax, im in zip(axes_list, images_cluster[batch*100:(batch+1)*100]):
                ax.imshow(im)
                ax.set_xticks([])
                ax.set_yticks([])
                ax.set_yticks([])
                for axis in ['top','bottom','left','right']:
                    ax.spines[axis].set_linewidth(4)
            plt.subplots_adjust(wspace=0.05, hspace=0.05)
            fig.tight_layout()
            if figures_path is None:
                plt.show()
            else:
                plt.savefig(os.path.join(figures_path, 'HPC_%s_TCGA_batch_%s.jpg' % (cluster_id, batch)))
                plt.close()

annotated       = True
sns.set_theme(style='white')
cluster_set_images(leiden_clusters, frame=train_df, data_dicts=img_dicts, groupby=groupby, batches=1, annotated=annotated, figures_path=None)


In [None]:
def get_crosscheck_frame(hdf5_path, original_set='train'):
    with h5py.File(hdf5_path, 'r') as content:
        for key in content.keys():
            if 'slides' in key:
                slides_key = key
            elif 'tiles' in key:
                tiles_key = key
        tiles  = content[tiles_key][:].astype('U13')
        slides = content[slides_key][:].astype('U13')
        indexes = list(range(tiles.shape[0]))
    frame_cc = pd.DataFrame(indexes, columns=['indexes'])
    frame_cc['tiles']  = tiles
    frame_cc['slides'] = slides
    frame_cc['original_set'] = original_set
    return frame_cc

def cross_check_dfs(additional_df, frame_cc, matching_fields=['slides', 'tiles']):
    additional_df['slides'] = additional_df['slides'].astype(str)
    additional_df['tiles']  = additional_df['tiles'].astype(str)
    frame_cc['slides']      = frame_cc['slides'].astype(str)
    frame_cc['tiles']       = frame_cc['tiles'].astype(str)
    cross_checked_df = frame_cc.merge(additional_df, how='inner', left_on=matching_fields, right_on=matching_fields)
    return cross_checked_df

def cluster_set_images_add(review_clusters, frame, hdf5_path, groupby, add_cohort, img_key='img', batches=1, ncols=20, nrows=4, figures_path=None):

    if figures_path is not None:
        figures_path = os.path.join(figures_path, 'hpc_tile_samples')
        if not os.path.isdir(figures_path):
            os.makedirs(figures_path)

    with h5py.File(hdf5_path, 'r') as content:

        for key in content.keys():
            if 'img' in key or 'images' in key:
                img_key = key
                break

        for cluster_id in review_clusters:
            indexes       = frame[(frame[groupby]==cluster_id)]['indexes'].values.tolist()
            original_sets = frame[(frame[groupby]==cluster_id)]['original_set'].values.tolist()
            combined      = list(zip(indexes, original_sets))
            random.shuffle(combined)
            combined_plot = sorted(combined[:100*batches])

            csv_information = list()
            images_cluster = list()
            for index, original_set in combined_plot:
                images_cluster.append(content[img_key][int(index)]/255.)
                entry_dict = frame[(frame.indexes==index)&(frame.original_set==original_set)].to_dict('index')
                for key in entry_dict:
                    csv_information.append(entry_dict[key])


            for batch in range(batches):
                fig, axs = plt.subplots(ncols=ncols, nrows=nrows)
                fig.set_figheight(8)
                fig.set_figwidth(8*(ncols/4)*0.8)
                fig.suptitle('HPC %s - %s' % (cluster_id, add_cohort), ha='center', fontweight='bold', fontsize=65)
                gs = axs[0, -4].get_gridspec()
                # remove the underlying axes
                for i in range(ncols-4,ncols):
                    for ax in axs[0:, i]:
                        ax.remove()
                axbig = fig.add_subplot(gs[0:, -4:])
                axbig.set_xticks([])
                axbig.set_yticks([])
                axbig.set_yticks([])
                axes_list = list(axs.flatten())
                axes_list.insert(0, axbig)
                j = 0
                for ax, im in zip(axes_list, images_cluster[batch*100:(batch+1)*100]):
                    ax.imshow(im)
                    ax.set_xticks([])
                    ax.set_yticks([])
                    ax.set_yticks([])
                    for axis in ['top','bottom','left','right']:
                        ax.spines[axis].set_linewidth(4)
                    j += 1
                if j != len(axes_list):
                    for i, ax in enumerate(axes_list[j:]):
                        ax.imshow(np.ones((224,224,3)))
                        ax.set_xticks([])
                        ax.set_yticks([])
                        ax.set_yticks([])
                        for axis in ['top','bottom','left','right']:
                            ax.spines[axis].set_linewidth(4)

                plt.subplots_adjust(wspace=0.05, hspace=0.05)
                fig.tight_layout()
                if figures_path is None:
                    plt.show()
                else:
                    plt.savefig(os.path.join(figures_path, 'HPC_%s_%s_batch_%s.jpg' % (cluster_id, add_cohort, batch)))
                    plt.close()
                if j != len(axes_list): break

sns.set_theme(style='white')

hdf5_path = '%s/datasets/NYUFFPE_LUADLUSC_5x_60pc/he/patches_h224_w224/hdf5_NYUFFPE_LUADLUSC_5x_60pc_he_combined.h5' % main_path

frame_cc       = get_crosscheck_frame(hdf5_path, original_set='additional')
cross_check_df = cross_check_dfs(additional_df, frame_cc, matching_fields=['slides', 'tiles'])

cluster_set_images_add(leiden_clusters, frame=cross_check_df, hdf5_path=hdf5_path, groupby=groupby, add_cohort='NYU', img_key='img', batches=3, figures_path=None)