# CNS AP-axis VAE model

This notebook runs through the CNS analysis used in the VAE paper (ref: ).  

We use genes that were found to be differentially expressed and aim to understand how different groups of genes responded to PRC2 knockout.  

We do this by using a variational autoencoder (VAE).  

Our VAE model uses a three layer network, and as input data, takes in the normalised expression matrix including the wild type and response to Eed-cKO.  


### Sections:

#### Part 1: VAE setup
    1) Function setup
    2) Read in processed dataframe and organise this
    3) Build VAE model
    
#### Part 2: VAE as a ranking method
    4) Setup data for visualisations
    5) Quantify separability of latent space
    6) Save VAE ranks to csv to run GSEA in R
    
#### Part 3: VAE to identify coordinating groups of genes
    6) Build gene groups using VAE latent space
    7) Save gene groups for ORA in R
    8) Inspecting the groups in context of chromHMM annotations 
    9) Inspect groups in terms of annotations for Epi paper
    

In [None]:
"""
--------------------------------------------------------
                     Imports
--------------------------------------------------------
"""

import os, sys
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
from sklearn.preprocessing import MinMaxScaler
import venn
from matplotlib_venn import venn3
from sklearn.decomposition import PCA
from scipy import stats
from sciviso import *
from sciutil import SciUtil
from scivae import Vis as VVis
import string
u = SciUtil()

module_path = os.path.abspath(os.path.join(''))
sys.path.append(module_path)


"""
--------------------------------------------------------
                     Global variables
--------------------------------------------------------
"""
date = '20210217'

gene_id = 'entrezgene_id'
gene_name = 'external_gene_name'

sci_colour = ['#483873', '#1BD8A6', '#B117B7', '#AAC7E2', '#FFC107', '#016957', '#9785C0', 
             '#D09139', '#338A03', '#FF69A1', '#5930B1', '#FFE884', '#35B567', '#1E88E5', 
             '#ACAD60', '#A2FFB4', '#B618F5', '#854A9C']

hist_colour = '#483873'

e11_colour = '#BEC0C2'
e13_colour = '#A7A9AC'
e15_colour = '#7D7E81'
e18_colour = '#58595B'

fb_colour = '#ffbf80'
mb_colour = '#ff8c1a'
hb_colour = '#b35900'
sc_colour = '#663300'
fb_color = '#ffbf80'
mb_color = '#ff8c1a'
hb_color = '#b35900'
sc_color = '#663300'

h3k36me3_colour = '#CEF471'
h3k27me3_colour = '#9F71F4'
h3k4me3_colour = '#9F00FA'
h3k4me2_colour = '#5930B1'
h3k4me1_colour = '#FFE884'
h3k27ac_colour = '#35B567'
h3k9me3_colour = '#1E88E5'
h3k9ac_colour = '#A2FFB4'
           
wt_colour = '#AADFF1'
ko_colour = '#A53736'

sns.palplot(sci_colour)
sns.color_palette(sci_colour)

project_name = '6_node_consistent_genes'

data_dir = '../../data/'
r_dir = f'{data_dir}results/deseq2/'
fig_dir = f'../../figures/{project_name}/vae/'
output_dir = f'{data_dir}results/{project_name}/vae/'
input_dir = f'{data_dir}results/prelim/'
rna_dir = f'{input_dir}feature-counts_04052020/'
supp_dir = f'{data_dir}input/supps/'
logging_dir = 'logging/'
ora_dir = '../../data/results/6_node_consistent_genes/functional/'

grey = '#bdbdbd'
hist_cmap = 'Greens'
rna_cmap = 'Purples'
div_cmap = 'seismic'

experiment_name = project_name
from_saved = True  # Load from a saved VAE

In [None]:
df_all = pd.read_csv(f'{input_dir}df-all_epi-2500_20210124.csv')
df_training = pd.read_csv(f'{input_dir}df-consistent_epi-2500_20210124.csv')

## 1) Function setup

Here we just initialise the functions that will be used in the notebook.

In [None]:
"""
--------------------------------------------------------
                     Colour/style getters
--------------------------------------------------------
"""
def get_time_colour(c):
    if '11' in c or '10' in c:
        return e11_colour
    elif '13' in c or '12' in c:
        return e13_colour
    elif '15' in c or '14' in c:
        return e15_colour
    elif '18' in c or '16' in c:
        return e18_colour
    return '#FFFFFF'

def get_tissue_colour(c):
    if 'sc' in c or 'spinal' in c:
        return sc_colour
    elif 'hb' in c or 'hindbrain' in c:
        return hb_colour
    elif 'fb' in c or 'forebrain' in c:
        return fb_colour
    elif 'mb' in c or 'midbrain' in c:
        return mb_colour
    return '#FFFFFF'

def get_cond_colour(c):
    if 'ko' in c:
        return ko_colour
    elif 'wt' in c:
        return wt_colour
    return '#FFFFFF'

def get_mark_colour(c):
    if '36me3' in c:
        return h3k36me3_colour
    elif '27me3' in c:
        return h3k27me3_colour
    elif 'K4me3' in c:
        return h3k4me3_colour
    elif 'K4me2' in c:
        return h3k4me2_colour
    elif 'K4me1' in c:
        return h3k4me1_colour
    elif '27ac' in c:
        return h3k27ac_colour
    elif 'K9me3' in c:
        return h3k9me3_colour
    elif 'K9ac' in c:
        return h3k9ac_colour
    return '#FFFFFF'

"""
--------------------------------------------------------
                     Plotting
--------------------------------------------------------
"""
def pplot():
    return
        
def cplot():
    return
    
def save_fig(title):
    plt.savefig(f'{fig_dir}{title.replace(" ", "-")}_{experiment_name}_{date}.svg')
    
"""
--------------------------------------------------------
                     Heatmap and vis function
--------------------------------------------------------
"""

def plot_gene_heatmap(df, fb_genes, mb_genes, hb_genes, sc_genes, 
                      input_cols, vae_data, vae, gene_name, mark, method='input', 
                      title="", merge_reps=False, input_col_ids=None, cmap='', genes=None, 
                      values=None):
    brain_genes = []
    colours = []
    gene_idxs = []
    genes = genes if genes is not None else df[gene_name].values
    i = 0
    for g in genes:
        if g in fb_genes:
            brain_genes.append(g)
            colours.append(fb_colour)
            gene_idxs.append(i)
        i += 1
    i = 0
    for g in genes:    
        if g in mb_genes:
            brain_genes.append(g)
            colours.append(mb_colour)
            gene_idxs.append(i)
        i += 1
    i = 0
    for g in genes:    
        if g in hb_genes:
            brain_genes.append(g)
            colours.append(hb_colour)
            gene_idxs.append(i)
        i += 1
    i = 0
    for g in genes:    
        if g in sc_genes:
            brain_genes.append(g)
            colours.append(sc_colour)
            gene_idxs.append(i)
        i += 1
    heatmap_df = pd.DataFrame()
    heatmap_df[gene_name] = brain_genes

    if method == 'decoding':
        deocded = vae.decoder.predict(vae_data)
        deocded = scaler.fit_transform(deocded)
        vals = deocded[gene_idxs]
    elif method == 'input':
        vals = df[input_cols].values[gene_idxs]
    else:
        vals = vae_data[gene_idxs]
    vals = values[gene_idxs] if values is not None else vals
    
    vals = np.nan_to_num(vals)
    min_r = np.min(vals, axis = 1)
    max_r = np.max(vals, axis = 1)
    ov_c = max_r - min_r
    if not merge_reps:
        cols = []
        for i, c in enumerate(input_cols):
            if input_col_ids is not None:
                if i in input_col_ids:
                    heatmap_df[c] = np.nan_to_num(vals[:,i])
                    cols.append(c)
            else:
                heatmap_df[c] = np.nan_to_num(vals[:,i])
                cols.append(c)
        heatmap_cols = cols
    else:
        i = 0
        offset = 0
        max_c = len(input_cols)
        if input_col_ids is not None:
            offset = input_col_ids[0]
            max_c = len(input_col_ids)
        heatmap_cols = []
        
        while(i < max_c):
            if input_col_ids is not None:
                heatmap_df[input_cols[input_col_ids[i]]] = (vals[:, input_col_ids[i]] - min_r)/(max_r - min_r)
                heatmap_cols.append(input_cols[input_col_ids[i]])
            else:
                heatmap_df[input_cols[i]] = (vals[:, i + offset] - min_r)/(max_r - min_r)
                heatmap_cols.append(input_cols[i])
            i += 1
            
    print(heatmap_df.head())
    cols = []
    col_colours = []
    cond_colours = []
    tissue_colours = []
    time_colours = []
    for c in heatmap_df.columns:
        if 'wt' in c or 'ko' in c:
            cols.append(c)
            time_colours.append(get_time_colour(c))
            tissue_colours.append(get_tissue_colour(c))
            cond_colours.append(get_cond_colour(c))
        if mark in c and hist_metric in c:
            time_colours.append(get_time_colour(c))
            tissue_colours.append(get_tissue_colour(c))
            cond_colours.append('white')
            cols.append(c)
        elif 'log2' in c:
            time_colours.append(get_time_colour(c))
            tissue_colours.append(get_tissue_colour(c))
            cond_colours.append(get_mark_colour(c))
            cols.append(c)
    col_colours = [cond_colours, tissue_colours, time_colours]

    # Now plot the rnaseq response as a heatmap
    vmin = None
    vmax = None
    if cmap == hist_cmap or cmap == rna_cmap:
        vmin = 0
        vmax = 10
    elif cmap == div_cmap:
        vmin = -3
        vmax = 3
    heatmap = Heatmap(heatmap_df, heatmap_cols, gene_name,
                      title=title, cluster_cols=False, cluster_rows=False, vmin=vmin, vmax=vmax,
                      figsize=(5, 5),
                      cmap=cmap)
    heatmap.cmap = cmap
    heatmap.plot()
    pplot()
    save_fig(title)
    plt.show()


## 2) Read in processed dataframe

Here we import the results from running the AP-axis dataset generation script.  

This used the output from differential analysis and also annotated histone marks to genes.  



In [None]:
# Remove genes that weren't able to map to a gene name
gene_mask = []
for g in df_training[gene_name].values:
    if g == None:
        gene_mask.append(0)
    else:
        try:
            e = float(g)
            gene_mask.append(1)
        except ValueError:
            gene_mask.append(0)

df_training['gene_mask'] = gene_mask
print(len(df_training))
print(len(df_training[df_training['gene_mask'] == 0]))
df_training = df_training[df_training['gene_mask'] == 0]

# Smooth out the columns in the data frame i.e. for the clones we only put in the mean of the two replicates

cols_to_merge = [c for c in df_training.columns if 'wt' in c or 'ko' in c]
col_names = []
col_values = []
values = df_training[cols_to_merge].values
i = 0
while(i < len(cols_to_merge)):
    df_training[f'{cols_to_merge[i][:-1]}_merged-rep'] = 0.5 * (df_training[cols_to_merge[i]].values +
                                                                df_training[cols_to_merge[i + 1]].values)
    print("merged", cols_to_merge[i], cols_to_merge[i + 1])
    i += 2


In [None]:
cols_to_merge = [c for c in df_all.columns if 'wt' in c or 'ko' in c]
col_names = []
col_values = []
values = df_all[cols_to_merge].values
i = 0
while(i < len(cols_to_merge)):
    df_all[f'{cols_to_merge[i][:-1]}_merged-rep'] = 0.5 * (df_all[cols_to_merge[i]].values +
                                                                df_all[cols_to_merge[i + 1]].values)
    print("merged", cols_to_merge[i], cols_to_merge[i + 1])
    i += 2


