In [2]:
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
from matplotlib.colors import LinearSegmentedColormap
import statsmodels.api as sm
from skimage.transform import resize
from plottify import autosize
from sklearn import metrics
from PIL import Image
import anndata
import random
import numpy as np
import seaborn as sns
import pandas as pd
import scanpy as sc
import fastcluster
import umap
import h5py
import sys
import os

sys.path.append('/media/adalberto/Disk2/PhD_Workspace')
from models.clustering.cox_proportional_hazard_regression_leiden_clusters import *
from models.visualization.attention_maps import *
from models.clustering.data_processing import *
from data_manipulation.data import Data

  warn("Tensorflow not installed; ParametricUMAP will be unavailable")


# Paper Figure - Forest plot

In [None]:
main_folder = '/media/adalberto/Disk2/PhD_Workspace'

meta_field      = 'luad'
matching_field  = 'slides'
resolution     = 2.0
fold_number    = 0
groupby        = 'leiden_%s' % resolution
meta_folder    = 'luad_overall_survival_nn250_clusterfold0'
alpha          = 1.0

coeff_csv = '%s/results/BarlowTwins_3/TCGAFFPE_LUADLUSC_5x_60pc_250K/h224_w224_n3_zdim128_filtered/%s/%s_%s_alpha_%s_l1ratio_0p0_mintiles_100/%s_stats_all_folds.csv' % (main_folder, meta_folder, meta_folder, groupby, str(alpha).replace('.', 'p'), groupby.replace('.', 'p'))
coeff_frame = pd.read_csv(coeff_csv)
coeff_frame

In [5]:
from matplotlib.font_manager import FontProperties
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.ticker as mticker

