# Imports

In [None]:
from mpl_toolkits.axes_grid1 import ImageGrid
from matplotlib.colors       import LinearSegmentedColormap
from matplotlib.colors       import TwoSlopeNorm
from skimage.transform       import resize
from plottify                import autosize
from sklearn                 import metrics
from PIL                     import Image
from adjustText              import adjust_text
from scipy.cluster           import hierarchy
import statsmodels.api       as sm
import matplotlib.pyplot     as plt
import numpy                 as np
import seaborn               as sns
import pandas                as pd
import scanpy                as sc
import matplotlib
import anndata
import random
import fastcluster
import copy
import umap
import h5py
import sys
import os

# Variables for data selections

In [None]:
# Workspace path.
main_path = '/media/adalberto/Disk2/PhD_Workspace'
sys.path.append(main_path)
from models.clustering.cox_proportional_hazard_regression_leiden_clusters import *
from models.evaluation.folds import load_existing_split
from models.visualization.attention_maps import *
from models.clustering.data_processing import *
from data_manipulation.data import Data

In [None]:
# Image dataset variables.
dataset            = 'TCGAFFPE_LUADLUSC_5x_60pc'
additional_dataset = 'NYUFFPE_survival_5x_60pc'

############# LUAD Overall and Recurrence Free Survival
meta_field       = 'luad'
matching_field   = 'samples'
resolution      = 2.0
fold_number     = 0
groupby         = 'leiden_%s' % resolution
meta_folder     = 'luad_overall_survival_nn250_fold%s_NYU_v3' % fold_number
folds_pickle    = '%s/utilities/files/LUAD/overall_survival_TCGA_folds.pkl'  % main_path

# Institutions.
inst_csv   = '%s/utilities/files/TCGA/TCGA_Institutions.csv' % main_path
inst_frame = pd.read_csv(inst_csv)
inst_frame = inst_frame[inst_frame['Study Name'].isin(['Lung adenocarcinoma', 'Lung squamous cell carcinoma'])]

# Representations.
h5_complete_path = '%s/results/BarlowTwins_3/TCGAFFPE_LUADLUSC_5x_60pc_250K/h224_w224_n3_zdim128_filtered/hdf5_TCGAFFPE_LUADLUSC_5x_60pc_he_complete_lungsubtype_survival_filtered.h5' % main_path
h5_additional_path = '%s/results/BarlowTwins_3/TCGAFFPE_LUADLUSC_5x_60pc_250K/h224_w224_n3_zdim128_filtered/NYU300LUAD_Survival_5x_60pc/h224_w224_n3_zdim128/hdf5_NYU300LUAD_Survival_5x_60pc_he_train_overall_progression_free_surival_filtered.h5' % main_path

# File name and directories.
file_name = h5_complete_path.split('/hdf5_')[1].split('.h5')[0] + '_%s__fold%s' % (groupby.replace('.', 'p'), fold_number)
if h5_additional_path is not None: file_additional = h5_additional_path.split('/hdf5_')[1].split('.h5')[0] + '_%s__fold%s' % (groupby.replace('.', 'p'), fold_number)

# Setup folder.
main_cluster_path = h5_complete_path.split('hdf5_')[0]
main_cluster_path = os.path.join(main_cluster_path, meta_folder)
adatas_path       = os.path.join(main_cluster_path, 'adatas')
figures_path      = os.path.join(main_cluster_path, 'figures')
if not os.path.isdir(figures_path):
    os.makedirs(figures_path)

### Correlation results

In [None]:
# This correlation values come from the correlation notebook.
correlation_hovernet = '%s/results/BarlowTwins_3/TCGAFFPE_LUADLUSC_5x_60pc_250K/h224_w224_n3_zdim128_filtered/luad_overall_survival_nn250_fold0_NYU_v3/leiden_2p0_fold0/correlations/NYU300LUAD_Survival_5x_60pc_he_train_overall_progression_free_surival_filtered_leiden_2p0__fold0_luad_overall_survival_nn250_fold0_NYU_v3_hovernet_critical_coef.csv' % main_path
hovernet_df = pd.read_csv(correlation_hovernet)
hovernet_df = hovernet_df.rename(columns={'Unnamed: 0':'Cell Type'})
hovernet_df

### Images

In [None]:
data = Data(dataset=dataset, marker='he', patch_h=224, patch_w=224, n_channels=3, batch_size=64, project_path=main_path, load=True)
img_dicts = dict()
img_dicts['train'] = data.training.images
img_dicts['valid'] = data.validation.images
img_dicts['test'] = data.test.images

additional_data = Data(dataset=additional_dataset, marker='he', patch_h=224, patch_w=224, n_channels=3, batch_size=64, project_path=main_path, load=True)
additional_img_dicts = dict()
additional_img_dicts['train'] = additional_data.training.images

### HPC Annotations

In [None]:
csv_annotations_path = '%s/utilities/files/LUAD/HPC_annotations/LUAD_HPC_annotations.csv' % main_path
annotations          = pd.read_csv(csv_annotations_path)
annotations          = annotations.set_index('HPC')
annotations          = annotations.replace({'other predominant tissue':'no epithelium', 'very sparse':'Very Sparse', 'severe':'Severe', 'moderate':'Moderate', 'mild':'Mild'})
annotations          = annotations.replace({'more stroma':'More Stroma', 'more epithelium':'More Epithelium', 'no epithelium':'No Epithelium', 'roughly equal':'Roughly Equal'})

annotations          = annotations.replace({'malignant epithelium':'Malignant Epithelium', 'elastosis or collagenosis':'Elastosis/Collagenosis',
       'near-normal lung':'Near-normal Lung', 'reactive lung changes':'Reactive Lung Changes', 'necrosis':'Necrosis',
       'other connective tissue':'Connective Tissue', 'vessels':'Vessels', 'airway':'Airway', 'cartilage':'Cartilage'})


# Paper Figure - Latent Space and Cluster Network - LUAD OS

In [None]:
adata_train, h5ad_path = read_h5ad_reference(h5_complete_path, meta_folder, groupby, fold_number)

In [None]:
done = False
if os.path.isfile(h5ad_path.replace('.h5ad', '_paga.h5ad')):
    done=True
    adata_train = anndata.read_h5ad(h5ad_path.replace('.h5ad', '_paga.h5ad'))
else:
    sc.tl.paga(adata_train, groups=groupby, neighbors_key='nn_leiden')