## 3) Build VAE model

Since the VAEs are stochastic rather than deterministic we want to create our clusters based on a number of iterations.
We also are interested in how much the genes differ in various latent spaces.

In [None]:
from scivae import VAE

hist_metric = 'signal'
labels = df_training[gene_name].values


# -----------------------------------------------------------------------------------
#                         Set the input data space
# -----------------------------------------------------------------------------------
def get_input_data(df: pd.DataFrame, label_col: str):
    scaler = MinMaxScaler(copy=True)
    """ Here we define the data for the VAE we choose to log2 the signal """
    cols_bin = []
    tmp_df = pd.DataFrame()
    non_nan_df = df.copy()
    non_nan_df = non_nan_df.fillna(0)
    for c in df.columns:
        if 'H3K27me3' in c and 'signal' in c and 'brain' in c:
            v = np.log2(non_nan_df[c].values + 1)
            nn_min = np.nanmin(v)
            nn_max = np.nanmax(v)
            tmp_df[f'{c}_log2'] = (v-nn_min) / (nn_max - nn_min)
            cols_bin.append(f'{c}_log2')
        elif ('merged-rep' not in c) and ('log2FoldChange' in c or 'wt' in c or 'ko' in c): 
            cols_bin.append(c)
            v = non_nan_df[c].values
            nn_min = np.nanmin(v)
            nn_max = np.nanmax(v)
            tmp_df[c] = (v-nn_min) / (nn_max - nn_min)
    vae_values = tmp_df[cols_bin].values
    vae_values = np.nan_to_num(vae_values)
    tmp_df[label_col] = df[label_col].values
    return vae_values, tmp_df, df[label_col].values


In [None]:
# # -----------------------------------------------------------------------------------
# #                         Run a test to work out the best number of nodes
# # -----------------------------------------------------------------------------------
# vae_input_values, df_input, labels = get_input_data(df_training, gene_name)
# tsts = [1, 2, 3, 4, 6, 8, 16, 24]
# for i in tsts:
#     for j in range(0, 5):
#         config = {'loss': {'loss_type': 'mse', 'distance_metric': 'mmd', 'mmd_weight': 2.0}, 
#                   'encoding': {'layers': [{'num_nodes': 64, 'activation_fn': 'selu'}, 
#                                           {'num_nodes': 32, 'activation_fn': 'relu'}]}, 
#                   'decoding': {'layers': [{'num_nodes': 32, 'activation_fn': 'relu'},
#                                           {'num_nodes': 64, 'activation_fn': 'selu'}]}, 
#                   'latent': {'num_nodes': i}, 'optimiser': {'params': {}, 'name': 'adam'}}

#         vae_mse = VAE(vae_input_values, vae_input_values, labels, config, f'vae_num-nodes_{i}_rep_{j}')
#         vae_mse.encode('default', epochs=100, batch_size=50,
#                        logging_dir=logging_dir, logfile=f'run_num-nodes_{i}_rep_{j}.csv')

# # -----------------------------------------------------------------------------------
# #                         Plot results
# # -----------------------------------------------------------------------------------

# # Read in each of the csv's and then add a line for each one
# tsts = [1, 2, 3, 4, 6, 8, 16, 24]
# fig, ax = plt.subplots()
# labels = []
# for i, t in enumerate(tsts):
#     c = sci_colour[i]
#     print(c, i)
#     for j in range(0, 5):
#         d = pd.read_csv(f'{logging_dir}run_num-nodes_{t}_rep_{j}.csv')
#         print(f'{logging_dir}run_num-nodes_{t}_rep_{j}.csv')
#         ax.plot(d['epoch'].values, d['loss'].values, c=c, alpha=0.5)
#         label=f'{t} node(s)'
#         if label not in labels:
#             ax.plot(d['epoch'].values, d['val_loss'].values, c=c, alpha=0.8, label=label)
#             labels.append(label)
#         else:
#             ax.plot(d['epoch'].values, d['val_loss'].values, c=c, alpha=0.8)
        
# ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
# axes = plt.gca()
# axes.set_ylim([0, 5])
# save_fig(f'node_num')
# plt.show()

In [None]:
# -----------------------------------------------------------------------------------
#                         Based on the above, set the number of nodes
#                         1) Build non-linear VAE
#                         2) Build linear VAE
# -----------------------------------------------------------------------------------
num_nodes = 6
vae_input_values, df_input, labels = get_input_data(df_training, gene_name)

config = {'loss': {'loss_type': 'mse', 'distance_metric': 'mmd', 'mmd_weight': 1.0, 
                   'mmcd_method': 'k'}, 
                  'encoding': {'layers': [{'num_nodes': 64, 'activation_fn': 'selu'}, 
                                          {'num_nodes': 32, 'activation_fn': 'relu'}
                                          ]}, 
                  'decoding': {'layers': [
                                          {'num_nodes': 32, 'activation_fn': 'relu'},
                                          {'num_nodes': 64, 'activation_fn': 'selu'}]}, 
                  'latent': {'num_nodes': num_nodes}, 'optimiser': {'params': {}, 'name': 'adam'}}

vae_mse = VAE(vae_input_values, vae_input_values, labels, config, f'final')
if not from_saved:
    vae_mse.encode('default', epochs=250, batch_size=50,
                   logging_dir="final_vae_logs", logfile=f'final.csv')
    vae_mse.save() # save the VAE
    vae_mse.u.dp(["Saved VAE to current directory."])
else:
    vae_mse.u.dp(["Trying to load saved VAE..."])
    vae_mse.load()
    vae_mse.u.dp(["Loaded saved VAE :) "])

df = df_training.copy()

# Encode the values using the VAE
scaler = MinMaxScaler(copy=True)
scaled_vals = scaler.fit_transform(vae_input_values)

vae_data = vae_mse.encode_new_data(scaled_vals)


## 4) Setup data for visualisations

Here we setup the data and summarise key features, for example calcualting the median values. 

We do this so we can easily access these later.

In [None]:
# -----------------------------------------------------------------------------------
#                         Data setup for Vis
# -----------------------------------------------------------------------------------

hist_metric = 'signal'
input_cols = list(df_input.columns)


def get_median(df, mark, time="", tissue="brain"):
    cols = []
    for c in df.columns:
        if tissue in c and mark in c and hist_metric in c and time in c and 'median' not in c:
            cols.append(c)
    # get nan median
    vals = np.nanmedian(df[cols].values, axis=1)
    
    return np.nan_to_num(vals)

marks = ['H3K27me3', 'H3K4me1', 'H3K4me2', 'H3K4me3', 'H3K27ac', 'H3K9me3', 'H3K36me3', 'H3K9ac']

# Get the median over all time points and all brain tissues
for m in marks:
    df[f'{m}'] = get_median(df, m, time="", tissue="brain")
    
times = ['10.5', '11.5', '12.5', '13.5', '14.5', '15.5', '16.5']
for m in marks:
    for e in times:
        df[f'median_brain_{e}_signal_{m}'] = get_median(df, m, time=e, tissue="brain")
    
tissues = ['hindbrain', 'midbrain', 'forebrain']
for m in marks:
    for e in times:
        for t in tissues:
            df[f'median_{t}_{e}_signal_{m}'] = get_median(df, m, time=e, tissue=t)


df['mean_wt'] = np.mean(df[[c for c in df.columns if 'wt' in c]].values, axis=1)
df['mean_ko'] = np.mean(df[[c for c in df.columns if 'ko' in c]].values, axis=1)
df['mean_ko-wt'] = df['mean_ko'].values - df['mean_wt'].values

# Add in the VAE columns 
for v in range(0, len(vae_data[0])):
    df[f'VAE{v}'] = vae_data[:, v]

# Rename/add some columns just to make figs easier
df['Log2FC FB'] = df['log2FoldChange_fb'].values
df['Log2FC MB'] = df['log2FoldChange_mb'].values
df['Log2FC HB'] = df['log2FoldChange_hb'].values
df['Log2FC SC'] = df['log2FoldChange_sc'].values
df['Log2FC A11'] = df['log2FoldChange_a11'].values
df['Log2FC A13'] = df['log2FoldChange_a13'].values
df['Log2FC A15'] = df['log2FoldChange_a15'].values
df['Log2FC A18'] = df['log2FoldChange_a18'].values
df['Log2FC P11'] = df['log2FoldChange_p11'].values
df['Log2FC P13'] = df['log2FoldChange_p13'].values
df['Log2FC P15'] = df['log2FoldChange_p15'].values
df['Log2FC P18'] = df['log2FoldChange_p18'].values

all_cols = [
    'H3K9me3',
    'H3K4me1', 
    'H3K4me2', 
    'H3K4me3', 
    'H3K27ac',
    'H3K36me3', 
    'H3K9ac',
    'mean_wt',
    'mean_ko',
    'mean_ko-wt',
    'H3K27me3',
    'Log2FC FB',
    'Log2FC MB',
    'Log2FC HB',
    'Log2FC SC',
    'VAE0',
    'VAE1',
    'VAE2',
    'Log2FC A11',
    'Log2FC A13',
    'Log2FC A15',
    'Log2FC A18',
    'Log2FC P11',
    'Log2FC P13',
    'Log2FC P15',
    'Log2FC P18'
]
unobserved_cols = [
    'H3K9me3',
    'H3K4me1', 
    'H3K4me2', 
    'H3K4me3', 
    'H3K27ac',
    'H3K36me3', 
    'H3K9ac',
    'mean_wt',
    'mean_ko',
    'mean_ko-wt'
]

observed_cols = [
    'H3K4me1', 
    'H3K4me2', 
    'H3K4me3', 
    'H3K27ac',
    'H3K36me3', 
    'H3K27me3',
    'Log2FC FB',
    'Log2FC MB',
    'Log2FC HB',
    'Log2FC SC',
    'VAE0',
    'VAE1',
    'VAE2'
]

hist_cols = [
    'H3K9me3',
    'H3K4me1', 
    'H3K4me2', 
    'H3K4me3', 
    'H3K27ac',
    'H3K36me3', 
    'H3K27me3',
    'H3K9ac'
]

div_cols = [
    'Log2FC FB',
    'Log2FC MB',
    'Log2FC HB',
    'Log2FC SC',
    'Log2FC A11',
    'Log2FC A13',
    'Log2FC A15',
    'Log2FC A18',
    'Log2FC P11',
    'Log2FC P13',
    'Log2FC P15',
    'Log2FC P18',
]

raw_cols = [
    'mean_wt',
    'mean_ko'
]
vae_cols = observed_cols
df_unobserved = df[unobserved_cols]
df_observed = df[observed_cols]
observed_dict = {}

for c in observed_cols:
    observed_dict[c] = df_observed[c].values
unobserved_dict = {}
for c in unobserved_cols:
    unobserved_dict[c] = df_unobserved[c].values
    

## 4) Plot the correlations for the latent nodes with data features

Here we want to plot how correlated each node is with the mean values of different epigenetic marks & also 
the means of input features in the dataset.  

We also plot the latent space coloured by each of the features (both observed and unobserved).  

Lastly, we plot heatmaps of the input, latent and decoding results for key marker genes so that we can work out whether our model is logical. These same genes are also projected onto the latent space where we can see that the different classes of markers (e.g. Hox genes vs brain genes) are easily discernable.