# Forest Plot for Cox Proportional Hazards Regression coefficients.
class EffectMeasurePlot_Cox:
    """Used to generate effect measure plots. effectmeasure plot accepts four list type objects.
    effectmeasure_plot is initialized with the associated names for each line, the point estimate,
    the lower confidence limit, and the upper confidence limit.
    Plots will resemble the following form:
        _____________________________________________      Measure     % CI
        |                                           |
    1   |        --------o-------                   |       x        n, 2n
        |                                           |
    2   |                   ----o----               |       w        m, 2m
        |                                           |
        |___________________________________________|
        #           #           #           #
    The following functions (and their purposes) live within effectmeasure_plot
    labels(**kwargs)
        Used to change the labels in the plot, as well as the center and scale. Inputs are
        keyword arguments
        KEYWORDS:
            -effectmeasure  + changes the effect measure label
            -conf_int       + changes the confidence interval label
            -scale          + changes the scale to either log or linear
            -center         + changes the reference line for the center
    colors(**kwargs)
        Used to change the color of points and lines. Also can change the shape of points.
        Valid colors and shapes for matplotlib are required. Inputs are keyword arguments
        KEYWORDS:
            -errorbarcolor  + changes the error bar colors
            -linecolor      + changes the color of the reference line
            -pointcolor     + changes the color of the points
            -pointshape     + changes the shape of points
    plot(t_adjuster=0.01,decimal=3,size=3)
        Generates the effect measure plot of the input lists according to the pre-specified
        colors, shapes, and labels of the class object
        Arguments:
            -t_adjuster     + used to refine alignment of the table with the line graphs.
                              When generate plots, trial and error for this value are usually
                              necessary
            -decimal        + number of decimal places to display in the table
            -size           + size of the plot to generate
    Example)
    >>>lab = ['One','Two'] #generating lists of data to plot
    >>>emm = [1.01,1.31]
    >>>lcl = ['0.90',1.01]
    >>>ucl = [1.11,1.53]
    >>>
    >>>x = zepid.graphics.effectmeasure_plot(lab,emm,lcl,ucl) #initializing effectmeasure_plot with the above lists
    >>>x.labels(effectmeasure='RR') #changing the table label to 'RR'
    >>>x.colors(pointcolor='r') #changing the point colors to red
    >>>x.plot(t_adjuster=0.13) #generating the effect measure plot
    """

    def __init__(self, label, effect_measure, lcl, ucl, pvalues, counts, mean_tp, max_tp, perc_pat, center=0):
        """Initializes effectmeasure_plot with desired data to plot. All lists should be the same
        length. If a blank space is desired in the plot, add an empty character object (' ') to
        each list at the desired point.
        Inputs:
        label
            -list of labels to use for y-axis
        effect_measure
            -list of numbers for point estimates to plot. If point estimate has trailing zeroes,
             input as a character object rather than a float
        lcl
            -list of numbers for upper confidence limits to plot. If point estimate has trailing
             zeroes, input as a character object rather than a float
        ucl
            -list of numbers for upper confidence limits to plot. If point estimate has
             trailing zeroes, input as a character object rather than a float
        """
        self.df = pd.DataFrame()
        self.df['study'] = label
        self.df['OR']    = effect_measure
        self.df['LCL']   = lcl
        self.df['UCL']   = ucl
        self.df['P']     = pvalues
        self.df['C']     = counts
        self.df['M']     = mean_tp
        self.df['Ma']    = max_tp
        self.df['Pp']    = perc_pat
        self.df['OR2']   = self.df['OR'].astype(str).astype(float)
        if (all(isinstance(item, float) for item in lcl)) & (all(isinstance(item, float) for item in effect_measure)):
            self.df['LCL_dif'] = self.df['OR'] - self.df['LCL']
        else:
            self.df['LCL_dif'] = (pd.to_numeric(self.df['OR'])) - (pd.to_numeric(self.df['LCL']))
        if (all(isinstance(item, float) for item in ucl)) & (all(isinstance(item, float) for item in effect_measure)):
            self.df['UCL_dif'] = self.df['UCL'] - self.df['OR']
        else:
            self.df['UCL_dif'] = (pd.to_numeric(self.df['UCL'])) - (pd.to_numeric(self.df['OR']))
        self.em       = 'OR'
        self.ci       = '95% CI'
        self.p        = 'P-Value'
        self.counts   = 'Tile\nCounts'
        self.mean_tp  = 'Mean Tile\nPer Pat.'
        self.max_tp   = 'Max Tile\nPer Pat.'
        self.perc_pat = 'Patients\n%'
        self.scale    = 'linear'
        self.center   = center
        self.errc     = 'dimgrey'
        self.shape    = 'o'
        self.pc       = 'k'
        self.linec    = 'gray'

    def labels(self, **kwargs):
        """Function to change the labels of the outputted table. Additionally, the scale and reference
        value can be changed.
        Accepts the following keyword arguments:
        effectmeasure
            -changes the effect measure label
        conf_int
            -changes the confidence interval label
        scale
            -changes the scale to either log or linear
        center
            -changes the reference line for the center
        """
        if 'effectmeasure' in kwargs:
            self.em = kwargs['effectmeasure']
        if 'ci' in kwargs:
            self.ci = kwargs['conf_int']
        if 'scale' in kwargs:
            self.scale = kwargs['scale']
        if 'center' in kwargs:
            self.center = kwargs['center']

    def colors(self, **kwargs):
        """Function to change colors and shapes.
        Accepts the following keyword arguments:
        errorbarcolor
            -changes the error bar colors
        linecolor
            -changes the color of the reference line
        pointcolor
            -changes the color of the points
        pointshape
            -changes the shape of points
        """
        if 'errorbarcolor' in kwargs:
            self.errc = kwargs['errorbarcolor']
        if 'pointshape' in kwargs:
            self.shape = kwargs['pointshape']
        if 'linecolor' in kwargs:
            self.linec = kwargs['linecolor']
        if 'pointcolor' in kwargs:
            self.pc = kwargs['pointcolor']

    def plot(self, bbox, figsize=(3, 3), t_adjuster=0.01, decimal=3, size=3, max_value=None, min_value=None, fontsize=12, p_th=0.05, strict=False):
        """Generates the matplotlib effect measure plot with the default or specified attributes.
        The following variables can be used to further fine-tune the effect measure plot
        t_adjuster
            -used to refine alignment of the table with the line graphs. When generate plots, trial
             and error for this value are usually necessary. I haven't come up with an algorithm to
             determine this yet...
        decimal
            -number of decimal places to display in the table
        size
            -size of the plot to generate
        max_value
            -maximum value of x-axis scale. Default is None, which automatically determines max value
        min_value
            -minimum value of x-axis scale. Default is None, which automatically determines min value
        """
        tval = []
        ytick = []
        for i in range(len(self.df)):
            if (np.isnan(self.df['OR2'][i]) == False):
                if ((isinstance(self.df['OR'][i], float)) & (isinstance(self.df['LCL'][i], float)) & (isinstance(self.df['UCL'][i], float))):
                    list_val = [round(self.df['OR2'][i], decimal), ('(' + str(round(self.df['LCL'][i], decimal)) + ', ' + str(round(self.df['UCL'][i], decimal)) + ')'), str(self.df['P'][i]),
                                self.df['C'][i], self.df['M'][i], self.df['Ma'][i], self.df['Pp'][i]]
                    tval.append(list_val)
                else:
                    list_val = [self.df['OR'][i], ('(' + str(self.df['LCL'][i]) + ', ' + str(self.df['UCL'][i]) + ')'), self.df['P'][i], self.df['C'][i],
                                self.df['M'][i], self.df['Ma'][i], self.df['Pp'][i]]
                    tval.append()
                ytick.append(i)
            else:
                tval.append([' ', ' ', ' ', ' '])
                ytick.append(i)
        if max_value is None:
            if pd.to_numeric(self.df['UCL']).max() < 1:
                maxi = round(((pd.to_numeric(self.df['UCL'])).max() + 0.05),
                             2)  # setting x-axis maximum for UCL less than 1
            if (pd.to_numeric(self.df['UCL']).max() < 9) and (pd.to_numeric(self.df['UCL']).max() >= 1):
                maxi = round(((pd.to_numeric(self.df['UCL'])).max() + 1),
                             0)  # setting x-axis maximum for UCL less than 10
            if pd.to_numeric(self.df['UCL']).max() > 9:
                maxi = round(((pd.to_numeric(self.df['UCL'])).max() + 10),
                             0)  # setting x-axis maximum for UCL less than 100
        else:
            maxi = max_value
        if min_value is None:
            if pd.to_numeric(self.df['LCL']).min() > 0:
                mini = round(((pd.to_numeric(self.df['LCL'])).min() - 0.1), 1)  # setting x-axis minimum
            if pd.to_numeric(self.df['LCL']).min() < 0:
                mini = round(((pd.to_numeric(self.df['LCL'])).min() - 0.05), 2)  # setting x-axis minimum
        else:
            mini = min_value
        plt.figure(figsize=figsize)  # blank figure
        gspec = gridspec.GridSpec(1, 6)  # sets up grid
        plot = plt.subplot(gspec[0, 0:4])  # plot of data
        tabl = plt.subplot(gspec[0, 4:])  # table of OR & CI
        plot.set_ylim(-1, (len(self.df)))  # spacing out y-axis properly
        if self.scale == 'log':
            try:
                plot.set_xscale('log')
            except:
                raise ValueError('For the log scale, all values must be positive')
        plot.axvline(self.center, color=self.linec, zorder=1)
        plot.errorbar(self.df.OR2, self.df.index, xerr=[self.df.LCL_dif, self.df.UCL_dif], marker='None', zorder=2, ecolor=self.errc, elinewidth=size*0.3, linewidth=0)
        plot.scatter(self.df.OR2, self.df.index, c=self.pc, s=(size * 25), marker=self.shape, zorder=3, edgecolors='None')
        plot.xaxis.set_ticks_position('bottom')
        plot.yaxis.set_ticks_position('left')
        plot.get_xaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())
        plot.get_xaxis().set_minor_formatter(matplotlib.ticker.NullFormatter())
        plot.set_yticks(ytick, fontsize=fontsize)
        plot.set_xlim([mini, maxi])
        plot.set_xticks([mini, self.center, maxi], fontsize=fontsize)
        plot.set_xticklabels([mini, self.center, maxi], fontsize=fontsize, fontweight='bold')
        plot.set_yticklabels(self.df.study, fontsize=fontsize, fontweight='bold')
        plot.yaxis.set_ticks_position('none')
        plot.invert_yaxis()  # invert y-axis to align values properly with table
        tb = tabl.table(cellText=tval, cellLoc='center', loc='right', colLabels=[self.em, self.ci, self.p, self.counts, self.mean_tp, self.max_tp, self.perc_pat], bbox=bbox)
        tabl.axis('off')
        tb.auto_set_font_size(False)
        tb.set_fontsize(fontsize)
        for (row, col), cell in tb.get_celld().items():
            flag = True
            c_pvalue = self.df['P'].values[row-1]
            coeff = self.df['OR'].values[row-1]
            lcl   = self.df['LCL'].values[row-1]
            ucl   = self.df['UCL'].values[row-1]
            pat_p = self.df['Pp'].values[row-1]
            if strict:
                if coeff > 0:
                    if lcl < 0:
                        flag = False
                else:
                    if ucl > 0:
                        flag = False
            if c_pvalue > p_th:
                flag = False
            if  pat_p < 10:
                flag = False
            if flag and row !=0:
                cell.set_text_props(fontproperties=FontProperties(size=fontsize))
            else:
                cell.set_text_props(fontproperties=FontProperties(weight='light', size=fontsize))
            if (row == 0):
                cell.set_text_props(fontproperties=FontProperties(weight='bold', size=fontsize))
                cell.set_height(.015)
            cell.set_linewidth(0)
        tb.auto_set_column_width(col=list(range(len([self.em, self.ci, self.p, self.counts, self.mean_tp, self.max_tp, self.perc_pat]))))
        return plot