In [None]:
# HPC network visualization
layout           = 'fa'  # ‘fa’, ‘fr’, ‘rt’, ‘rt_circular’, ‘drl’, ‘eq_tree’
random_state     = 0
threshold        = 0.74

# Figure related
node_size_scale  = 7
node_size_power  = 0.5
edge_width_scale = .05
fontsize    = 15
fontoutline = 2

if not done:
        fig = plt.figure(figsize=(100,10))
        ax  = fig.add_subplot(1, 3, 1)
        sc.pl.paga(adata_train, layout=layout, random_state=random_state, threshold=threshold, node_size_scale=node_size_scale, node_size_power=node_size_power,
                edge_width_scale=edge_width_scale, fontsize=fontsize, fontoutline=fontoutline, frameon=False, show=False, ax=ax)
        plt.show()

In [None]:
if not done:
    sc.tl.umap(adata_train, init_pos="paga", neighbors_key='nn_leiden')
    adata_train.write(h5ad_path.replace('.h5ad', '_paga.h5ad'))

In [None]:
cap_depletion = False

cell_types = list()
for cell_type in hovernet_df['Cell Type']:
    cell_types.append(cell_type)
    for cluster in np.unique(adata_train.obs[groupby]):
        value = hovernet_df[hovernet_df['Cell Type']==cell_type][cluster].values[0]
        if cap_depletion and value < 0:
            value = 0
        adata_train.obs.at[adata_train.obs[groupby]==str(cluster), cell_type] = value


In [None]:
# Graph visualization related
layout           = 'fa'  # ‘fa’, ‘fr’, ‘rt’, ‘rt_circular’, ‘drl’, ‘eq_tree’
random_state     = 0
threshold        = 0.74

# Figure related
node_size_scale  = 7
node_size_power  = 0.5
edge_width_scale = .05
fontsize    = 15
fontoutline = 2

cmap = sns.diverging_palette(250, 20, as_cmap=True)

sns.set_theme(style='white')
fig = plt.figure(figsize=(30,10))

ax  = fig.add_subplot(1, 3, 1)
ax.set_title('Cell Neoplastic\nEnrichment', fontweight='bold', fontsize=20)
sc.pl.paga(adata_train, layout=layout, random_state=random_state, threshold=threshold, node_size_scale=node_size_scale, node_size_power=node_size_power,
           edge_width_scale=edge_width_scale, fontsize=fontsize, fontoutline=fontoutline, frameon=False, show=False, ax=ax, color='cell neoplastic', cmap=cmap, colorbar=False)


ax  = fig.add_subplot(1, 3, 2)
ax.set_title('Cell Inflammatory\nEnrichment', fontweight='bold', fontsize=20)
sc.pl.paga(adata_train, layout=layout, random_state=random_state, threshold=threshold, node_size_scale=node_size_scale, node_size_power=node_size_power,
           edge_width_scale=edge_width_scale, fontsize=fontsize, fontoutline=fontoutline, frameon=False, show=False, ax=ax, color='cell inflammatory', cmap=cmap, colorbar=False)

ax  = fig.add_subplot(1, 3, 3)
ax.set_title('Cell Dead\nEnrichment', fontweight='bold', fontsize=20)
sc.pl.paga(adata_train, layout=layout, random_state=random_state, threshold=threshold, node_size_scale=node_size_scale, node_size_power=node_size_power,
           edge_width_scale=edge_width_scale, fontsize=fontsize, fontoutline=fontoutline, frameon=False, show=False, ax=ax, color='cell dead', cmap=cmap, colorbar=False)

plt.show()


In [None]:
for feature in annotations.columns:
    adata_assignations = list()
    for leiden_tile in adata_train.obs[groupby].values.astype(int):
        value = annotations.loc[leiden_tile, feature]
        if str(annotations.loc[leiden_tile, feature])=='nan':
            value = 'N/A'
        adata_assignations.append(value)
    adata_train.obs[feature] = np.array(adata_assignations).astype(str)
annotations.head(5)


In [None]:
def fix_umap_annotations(ax1):
    # HPC 15
    old_y = ax1.texts[15]._y
    ax1.texts[15]._y = old_y*1.02

    #  -----------
    # HPC 14
    old_y = ax1.texts[14]._y
    ax1.texts[14]._y = old_y*1.022

    # HPC 3
    old_y = ax1.texts[3]._y
    ax1.texts[3]._y = old_y*0.975

    #  -----------
    # HPC 13
    old_y = ax1.texts[13]._y
    ax1.texts[13]._y = old_y*1.5

    # HPC 12
    old_y = ax1.texts[12]._y
    ax1.texts[12]._y = old_y*0.8

    #  -----------
    # HPC 21
    old_y = ax1.texts[21]._y
    ax1.texts[21]._y = old_y*1.02

    # HPC 12
    old_y = ax1.texts[11]._y
    ax1.texts[11]._y = old_y*0.98

    #  -----------
    # HPC 39
    old_y = ax1.texts[39]._y
    ax1.texts[39]._y = old_y*0.97

    #  -----------
    # HPC 41
    old_x = ax1.texts[41]._x
    ax1.texts[41]._x = old_x*1.1

    # HPC 9
    old_y = ax1.texts[9]._y
    ax1.texts[9]._y = old_y*1.02

    #  -----------
    # HPC 40
    old_y = ax1.texts[40]._y
    ax1.texts[40]._y = old_y*0.95

    #  -----------
    # HPC 27
    old_y = ax1.texts[27]._y
    ax1.texts[27]._y = old_y*0.8

    #  -----------
    # HPC 24
    old_x = ax1.texts[24]._x
    ax1.texts[24]._x = old_x-0.5

    # HPC 35
    old_y = ax1.texts[35]._y
    ax1.texts[35]._y = old_y-0.2

# Graph visualization related
layout           = 'fa'  # ‘fa’, ‘fr’, ‘rt’, ‘rt_circular’, ‘drl’, ‘eq_tree’
random_state     = 0
threshold        = 0.74

# Figure related
node_size_scale  = 7
node_size_power  = 0.5
edge_width_scale = .05
fontoutline = 4
marker_size = 2

fontsize       = 20
fontsize_title = 22

sns.set_theme(style='white')