In [None]:
# -----------------------------------------------------------------------------------
#                         Plot scatter of the features
# -----------------------------------------------------------------------------------
all_vae_data = np.copy(vae_data)
vae_data = all_vae_data[:, 0:3]
vae_vis = VVis(vae_mse, u, None)
div_cmap = 'seismic'

feature_obs_ax = vae_vis.plot_feature_scatters(df, "", columns=hist_cols, vae_data=vae_data, 
                                               show_plt=False, output_dir=fig_dir, fig_type="svg", 
                                               save_fig=False, title=f'{experiment_name}',
                                               angle_plot=315, cmap=hist_cmap, vmin=0, vmax=10)

feature_obs_ax = vae_vis.plot_feature_scatters(df, "", columns=div_cols, vae_data=vae_data, 
                                               show_plt=False, output_dir=fig_dir, fig_type="svg", 
                                               save_fig=False, title=f'{experiment_name}',
                                               angle_plot=315, cmap=div_cmap, vmin=-10, vmax=10)

feature_obs_ax = vae_vis.plot_feature_scatters(df, "", columns=raw_cols, vae_data=vae_data, 
                                               show_plt=False, output_dir=fig_dir, fig_type="svg", 
                                               save_fig=False, title=f'{experiment_name}',
                                               angle_plot=315, cmap=rna_cmap, vmin=0, vmax=10)

# -----------------------------------------------------------------------------------
#                         Plot histograms
# -----------------------------------------------------------------------------------
vae_vis.plot_node_hists(show_plt=True, save_fig=False)
vae_vis.plot_input_distribution(df_observed, show_plt=True, save_fig=True, output_dir=fig_dir)
vae_vis.plot_input_distribution(df_unobserved, show_plt=True, output_dir=fig_dir, save_fig=True)

# -----------------------------------------------------------------------------------
#                         Plot correlations
# -----------------------------------------------------------------------------------
vae_vis.plot_node_correlation(show_plt=True, save_fig=False)

vae_vis.plot_node_feature_correlation(df_observed, '', columns=observed_cols, vae_data=vae_data, 
                                  show_plt=True, output_dir=fig_dir, save_fig=True, 
                                      title=f'Node vs observed features {experiment_name}', cmap=div_cmap)

vae_vis.plot_feature_correlation(df_observed, '', columns=observed_cols, show_plt=True, cmap=div_cmap,
                                 title=f'Feature vs observed features {experiment_name}',
                             output_dir=fig_dir, save_fig=True)

vae_vis.plot_node_feature_correlation(df_unobserved, '', columns=unobserved_cols, vae_data=vae_data, 
                                      cmap=div_cmap,
                                      show_plt=True, output_dir=fig_dir, save_fig=True,
                                      title=f'Node vs unobserved features {experiment_name}')

vae_vis.plot_feature_correlation(df_unobserved, '', columns=unobserved_cols, show_plt=True, cmap=div_cmap,
                             output_dir=fig_dir, save_fig=True, title=f'feature un-observed features {experiment_name}')

vae_vis.plot_feature_correlation(df, '', columns=unobserved_cols + observed_cols, show_plt=True, cmap=div_cmap,
                             output_dir=fig_dir, save_fig=True, title=f'features all {experiment_name}')
        
vae_vis.plot_feature_correlation(df, '', columns=vae_cols, show_plt=True, cmap=div_cmap,
                             output_dir=fig_dir, save_fig=True, title=f'features VAE {experiment_name}')
# -----------------------------------------------------------------------------------
#                         Plot heatmaps of marker genes
# -----------------------------------------------------------------------------------
fb_genes = ['Foxg1']
mb_genes = ['En2']
hb_genes = ['Sox3']
sc_genes = ['Hoxc9']
input_cols = [c for c in df_input.columns if c != gene_name]

merged_cols = [c for c in df.columns if "merged" in c]
# Plot the input space
plot_gene_heatmap(df, fb_genes, mb_genes, hb_genes, sc_genes, 
                  merged_cols, vae_data, vae_mse,
                  gene_name, mark='H3K27me3', method='input', cmap=hist_cmap)

# Plot the decoding space
plot_gene_heatmap(df_input, fb_genes, mb_genes, hb_genes, sc_genes, 
                  input_cols, all_vae_data, vae_mse, 
                  gene_name, mark='H3K27me3', method='decoding', cmap=div_cmap)

# Plot the latent space
plot_gene_heatmap(df, fb_genes, mb_genes, hb_genes, sc_genes, 
                  ['VAE0', 'VAE1', 'VAE2'], all_vae_data, vae_mse, 
                  gene_name, mark='H3K27me3', method='latent', cmap=hist_cmap)

# Also do the main markers
fb_genes = ['Emx1', 'Eomes', 'Tbr1', 'Foxg1', 'Lhx6']
mb_genes = ['En1', 'En2', 'Lmx1a', 'Bhlhe23', 'Sall4']
hb_genes = ['Phox2b', 'Krox20', 'Fev', 'Hoxb1',  'Hoxd3']
sc_genes = ['Hoxd8', 'Hoxd9', 'Hoxd10', 'Hoxd11','Hoxa7', 'Hoxa9', 'Hoxa10',
              'Hoxb9', 'Hoxb13',  'Hoxc8', 'Hoxc9', 'Hoxc10', 'Hoxc11', 'Hoxc12', 'Hoxc13']


## Visualisations

In [None]:
ko_cols = [c for c in df_input.columns if 'ko' in c]
wt_cols = [c for c in df_input.columns if 'wt' in c]
# get IDs of the KO cols 
ko_col_ids = []
for i, c in enumerate(df_input.columns):
    if 'ko' in c:
        ko_col_ids.append(i)

bin_cmap = 'bone'
# Plot the input space WT and KO
merged_wt = [c for c in merged_cols if 'wt' in c]
merged_ko = [c for c in merged_cols if 'ko' in c]
mb_genes = ['En1', 'En2', 'Lmx1a', 'Otx2', 'Sall4']

plot_gene_heatmap(df, fb_genes, mb_genes, hb_genes, sc_genes, 
                  merged_wt, all_vae_data, vae_mse,
                  gene_name, mark='H3K27me3', merge_reps=True, cmap=bin_cmap,
                  method='input', title=f'All markers input WT {experiment_name}')

plot_gene_heatmap(df, fb_genes, mb_genes, hb_genes, sc_genes, 
                  merged_ko, all_vae_data, vae_mse,
                  gene_name, mark='H3K27me3', merge_reps=True,  cmap=bin_cmap,
                  method='input', title=f'All markers input KO {experiment_name}')

# Plot the decoding space for KO and then WT
plot_gene_heatmap(df_input, fb_genes, mb_genes, hb_genes, sc_genes, 
                  wt_cols, all_vae_data, vae_mse, 
                  gene_name, merge_reps=True, mark='H3K27me3', cmap=bin_cmap,
                  method='decoding', title=f'All markers decoding WT {experiment_name}')


# Plot the latent space
plot_gene_heatmap(df, fb_genes, mb_genes, hb_genes, sc_genes, 
                  ['VAE0', 'VAE1', 'VAE2'], vae_data, vae_mse,
                  gene_name, 
                   cmap=div_cmap,
                  mark='H3K27me3', method='latent', title=f'All markers latent {experiment_name}')


progenitors = ['Sox2', 'Sox1', 'Sox3', 'Hes1', 'Hes5']
neurons = ['Snap25', 'Syt1', 'Slc32a1','Slc17a6', 'Syn1']
glia = ['Cspg4', 'Aqp4', 'Slc6a11', 'Olig1', 'Igfbp3']


# Plot the input space
plot_gene_heatmap(df, progenitors, neurons, glia, [], 
                  merged_wt, vae_data, vae_mse,
                  gene_name, mark='H3K27me3', method='input', cmap=bin_cmap,
                  merge_reps=True, title=f'Time markers input WT {experiment_name}')
plot_gene_heatmap(df, progenitors, neurons, glia, [], 
                  merged_ko, vae_data, vae_mse,
                  gene_name, mark='H3K27me3', method='input', cmap=bin_cmap,
                  merge_reps=True, title=f'Time markers input KO {experiment_name}')
# Plot the decoding space
plot_gene_heatmap(df_input, progenitors, neurons, glia, [], 
                  wt_cols, all_vae_data, vae_mse, 
                  gene_name, mark='H3K27me3', method='decoding', cmap=bin_cmap,
                  merge_reps=True, title=f'Time markers decoding WT {experiment_name}')
# plot_gene_heatmap(df_input, progenitors, neurons, glia, [], 
#                   ko_cols, vae_data, vae_mse, 
#                   gene_name, mark='H3K27me3', method='decoding', cmap=bin_cmap,
#                   merge_reps=True, title=f'Time markers decoding KO {experiment_name}', 
#                   input_col_ids=ko_col_ids)

# Plot the latent space

plot_gene_heatmap(df, progenitors, neurons, glia, [], 
                  ['VAE0', 'VAE1', 'VAE2'], all_vae_data, vae_mse,
                  gene_name, mark='H3K27me3', method='latent', 
                   cmap=div_cmap,  title=f'Time markers latent {experiment_name}')
                  
# -----------------------------------------------------------------------------------
#                         Plot the genes on a scatter plot
# -----------------------------------------------------------------------------------
gene_markers_ns = [['Slc32a1', 'Gad1', 'Gad2', 'Slc6a1'],
                    ['Slc17a6', 'VGLUT2', 'Slc17a7', 'VGLUT1', 'Slc17a8', 'VGLUT3', 'Slc1a1', 'Slc1a2', 'Slc1a6'],
                    ['Chat', 'Slc5a7', 'Slc18a3', 'Ache'],
                    ['Tubb3', 'Snap25', 'Syt1'],
                    ['Gfap', 'Olig2'],
                    ['Sox2', 'Hes1', 'Hes5', 'Vim']]

marker_labels_ns = ['NS GABAergic', 'NS Glutamatergic', 'NS Cholinergic', 'NS Neurons', 'NS Glia', 'NS Progenitors']

vae_vis.plot_values_on_scatters(df_input, gene_name, marker_labels_ns,
                            gene_markers_ns, output_dir=fig_dir, plt_bg=False,vae_data=vae_data,
                            title=f'Markers NS {experiment_name}', fig_type="pdf",
                            show_plt=True, save_fig=True)

gene_markers_sep = [['Emx1', 'Eomes', 'Tbr1', 'Foxg1', 'Lhx6'], 
                    ['En1', 'En2', 'Lmx1a', 'Bhlhe23', 'Sall4'], 
                    ['Hoxb1', 'Krox20', 'Fev', 'Hoxd3', 'Phox2b'],
                    ['Hoxd8', 'Hoxd9', 'Hoxd10', 'Hoxd11', 'Hoxd12', 'Hoxd13', 'Hoxa7', 'Hoxa9', 'Hoxa10', 'Hoxa11', 
                    'Hoxa13', 'Hoxb9', 'Hoxb13',  'Hoxc8', 'Hoxc9', 'Hoxc10', 'Hoxc11', 'Hoxc12', 'Hoxc13'],
                    ['Ccna1', 'Ccna2', 'Ccnd1', 'Ccnd2', 'Ccnd3', 'Ccne1', 'Ccne2', 'Cdc25a', 
                    'Cdc25b', 'Cdc25c', 'E2f1', 'E2f2', 'E2f3', 'Mcm10', 'Mcm5', 'Mcm3', 'Mcm2', 'Cip2a'],
                    ['Cdkn1a', 'Cdkn1b', 'Cdkn1c', 'Cdkn2a', 'Cdkn2b', 'Cdkn2c', 'Cdkn2d'],
                    ['Sox2', 'Sox1', 'Sox3', 'Hes1', 'Hes5'],
                    ['Snap25', 'Syt1', 'Slc32a1','Slc17a6', 'Syn1'],
                    ['Cspg4', 'Aqp4', 'Slc6a11', 'Olig1', 'Igfbp3'],
                    ['Foxg1'], 
                    ['En2'], 
                    ['Phox2b'],
                    ['Hoxc9'],
                    ['Sox3']
                    ]

