In [None]:
from functools import reduce
from operator import mul

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import xarray  as xr

In [None]:
import plot

## Load Data ##

In [None]:
gsea_cgp = xr.open_dataset("../analyses/gsea/mri-features-all-fa_c2.cgp_F.nc").load()
gsea_cgp['gene_set'] = np.array([s.decode() for s in gsea_cgp['gene_set'].values], dtype='object')
gsea_cgp['mri_feature'] = np.array([int(f) for f in gsea_cgp['mri_feature'].values])

In [None]:
gsea_cp = xr.open_dataset("../analyses/gsea/mri-features-all-fa_c2.cp_T.nc").load()
gsea_cp['gene_set'] = np.array([s.decode() for s in gsea_cp['gene_set'].values], dtype='object')
gsea_cp['mri_feature'] = np.array([int(f) for f in gsea_cp['mri_feature'].values])

In [None]:
gsea_cp_er = xr.open_dataset("../analyses/gsea/mri-features-er-fa_c2.cp_T.nc").load()
gsea_cp_er['gene_set'] = np.array([s.decode() for s in gsea_cp_er['gene_set'].values], dtype='object')
gsea_cp_er['mri_feature'] = np.array([int(f) for f in gsea_cp_er['mri_feature'].values])

## Plots ##

In [None]:
def sgn_square(x):
    return np.sign(x) * x**2

In [None]:
gene_sets = ['EGUCHI_CELL_CYCLE_RB1_TARGETS', 'FINETTI_BREAST_CANCER_KINOME_RED', 'KALMA_E2F1_TARGETS',
             'SMID_BREAST_CANCER_LUMINAL_A_DN', 'CHANG_CYCLING_GENES', 'ZHOU_CELL_CYCLE_GENES_IN_IR_RESPONSE_24HR']
plot.heatmap(
    gsea_cgp.sel(gene_set=gene_sets)['nes'].T,
)
plot.heatmap(
    np.abs(gsea_cgp.sel(gene_set=gene_sets)['nes'].T),
)
plot.heatmap(
    (gsea_cgp.sel(gene_set=gene_sets)['fdr'].T < 0.25).astype('f8'),
)
plot.heatmap(
    -np.log10(gsea_cgp.sel(gene_set=gene_sets)['fdr']).T,
)
plot.heatmap(
    -np.log10(gsea_cgp.sel(gene_set=gene_sets)['p']).T,
)
plot.heatmap(
    sgn_square(gsea_cgp.sel(gene_set=gene_sets)['nes'].T),
)
plot.heatmap(
    np.sign(gsea_cgp.sel(gene_set=gene_sets)['nes']).T * -np.log10(gsea_cgp.sel(gene_set=gene_sets)['fdr']).T,
);

In [None]:
with plot.subplots(1, 1) as (fig, ax):
    plot.hist(
        gsea_cgp['nes'],
        ax=ax,
    )
    for v in gsea_cgp.sel(gene_set=gene_sets)['nes'].max('mri_feature').values:
        ax.axvline(v)

In [None]:
def wf_plot(vals, highlight, ylabel="", yscale="linear", ylim=None, xbaseline=None):
    vals_order = np.argsort(vals.values)
    vals = vals[vals_order]
    
    x = np.arange(len(vals)) / len(vals) * 100
    
    hl_mask = np.isin(vals['gene_set'], gene_sets)
    x_hl = x[hl_mask]
    vals_hl = vals[hl_mask]
    
    y_invert = False
    if yscale == 'mlog10':
        yscale = 'log'
        y_invert = True
    
    if xbaseline is None:
        if yscale == 'log':
            xbaseline = np.max(x)
        else:
            xbaseline = 0
    
    with plot.subplots(1, 1) as (fig, ax):
        ax.set_yscale(yscale)
        if y_invert:
            ax.invert_yaxis()
        
        ax.fill_between(x, xbaseline, vals, color='#777777')
        #ax.vlines(x, 0, vals, colors='#777777')
        ax.vlines(x_hl, xbaseline, vals_hl, colors='#ff4444')
        
        if ylim is not None:
            ax.set_ylim(ylim)
        
        ax.set_xlabel("Rank (%)")
        ax.set_ylabel(ylabel)

wf_plot(gsea_cgp['nes'][0, :], gene_sets, 'NES')
mesa_mid = int(gsea_cgp['max_es_at'].max() / 2)
wf_plot(gsea_cgp['max_es_at'][0, :], gene_sets, 'Max. ES at', xbaseline=mesa_mid)
wf_plot(gsea_cgp['le_prop'][0, :]*100, gene_sets, 'Leading Edge (%)')
wf_plot(gsea_cgp['p'][0, :], gene_sets, 'p', yscale='mlog10', ylim=[1, 0])