print('UMAP_leiden')
fig = plt.figure(figsize=(10,10))
# # Axes 1 - Leiden clusters.
colors = sns.color_palette('tab20', len(np.unique(adata_train.obs[groupby].values)))
ax1  = fig.add_subplot(1, 1, 1)
ax1 = sc.pl.umap(adata_train, legend_fontsize=fontsize, legend_fontoutline=fontoutline, legend_loc='on data', frameon=False, show=False, ax=ax1, color=groupby, palette=colors, size=marker_size)
# adjust_text(ax1.texts)
ax1.set_title('LUAD\nHistomorphological Phenotype Clusters', fontweight='bold', fontsize=fontsize_title)
fix_umap_annotations(ax1)
plt.show()

print('UMAP_tissue_morphologies')
fig = plt.figure(figsize=(10,10))
# Axes 3 - Morphological Supergroup.
feature_3 = 'Tissue Morphologies'
colors = sns.color_palette('Set1', len(np.unique(adata_train.obs[feature_3].values)))
colors = ['Grey', colors[5], colors[0], 'purple', colors[2], colors[1]]
ax3  = fig.add_subplot(1, 1, 1)
ax3 = sc.pl.umap(adata_train, legend_fontsize=fontsize, legend_fontoutline=fontoutline, legend_loc='right margin', frameon=False, show=False, ax=ax3, color=feature_3, palette=colors, size=marker_size)
adjust_text(ax3.texts)
ax3.set_title('Tissue Morphologies', fontweight='bold', fontsize=fontsize_title)
handles, labels = ax3.get_legend_handles_labels()
handles = [handles[5], handles[4], handles[2], handles[3], handles[1], handles[0]]
labels  = [labels[5],  labels[4],  labels[2],  labels[3],  labels[1],  labels[0]]
ax3.legend(handles, labels, loc='upper right', frameon=False, bbox_to_anchor=(1.52,0.75))
for text in ax3.legend_.get_texts():
    text.set_size(fontsize)
    text.set_fontweight('bold')
plt.show()

print('UMAP_epth_stroma_ratio')
fig = plt.figure(figsize=(10,10))
# Axes 2 - Lymphocytic Infiltration.
feature_2 = 'Epithelium Stroma Ratio'
colors = sns.color_palette('Set1', len(np.unique(adata_train.obs[feature_3].values)))
colors = [colors[0], colors[1], colors[4], colors[2]]
ax2  = fig.add_subplot(1, 1, 1)
ax2 = sc.pl.umap(adata_train, legend_fontsize=fontsize, legend_fontoutline=fontoutline, legend_loc='right margin', frameon=False, show=False, ax=ax2, color=feature_2, palette=colors, size=marker_size)
adjust_text(ax2.texts)
ax2.set_title('Epithelium Stroma Ratio', fontweight='bold', fontsize=fontsize_title)
handles, labels = ax2.get_legend_handles_labels()
handles = [handles[1], handles[2], handles[0], handles[3]]
labels  = [labels[1],  labels[2],  labels[0],  labels[3]]
ax2.legend(handles, labels, loc='center right', frameon=False, bbox_to_anchor=(1.3,0.55))
for text in ax2.legend_.get_texts():
    text.set_size(fontsize)
    text.set_fontweight('bold')
plt.show()

print('UMAP_inflammation')
fig = plt.figure(figsize=(10,10))
# Axes 3 - Morphological Supergroup.
feature_3 = 'Lymphocytic Infiltration'
colors = sns.diverging_palette(250, 20, n=len(np.unique(adata_train.obs[feature_3].values))-1, center='light')
colors = [colors[1], colors[2], 'Grey', colors[3], colors[0]]
ax3  = fig.add_subplot(1, 1, 1)
ax3 = sc.pl.umap(adata_train, legend_fontsize=fontsize, legend_fontoutline=fontoutline, legend_loc='right margin', frameon=False, show=False, ax=ax3, color=feature_3, palette=colors, size=marker_size)
ax3.set_title('Lymphocytic Infiltration', fontweight='bold', fontsize=fontsize_title)
handles, labels = ax3.get_legend_handles_labels()
handles = [handles[3], handles[1], handles[0], handles[-1], handles[2]]
labels  = [labels[3],  labels[1],  labels[0],  labels[-1],  labels[2]]
ax3.legend(handles, labels, loc='center right', frameon=False, bbox_to_anchor=(1.2,0.5))
for text in ax3.legend_.get_texts():
    text.set_size(fontsize)
    text.set_fontweight('bold')
plt.show()


In [None]:
# Graph visualization related
layout           = 'fa'  # ‘fa’, ‘fr’, ‘rt’, ‘rt_circular’, ‘drl’, ‘eq_tree’
random_state     = 0
threshold        = 0.74

# Figure related
node_size_scale  = 7
node_size_power  = 0.5
edge_width_scale = .05
fontsize    = 15
fontoutline = 4
marker_size = 6
only_seleted = []


cmap = sns.color_palette("Reds", as_cmap=True)
vmax = np.max(adata_train.obs[cell_types].to_numpy())
vmin = np.min(adata_train.obs[cell_types].to_numpy())
if vmin != 0:
    vmin = -vmax
    # cmap = sns.color_palette("vlag", as_cmap=True)
    cmap = sns.diverging_palette(250, 20, as_cmap=True)

sns.set_theme(style='white')
fig = plt.figure(figsize=(30,10))

ax1  = fig.add_subplot(1, 3, 1)
ax1 = sc.pl.umap(adata_train, legend_fontsize=fontsize, legend_fontoutline=fontoutline, legend_loc='on data', frameon=False, show=False, ax=ax1, color=groupby, cmap=cmap, vmin=vmin, vmax=vmax, size=marker_size)
fix_umap_annotations(ax1=ax1)
prev_texts = ax1.texts

# Axes 1 - Inflammatory
ax2  = fig.add_subplot(1, 3, 2)
ax2 = sc.pl.umap(adata_train, legend_fontsize=fontsize, legend_fontoutline=fontoutline, frameon=False, show=False, ax=ax2, color='cell inflammatory', cmap=cmap, legend_loc=None, vmin=vmin, vmax=vmax, size=marker_size)
ax2.set_title('Cell Inflammatory\nEnrichment', fontweight='bold', fontsize=20)
ax2.collections[-1].colorbar.remove()
for a in prev_texts:
    ax2.annotate(a._text, xy=(a._x,a._y), color=a._color, verticalalignment=a._verticalalignment, horizontalalignment=a._horizontalalignment,
                 fontproperties=a._fontproperties, linespacing=a._linespacing, path_effects=a._path_effects)