In [None]:
sns.set_theme(style='white')

frame = coeff_frame.sort_values(by='coef')
frame = frame.drop(frame[frame['coef'].isna()].index)

figsize    = (27,30)
decimal    = 3
size       = 10
fontsize   = 35
p_th       = 0.05
strict     = False
t_adjuster = 0.007

bbox=[0, t_adjuster, 3.5, 1.02]

sns.set_theme(style='white')
groupby   = [value for value in frame.columns if 'leiden' in value][0]
labs      = frame[groupby].values.tolist()
measure   = np.round(frame['coef'],3).values.tolist()
lower     = np.round(frame['coef lower 95%'],3).values.tolist()
upper     = np.round(frame['coef upper 95%'],3).values.tolist()
pvalues   = np.round(frame['p'],3).values.tolist()
subtype   = frame['Subtype'].values.tolist()
purity    = np.round(frame['Subtype Purity(%)'].values,1).tolist()
counts    = frame['Subtype Counts'].values.tolist()
mean_tp   = frame['mean_tile_sample'].values.astype(int).tolist()
max_tp    = np.round(frame['max_tile_sample'].values*100,1).tolist()
perc_pat  = np.round(frame['percent_sample'].values*100,1).tolist()
max_value = max(abs(max(upper)), abs(min(lower)))