marker_labels_sep = ['forebrain', 'midbrain', 'hindbrain',  'spinalcord', 
                     'Proliferation', 'Negative regulators of Cell Cycle', 
                    'progenitors', 'neurons', 'glia',
                     'Foxg1', 'En2', 'Phox2b', 'Hoxc9', 'Sox3'
                    ]
color_map = {}
i = 4
for c in marker_labels_sep:
    if 'brain' in c or 'spinal' in c:
        if 'spinal' in c:
            color_map[c] = sc_colour
        else:
            color_map[c] = get_tissue_colour(c.lower())
    else:
        color_map[c] = sci_colour[i]
    i += 1
vae_vis.plot_values_on_scatters(df_input, gene_name, marker_labels_sep,
                            gene_markers_sep, output_dir=fig_dir, plt_bg=False,vae_data=vae_data, fig_type="svg",
                            title=f'Markers significant {experiment_name}',
                            show_plt=False, save_fig=False, color_map=color_map, angle_plot=15)

col_id = len(wt_cols)
ko_cols_i = [i for i, c in enumerate(df_input.columns) if 'ko' in c and 'merged' not in c]


plot_gene_heatmap(df_input, ['Sox3'], ['Foxg1'], ['En2'], ['Hoxc9'], 
                  df_input.columns, all_vae_data, vae_mse,
                  gene_name, mark='H3K27me3', merge_reps=False,  cmap=rna_cmap, input_col_ids=ko_cols_i,
                  method='decoding', title=f'RNA KO DEC {experiment_name}')
col_id = 0
wt_cols_i = [i for i, c in enumerate(df_input.columns) if 'wt' in c and 'merged' not in c]

plot_gene_heatmap(df_input, ['Sox3'], ['Foxg1'], ['En2'], ['Hoxc9'], 
                  df_input.columns, all_vae_data, vae_mse,
                  gene_name, mark='H3K27me3', merge_reps=False,  cmap=rna_cmap, input_col_ids=wt_cols_i,
                  method='decoding', title=f'RNA WT DEC {experiment_name}')
log2_cols = [c for c in df_input.columns if 'log2Fold' in c]
log2_cols_i = [i for i, c in enumerate(df_input.columns) if 'log2Fold' in c]

plot_gene_heatmap(df_input, ['Sox3'], ['Foxg1'], ['En2'], ['Hoxc9'], 
                  df_input.columns, all_vae_data, vae_mse,
                  gene_name, mark='H3K27me3', merge_reps=False,  cmap=div_cmap, input_col_ids=log2_cols_i,
                  method='decoding', title=f'log2FC DEC {experiment_name}')

h3k_cols = [c for c in df_input.columns if 'H3K27me3' in c and 'signal' in c and 'median' not in c and 'brain' in c]
h3k_cols_i = [i for i, c in enumerate(df_input.columns) if 'H3K27me3' in c and 'signal' in c and 'median' not in c and 'brain' in c]

plot_gene_heatmap(df_input, ['Sox3'], ['Foxg1'], ['En2'], ['Hoxc9'], 
                  df_input.columns, all_vae_data, vae_mse,
                  gene_name, mark='H3K27me3', merge_reps=False,  cmap=hist_cmap, input_col_ids=h3k_cols_i,
                  method='decoding', title=f'H3K WT DEC {experiment_name}')

# Ploot all the raw H3K values on teh scatter plot
h3k_cols = [c for c in df.columns if 'H3K27me3' in c and 'signal' in c and 'median' not in c and 'brain' in c]
tmp_df = df.fillna(0)
feature_obs_ax = vae_vis.plot_feature_scatters(tmp_df, "", columns=h3k_cols, vae_data=vae_data, 
                                               show_plt=False, output_dir=fig_dir, fig_type="svg", 
                                               save_fig=False, title=f'{experiment_name}',
                                               angle_plot=315, cmap=hist_cmap, vmin=0, vmax=10)


# Plot  genes on heatmap
plot_gene_heatmap(df, ['Sox3'], ['Foxg1'], ['En2'], ['Hoxc9'], 
                  ko_cols, all_vae_data, vae_mse,
                  gene_name, mark='H3K27me3', merge_reps=True,  cmap=rna_cmap,
                  method='input', title=f'RNA KO {experiment_name}')
plot_gene_heatmap(df, ['Sox3'], ['Foxg1'], ['En2'], ['Hoxc9'], 
                  wt_cols, all_vae_data, vae_mse,
                  gene_name, mark='H3K27me3', merge_reps=True,  cmap=rna_cmap,
                  method='input', title=f'RNA WT {experiment_name}')
h3k_cols = [c for c in df.columns if 'H3K27me3' in c and 'signal' in c and 'median' not in c and 'brain' in c]
print(h3k_cols)
plot_gene_heatmap(df, ['Sox3'], ['Foxg1'], ['En2'], ['Hoxc9'], 
                  h3k_cols, all_vae_data, vae_mse,
                  gene_name, mark='H3K27me3', merge_reps=True,  cmap=hist_cmap,
                  method='input', title=f'H3K WT {experiment_name}')
log2_cols = [c for c in df.columns if 'log2Fold' in c]
print(log2_cols)
plot_gene_heatmap(df, ['Sox3'], ['Foxg1'], ['En2'], ['Hoxc9'], 
                  log2_cols, all_vae_data, vae_mse,
                  gene_name, mark='H3K27me3', merge_reps=True,  cmap=div_cmap,
                  method='input', title=f'log2FC {experiment_name}')

ko_cols = [c for c in df_all if 'ko' in c and 'merged' in c]
plot_gene_heatmap(df_all, ['Sox1'], ['Sox2'], ['Sox3'], [], 
                  ko_cols, all_vae_data, vae_mse,
                  gene_name, mark='H3K27me3', merge_reps=True,  cmap=rna_cmap,
                  method='input', title=f'Sox RNA KO {experiment_name}')
wt_cols = [c for c in df_all if 'wt' in c and 'merged' in c]

plot_gene_heatmap(df_all, ['Sox1'], ['Sox2'], ['Sox3'], [],
                  wt_cols, all_vae_data, vae_mse,
                  gene_name, mark='H3K27me3', merge_reps=True,  cmap=rna_cmap,
                  method='input', title=f'Sox RNA WT {experiment_name}')

plot_gene_heatmap(df, ['Sox3'], ['Foxg1'], ['En2'], ['Hoxc9'], 
                  ['VAE0', 'VAE1', 'VAE2'], all_vae_data, vae_mse,
                  gene_name, mark='H3K27me3', merge_reps=False,  cmap=bin_cmap,
                  method='input', title=f'Latent {experiment_name}')

# Plot feature correlation
vae_vis.plot_feature_correlation(df_observed, '', columns=observed_cols, show_plt=True, cmap="RdBu_r",
                                 title=f'Feature vs observed features {experiment_name}',
                             output_dir=fig_dir, save_fig=True)

# Plot feature correlations for heatmap 
vae_vis.plot_feature_correlation(df, '', columns=vae_cols, show_plt=True, cmap='RdBu_r',
                             output_dir=fig_dir, save_fig=True, title=f'features VAE {experiment_name}')


"""
Plot the heatmaps of the Hox and other tisue specific genes
"""
plot_gene_heatmap(df_all, ['Pax2'], ['Hoxb9'], ['Hoxc9'], ['Hoxd9'],
                  wt_cols, all_vae_data, vae_mse,
                  gene_name, mark='H3K27me3', merge_reps=False,  cmap=rna_cmap,
                  method='input', title=f'Sox pax2 RNA WT {experiment_name}')

plot_gene_heatmap(df_all, ['Pax2'], ['Hoxb9'], ['Hoxc9'], ['Hoxd9'],
                  ko_cols, all_vae_data, vae_mse,
                  gene_name, mark='H3K27me3', merge_reps=False,  cmap=rna_cmap,
                  method='input', title=f'pax2 RNA KO {experiment_name}')

mb_genes = ['En1', 'En2', 'Lmx1a', 'Bhlhe23', 'Sall4']
hb_genes = ['Phox2b', 'Krox20', 'Fev', 'Hoxb1',  'Hoxd3']
mb_genes = ['En1', 'En2', 'Lmx1a', 'Bhlhe23', 'Sall4', 'Pax7', 'Lhx9', 'Evx1', 'Sox14', 
           'Gata3', 'Dmbx1', 'Tal1', 'Pou4f3', 'Pouf4f2', 'Gata3', 'Barhl1', 'Pax5', 'Dmbx1',
           'Pou4f1', 'Irx4', 'Pax7', 'Pax1', 'Pax8', 'Ebf2', 'Patx3', 'Onecut1', 'Otx2', 
           'Lhx5', 'Irx3', 'Irx5', 'Ebf3', 'Foxb1', 'Irx2', 'Shox2', 'Foxa2', 'Irx1', 'Tcfap2b']

plot_gene_heatmap(df_all, ['Otx2'], ['En2'], ['Phox2b'], ['Hoxd3'],
                  wt_cols, all_vae_data, vae_mse,
                  gene_name, mark='H3K27me3', merge_reps=False,  cmap=rna_cmap,
                  method='input', title=f'Mid RNA WT {experiment_name}')


plot_gene_heatmap(df_all, ['Otx2'], ['En2'], ['Phox2b'], ['Hoxd3'],
                  ko_cols, all_vae_data, vae_mse,
                  gene_name, mark='H3K27me3', merge_reps=False,  cmap=rna_cmap,
                  method='input', title=f'Mid RNA KO {experiment_name}')


# Plot gene names on scatter
gene_markers_sep = [['Cdkn1a'], ['Cdkn2a'], ['Cdkn2b'], ['Cip2a'], ['Cdc25c'], ['Mcm10'],
                    ['E2f1'], ['Ccna2'], ['Ccnd1'], ['Aqp4'], ['Snap25'], ['Hes5']
                    ]

marker_labels_sep = ['Cdkn1a', 'Cdkn2a', 'Cdkn2b',  'Cip2a', 
                     'Cdc25c', 'Mcm10', 
                    'E2f1', 'Ccna2', 'Ccnd1',
                     'Aqp4', 'Snap25', 'Hes5'
                    ]
color_map = {}
i = 4
for c in marker_labels_sep:
    if 'brain' in c or 'spinal' in c:
        if 'spinal' in c:
            color_map[c] = sc_colour
        else:
            color_map[c] = get_tissue_colour(c.lower())
    else:
        color_map[c] = sci_colour[i]
    i += 1
vae_vis.plot_values_on_scatters(df_input, gene_name, marker_labels_sep,
                            gene_markers_sep, output_dir=fig_dir, plt_bg=False,vae_data=vae_data, fig_type="svg",
                            title=f'Markers specific {experiment_name}',
                            show_plt=False, save_fig=False, color_map=color_map, angle_plot=315)

In [None]:
""" Print out the gene lists. """