# Axes 2 - Dead
ax3  = fig.add_subplot(1, 3, 3)
ax3 = sc.pl.umap(adata_train, legend_fontsize=fontsize, legend_fontoutline=fontoutline, frameon=False, show=False, ax=ax3, color='cell dead', cmap=cmap, vmin=vmin, vmax=vmax, size=marker_size)
ax3.set_title('Cell Dead\nEnrichment', fontweight='bold', fontsize=20)
cbar = ax3.collections[-1].colorbar
for a in prev_texts:
    ax3.annotate(a._text, xy=(a._x,a._y), color=a._color, verticalalignment=a._verticalalignment, horizontalalignment=a._horizontalalignment,
                 fontproperties=a._fontproperties, linespacing=a._linespacing, path_effects=a._path_effects)
prev_texts = ax3.texts

# Legend on side
cbar.ax.tick_params(labelsize=fontsize*0.9)
[label.set_fontweight('bold') for label in cbar.ax.get_yticklabels()]

# Axes 0 - Neoplastic
ax1.clear()
ax1 = sc.pl.umap(adata_train, legend_fontsize=fontsize, legend_fontoutline=fontoutline, frameon=False, show=False, ax=ax1, color='cell neoplastic', cmap=cmap, vmin=vmin, vmax=vmax, size=marker_size)
ax1.collections[-1].colorbar.remove()
ax1.set_title('Cell Neoplastic\nEnrichment', fontweight='bold', fontsize=20)
for a in prev_texts:
    ax1.annotate(a._text, xy=(a._x,a._y), color=a._color, verticalalignment=a._verticalalignment, horizontalalignment=a._horizontalalignment,
                 fontproperties=a._fontproperties, linespacing=a._linespacing, path_effects=a._path_effects)
plt.show()


In [None]:
mark_set = list()
mark_set.append(('UMAP_leiden_A', [2, 4, 10, 13, 32, 36]))
mark_set.append(('UMAP_leiden_B', [3, 14, 16, 20, 22, 23, 31, 34, 40, 42, 42]))
mark_set.append(('UMAP_leiden_C', [8, 21, 0, 26, 5, 15, 25, 29, 38, 18, 45, 19, 24, 30, 35, 28, 37]))
mark_set.append(('UMAP_leiden_D', [6, 11, 12, 27, 33, 39, 41]))
mark_set.append(('UMAP_leiden_E', [1,7,9,12,17,41,44]))

for name, mark in mark_set:

    print(name)
    fig = plt.figure(figsize=(10,10))
    # # Axes 1 - Leiden clusters.
    colors = sns.color_palette('tab20', len(np.unique(adata_train.obs[groupby].values)))
    for i in np.unique(adata_train.obs[groupby].values.astype(int)):
        if i not in mark:
            colors[i] = 'Grey'

    ax1  = fig.add_subplot(1, 1, 1)
    # ax1 = sc.pl.umap(adata_train, legend_fontsize=fontsize, legend_fontoutline=fontoutline, legend_loc='on data', frameon=False, show=False, ax=ax1, color=groupby, palette=colors, size=marker_size)
    ax1 = sc.pl.umap(adata_train, legend_fontsize=fontsize, legend_fontoutline=fontoutline, legend_loc='on data', frameon=False, show=False, ax=ax1, 
                    color=groupby, palette=colors, size=5)
    ax1.set_title('LUAD\nHistomorphological Phenotype Clusters', fontweight='bold', fontsize=fontsize_title)
    fix_umap_annotations(ax1)

    sizes = list()
    for flag in adata_train.obs[groupby].astype(int).isin(mark):
        if flag:
            sizes.append(10)
        else:
            sizes.append(0.75)
    ax1._children[0]._sizes = np.array(sizes)
    plt.show()


# Paper Figure - UMAP Patient vector representations

In [None]:
frames = build_cohort_representations(meta_folder, meta_field, matching_field, groupby, fold_number, folds_pickle, h5_complete_path, h5_additional_path, 'clr', 100)
complete_df, additional_complete_df, frame_clusters, frame_samples, features = frames

labels = complete_df.to_numpy()[1:,-1]
data   = complete_df.to_numpy()[1:,2:-1]

labels_add = additional_complete_df.to_numpy()[1:,-1]
data_add   = additional_complete_df.to_numpy()[1:,2:-1]

columns = [col for col in complete_df.columns if col != 'luad' and col != 'samples' and col != 'slides']

labels = complete_df.to_numpy()[1:,-1]
data   = complete_df.to_numpy()[1:,2:-1]
df     = pd.DataFrame(data, columns=columns)
df['Lung Type'] = labels
df['Cohort']       = 'TCGA'

labels_add = additional_complete_df.to_numpy()[1:,-1]
data_add   = additional_complete_df.to_numpy()[1:,1:-1]
df_add     = pd.DataFrame(data_add, columns=columns)
df_add['Lung Type'] = labels_add
df_add['Cohort']       = 'NYU'

df_all = pd.concat([df, df_add], axis=0)

In [None]:
scatter_size    = 200

figsize         = (10,10)
fontsize_labels = 30
fontsize_legend = 30
l_markerscale   = 5
l_box_w         = 2
lw              = 2

min_dist     = 0.0
n_components = 2
n_neighbors  = 25
metric       = 'euclidean'

print(metric, n_neighbors)
# UMAP
fit = umap.UMAP(n_neighbors=n_neighbors, min_dist=min_dist, n_components=n_components, metric=metric)
u   = fit.fit_transform(df_all[columns])
df_all['UMAP Dim. 0'] = u[:, 0]
df_all['UMAP Dim. 1'] = u[:, 1]

fig   = plt.figure(figsize=figsize)
ax    = fig.add_subplot(1, 1, 1)

# Scatter plot.
sns.scatterplot(data=df_all, x='UMAP Dim. 0', y='UMAP Dim. 1', hue='Lung Type', style='Cohort', markers={'TCGA':'v', 'NYU':'s'}, s=scatter_size, ax=ax)
ax.set_xlabel('UMAP Dim. 0', fontsize=fontsize_labels)
ax.set_ylabel('UMAP Dim. 1', fontsize=fontsize_labels)
ax.set_title('Patient\nVector Representations',  fontsize=fontsize_labels, fontweight='bold')
ax.tick_params(axis='both', which='major', labelsize=fontsize_labels)
legend = ax.legend(loc='upper left', markerscale=l_markerscale, prop={'size': fontsize_legend-5}, ncol=2)
legend.get_texts()[1].set_text('LUSC')
legend.get_texts()[2].set_text('LUAD')
legend.get_texts()[0].set_size(fontsize_legend)
legend.get_texts()[3].set_size(fontsize_legend)
legend.get_frame().set_linewidth(l_box_w)