p = EffectMeasurePlot_Cox(label=labs, effect_measure=measure, lcl=lower, ucl=upper, pvalues=pvalues, counts=counts, mean_tp=mean_tp, max_tp=max_tp, perc_pat=perc_pat)
p.labels(effectmeasure='Log Hazard\nRatio')
p.colors(pointshape="o")
ax=p.plot(bbox=bbox, figsize=figsize, t_adjuster=t_adjuster, max_value=max_value, min_value=-max_value, size=size, decimal=decimal, fontsize=fontsize, p_th=p_th, strict=strict)
plt.suptitle("HPC\n \n ",x=0.1,y=0.89, fontsize=fontsize, fontweight='bold')
if 'overall' in meta_folder:
    ax.set_xlabel("Favors Survival               Favors Death", fontsize=fontsize, x=0.5, fontweight='bold')
else:
    ax.set_xlabel("Against Recurrence    Favors Recurrence", fontsize=fontsize, x=0.49, fontweight='bold')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(True)
ax.spines['left'].set_visible(False)


# Paper Figure - High Low Risk Groups

In [11]:
main_path = '/media/adalberto/Disk2/PhD_Workspace'

folds_pickle     = '%s/utilities/files/LUAD/overall_survival_TCGA_folds.pkl' % main_path
meta_folder      = 'luad_overall_survival_nn250_clusterfold%s' % fold_number
event_ind_field  = 'os_event_ind'
event_data_field = 'os_event_data'
additional_as_fold = False

