In [71]:
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.colors import TwoSlopeNorm
from matplotlib.pyplot import rc_context
from mpl_toolkits.axes_grid1 import ImageGrid
from skimage.transform import resize
import matplotlib.pyplot as plt
import matplotlib

from plottify import autosize
from PIL import Image
import seaborn as sns
import numpy as np
import pandas as pd
import scanpy as sc
import fastcluster
import anndata
import random
import h5py
import sys
import os

main_path = '/media/adalberto/Disk2/PhD_Workspace'
sys.path.append(main_path)
from models.clustering.correlations import *
from models.clustering.data_processing import *
from models.visualization.attention_maps import *
from data_manipulation.data import Data

# Figure method

In [72]:
def plot_clustermap(all_data_rho, mask, x_label, y_label, directory, file_name, figsize, vcenter=0, annot=True, fmt='.2f', cox_os_clusters=None, cox_pfs_clusters=None, col_linkage=None,
                    fontsize_ticks=28, fontsize_labels=30, fontsize_annot=20, dendrogram_ratio=0.2, show_row_cols=False, show=False, p_th=0.05):
    with rc_context({'figure.figsize': figsize}):
        colors        = None
        colors_masked = None
        if cox_os_clusters is not None:
            # Column colors.
            coef_df   = cox_os_clusters.sort_values(by=groupby)
            cmap_PiYG = plt.cm.PiYG_r
            norm      = TwoSlopeNorm(vmin=coef_df['coef'].min(), vcenter=0, vmax=coef_df['coef'].max())
            column_os_colors              = pd.Series([cmap_PiYG(norm(coef)) for p, coef in zip(coef_df['p'], coef_df['coef'])], name='Cox Coefficient\nOverall Survival')
            column_os_colors_masked       = pd.Series([cmap_PiYG(norm(coef)) if p <p_th else cmap_PiYG(norm(0))[:3] for p, coef in zip(coef_df['p'], coef_df['coef'])], name='Cox Coefficient\nOverall Survival')
            column_os_colors_masked.index = coef_df[groupby].astype(str)
            column_os_colors.index        = coef_df[groupby].astype(str)
            colors        = column_os_colors
            colors_masked = column_os_colors_masked

            if cox_pfs_clusters is not None:
                cox_pfs_clusters = cox_pfs_clusters.sort_values(by=groupby)
                cmap_PiYG = plt.cm.PiYG_r
                norm      = TwoSlopeNorm(vmin=cox_pfs_clusters['coef'].astype(float).min(), vcenter=0, vmax=cox_pfs_clusters['coef'].astype(float).max())
                column_pfs_colors        = pd.Series([cmap_PiYG(norm(coef)) for p, coef in zip(cox_pfs_clusters['p'], cox_pfs_clusters['coef'])],
                                                     name='Cox Coefficient\nProgression Free Survival')
                column_pfs_colors_masked = pd.Series([cmap_PiYG(norm(coef)) if p <p_th else cmap_PiYG(norm(0))[:3] for p, coef in zip(cox_pfs_clusters['p'], cox_pfs_clusters['coef'])],
                                                     name='Cox Coefficient\nProgression Free Survival')
                column_pfs_colors.index        = coef_df[groupby].astype(str)
                column_pfs_colors_masked.index = coef_df[groupby].astype(str)

                colors = pd.concat([column_os_colors, column_pfs_colors],axis=1)
                colors_masked = pd.concat([column_os_colors_masked, column_pfs_colors_masked],axis=1)

        for name, col_colors in [('', colors), ('_masked', colors_masked)]:
            sns.set_theme(style='white')
            vref = np.max(np.abs(all_data_rho.values))
            if vcenter == 0:
                norm = TwoSlopeNorm(vmin=-vref, vcenter=vcenter, vmax=vref)
            else:
                norm = TwoSlopeNorm(vmin=all_data_rho.values.min(), vcenter=vcenter, vmax=all_data_rho.values.max())
            if col_colors is not None:
                g = sns.clustermap(all_data_rho, vmin=-vref, vmax=vref, method='ward', metric='euclidean', annot=annot, mask=mask, col_colors=col_colors, col_linkage=col_linkage, fmt=fmt,
                                   norm=norm, cmap=sns.diverging_palette(250, 20, as_cmap=True), dendrogram_ratio=dendrogram_ratio, annot_kws={"size": fontsize_annot},  yticklabels=True,  xticklabels=True)
                g.ax_col_colors.set_yticklabels(g.ax_col_colors.get_ymajorticklabels(), fontsize=fontsize_ticks)
            else:
                g = sns.clustermap(all_data_rho, vmin=-vref, vmax=vref, method='ward', metric='euclidean', annot=annot, mask=mask, col_linkage=col_linkage, fmt=fmt,
                                   norm=norm, cmap=sns.diverging_palette(250, 20, as_cmap=True), dendrogram_ratio=dendrogram_ratio, annot_kws={"size": fontsize_annot},  yticklabels=True,  xticklabels=True)
            g.ax_heatmap.set_ylabel('\n%s' % y_label, fontsize=fontsize_labels)
            g.ax_heatmap.set_xlabel('\n%s' % x_label, fontsize=fontsize_labels)
            g._figure.set_size_inches(figsize[0]*1.1, figsize[1]*1.1)
            g.ax_heatmap.set_xticklabels(g.ax_heatmap.get_xmajorticklabels(), fontsize=fontsize_ticks)
            g.ax_heatmap.set_yticklabels(g.ax_heatmap.get_ymajorticklabels(), fontsize=fontsize_ticks)
            g.ax_cbar.tick_params(labelsize=fontsize_ticks)
            if show:
                plt.show()
            else:
                plt.savefig('%s/%s' % (directory, file_name.replace('.jpg', '%s.jpg' % name)))
                plt.close(g._figure)

            if col_colors is None:
                break
    return g