for tick in ax.xaxis.get_major_ticks():
    tick.label1.set_fontsize(fontsize_labels)
    tick.label1.set_fontweight('bold')
for tick in ax.yaxis.get_major_ticks():
    tick.label1.set_fontsize(fontsize_labels)
    tick.label1.set_fontweight('bold')

ax.set_xlabel('UMAP Dim. 0', fontsize=fontsize_labels, fontweight='bold')
ax.set_ylabel('UMAP Dim. 1', fontsize=fontsize_labels, fontweight='bold')
for axis in ['top','bottom','left','right']:
    ax.spines[axis].set_linewidth(4)

plt.tight_layout()
plt.show()


# Paper Figure - HPC Generalization across Institutions and Patients

In [None]:
# Read tile vector representations.
folds = load_existing_split(folds_pickle)
fold = folds[fold_number]
dataframes, complete_df, leiden_clusters = read_csvs(adatas_path, matching_field, groupby, fold_number, fold, h5_complete_path, h5_additional_path, additional_as_fold=False, force_fold=fold_number)

# Concatenate all data.
tiles_df = pd.concat([dataframes[0], dataframes[1], dataframes[2]])
tiles_df['samples']  = tiles_df['slides'].apply(lambda x: '-'.join(x.split('-')[:3]))
tiles_df['TSS Code'] = tiles_df['samples'].apply(lambda x: x.split('-')[1]).values.astype(str)

# Include counts per samples and HPC
for name, field in [('sample',matching_field), ('hpc', groupby)]:
    counts_per_field = tiles_df.groupby(field).count()
    counts_per_field = counts_per_field.reset_index()
    counts_per_field = counts_per_field.rename(columns={'tiles':'nt_per_%s'%name})
    tiles_df = tiles_df.merge(counts_per_field[[field, 'nt_per_%s'%name]], on=field)

# Normalized values of percentage of total tiles in HPC
tiles_df.insert(loc=len(tiles_df.columns), column='nt_per_hpc_norm', value=tiles_df['nt_per_hpc'].values/tiles_df.shape[0])

# Normalize contribution of HPC in patient.
hpc_pat = tiles_df[[matching_field, groupby, 'tiles']].groupby([matching_field, groupby]).count()
hpc_pat = hpc_pat.reset_index()
hpc_pat = hpc_pat.rename(columns={'tiles':'nt_per_sample_hpc'})
hpc_pat = hpc_pat.merge(tiles_df[[matching_field, 'nt_per_sample']].drop_duplicates(), on=matching_field)
hpc_pat = hpc_pat.drop_duplicates()
hpc_pat['nt_per_sample_hpc_norm'] = np.divide(hpc_pat['nt_per_sample_hpc'].values.astype(float), hpc_pat['nt_per_sample'].values.astype(float))


In [None]:
def plot_institution_distribution(data_hpc_inst, unique_values, unique_label, field, title, figsize=(30,7), fontsize_labels=22, fontsize_legend=20, show_max_min=False):
    def colors_from_values(values, palette_name, normalize=False):
        # normalize the values to range [0, 1]
        if values.max() >1:
            normalized = (values - min(values)) / (max(values) - min(values))
        else:
            normalized = values
        # convert to indices
        indices = np.round(normalized * (len(values) - 1)).astype(np.int32)
        # use the indices to get the colors
        palette = sns.color_palette(palette_name, int(1.5*len(values)))
        return np.array(palette).take(indices, axis=0)

    fig   = plt.figure(figsize=figsize)

    fig, (ax0, ax1) = plt.subplots(1, 2, figsize=figsize, gridspec_kw={'width_ratios': [25, 1]})

    # pal  = sns.color_palette("Greens_d")
    # rank = data_hpc_inst[field].argsort().argsort()
    # sns.barplot(data=data_hpc_inst, x='HPC', y=field, palette=np.array(pal[::-1])[rank], ax=ax)
    y = data_hpc_inst[field].values
    y_color = np.array([100] + y.tolist())
    max_color = colors_from_values(y_color, "Greens_d")[0]
    color_palette = colors_from_values(y_color, "Greens_d")[1:]
    
    sns.barplot(data=data_hpc_inst, x='HPC', y=field, palette=color_palette, ax=ax0)

    ax0.tick_params(axis='x', rotation=90)
    ax0.set_ylim([0,105])
    yticks = (np.array(range(0,110,10))).tolist()
    ax0.set_yticks(yticks, yticks)

    ax0.set_title(title,  fontsize=fontsize_labels*1.3, fontweight='bold')
    ax0.set_xlabel('\nHistomorphological Phenotype Cluster (HPC)', fontsize=fontsize_labels,     fontweight='bold')
    ax0.set_ylabel(' ', fontsize=fontsize_labels, fontweight='bold')
    if show_max_min:
        max_val = np.max(data_hpc_inst[field].values)
        min_val = np.min(data_hpc_inst[field].values)
        ax0.axhline(max_val, linestyle='--')
        ax0.axhline(min_val, linestyle='--')
    ax0.axhline(50, linestyle='--', color='black')
    ax0.axhline(25, linestyle='--', color='black')

    sns.barplot(y=[unique_values], palette='Blues_r', ax=ax1)
    ax1.set_xlabel(unique_label, fontsize=fontsize_labels,     fontweight='bold')
    yticks = ax1.get_yticks().tolist()
    yticks = np.array(yticks).astype(int)
    ax1.set_yticks(yticks, yticks)

    for ax in [ax0, ax1]:
        for tick in ax.xaxis.get_major_ticks():
            tick.label1.set_fontsize(fontsize_labels)
            tick.label1.set_fontweight('bold')
        for tick in ax.yaxis.get_major_ticks():
            tick.label1.set_fontsize(fontsize_labels)
            tick.label1.set_fontweight('bold')
        for axis in ['top','bottom','left','right']:
            ax.spines[axis].set_linewidth(4)


    plt.tight_layout()
    plt.show()
    
# Threshold minimum contribution for patient or institution. 
threshold = 0.005