resolution         = 2.0
fold_number        = 0
force_fold         = fold_number
alpha              = 1.0
l1_ratio           = 0.0

h5_complete_path   = '%s/results/BarlowTwins_3/TCGAFFPE_LUADLUSC_5x_60pc_250K/h224_w224_n3_zdim128_filtered/hdf5_TCGAFFPE_LUADLUSC_5x_60pc_he_complete_lungsubtype_survival_filtered.h5' % main_path
h5_additional_path = None
matching_field      = 'samples'
type_composition   = 'clr'
min_tiles          = 100
p_th               = 0.05

# Other features
q_buckets         = 2
max_months        = 15.0*15.0
use_conn          = False
use_ratio         = False
top_variance_feat = 0
remove_clusters   = None

In [12]:
groupby     = 'leiden_%s' % resolution
# Get folds from existing split.
folds     = load_existing_split(folds_pickle)
num_folds = len(folds)

# If diversity key is not specified, use the key that represents samples.
diversity_key = matching_field

# Paths.
main_cluster_path = h5_complete_path.split('hdf5_')[0]
main_cluster_path = os.path.join(main_cluster_path, meta_folder)
adatas_path       = os.path.join(main_cluster_path, 'adatas')

# Fold cross-validation performance.
print('')
print('\tResolution', resolution)
risk_groups     = [pd.DataFrame(), pd.DataFrame()]
additional_risk = pd.DataFrame()
cis_folds       = list()
estimators      = list()
for i, fold in enumerate(folds):
    # Read CSV files for train, validation, test, and additional sets.
    dataframes, _, leiden_clusters = read_csvs(adatas_path, matching_field, groupby, i, fold, h5_complete_path, h5_additional_path, additional_as_fold, force_fold)

    # Check clusters and diversity within.
    frame_clusters, frame_samples = create_frames(dataframes[0], groupby, event_ind_field, diversity_key=matching_field, reduction=2)

    # Prepare data for COX.
    data, datas_all, features = prepare_data_survival(dataframes, groupby, leiden_clusters, type_composition, max_months, matching_field, event_ind_field, event_data_field, min_tiles,
                                                      use_conn=use_conn, use_ratio=use_ratio, top_variance_feat=top_variance_feat, remove_clusters=remove_clusters)

    # COX Regression
    estimator, predictions, frame_clusters = train_cox(data, penalizer=alpha, l1_ratio=l1_ratio, robust=True, event_ind_field=event_ind_field, event_data_field=event_data_field, frame_clusters=frame_clusters, groupby=groupby)
    estimators.append(estimator)

    # Evaluation metrics.
    cis = evalutaion_survival(data, predictions, event_ind_field=event_ind_field, event_data_field=event_data_field)
    cis_folds.append([ci[0] for ci in cis])

    # High, low risk groups
    high_lows = get_high_low_risks(predictions, datas_all, i, matching_field, q_buckets=q_buckets)
    risk_groups, additional_risk = combine_risk_groups(risk_groups, additional_risk, high_lows, i, num_folds, matching_field, event_ind_field, event_data_field)

    print('\t\tFold', i, 'Alpha', np.round(alpha,2), 'Train/Valid/Test/Additional C-Index:', '/'.join([str(i) for i in cis_folds[i]]))

test_ci          = np.round(mean_confidence_interval([a[2] for a in cis_folds]), 2)
print('Test Set C-Index Mean & CI: %s (%s-%s)' % (test_ci[0], test_ci[1], test_ci[2]))
if h5_additional_path is not None and not additional_as_fold:
    test_ci          = np.round(mean_confidence_interval([a[-1] for a in cis_folds]), 2)
    print('Additional Set C-Index Mean & CI: %s (%s-%s)' % (test_ci[0], test_ci[1], test_ci[2]))


	Resolution 2.0
		Fold 0 Alpha 1.0 Train/Valid/Test/Additional C-Index: 0.68/None/0.65/None
		Fold 1 Alpha 1.0 Train/Valid/Test/Additional C-Index: 0.69/None/0.55/None
		Fold 2 Alpha 1.0 Train/Valid/Test/Additional C-Index: 0.69/None/0.58/None
		Fold 3 Alpha 1.0 Train/Valid/Test/Additional C-Index: 0.67/None/0.57/None
		Fold 4 Alpha 1.0 Train/Valid/Test/Additional C-Index: 0.67/None/0.63/None