# Variables for visualization

In [75]:
# Dataset name for images.
dataset            = 'TCGAFFPE_LUADLUSC_5x_60pc'
additional_dataset = 'NYUFFPE_survival_5x_60pc'

# Clustering folder details.
meta_folder    = 'subtypes_nn250'
meta_field     = 'labels'
matching_field = 'slides'

resolution     = 9.5
groupby        = 'leiden_%s' % resolution
fold_number    = 0

folds_pickle       = '%s/utilities/files/LUAD/overall_survival_TCGA_folds.pkl' % main_path
h5_complete_path   = '%s/results/BarlowTwins_3/TCGAFFPE_LUADLUSC_5x_60pc_250K/h224_w224_n3_zdim128_filtered/hdf5_TCGAFFPE_LUADLUSC_5x_60pc_he_complete_lungsubtype_survival_filtered.h5' % main_path
h5_additional_path = None

file_name = h5_complete_path.split('/hdf5_')[1].split('.h5')[0] + '_%s__fold%s' % (groupby.replace('.', 'p'), fold_number)
if h5_additional_path is not None: file_additional = h5_additional_path.split('/hdf5_')[1].split('.h5')[0] + '_%s__fold%s' % (groupby.replace('.', 'p'), fold_number)

# Setup folder esqueme
main_cluster_path = h5_complete_path.split('hdf5_')[0]
main_cluster_path = os.path.join(main_cluster_path, meta_folder)
figures_path      = os.path.join(main_cluster_path, 'figures')
if not os.path.isdir(figures_path):
    os.makedirs(figures_path)


In [89]:
# Read H5 AnnData file where the clustering was done.
adata_train, h5ad_path = read_h5ad_reference(h5_complete_path, meta_folder, groupby, fold_number)

# Read slide representations.
all_data = build_cohort_representations(meta_folder, meta_field, matching_field, groupby, fold_number, folds_pickle, h5_complete_path, None,
                                        type_composition='clr', min_tiles=100, use_conn=False, use_ratio=False, top_variance_feat=0, reduction=2)

complete_df, additional_complete_df, frame_clusters, frame_samples, features = all_data

## Clustermap Slide representatations vs Clusters