# Threshold out patients with not enough contribution at least 0.5%.
hpc_pat_th = hpc_pat[hpc_pat.nt_per_sample_hpc_norm >= threshold]
hpc_pat_th = hpc_pat_th.groupby([groupby]).count()['samples']/len(np.unique(hpc_pat[matching_field]))*100
hpc_pat_th = hpc_pat_th.reset_index()
hpc_pat_th = hpc_pat_th.rename(columns={groupby:'HPC'})

# Threshold out institution with not enough contribution at least 0.5%.
tiles_df = pd.concat([dataframes[0], dataframes[1], dataframes[2]])
tiles_df['TSS Code'] = tiles_df['samples'].apply(lambda x: x.split('-')[1]).values.astype(str)
tss_hpc = tiles_df[['TSS Code', 'tiles']].groupby('TSS Code').count()
tss_hpc = tss_hpc.reset_index()
tss_hpc = tss_hpc.rename(columns={'tiles':'total_tiles'})

data_hpc_inst = list()
for hpc in np.unique(tiles_df[groupby]):
    hpc_df    = tiles_df[tiles_df[groupby]==hpc]
    hpc_df    = hpc_df.groupby('TSS Code').count()
    hpc_df    = hpc_df.reset_index()[['TSS Code', 'tiles']]
    hpc_df    = hpc_df.merge(tss_hpc, on='TSS Code', how='inner')
    hpc_df.insert(len(hpc_df.columns), 'tiles_norm', np.divide(hpc_df['tiles'].values,hpc_df['total_tiles'].values))
    hpc_df    = hpc_df[hpc_df['tiles_norm']>=threshold]
    data_hpc_inst.append((hpc, hpc_df.shape[0]/tss_hpc.shape[0]*100))
data_hpc_inst = pd.DataFrame(data_hpc_inst, columns=['HPC', 'Percentage of Institutions in HPC'])

plot_institution_distribution(data_hpc_inst, unique_values=tiles_df['TSS Code'].unique().shape[0], unique_label='Total\nInstitutions', 
                              field='Percentage of Institutions in HPC', title='Percentage of TCGA institutions\npresent in the HPC', 
                              figsize=(30,10), fontsize_labels=35, fontsize_legend=32)
plot_institution_distribution(hpc_pat_th,    unique_values=tiles_df['samples'].unique().shape[0], unique_label='Total\nPatients', 
                              field='samples', title='Percentage of TCGA patients\npresent in the HPC',
                              figsize=(30,10), fontsize_labels=35, fontsize_legend=32)

# Paper Figure - HPC Summary Samples

In [None]:
def cluster_set_images(leiden_clusters_order, hpc_index_map, data_dicts, groupby, fontsize_title, fontsize_label, batches=1, ncols=10, nrows=5, figsize=(30, 18), annotations=None, width=None, main_cluster_path=None):
    import textwrap

    for batch in range(batches):
        fig, axs = plt.subplots(ncols=ncols, nrows=nrows, figsize=figsize)
        axes_list = list(axs.flatten())

        for ax, cluster_id in zip(axes_list, leiden_clusters_order):            
            # indexes       = frame[(frame[groupby]==cluster_id)]['indexes'].values.tolist()
            # original_sets = frame[(frame[groupby]==cluster_id)]['original_set'].values.tolist()
            # combined      = list(zip(indexes, original_sets))
            # random.shuffle(combined)
            # combined_plot = sorted(combined[:100*batches])
            # index = indexes[0]
            # original_set = original_sets[0]
            index        = hpc_index_map[cluster_id]['index'] 
            original_set = hpc_index_map[cluster_id]['original_set'] 
            ax.imshow(data_dicts[original_set][int(index)]/255.)
            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_yticks([])
            for axis in ['top','bottom','left','right']:
                ax.spines[axis].set_linewidth(4)
            title = 'HPC %s' % cluster_id
            text  = []
            if annotations is not None:
                lines = textwrap.wrap(annotations.loc[cluster_id,'Summary'], width, break_long_words=False)
                text.extend(['\n%s' % s for s in lines])
                text.append('\n')
                # for i in range(4-len(text)-1):
                    # text.append('\n')
            ax.set_title(title, fontweight='bold', fontsize=fontsize_title)
            ax.set_xlabel(''.join(text), fontweight='bold', fontsize=fontsize_label)
            ax.xaxis.set_label_coords(0.5,0.075)

        plt.subplots_adjust(wspace=0.05, hspace=0.2)
        fig.tight_layout()
        if main_cluster_path is not None:
            plt.savefig(os.path.join(main_cluster_path, 'Tiles_annotated.png'), dpi=300)
            plt.close()
        else:
            plt.show()


csv_annotations_path = '/media/adalberto/Disk2/PhD_Workspace/utilities/files/LUAD/LUAD_HPC_annotations.csv'
csv_backtrack_path   = '/media/adalberto/Disk2/PhD_Workspace/utilities/files/LUAD/HPC_annotations/backtrack'
annotations          = pd.read_csv(csv_annotations_path)
annotations          = annotations.set_index('HPC')
annotations          = annotations.replace({'other predominant tissue':'no epithelium', 'very sparse':'Very Sparse', 'severe':'Severe', 'moderate':'Moderate', 'mild':'Mild'})
annotations          = annotations.replace({'more stroma':'More Stroma', 'more epithelium':'More Epithelium', 'no epithelium':'No Epithelium', 'roughly equal':'Roughly Equal'})

annotations          = annotations.replace({'malignant epithelium':'Malignant Epithelium', 'elastosis or collagenosis':'Elastosis/Collagenosis',
       'near-normal lung':'Near-normal Lung', 'reactive lung changes':'Reactive Lung Changes', 'necrosis':'Necrosis',
       'other connective tissue':'Connective Tissue', 'vessels':'Vessels', 'airway':'Airway', 'cartilage':'Cartilage'})

hpc_index_map = dict()
for hpc in leiden_clusters:
    hpc_index_map[hpc] = dict()
    csv_path = os.path.join(csv_backtrack_path, 'set_%s_train.csv' % hpc)
    hpc_df = pd.read_csv(csv_path)
    x,y = annotations.loc[hpc,'Exemplar Tiles'].split(',')
    # indexes_hpc      =  hpc_df.indexes.values.reshape((20,5))
    # original_set_hpc =  hpc_df.original_set.values.reshape((20,5))
    # print(hpc, x, y, indexes_hpc.shape)
    indexes_hpc      =  hpc_df.indexes.values
    original_set_hpc =  hpc_df.original_set.values
    if int(y) < 5:
        index = 20*int(y) + int(x)
    else:
        index = 20*int(x) + int(y)
    selected_index = indexes_hpc[index]
    selected_set   = original_set_hpc[index]
    hpc_index_map[hpc]['index']        = selected_index
    hpc_index_map[hpc]['original_set'] = selected_set