Test Set C-Index Mean & CI: 0.6 (0.56-0.63)


In [None]:
# Survival libs.
from decimal import Decimal
from lifelines import KaplanMeierFitter
from lifelines.statistics import logrank_test

def plot_k_fold_cv_KM(high_risk, low_risk, title, max_months, event_ind_field, event_data_field, fontsize_title, fontsize_labels, fontsize_ticks, fontsize_legend, l_markerscale, l_box_w, lw, ms, ci_show=True):
    results = logrank_test(high_risk[event_data_field].astype(float), low_risk[event_data_field].astype(float), event_observed_A=high_risk[event_ind_field].astype(float), event_observed_B=low_risk[event_ind_field].astype(float))
    title_add = 'P-Value: %.2E ' % (Decimal(results.p_value))

    fig, ax = plt.subplots(ncols=1, nrows=1)
    ncols = 20
    fig.set_figheight(10)
    fig.set_figwidth(10*(ncols/4)*0.8)

    kmf_l = KaplanMeierFitter(label='Low')
    kmf_l.fit(low_risk[event_data_field].astype(float)/12,  event_observed=low_risk[event_ind_field].astype(float))
    kmf_h = KaplanMeierFitter(label='High')
    kmf_h.fit(high_risk[event_data_field].astype(float)/12, event_observed=high_risk[event_ind_field].astype(float))

    kmf_l.plot_survival_function(show_censors=True, ci_show=ci_show, ax=ax, linewidth=lw, censor_styles={'ms': ms, 'marker': '+'})
    kmf_h.plot_survival_function(show_censors=True, ci_show=ci_show, ax=ax, linewidth=lw, censor_styles={'ms': ms, 'marker': '+'})

    # ax.set_title(ax_title)
    ax.set_ylim([0.0,1.10])
    if max_months is not None:
        ax.set_xlim([-0.1, max_months])

    for ticks in [ax.xaxis.get_major_ticks(), ax.yaxis.get_major_ticks()]:
        for tick in ticks:
            tick.label1.set_fontsize(fontsize_ticks)
            tick.label1.set_fontweight('bold')

    ax.set_title(title, fontsize=fontsize_title, fontweight='bold', y=1.02)

    ax.set_xlabel('Time (Years)',         fontsize=fontsize_labels, fontweight='bold')
    ax.set_ylabel('Survival Probability', fontsize=fontsize_labels, fontweight='bold')
    for axis in ['top','bottom','left','right']:
        ax.spines[axis].set_linewidth(4)

    legend = ax.legend(loc='best', markerscale=l_markerscale, title='Risk Group', prop={'size': fontsize_legend})
    legend.get_title().set_fontsize(fontsize_legend)
    legend.get_frame().set_linewidth(l_box_w)
    for line in legend.get_lines():
        line.set_linewidth(l_markerscale)

    plt.show()

fontsize_title  = 50
fontsize_labels = 45
fontsize_ticks  = 42
fontsize_legend = 45
l_markerscale   = 15
l_box_w         = 4
lw              = 5
ms              = 20

test_ci          = np.round(mean_confidence_interval([a[2] for a in cis_folds]), 2)
title            = 'TCGA Cohort'
plot_k_fold_cv_KM(risk_groups[1], risk_groups[0], title, np.sqrt(15.5*15.5), event_ind_field, event_data_field, fontsize_title=fontsize_title, fontsize_labels=fontsize_labels, fontsize_ticks=fontsize_ticks, fontsize_legend=fontsize_legend, l_markerscale=l_markerscale, l_box_w=l_box_w, lw=lw, ms=ms, ci_show=True)