In [None]:
def clustermap_representations(features, slide_rep_df, adata, leiden_linkage_method, leiden_cor_method, method_slides, metric_slides, figsize, directory, file_name, use_leiden=False):
    slide_rep_df = slide_rep_df.iloc[1:].copy(deep=True).reset_index()

    col_colors = None

    if use_leiden:
        # Leiden clusters dendrogram.
        sc.tl.dendrogram(adata, groupby, use_rep='X', linkage_method=leiden_linkage_method, cor_method=leiden_cor_method)
        col_linkage = adata.uns['dendrogram_%s' % groupby]['linkage']
    else:
        col_linkage = fastcluster.linkage(slide_rep_df[features].transpose(), method=method_slides, metric=metric_slides)

    # Row hierarchical clustering
    row_linkage = fastcluster.linkage(slide_rep_df[features], method=method_slides, metric=metric_slides)

    # Row and Columns colors
    # TODO: This is hard value for color, needs to be modified for different number of labels.
    row_lut = dict(zip(np.unique(slide_rep_df[meta_field]), ['blue', 'orange']))
    print(row_lut)
    row_colors = pd.Series(slide_rep_df[meta_field].map(row_lut), name='Cancer Type\nSlide\n')

    g = sns.clustermap(slide_rep_df[features].astype(float), dendrogram_ratio=(.05, .235),  row_colors=row_colors, col_colors=col_colors, row_linkage=row_linkage, col_linkage=col_linkage, xticklabels=True, cmap='rocket_r', figsize=figsize)
    g.ax_heatmap.set_xlabel('Clusters', fontsize=12)
    g.ax_heatmap.set_ylabel('Slides', fontsize=12)
    g.ax_heatmap.set_yticks([])
    g.ax_row_colors.tick_params(axis='both', length=0)
    if col_colors is not None: g.ax_col_colors.tick_params(axis='both', length=0)
    plt.savefig(os.path.join(directory, file_name))
    plt.show()
    plt.close()

use_leiden            = False
leiden_linkage_method = 'average'
leiden_cor_method     = 'spearman'
method_slides         = 'ward'
metric_slides         = 'correlation'

clustermap_representations(features, complete_df, adata_train, leiden_linkage_method, leiden_cor_method, method_slides, metric_slides, figsize=(30,10), directory=figures_path, file_name=file_name+'_representations.jpg', use_leiden=use_leiden)


## Cluster Network

In [4]:
done = False
if os.path.isfile(h5ad_path.replace('.h5ad', '_paga.h5ad')):
    done=True
    adata_train = anndata.read_h5ad(h5ad_path.replace('.h5ad', '_paga.h5ad'))
else:
    sc.tl.paga(adata_train, groups=groupby, neighbors_key='nn_leiden')

In [None]:
# Graph visualization related
# This next variable can be modified so it more clearly shows the cluster network
# Layout corresponds to the technique for the cluster network layout.
# Threshold corresponds to threshold for the edge connection between nodes, higher values keep only stronger bonds.
layout           = 'fa'  # ‘fa’, ‘fr’, ‘rt’, ‘rt_circular’, ‘drl’, ‘eq_tree’
random_state     = 3
threshold        = 1.0

# Figure related
node_size_scale  = 5
node_size_power  = 0.5
edge_width_scale = .01
fontsize    = 10
fontoutline = 2

fig = plt.figure(figsize=(100,10))
ax  = fig.add_subplot(1, 3, 1)
sc.pl.paga(adata_train, layout=layout, random_state=random_state, color=meta_field, threshold=threshold, node_size_scale=node_size_scale, node_size_power=node_size_power,
           edge_width_scale=edge_width_scale, fontsize=fontsize, fontoutline=fontoutline, frameon=False, show=False, ax=ax)
plt.show()

In [29]:
# Run UMAP based on PAGA cluster visualization.
if not done:
    sc.tl.umap(adata_train, init_pos="paga", neighbors_key='nn_leiden')
    adata_train.write(h5ad_path.replace('.h5ad', '_paga.h5ad'))

In [None]:
# Figure related
marker_size   = 10
l_markerscale = 2
l_size        = 18
l_t_size      = 20
l_box_w       = 3
ax_size       = 30

fig = plt.figure(figsize=(100,10))
ax  = fig.add_subplot(1, 3, 1)
ax = sc.pl.umap(adata_train, ax=ax, color=meta_field, size=marker_size, show=False, frameon=False, na_color='black')
if meta_field is not None:
    legend_c = ax.legend(loc='best', markerscale=l_markerscale, title='Subtype', prop={'size': l_size})
    legend_c.get_title().set_fontsize(l_t_size)
    legend_c.get_frame().set_linewidth(l_box_w)
    # TODO: This is hard value for color, needs to be modified for different number of labels.
    legend_c.get_texts()[0].set_text('LUAD')
    legend_c.get_texts()[1].set_text('LUSC')
ax.set_title('Tile Representations', fontsize=ax_size, fontweight='bold')
plt.show()