gene_markers_sep = [['Emx1', 'Eomes', 'Tbr1', 'Foxg1', 'Lhx6'], 
                    ['En1', 'En2', 'Lmx1a', 'Bhlhe23', 'Sall4'], 
                    ['Hoxb1', 'Krox20', 'Fev', 'Hoxd3', 'Phox2b'],
                    ['Hoxd8', 'Hoxd9', 'Hoxd10', 'Hoxd11', 'Hoxd12', 'Hoxd13', 'Hoxa7', 'Hoxa9', 'Hoxa10', 'Hoxa11', 
                    'Hoxa13', 'Hoxb9', 'Hoxb13',  'Hoxc8', 'Hoxc9', 'Hoxc10', 'Hoxc11', 'Hoxc12', 'Hoxc13'],
                    ['Ccna1', 'Ccna2', 'Ccnd1', 'Ccnd2', 'Ccnd3', 'Ccne1', 'Ccne2', 'Cdc25a', 
                    'Cdc25b', 'Cdc25c', 'E2f1', 'E2f2', 'E2f3', 'Mcm10', 'Mcm5', 'Mcm3', 'Mcm2', 'Cip2a'],
                    ['Cdkn1a', 'Cdkn1b', 'Cdkn1c', 'Cdkn2a', 'Cdkn2b', 'Cdkn2c', 'Cdkn2d'],
                    ['Sox2', 'Sox1', 'Sox3', 'Hes1', 'Hes5'],
                    ['Snap25', 'Syt1', 'Slc32a1','Slc17a6', 'Syn1'],
                    ['Cspg4', 'Aqp4', 'Slc6a11', 'Olig1', 'Igfbp3'],
                    ['Foxg1'], 
                    ['En2'], 
                    ['Phox2b'],
                    ['Hoxc9'],
                    ['Sox3']
                    ]

marker_labels_sep = ['forebrain', 'midbrain', 'hindbrain',  'spinalcord', 
                     'Proliferation', 'Negative regulators of Cell Cycle', 
                    'progenitors', 'neurons', 'glia',
                     'Foxg1', 'En2', 'Phox2b', 'Hoxc9', 'Sox3'
                    ]
# Print genes that were actually in our dataset 
cx = [g for g in df[gene_name].values if g in ['Hoxd8', 'Hoxd9', 'Hoxd10', 'Hoxd11', 'Hoxd12', 'Hoxd13', 'Hoxa7', 'Hoxa9', 'Hoxa10', 'Hoxa11', 
                    'Hoxa13', 'Hoxb9', 'Hoxb13',  'Hoxc8', 'Hoxc9', 'Hoxc10', 'Hoxc11', 'Hoxc12', 'Hoxc13']]
print(', '.join(cx))

len(cx)
cx = [g for g in df[gene_name].values if g in ['Sox2', 'Sox1', 'Sox3', 'Hes1', 'Hes5']]
print(', '.join(cx))
cx = [g for g in df[gene_name].values if g in ['Snap25', 'Syt1', 'Slc32a1','Slc17a6', 'Syn1']]
print(', '.join(cx))
cx = [g for g in df[gene_name].values if g in['Cspg4', 'Aqp4', 'Slc6a11', 'Olig1', 'Igfbp3']]
print(', '.join(cx))
prolif = ['Ccna1', 'Ccna2', 'Ccnd1', 'Ccnd2', 'Ccnd3', 'Ccne1', 'Ccne2', 'Cdc25a', 'Cdc25b', 'Cdc25c', 'E2f1', 'E2f2', 'E2f3', 'Mcm10', 'Mcm5', 'Mcm3', 'Mcm2', 'Cip2a']
anti_prolif = ['Cdkn1a', 'Cdkn1b', 'Cdkn1c', 'Cdkn2a', 'Cdkn2b', 'Cdkn2c', 'Cdkn2d']
pl = [g for g in df[gene_name].values if g in prolif]
apl = [g for g in df[gene_name].values if g in anti_prolif]
print(', '.join(pl))
print(', '.join(apl))


# Reset the vae data 
vae_data = all_vae_data

## 5) Quantify separability of latent space

In the next section we compare the distances in the latent space to those using other summaries of the data. 

The distances are computed as the euclidean distance to the centroid. Where the centroid is either the center of the cluster of genes of interest OR the center of the two groups. We expect the referece to have a greater spread and that our genes should be closer to the centroid of the gene set they belong to.

Datasets tested:
1) VAE deep (3 layer)
2) VAE shallow (1 layer)
3) PCA (3 PCs)
4) RNAseq data (normalised TMM 64 features)
5) VAE input data (normalised TMM and normalised H3K27me3 and LogFC, 100 features)
6) logFC forebrain and H3K27me3 signal.



## 6) Save ranks for GSEA

We perform GSEA on the ranks of the VAE.

In [None]:

# -----------------------------------------------------------------------------------
#                         Save genes sorted by latent node (GSEA)
# -----------------------------------------------------------------------------------
for n in range(0, num_nodes):
    with open(f'{ora_dir}vae{n}_genes_{experiment_name}_{date}.csv', 'w+') as f:
        f.write(gene_id + ',value\n')
        desc_sorted = (-vae_data[:,n]).argsort()  # Sort the genes by descending order
        genes = df[gene_id].values[desc_sorted]
        vae_data_sorted = vae_data[:,n][desc_sorted]
        i = 0
        for g in genes:
            f.write(f'{g},{vae_data_sorted[i]}\n')
            i += 1
            

# -----------------------------------------------------------------------------------
#                         Save genes sorted by latent node (GSEA)
# -----------------------------------------------------------------------------------
for n in range(0, num_nodes):
    with open(f'{ora_dir}vae{n}_gene-names_{experiment_name}_{date}.csv', 'w+') as f:
        f.write(gene_name + ',value\n')
        desc_sorted = (-vae_data[:,n]).argsort()  # Sort the genes by descending order
        genes = df[gene_name].values[desc_sorted]
        vae_data_sorted = vae_data[:,n][desc_sorted]
        i = 0
        for g in genes:
            f.write(f'{g},{vae_data_sorted[i]}\n')
            i += 1

## 7) Compute clusters driving VAE 

In the next section, we use information about the VAE nodes to compute clusters regarding the latent space. 

In [None]:
def get_q1_q3(node_idx):
    iqr_val = 0.95 * stats.iqr(vae_data[:, node_idx])
    q1_genes_idxs = np.where(vae_data[:, node_idx] <= -1.25)[0] # (np.mean(vae_data[:, node_idx]) - iqr_val))[0]
    q3_genes_idxs = np.where(vae_data[:, node_idx] >= 1.25)[0] #(np.mean(vae_data[:, node_idx]) + iqr_val))[0]
    return q1_genes_idxs, q3_genes_idxs

# -----------------------------------------------------------------------------------
#                         Find the genes driving the latent space
# -----------------------------------------------------------------------------------

n_clusters = 6
# Get the indexs for each of the q1 and q3 for each node
q1_n0_idxs, q3_n0_idxs = get_q1_q3(0)
q1_n1_idxs, q3_n1_idxs = get_q1_q3(1)
q1_n2_idxs, q3_n2_idxs = get_q1_q3(2)

# Create a list of the indexs
vae_set_idxs = [
    q1_n0_idxs, q3_n0_idxs, 
    q1_n1_idxs, q3_n1_idxs, 
    q1_n2_idxs, q3_n2_idxs
]

vae_set_labels = ['Set 1', 'Set 2', 'Set 3', 'Set 4', 'Set 5', 'Set 6']

# Get all the driver and non-driver idxs
non_driver_idxs = []
driver_idxs = []
for i, ids in enumerate(vae_set_idxs):
    vae_mse.u.dp([vae_set_labels[i], len(ids)])
    driver_idxs += list(ids)
    
for i in range(0, len(df)):
    if i not in driver_idxs:
        non_driver_idxs.append(i)
# -----------------------------------------------------------------------------------
#                         Plot a VENN of how the genes in teh clusters overlap
# -----------------------------------------------------------------------------------

gene_sets = [[], [], [], [], [], []]

for i, g in enumerate(df[gene_name].values):
    ci = 0
    for gene_set in vae_set_idxs:
        for j in gene_set:
            if i == j:
                gene_sets[ci].append(g)
        ci += 1

set_sets = []
for g in gene_sets:  
    if len(g) > 1:
        set_sets.append(set(g))

set_labels = venn.get_labels(set_sets)
vae_set_labels_l = [s.replace('Set', 'Group') for s in vae_set_labels]
vae_set_genes_len_dict = {}
vae_set_genes_dict = {}
for i, l in enumerate(vae_set_labels):
    vae_set_genes_dict[l] = gene_sets[i]
    vae_set_genes_len_dict[l] = len(gene_sets[i])

venn.venn6(set_labels, names=vae_set_labels_l)
save_fig('gene_sets.svg')
plt.show()


vae_set_genes_as_list = []
for set_name in vae_set_labels:
    vae_set_genes_as_list.append(vae_set_genes_dict.get(set_name))

sorted_sets_by_size = {k: v for k, v in sorted(vae_set_genes_len_dict.items(), key=lambda item: item[1])}
gene_set_labels_unique = []
for g in df[gene_name].values:
    found = False
    for label, values in sorted_sets_by_size.items():
        gene_group = vae_set_genes_dict[label]
        for g_l in gene_group:
            if g_l == g:
                gene_set_labels_unique.append(label)
                found = True
                break
        if found:
            break
    if not found:
        gene_set_labels_unique.append(None)
df['GeneDriverSet'] =  gene_set_labels_unique

# We also want to add each of the gene sets
gene_labels = [[], [], [], [], [], []]
for g in df[gene_name].values:
    for i, label in enumerate(vae_set_labels):
        found = False
        gene_group = vae_set_genes_dict[label]
        for g_l in gene_group:
            if g == g_l:
                gene_labels[i].append(True)
                found = True
                break
        if not found:
            gene_labels[i].append(False)
for i, label in enumerate(vae_set_labels):
    df[label] = gene_labels[i]
    
            
# -----------------------------------------------------------------------------------
#                         Plot the genes on a scatter plot
# -----------------------------------------------------------------------------------
vae_vis.plot_values_on_scatters(df, gene_name, vae_set_labels,
                            vae_set_genes_as_list, vae_data=vae_data, output_dir=fig_dir, plt_bg=True,
                            title=f'VAE gene driver sets {experiment_name}',
                            show_plt=True, save_fig=True)

# -----------------------------------------------------------------------------------
#                         Plot histograms of the VAE nodes with cutoffs used above
# -----------------------------------------------------------------------------------
for n in range(0, num_nodes):
    x = vae_data[:,n]
    iqr_val = 1.0 * stats.iqr(x)
    plt.hist(vae_data[:,n], bins=20, color='grey')
    plt.ylim(0, 380)
    plt.axvline(x.mean() - iqr_val, color='k', linestyle='dashed', linewidth=1)
    plt.axvline(x.mean() + iqr_val, color='r', linestyle='dashed', linewidth=1)
    save_fig(f'hist-{n}')
    plt.show()


In [None]:
hox_genes = sc_genes
forebrain_genes = fb_genes
h3cols = []
fb_rna_cols = []
sc_rna_cols = []

# Do it in this way so they are correctly ordered
for c in df.columns:
    if 'H3K27me' in c and 'signal' in c:
        h3cols.append(c)
    elif 'wt' in c:
        if 'fb' in c:
            fb_rna_cols.append(c)
        if 'sc' in c:
            sc_rna_cols.append(c)
for c in df.columns:
    if 'ko' in c:
        if 'fb' in c:
            fb_rna_cols.append(c)
        if 'sc' in c:
            sc_rna_cols.append(c)