fontsize_title = 24
fontsize_label = 20

leiden_clusters_order = [36, 32, 13, 3,  16, 34, 40, 31,
                         10, 2,  4,  43, 23, 14, 22, 42,
                         21, 0,  45, 19, 28, 15, 5,  20, 
                         8,  26, 18, 24, 37, 25, 38, 29, 
                         10, 35, 12, 27, 9,  1,  7,  27,
                         11, 39, 6,  41, 12, 44, 17, 33,]

sns.set_theme(style='white')
cluster_set_images(leiden_clusters_order, hpc_index_map=hpc_index_map, data_dicts=img_dicts, groupby=groupby,
                    fontsize_title=fontsize_title, fontsize_label=fontsize_label, batches=1, ncols=8, nrows=6, figsize=(25, 27), annotations=annotations, width=17, main_cluster_path=None)
                    

# Paper Figure - C-Index across resolutions

In [None]:
# CSV path to cox results.
csv_path = '%s/results/BarlowTwins_3/TCGA_LUAD_2.016umpx_60Bkg_split4/h224_w224_n3_zdim128/v01_SOTA/OS_c_index_v01_SOTA_l1_ratio_0.0_mintiles_0.csv' % main_path

# Get best mean c-index per alpha per resolution.
results_df = pd.read_csv(csv_path)
resolutions = results_df.resolution.unique()
results_subset_df = results_df[results_df.alpha==1.2067926406393288]

all_data = list()
for row in results_subset_df.iterrows():
    resolution, alpha, fold, tcga_train, tcga_test, nyu_additional = row[1].values
    all_data.append((resolution, fold, 'TCGA Train', tcga_train))
    all_data.append((resolution, fold, 'TCGA Test', tcga_test))
    all_data.append((resolution, fold, 'NYU Cohort', nyu_additional))
    
all_data = pd.DataFrame(all_data, columns=['Resolution', 'Fold', 'Set', 'C-Index'])

fontsize_labels = 14
lw = 3
l_box_w = 3

for x_label in [ 'Resolution']:
    sns.set_theme(style='darkgrid')
    fig, ax = plt.subplots(figsize=(20, 7), nrows=1, ncols=1)
    sns.pointplot(x=x_label, hue='Set', y='C-Index', data=all_data, ax=ax, dodge=.3, join=False, capsize=.00, markers='o', errorbar=('ci', 95))
    ax.set_ylim([0.4, 0.8])
    ax.set_title('LUAD Overall Survival', fontweight='bold', fontsize=18)
    ax.legend(loc='upper left')
    start, end = ax.get_ylim()
    ax.yaxis.set_ticks(np.arange(start, end, 0.05))
    for tick in ax.xaxis.get_major_ticks():
        tick.label1.set_fontsize(fontsize_labels)
        tick.label1.set_fontweight('bold')
    for tick in ax.yaxis.get_major_ticks():
        tick.label1.set_fontsize(fontsize_labels)
        tick.label1.set_fontweight('bold')

    for axis in ['top','bottom','left','right']:
        ax.spines[axis].set_linewidth(4)

    ax.set_ylabel('Concordance Index', fontweight='bold', size=fontsize_labels+2)
    if x_label == 'Number HPCs':
        ax.set_xlabel('Number of HPCs', fontweight='bold', size=fontsize_labels+2)
    else:
        ax.set_xlabel('Leiden Resolution Parameter', fontweight='bold', size=fontsize_labels+2)


    legend = ax.legend_
    for line in legend.get_lines():
        line.set_linewidth(lw)
    legend.get_frame().set_linewidth(l_box_w)
    for i in range(len(legend.get_texts())):
        legend.get_texts()[i].set_fontweight('bold')
        legend.get_texts()[i].set_fontsize(fontsize_labels)
    plt.show()


# Paper Figure - HPC Samples

In [None]:
# Workspace path.
main_path = '/media/adalberto/Disk2/PhD_Workspace'

dataset            = 'TCGAFFPE_LUADLUSC_5x_60pc'
additional_dataset = 'NYUFFPE_survival_5x_60pc'

data = Data(dataset=dataset, marker='he', patch_h=224, patch_w=224, n_channels=3, batch_size=64, project_path=main_path, load=True)
img_dicts = dict()
img_dicts['train'] = data.training.images
img_dicts['valid'] = data.validation.images
img_dicts['test'] = data.test.images

additional_data = Data(dataset=additional_dataset, marker='he', patch_h=224, patch_w=224, n_channels=3, batch_size=64, project_path=main_path, load=True)
additional_img_dicts = dict()
additional_img_dicts['train'] = additional_data.training.images

In [None]:
folds = load_existing_split(folds_pickle)
dataframes, complete_df, leiden_clusters = read_csvs(adatas_path, matching_field, groupby, fold_number, folds[fold_number], h5_complete_path, h5_additional_path)
train_df, valid_df, test_df, additional_df = dataframes


