## Creating correlation distance heat maps for samples

In this notebook, we compute the matrix of pairwise correlation distances between the gene expression profiles of samples in our data and visualize the matrix as a heat map. In addition, we will use this distance matrix to compute a hierarchical clustering dendrogram of samples with average linkage clustering.

### import packages
First thing, import all the necessary packages and methods. Most of these are commonly used python packages, that can be easily installed. We also need to import custom functions defined in `helper_functions.py`. As the name suggests, it contains several utility functions required to perform tasks such as loading and scaling input data.

In [None]:
import os
import numpy as np
import pandas as pd

# Plotting
import matplotlib.pyplot as plt
from matplotlib import patches
import seaborn as sns
from scipy.cluster.hierarchy import linkage, dendrogram
from scipy.spatial.distance import squareform
%matplotlib inline

# Helper functions
from helper_functions import loaddata, get_color_dict

### set paths, specify file names
Next thing we want to do is set the paths to the directories containing `code`, `data`, and `results`. We also specify input files. The data to be analyzed is stored in two csv files. `clean_metadata.csv` contains the metadata, `clean_RNAseq_OutlierRemoved.csv` contains the gene expression data. 

The metadata file contains one sample per row, identified by its *SRA*. There are 3172 samples, and for each sample we have 7 descriptive attributes. We are particularly interested in three of them: namely, its plant *family*, *tissue* type and *stress* type. There are 16 different plant families, 8 tissue types and 10 stress types (including *healthy*) represented in the data. The RNAseq file, as the name suggests contains the gene expression for each sample, across a 6335 orthogroups. 

The third file, named `colors.pkl`, is a pickle file (that's a python specific file format!) that contains a mapping between colors (RGB and hex codes) and the different classes of the factors (*family*, *stress* and *tissue* type). Kepler mapper outputs an `html` file visualized in a browser window, other types of plots and figures are created using `seaborn` and `matplotlib`. We have set the color mapping in a file to ensure consistency across these plotting tools used in the project.

In [None]:
projdir = "../.."

datadir = projdir + "/data"
factorfile = datadir + "/clean_metadata.csv"
colorfile = datadir + "/colors.pkl"
rnafile = datadir + "/raw_RNAseq_OutlierRemoved.csv"

resdir = projdir + "/results"
heatmapdir = resdir + "/heat_maps"
os.makedirs(heatmapdir, exist_ok=True)

### Loading data

The input data (metadata and gene expression) is stored in two separate files, but samples can be matched using the *SRA*. The function `loaddata` does exactly that. It loads both files into pandas dataframes and merges them using the *SRA* as the key. We perform an inner join, to keep only those *SRAs* for which we have metadata as well as gene expression. The `get_color_dict` is another utility function which looks for an existing `color.pkl` file and loads it. If the file is not present, the function will create it and return a dictionary containing the color mappings.

In [None]:
factors = ["stress", "tissue", "family"]
df, orthos = loaddata(factorfile, rnafile, factors)
color_dict = get_color_dict(df, factors, colorfile)

### Frequency plots

To understand the distribution of class labels under each factor of interest, namely *stress*, *tissue*, and *family*, we will first visualize the frequency plots. For categorical data, the *value_counts* method of pandas is very useful.

In [None]:
fig, ax = plt.subplots(1, len(factors), figsize=(15, 6))
for i, f in enumerate(df[factors]):
    val_counts = df[f].value_counts()
    cols = [color_dict[f][l]["hex"] for l in val_counts.index]
    val_counts.plot(kind="bar", ax=ax[i], color=cols).set_title(f)

figname = resdir + "/Factor_Frequency_Plots.png"
fig.savefig(figname, dpi=300, format="png",
            bbox_inches="tight", pad_inches=0.1, facecolor="white")

### Hierarchical clustering of samples

Next, we create a dendrogram, illustrating the hierarchical clustering of the samples. We will use the correlation distance between the gene expression profiles as the clustering metric, and average linkage as the clustering method to create the dendrogram.

In [None]:
def create_dendrogram(data, savepath=None, saveformat="png"):
    abscor = np.abs(np.corrcoef(data))
    dist = 1. - abscor
    dist_arr = squareform(dist[np.triu_indices_from(dist, 1)])
    Z = linkage(dist_arr, 'average')
    fig = plt.figure(figsize=(25, 10))
    dn = dendrogram(Z)
    plt.show()

    return None

create_dendrogram(df[orthos])

### Correlation distance heat maps

The function defined in the next cell allows us to visualize the correlation distance matrix as a heat map. In addition, it also puts the hierarchical average linkage clustering dendrograms on x and y axes. For a given factor (*stress*, *tissue*, or *family*), the function also shows a colorbar on the y axis to indicate the corresponding class label for the sample.

In [None]:
def create_heatmap(df, data_cols, row_idx, f, color_dict,
                    savepath=None, saveformat="png"):
    
    c_map = color_dict[f]
    cols = [c_map[l]["hex"] for l in df[f]]
    handles = [patches.Patch(color=c_map[l]["hex"],
                             label=l) for l in c_map.keys()]
    cor = np.corrcoef(df[orthos])
    dist = 1. - np.abs(cor)
    dist_arr = squareform(dist[np.triu_indices_from(dist, 1)])
    link_mat = linkage(dist_arr, 'average')
    data = pd.DataFrame(cor, columns=row_idx)
    cmap = sns.color_palette("Spectral", as_cmap=True)
    clust_map = sns.clustermap(data,
                               row_cluster=True,
                               col_cluster=True,
                               center=0.,
                               row_linkage=link_mat,
                               col_linkage=link_mat,
                               cmap=cmap,
                               row_colors=cols,
                               yticklabels=False,
                               cbar_pos=(1.07, 0.51, 0.03, 0.3))

    l2 = clust_map.ax_heatmap.legend(loc='center left',
                                     bbox_to_anchor=(1.1, 0.2),
                                     handles=handles,
                                     frameon=True)
    l2.set_title(title=f, prop={'size':14})

    if savepath is not None:
        clust_map.savefig(savepath, dpi=300, format=saveformat,
                          bbox_inches="tight", facecolor="white")
    
    plt.show()

    return clust_map

#### Correlation distance heat map with family class labels

In [None]:
figpath = heatmapdir + "/HeatMap_ColorBy_family.png"
_ = create_heatmap(df, orthos, df["sra"], "family", color_dict, figpath)

#### Correlation distance heat map with tissue class labels

In [None]:
figpath = heatmapdir + "/HeatMap_ColorBy_tissue.png"
_ = create_heatmap(df, orthos, df["sra"], "tissue", color_dict, figpath)

#### Correlation distance heat map with stress class labels

In [None]:
figpath = heatmapdir + "/HeatMap_ColorBy_stress.png"
_ = create_heatmap(df, orthos, df["sra"], "stress", color_dict, figpath)