def plot_gene_boxplot(df, title, cluster_id, cols, idxs):
    idxs = idxs[cluster_id]
    boxplot = Boxplot(df, gene_name, cols[0])
    box_df = boxplot.format_data_for_boxplot(df, cols, gene_name, df[gene_name].values[idxs])
    is_hox = []
    sns.set_style("white")
    for c in box_df['Samples'].values:
        if 'Hox' in c:
            is_hox.append('Hox')
        else:
            is_hox.append('Hox like')
    box_df['is_hox'] = is_hox

    boxplot = Boxplot(box_df, "Conditions", "Values", add_stats=False, add_dots=False, figsize=(1,0.5),
                       order=cols)
    boxplot.palette = sci_colour

    ax = boxplot.plot()
    plt.title(title)
    ax.set_ylim(0, 12)
    sns.set_style("white")
    c = 0
    for b in ax.artists:
        if c < len(cols)/2:
            b.set_facecolor(wt_colour)
        else:
            b.set_facecolor(ko_colour)
        c += 1
    save_fig(f'{title}')
    
    plt.show()

avgs_fb = [['wt11fb1',
         'wt11fb2'],
         ['wt13fb1',
         'wt13fb2'],
         ['wt15fb1',
         'wt15fb2'],
         ['wt18fb1',
         'wt18fb2'],
         ['ko11fb1',
         'ko11fb2'],
         ['ko13fb1',
         'ko13fb2'],
         ['ko15fb1',
         'ko15fb2'],
         ['ko18fb1',
         'ko18fb2']]

avgs_sc = [['wt11sc1',
         'wt11sc2'],
         ['wt13sc1',
         'wt13sc2'],
         ['wt15sc1',
         'wt15sc2'],
         ['wt18sc1',
         'wt18sc2'],
         ['ko11sc1',
         'ko11sc2'],
         ['ko13sc1',
         'ko13sc2'],
         ['ko15sc1',
         'ko15sc2'],
         ['ko18sc1',
         'ko18sc2']]

avgs_hb = [['wt11hb1',
         'wt11hb2'],
         ['wt13hb1',
         'wt13hb2'],
         ['wt15hb1',
         'wt15hb2'],
         ['wt18hb1',
         'wt18hb2'],
         ['ko11hb1',
         'ko11hb2'],
         ['ko13hb1',
         'ko13hb2'],
         ['ko15hb1',
         'ko15hb2'],
         ['ko18hb1',
         'ko18hb2']]

avgs_mb = [['wt11mb1',
         'wt11mb2'],
         ['wt13mb1',
         'wt13mb2'],
         ['wt15mb1',
         'wt15mb2'],
         ['wt18mb1',
         'wt18mb2'],
         ['ko11mb1',
         'ko11mb2'],
         ['ko13mb1',
         'ko13mb2'],
         ['ko15mb1',
         'ko15mb2'],
         ['ko18mb1',
         'ko18mb2']]
avgs_df = pd.DataFrame()
avgs_df[gene_name] = df[gene_name].values
fb_cols = []
for f in avgs_fb:
    new_col = f'{f[0][:2]} {f[0][2:4]} FB'
    avgs_df[new_col] = (df[f[0]].values + df[f[1]].values) / 2.0
    fb_cols.append(new_col)
    
sc_cols = []
for f in avgs_sc:
    new_col = f'{f[0][:2]} {f[0][2:4]} SC'
    avgs_df[new_col] = (df[f[0]].values + df[f[1]].values) / 2.0
    sc_cols.append(new_col)
    
mb_cols = []
for f in avgs_mb:
    new_col = f'{f[0][:2]} {f[0][2:4]} MB'
    avgs_df[new_col] = (df[f[0]].values + df[f[1]].values) / 2.0
    mb_cols.append(new_col)
    
hb_cols = []
for f in avgs_hb:
    new_col = f'{f[0][:2]} {f[0][2:4]} HB'
    avgs_df[new_col] = (df[f[0]].values + df[f[1]].values) / 2.0
    hb_cols.append(new_col)
    
for i, gs in enumerate(vae_set_idxs):
    plot_gene_boxplot(avgs_df, f'Cluster {i + 1} FB', i, fb_cols, vae_set_idxs)
    plot_gene_boxplot(avgs_df, f'Cluster {i + 1} MB', i, mb_cols, vae_set_idxs)
    plot_gene_boxplot(avgs_df, f'Cluster {i + 1} HB', i, hb_cols, vae_set_idxs)
    plot_gene_boxplot(avgs_df, f'Cluster {i + 1} SC', i, sc_cols, vae_set_idxs)

## Plot the heatmaps

In [None]:
def plot_vae_rna_cols(vis_df, cols, num_nodes, fb_logfc, mb_logfc, hb_logfc, sc_logfc, 
                      h3k27me3_fb_e11, h3k27me3_fb_e18, h3k27me3_hb_e11, h3k27me3_hb_e18, title='Heatmap', 
                      num_values=10):
    vae_cols = [f'VAE{i}' for i in range(0, num_nodes)]

    sns.set_style("ticks")
    for r_i, r_c in enumerate(cols):
        for direction in ['up', 'down']:
            # First add the bottom ones
            fig, ax = plt.subplots()
            sns.set(rc={'figure.figsize': (3.5, 2), 'font.family': 'sans-serif',
            'font.sans-serif': 'Arial', 'font.size': 6}, style="ticks")
            rcm_df_sorted = vis_df.sort_values(by=[r_c])
            if direction == 'up':
                rcm_df_tails = rcm_df_sorted.nlargest(num_values, r_c)
            else:
                rcm_df_tails = rcm_df_sorted.nsmallest(num_values, r_c)
            # Add the smallest and the largest
            # first heatmap
            N = len(rcm_df_tails)
            ax.set_xlim(-0.5, 7.5)
            ax.set_xticks([0, 1, 2, 3, 4, 5, 6, 7])
            ax.set_xticklabels(['FB', 'MB', 'HB', 'SC', 
                                'FB e11.5', 'FB e18.5', 'HB e11.5', 'HB e18'])
            ax.set_yticks(range(N))
            ax.set_yticklabels(list(rcm_df_tails[gene_name].values))
            ax.set_ylabel('Genes')

            max_hk3 = 35
            max_rna = 10
            min_rna = -10
            cmap_rna = 'seismic'
            cmap_hk3 = 'Greens'
            im1 = ax.imshow(np.vstack([rcm_df_tails[fb_logfc], rcm_df_tails[fb_logfc]]).T, 
                            aspect='auto',
                            extent=[-0.5, 0.5, -0.5, N - 0.5], vmin=min_rna, vmax=max_rna,
                            origin='lower', cmap=cmap_rna)
            im1 = ax.imshow(np.vstack([rcm_df_tails[mb_logfc], rcm_df_tails[mb_logfc]]).T, 
                            aspect='auto',
                            extent=[0.5, 1.5, -0.5, N - 0.5], vmin=min_rna, vmax=max_rna,
                            origin='lower', cmap=cmap_rna)
            im1 = ax.imshow(np.vstack([rcm_df_tails[hb_logfc], rcm_df_tails[hb_logfc]]).T, 
                aspect='auto',
                extent=[1.5, 2.5, -0.5, N - 0.5], vmin=min_rna, vmax=max_rna,
                origin='lower', cmap=cmap_rna)
            # Now add in the H3K27me3 ones
            im1 = ax.imshow(np.vstack([rcm_df_tails[sc_logfc], rcm_df_tails[sc_logfc]]).T, 
                aspect='auto',
                extent=[2.5, 3.5, -0.5, N - 0.5], vmin=min_rna, vmax=max_rna,
                origin='lower', cmap=cmap_rna)

            im3 = ax.imshow(np.vstack([rcm_df_tails[h3k27me3_fb_e11], rcm_df_tails[h3k27me3_fb_e11]]).T, 
                            aspect='auto',
                            extent=[3.5, 4.5, -0.5, N - 0.5], vmin=0, vmax=max_hk3,
                            origin='lower', cmap=cmap_hk3)

            im3 = ax.imshow(np.vstack([rcm_df_tails[h3k27me3_fb_e18], rcm_df_tails[h3k27me3_fb_e18]]).T, 
                aspect='auto',
                extent=[4.5, 5.5, -0.5, N - 0.5], vmin=0, vmax=max_hk3,
                origin='lower', cmap=cmap_hk3)

            im3 = ax.imshow(np.vstack([rcm_df_tails[h3k27me3_hb_e11], rcm_df_tails[h3k27me3_hb_e11]]).T, 
                            aspect='auto',
                            extent=[5.5, 6.5, -0.5, N - 0.5], vmin=0, vmax=max_hk3,
                            origin='lower', cmap=cmap_hk3)

            im3 = ax.imshow(np.vstack([rcm_df_tails[h3k27me3_hb_e18], rcm_df_tails[h3k27me3_hb_e18]]).T, 
                aspect='auto',
                extent=[6.5, 7.5, -0.5, N - 0.5], vmin=0, vmax=max_hk3,
                origin='lower', cmap=cmap_hk3)

            cbar1 = fig.colorbar(im1, ax=ax, label='RNA logFC')
            cbar2 = fig.colorbar(im3, ax=ax, label='H3K27me3 signal')

            fig.tight_layout()
            plt.xticks(rotation=45, horizontalalignment='right')
            plt.title(f'{title} {r_c}')
            save_fig(f'{title.replace(" ", "_")}_{direction}_{r_c}')

In [None]:
h3k27me3_fb_e11 = [c for c in df.columns if '10.5' in c and 'forebrain' in c and 'H3K27me3' in c and 'signal' in c and 'median' not in c][0]
h3k27me3_fb_e18 = [c for c in df.columns if '16.5' in c and 'forebrain' in c and 'H3K27me3' in c and 'signal' in c and 'median' not in c][0]

h3k27me3_hb_e11 = [c for c in df.columns if '10.5' in c and 'hindbrain' in c and 'H3K27me3' in c and 'signal' in c and 'median' not in c][0]
h3k27me3_hb_e18 = [c for c in df.columns if '16.5' in c and 'hindbrain' in c and 'H3K27me3' in c and 'signal' in c and 'median' not in c][0]

plot_vae_rna_cols(df, ["VAE0", "VAE1", "VAE2"], 3,     
                    'Log2FC FB',
                    'Log2FC MB',
                    'Log2FC HB',
                    'Log2FC SC', 
                     h3k27me3_fb_e11, h3k27me3_fb_e18, h3k27me3_hb_e11, h3k27me3_hb_e18)

In [None]:
# -----------------------------------------------------------------------------------
#                         Save background genes (ORA)
# -----------------------------------------------------------------------------------
n_clusters = num_nodes
with open(f'{ora_dir}background_genes_{experiment_name}_{date}.csv', 'w+') as f:
    genes = df_all[gene_id].values
    f.write(gene_id + '\n')
    for g in genes:
        f.write(f'{g}\n')
        
# -----------------------------------------------------------------------------------
#                         Save training genes (sig genes) (ORA)
# -----------------------------------------------------------------------------------
with open(f'{ora_dir}training-df_genes_{experiment_name}_{date}.csv', 'w+') as f:
    genes = df_training[gene_id].values
    f.write(gene_id + '\n')
    for g in genes:
        f.write(f'{g}\n')
        
# -----------------------------------------------------------------------------------
#                         Save genes in each of the clusters we annotated (ORA)
# -----------------------------------------------------------------------------------
for i in range(0, n_clusters):
    with open(f'{ora_dir}cluster_{i}_genes_{experiment_name}_{date}.csv', 'w+') as f:
        genes = df[gene_id].values[vae_set_idxs[i]]
        f.write(gene_id + '\n')
        for g in genes:
            f.write(f'{g}\n')
    