In [None]:
def cluster_set_images(review_clusters, frame, data_dicts, groupby, batches=1, ncols=20, nrows=4, annotated=False, figures_path=None):

    if figures_path is not None:
        figures_path = os.path.join(figures_path, 'hpc_tile_samples')
        if not os.path.isdir(figures_path):
            os.makedirs(figures_path)

    for cluster_id in review_clusters:
        indexes       = frame[(frame[groupby]==cluster_id)]['indexes'].values.tolist()
        original_sets = frame[(frame[groupby]==cluster_id)]['original_set'].values.tolist()
        combined      = list(zip(indexes, original_sets))
        random.shuffle(combined)
        combined_plot = sorted(combined[:100*batches])

        csv_information = list()
        images_cluster = list()
        for index, original_set in combined_plot:
            images_cluster.append(data_dicts[original_set][int(index)]/255.)
            entry_dict = frame[(frame.indexes==index)&(frame.original_set==original_set)].to_dict('index')
            for key in entry_dict:
                csv_information.append(entry_dict[key])

        for batch in range(batches):
            fig, axs = plt.subplots(ncols=ncols, nrows=nrows)
            fig.set_figheight(8)
            fig.set_figwidth(8*(ncols/4)*0.8)
            if annotated:
                fig.suptitle('HPC %s - TCGA' % (cluster_id), ha='center', fontweight='bold', fontsize=65)
            else:
                fig.suptitle('HPC %s' % (cluster_id), ha='center', fontweight='bold', fontsize=65)
            gs = axs[0, -4].get_gridspec()
            # remove the underlying axes
            for i in range(ncols-4,ncols):
                for ax in axs[0:, i]:
                    ax.remove()
            axbig = fig.add_subplot(gs[0:, -4:])
            axbig.set_xticks([])
            axbig.set_yticks([])
            axbig.set_yticks([])
            axes_list = list(axs.flatten())
            axes_list.append(axbig)
            for ax, im in zip(axes_list, images_cluster[batch*100:(batch+1)*100]):
                ax.imshow(im)
                ax.set_xticks([])
                ax.set_yticks([])
                ax.set_yticks([])
                for axis in ['top','bottom','left','right']:
                    ax.spines[axis].set_linewidth(4)
            plt.subplots_adjust(wspace=0.05, hspace=0.05)
            fig.tight_layout()
            if figures_path is None:
                plt.show()
            else:
                plt.savefig(os.path.join(figures_path, 'HPC_%s_TCGA_batch_%s.jpg' % (cluster_id, batch)))
                plt.close()

annotated       = True
sns.set_theme(style='white')
cluster_set_images(leiden_clusters, frame=train_df, data_dicts=img_dicts, groupby=groupby, batches=1, annotated=annotated, figures_path=None)


In [None]:
def get_crosscheck_frame(hdf5_path, original_set='train'):
    with h5py.File(hdf5_path, 'r') as content:
        for key in content.keys():
            if 'slides' in key:
                slides_key = key
            elif 'tiles' in key:
                tiles_key = key
        tiles  = content[tiles_key][:].astype('U13')
        slides = content[slides_key][:].astype('U13')
        indexes = list(range(tiles.shape[0]))
    frame_cc = pd.DataFrame(indexes, columns=['indexes'])
    frame_cc['tiles']  = tiles
    frame_cc['slides'] = slides
    frame_cc['original_set'] = original_set
    return frame_cc

def cross_check_dfs(additional_df, frame_cc, matching_fields=['slides', 'tiles']):
    additional_df['slides'] = additional_df['slides'].astype(str)
    additional_df['tiles']  = additional_df['tiles'].astype(str)
    frame_cc['slides']      = frame_cc['slides'].astype(str)
    frame_cc['tiles']       = frame_cc['tiles'].astype(str)
    cross_checked_df = frame_cc.merge(additional_df, how='inner', left_on=matching_fields, right_on=matching_fields)
    return cross_checked_df

def cluster_set_images_add(review_clusters, frame, hdf5_path, groupby, add_cohort, img_key='img', batches=1, ncols=20, nrows=4, figures_path=None):

    if figures_path is not None:
        figures_path = os.path.join(figures_path, 'hpc_tile_samples')
        if not os.path.isdir(figures_path):
            os.makedirs(figures_path)

    with h5py.File(hdf5_path, 'r') as content:

        for key in content.keys():
            if 'img' in key or 'images' in key:
                img_key = key
                break

        for cluster_id in review_clusters:
            indexes       = frame[(frame[groupby]==cluster_id)]['indexes'].values.tolist()
            original_sets = frame[(frame[groupby]==cluster_id)]['original_set'].values.tolist()
            combined      = list(zip(indexes, original_sets))
            random.shuffle(combined)
            combined_plot = sorted(combined[:100*batches])

            csv_information = list()
            images_cluster = list()
            for index, original_set in combined_plot:
                images_cluster.append(content[img_key][int(index)]/255.)
                entry_dict = frame[(frame.indexes==index)&(frame.original_set==original_set)].to_dict('index')
                for key in entry_dict:
                    csv_information.append(entry_dict[key])


            for batch in range(batches):
                fig, axs = plt.subplots(ncols=ncols, nrows=nrows)
                fig.set_figheight(8)
                fig.set_figwidth(8*(ncols/4)*0.8)
                fig.suptitle('HPC %s - %s' % (cluster_id, add_cohort), ha='center', fontweight='bold', fontsize=65)
                gs = axs[0, -4].get_gridspec()
                # remove the underlying axes
                for i in range(ncols-4,ncols):
                    for ax in axs[0:, i]:
                        ax.remove()
                axbig = fig.add_subplot(gs[0:, -4:])
                axbig.set_xticks([])
                axbig.set_yticks([])
                axbig.set_yticks([])
                axes_list = list(axs.flatten())
                axes_list.insert(0, axbig)
                j = 0
                for ax, im in zip(axes_list, images_cluster[batch*100:(batch+1)*100]):
                    ax.imshow(im)
                    ax.set_xticks([])
                    ax.set_yticks([])
                    ax.set_yticks([])
                    for axis in ['top','bottom','left','right']:
                        ax.spines[axis].set_linewidth(4)
                    j += 1
                if j != len(axes_list):
                    for i, ax in enumerate(axes_list[j:]):
                        ax.imshow(np.ones((224,224,3)))
                        ax.set_xticks([])
                        ax.set_yticks([])
                        ax.set_yticks([])
                        for axis in ['top','bottom','left','right']:
                            ax.spines[axis].set_linewidth(4)

                plt.subplots_adjust(wspace=0.05, hspace=0.05)
                fig.tight_layout()
                if figures_path is None:
                    plt.show()
                else:
                    plt.savefig(os.path.join(figures_path, 'HPC_%s_%s_batch_%s.jpg' % (cluster_id, add_cohort, batch)))
                    plt.close()
                if j != len(axes_list): break

sns.set_theme(style='white')

hdf5_path = '%s/datasets/NYUFFPE_survival_5x_60pc/he/patches_h224_w224/hdf5_NYUFFPE_survival_5x_60pc_he_train.h5' % main_path

frame_cc       = get_crosscheck_frame(hdf5_path, original_set='additional')
cross_check_df = cross_check_dfs(additional_df, frame_cc, matching_fields=['slides', 'tiles'])

cluster_set_images_add(leiden_clusters, frame=cross_check_df, hdf5_path=hdf5_path, groupby=groupby, add_cohort='NYU', img_key='img', batches=3, figures_path=None)