# -----------------------------------------------------------------------------------
#                         Save genes sorted by latent node (GSEA) and only the top
# -----------------------------------------------------------------------------------
def plot_top_vae_genes(vae_data_sorted, gene_ids_sorted, gene_names_sorted, label="", num=20):
        i = 0
        gs = []
        for g in gene_names_sorted:
            if g != "0":
                f.write(f'{g},{vae_data_sorted[i]}\n')
                gs.append(g)
            else:
                f.write(f'{gene_ids_sorted[i]},{vae_data_sorted[i]}\n')
                gs.append(gene_ids_sorted[i])
            i += 1
            if i >= num:
                break
        heatmap_df = pd.DataFrame()
        heatmap_df[gene_name] = gs
        heatmap_df['values'] = vae_data_sorted[:num]
        heatmap = Heatmap(heatmap_df, ['values'], gene_name, vmin=-4, vmax=4,
                  title=f'top {num} genes', cluster_cols=False, cluster_rows=False)
        heatmap.palette=sns.color_palette(sci_colour)
        heatmap.plot()
        pplot()
        save_fig(f'vae{n}_genes-top-{num}-{label}')
        plt.show()
        
for n in range(0, num_nodes):
    with open(f'{ora_dir}vae-{n + 1}_genes-top10_{experiment_name}_{date}.csv', 'w+') as f:
        f.write(gene_id + ',value\n')
        desc_sorted = (-vae_data[:,n]).argsort()  # Sort the genes by descending order
        gene_names_sorted = df[gene_name].values[desc_sorted]
        gene_ids_sorted = df[gene_id].values[desc_sorted]
        vae_data_sorted = vae_data[:,n][desc_sorted]
        plot_top_vae_genes(vae_data_sorted, gene_ids_sorted, gene_names_sorted, label="top")
        
    with open(f'{ora_dir}vae-{n + 1}_genes-bottom10_{experiment_name}_{date}.csv', 'w+') as f:
        f.write(gene_id + ',value\n')
        desc_sorted = (vae_data[:,n]).argsort()  # Sort the genes by descending order
        gene_names_sorted = df[gene_name].values[desc_sorted]
        vae_data_sorted = vae_data[:,n][desc_sorted]
        gene_ids_sorted = df[gene_id].values[desc_sorted]
        plot_top_vae_genes(vae_data_sorted, gene_ids_sorted, gene_names_sorted, label="bottom")

for n in range(0, num_nodes):
    with open(f'{ora_dir}vae-{n + 1}_{experiment_name}_{date}.csv', 'w+') as f:
        f.write(gene_id + ',value\n')
        desc_sorted = (vae_data[:,n]).argsort()  # Sort the genes by descending order
        gene_names_sorted = df[gene_id].values[desc_sorted]
        vae_data_sorted = vae_data[:,n][desc_sorted]
        for i, g in enumerate(gene_names_sorted):
            f.write(f'{g},{vae_data_sorted[i]}\n')

In [None]:
# -----------------------------------------------------------------------------------
#                         Save genes in each of the clusters we annotated (ORA)
# -----------------------------------------------------------------------------------
for i in range(0, n_clusters):
    with open(f'{ora_dir}cluster_{i}_gene-names_{experiment_name}_{date}.csv', 'w+') as f:
        genes = df[gene_name].values[vae_set_idxs[i]]
        rank_vae0 = vae_data[:, 0][vae_set_idxs[i]]
        rank_vae1 = vae_data[:, 1][vae_set_idxs[i]]
        rank_vae2 = vae_data[:, 2][vae_set_idxs[i]]
        f.write(f'gene_name,VAE-1,VAE-2,VAE-3\n')
        for i, g in enumerate(genes):
            f.write(f'{g},{rank_vae0[i]},{rank_vae1[i]},{rank_vae2[i]}\n')
    

In [None]:
import string
tissues = ['forebrain', 'midbrain', 'hindbrain', 'neural-tube', 
           'embryonic-facial-prominence', 'limb',
           'heart', 'liver']
marks = ['H3K9me3', 'H3K36me3', 'H3K27me3', 'H3K27ac', 'H3K4me1', 'H3K4me2', 'H3K4me3','H3K9ac']
def get_median_naninc(df, mark, hist_metric="", time_1="", time_2="", time_3="", tissue="brain"):
    cols = []
    for c in df.columns:
        if tissue in c and mark in c and hist_metric in c and (time_1 in c or time_2 in c or time_3 in c):
            cols.append(c)
    # get nan median
    vals = np.nanmedian(df[cols].values, axis=1)
    return vals

def plot_mark_heatmap(df, idxs, title, mark_cutoff=3.0):
    mark_df = pd.DataFrame()
    mark_values = []
    mean_col = []
    if idxs is None:
        num_genes = len(df)
    else:
        num_genes = len(idxs) * 1.0
    for t in [['10.5', '11.5', '12.5'], ['15.5', '15.5', '16.5']]:
        tissue_titles = []
        marks_title = []
        for m in marks:
            mark_col = []
            for tissue in tissues:
                median_all_data = get_median_naninc(df, m, "signal", t[0], t[1], t[2], tissue)
                if idxs is None:
                    median_genes = median_all_data
                else:
                    median_genes = median_all_data[idxs]     
                has_mark = 1.0 * len(np.where(median_genes > mark_cutoff)[0])
                mark_col.append(has_mark/num_genes) #np.nan_to_num(np.nanmean(median_genes)))
                if string.capwords(tissue) not in tissue_titles:
                    if tissue == 'embryonic-facial-prominence':
                        if 'E.F.P' not in tissue_titles:
                            tissue_titles.append('E.F.P')
                    else:
                        tissue_titles.append(string.capwords(tissue))
            marks_title.append(string.capwords(m))
            mark_df[string.capwords(m)] = mark_col

        mark_df['Tissue'] = tissue_titles

        heatmap = Heatmap(mark_df, marks_title, 'Tissue', vmin=0, vmax=None, cmap=hist_cmap, figsize=(2,2),
                          title=f'{title} {",".join(t)}', cluster_rows=False, cluster_cols=False)
        heatmap.plot()
        pplot()
        print(mark_df.head())
        # vmin=0, vmax=1,
        save_fig(f'mark_all_tissues_signal-{title}_{"-".join(t)}')
        plt.show()
        
for i, gs in enumerate(vae_set_idxs):
    plot_mark_heatmap(df, gs, f'Cluster {i}')

## 7) Inspecting the data in context of chromHMM annotations

Downloaded the pooled samples frmo http://enhancer.sdsc.edu/enhancer_export/ENCODE/chromHMM/pooled/
for forebrain, hindbrain, and midbrian at e16.5 on 04 December 2020.


http://enhancer.sdsc.edu/enhancer_export/ENCODE/chromHMM/readme

The files in this directory contain chromHMM chromatin state calls for multiple tissues during mouse embryonic development, as described here: http://www.biorxiv.org/content/early/2017/07/21/166652.

All files are in bed format. There is one file for each tissue at each developmental stage.
The columns are 1) chromosome, 2) regions start, 3)  region end, 4) chromatin state number 5) Chromatin state label. 
There are 15 chromating states, labeled as follows:


        State	Label	Description
        1	Pr-A	Promoter, Active
        2	Pr-W	Promoter, Weak/Inactive
        3	Pr-B	Promoter, Bivalent
        4	Pr-F	Promoter, Flanking
        5	En-Sd	Enhancer, Strong TSS-distal
        6	En-Sp	Enhancer, Strong TSS-proximal
        7	En-W	Enhancer, Weak TSS-distal
        8	En-Pd	Enhancer, Poised TSS-distal
        9	En-Pp	Enhancer, Poised TSS-proximal
        10	Tr-S	Transcription, Strong
        11	Tr-P	Transcription, Permissive
        12	Tr-I	Transcription, Initiation
        13	Hc-P	Heterochromatin, Polycomb-associated
        14	Hc-H	Heterochromatin, H3K9me3-associated
        15	NS	No Chromatin Signal
        NRS	NRS	No Reproducible State (ony relevant to "replicated" call set)


There are four subdirectories:

rep1\
This subdirectory contains chromatin state calls made using ChIP-seq datasets from biological replicate 1 (two biological replicates were performed for each ChIP-seq experiment).

rep2\
This subdirectory contains chromatin state calls made using ChIP-seq datasets from biological replicate 2.

replicated\
This subdirectory contains the intersection of the replicate 1 and replicate 2 chromatin state calls. 
Here, we require that a region is called in the same state in both replicates. If a region is not called in the same state in both replicates, it is labeled "NRS" for No Reproducible Signal. Note this is differenct than state 15, which is "No Signal." 

pooled\
This subdirectory contains chromatin state calls on ChIP-seq data pooled from both replicates.
We expect this will be the desired set of files for most users.


Each subdirectory has two additional subdirectorites:

archive\
This contains an older version of the bed files in which there are only state number, no labels. Aside from the labels, the files are identical to the working versions in the main directory.

tracks\
This contains bigBed versions of the files for visualization on the UCSC Genome Browser.


There is one additional subdirectory that does not conform to the templates decribed above. This is 6_mark_models\.
This subdirectory contains chromHMM segmentation of the e10.5 stage where we did not have all 8 marks. 
Note that the 11-state and 16-states models used for this segmentation are distinct from the 15-state (8-mark) model described above for the other stages. 
Please see manuscript and Extended Data Figure 7 for additional information.



        header = "chr,start,end,label,annotation"
        bed = Bed(os.path.join(self.data_dir, 'e16.5_forebrain_15_segments.bed'),
                  header=None, overlap_method='in_promoter',
                  output_bed_file=os.path.join(self.data_dir, 'e16.5_forebrain_15_segments_selectedPeaks.bed'),
                  buffer_after_tss=0, buffer_before_tss=10, buffer_gene_overlap=0,
                  gene_start=3, gene_end=4, gene_chr=2, gene_direction=5, gene_name=0,
                  chr_idx=0, start_idx=1, end_idx=2, peak_value=4, header_extra="3", sep='\t'
                  )
        # Add the gene annot
        bed.set_annotation_from_file(self.mm10_annot)
        # Now we can run the assign values
        bed.assign_locations_to_genes()
        bed.save_loc_to_csv(f'{self.data_dir}test_e16.5_forebrain_15_segments.csv')

In [None]:
# -----------------------------------------------------------------------------------
#                         Merge information with annotations
# -----------------------------------------------------------------------------------

chrom_dir = os.path.join(supp_dir, "annot")
chromhmm_annot = pd.read_csv(f'{chrom_dir}/e16.5_forebrain_15_segments.csv')
print(len(chromhmm_annot), len(df))
# Make sure we only keep 1 annotation per gene
chromhmm_annot = chromhmm_annot.groupby('external_gene_name').first()
# Merge this with our dataframe
df = df.merge(chromhmm_annot, on='external_gene_name', how='left', suffixes=('', '_chmm'))
# Ensure the new length is the same as the old length
len(df)

In [None]:
# We also want to plot the annotations from gorkin et al for each of the clusters
from statsmodels.stats.multitest import multipletests

def run_annot_plot(df_bg, df_fg, title=""):
    changes = []
    order = ['Pr-A', 'Pr-W', 'Pr-B', 'Pr-F', 'En-Sd', 'En-Sp', 
         'En-W', 'En-Pd', 'En-Pp', 'Tr-S', 'Tr-P', 'Tr-I', 
          'Hc-P', 'Hc-H', 'NS']
    total = len(df_bg) * 1.0
    total_sig = len(df_fg) * 1.0
    perc_all = []
    perc_sig = []
    
    for o in order:
        o_all = len(df_bg[df_bg['peak_value'] == o])
        o_sig = len(df_fg[df_fg['peak_value'] == o])
        perc_all.append(o_all)
        perc_sig.append(o_sig)
        if o_all == 0 and o_sig == 0:
            changes.append(0)
        else:
            changes.append(((o_sig) - (o_all)))

    odds_ratios = []
    pvalues = []
    for i, o in enumerate(order):
        # Do a FET on each one
        oddsratio, pvalue = stats.fisher_exact([[perc_sig[i], perc_all[i]], 
                                                [total_sig - perc_sig[i],
                                                 total - perc_all[i]]])

        print(o, oddsratio, pvalue,[[perc_sig[i], perc_all[i]], 
                                                [total_sig - perc_sig[i], 
                                                 total - perc_all[i]]])
        odds_ratios.append(oddsratio)
        pvalues.append(pvalue)

    reg, padj, a, b = multipletests(pvalues, alpha=0.1, 
                                    method='fdr_bh', returnsorted=False)

    p_sigs = []
    for p in padj: 
        if p > 0.05:
            p_sigs.append('')
        elif p <= 0.05 and p > 0.01:
            p_sigs.append('*')
        elif p <= 0.01 and p > 0.001:
            p_sigs.append('**')
        elif p <= 0.001 and p > 0.0001:
            p_sigs.append('***')
        elif p <= 0.0001:
            p_sigs.append('****')
        else:
            print(p)
            p_sigs.append('')

    fig, ax = plt.subplots(figsize=(1.5,1.0))
    c = [grey] * 14
    c[2] = "green"
    c[12] = "green"
    plt.bar(order, odds_ratios, color=c, linewidth=0.5, edgecolor='black')
    rects = ax.patches
    # Make some labels.
    labels = [f'{p_sigs[i]}' for i in range(0, len(rects))]

    for rect, label in zip(rects, labels):
        height = rect.get_height()
        ax.text(rect.get_x() + rect.get_width() / 2, height, label,
                ha='center', va='bottom')
    plt.ylim(0, 7)
    plt.yticks(np.arange(0, 7, 2.0))
    ax.set_title(f'Odds ratio FET for significant dataset {title} compared to all genes ChromHMM annotations')
    ax.set_xlabel('', fontsize=6)
    ax.set_ylabel('Odds ratio', fontsize=6)
    ax.set_xticklabels(order, rotation=90, ha="center")

        
    ax.tick_params(direction='out', length=2, width=0.5)
    ax.spines['bottom'].set_linewidth(0.5)
    ax.spines['top'].set_linewidth(0)
    ax.spines['left'].set_linewidth(0.5)
    ax.spines['right'].set_linewidth(0)
    ax.tick_params(labelsize=6)
    ax.tick_params(axis='x', which='major', pad=0)
    ax.tick_params(axis='y', which='major', pad=0)

    save_fig(f'OR-CM-{title}')
    
    plt.show()

for gene_set_label in vae_set_labels:
    run_annot_plot(df, df[df[gene_set_label] == True], title=f'Training vs {gene_set_label}')

In [None]:
df_bg = df_all.copy()  # Make a copy so we don't override the other one

# Also have a look at how the marks plot over genes that are not significant
chrom_dir = os.path.join(supp_dir, "annot")
chromhmm_annot = pd.read_csv(f'{chrom_dir}/e16.5_forebrain_15_segments.csv')
print(len(chromhmm_annot), len(df_bg))
# Make sure we only keep 1 annotation per gene
chromhmm_annot = chromhmm_annot.groupby('external_gene_name').first()
# Merge this with our dataframe
df_bg = df_bg.merge(chromhmm_annot, on='external_gene_name', how='left', suffixes=('', '_chmm'))
# Ensure the new length is the same as the old length
run_annot_plot(df_bg, df, title=f'Training vs All')

## ) Epi genes annotation

Using the information from: https://epifactors.autosome.ru/ (EpiFactors: a comprehensive database of human
epigenetic factors and complexes, by Yulia A. Medvedeva et al) we test each cluser for enrichment for epigenetic complex's and functions.

In [None]:
epi_genes = pd.read_csv(os.path.join(supp_dir, 'EpiGenes_main.csv'))
df_epi = df.merge(epi_genes, left_on='external_gene_name', right_on='MGI_symbol',
                      how='left', suffixes=('', '_eg'))
    
ax = sns.countplot(x="Function", data=df_epi)
ax.set_xticklabels(ax.get_xticklabels(), rotation=40, ha="right")
plt.tight_layout()
save_fig(f'Epigenes-function')
plt.show() 
ax = sns.countplot(x="Target", data=df_epi)
ax.set_xticklabels(ax.get_xticklabels(), rotation=40, ha="right")
plt.tight_layout()
save_fig(f'Epigenes-target')
plt.show() 

for set_label in vae_set_labels:
    try:
        vae_mse.u.dp([set_label])
        ax = sns.countplot(x="Function", data=df_epi[df_epi[set_label] == True])
        ax.set_xticklabels(ax.get_xticklabels(), rotation=40, ha="right")
        plt.tight_layout()
        save_fig(f'Epigenes-function-{set_label}')
        plt.show() 
        ax = sns.countplot(x="Target", data=df_epi[df_epi[set_label] == True])
        ax.set_xticklabels(ax.get_xticklabels(), rotation=40, ha="right")
        plt.tight_layout()
        save_fig(f'Epigenes-target-{set_label}')
        plt.show() 
        # Complex_name Modification
        ax = sns.countplot(x="Complex_name", data=df_epi[df_epi[set_label] == True])
        ax.set_xticklabels(ax.get_xticklabels(), rotation=40, ha="right")
        plt.tight_layout()
        save_fig(f'Epigenes-Complex_name-{set_label}')
        plt.show() 
        ax = sns.countplot(x="Modification", data=df_epi[df_epi[set_label] == True])
        ax.set_xticklabels(ax.get_xticklabels(), rotation=40, ha="right")
        plt.tight_layout()
        save_fig(f'Epigenes-Modification-{set_label}')
        plt.show() 
    except:
        print(set_label)

In [None]:
###############################################################################
#                                                                             #
#    This program is free software: you can redistribute it and/or modify     #
#    it under the terms of the GNU General Public License as published by     #
#    the Free Software Foundation, either version 3 of the License, or        #
#    (at your option) any later version.                                      #
#                                                                             #
#    This program is distributed in the hope that it will be useful,          #
#    but WITHOUT ANY WARRANTY; without even the implied warranty of           #
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the            #
#    GNU General Public License for more details.                             #
#                                                                             #
#    You should have received a copy of the GNU General Public License        #
#    along with this program. If not, see <http://www.gnu.org/licenses/>.     #
#                                                                             #
###############################################################################

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from sciviso import Vis


class Violinplot(Vis):

    def __init__(self, df: pd.DataFrame, x: object, y: object, title='', xlabel='', ylabel='', hue=None, order=None,
                 hue_order=None, showfliers=False, add_dots=False, figsize=(1.5, 1.5), title_font_size=8,
                 add_stats=False, stat_method='Mann-Whitney',
                 label_font_size=6, title_font_weight=700):
        super().__init__(df, figsize=figsize, title_font_size=title_font_size, label_font_size=label_font_size,
                         title_font_weight=title_font_weight)
        self.df = df
        self.x = x
        self.y = y
        self.xlabel = xlabel
        self.ylabel = ylabel
        self.title = title
        self.hue = hue
        self.order = order
        self.hue_order = hue_order
        self.showfliers = showfliers
        self.add_dots = add_dots
        self.add_stats = add_stats
        self.stat_method = stat_method

    def plot(self):
        x, y, hue, order, hue_order = self.x, self.y, self.hue, self.order, self.hue_order
        if not isinstance(self.x, str) and not isinstance(self.y, str):
            vis_df = pd.DataFrame()
            vis_df['x'] = x
            vis_df['y'] = y
            x = 'x'
            y = 'y'
            if self.hue is not None:
                vis_df['colour'] = self.hue
                hue = 'colour'
            if order is None:
                order = list(set(vis_df['x'].values))
                order.sort()
        else:
            vis_df = self.df
        # set the orders
        if hue_order is None and hue is not None:
            hue_order = list(set(vis_df[hue].values))
            hue_order.sort()
        if order is None:
            order = list(set(vis_df[x].values))
            order.sort()

        ax = sns.violinplot(data=vis_df, x=x, y=y, hue=hue, hue_order=hue_order, order=order, palette=self.palette,
                            showfliers=self.showfliers, inner="quartile")
        if self.add_dots:
            ax = sns.stripplot(data=vis_df, x=x, y=y, hue_order=hue_order, order=order, color='.2')

        if self.add_stats:
            # Add all pairs in the order if the box pairs is none

            pairs = []
            box_pairs = []
            for i in order:
                for j in order:
                    if i != j:
                        # Ensure we don't get duplicates
                        pair = f'{i}{j}' if i < j else f'{j}{i}'
                        if pair not in pairs:
                            box_pairs.append((i, j))
                            pairs.append(pair)
            # Add stats annotation
            add_stat_annotation(ax, data=vis_df, x=x, y=y, order=order,
                                box_pairs=box_pairs,
                                test=self.stat_method, text_format='star', loc='inside', verbose=2)
        ax.set_xticklabels(ax.get_xticklabels(), rotation=45, horizontalalignment='right')
        ax.tick_params(labelsize=self.label_font_size)
        plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., fontsize=self.label_font_size)
        self.add_labels()
        self.set_ax_params(ax)
        return ax

## Print out the TF relationships



In [None]:
# Have a look at the TFs and their targets for regulons (i.e. does grp2 have TFs that tagret group3?)
regs = pd.read_csv(supp_dir + 'regulons_mm10_A-B.csv')
reg_tfs = set(regs['tf'].values)
group_labels = {1: 'Prolifferation', 2: 'Posterior', 3: 'Immune response', 4: 'Development', 5: 'Forebrain'}

for i in range(1, n_clusters):
    genes = df_training[gene_name].values[vae_set_idxs[i]]
    # Flag the TFs in this group, we'll then test these to see if we have lots of the targets in another group
    tfs = list(reg_tfs & set(genes)) 
    # Now we want to go through each of the other groups and record how many were targetted
    u.dp([group_labels[i], f'\n{len(tfs)} Tfs: ', ', '.join(tfs)])

    for j in range(1, n_clusters):
        genes_targets = df_training[gene_name].values[vae_set_idxs[j]]
        total_targets = 0
        for tf in tfs:
            targets = set(regs[regs['tf'] == tf]['target'].values)
            total_targets += len(list(targets & set(genes_targets)))
            if len(list(targets & set(genes_targets))) > 0:
                print(f'{tf} -->', ', '.join(list(targets & set(genes_targets))))
        print(f'--------------- Total for {group_labels[i]} targeting {group_labels[j]}: {total_targets} -------------')

        
#regs = pd.read_csv(supp_dir + 'regulons_mm10_A-B-C-D.csv')
regs[regs['tf'] == 'Cdkn1b'] #Tbx3, mesoderm markeds Hand1, gata2
#df_sig = pd.read_csv(f'{input_dir}df-significant_epi-2500_20210124.csv')
for g in regs[regs['target'] == 'Foxm1']['tf'].values:
    if len(df_training[df_training[gene_name] == g]) > 0:
        print(g, '\n')
        #print(df_sig[df_sig[gene_name] == g])
