## Applying dimension reduction to the data

In this notebook, we apply two standard dimension reduction methods to the data: PCA and t-SNE. We project the 6335 dimensional gene expression data onto 2 dimensions and visualize the projections as scatter plots.

### import packages
First thing, import all the necessary packages and methods. Most of these are commonly used python packages, that can be easily installed. We also need to import custom functions defined in `helper_functions.py`. As the name suggests, it contains several utility functions required to perform tasks such as loading and scaling input data.

In [None]:
import os
import pandas as pd

# Plotting
import matplotlib.pyplot as plt
import matplotlib as mpl

# Dimension Reduction
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

# Helper functions
from helper_functions import loaddata, get_color_dict

### set paths, specify file names
Next thing we want to do is set the paths to the directories containing `code`, `data`, and `results`. We also specify input files. The data to be analyzed is stored in two csv files. `clean_metadata.csv` contains the metadata, `clean_RNAseq_OutlierRemoved.csv` contains the gene expression data. 

The metadata file contains one sample per row, identified by its *SRA*. There are 3172 samples, and for each sample we have 7 descriptive attributes. We are particularly interested in three of them: namely, its plant *family*, *tissue* type and *stress* type. There are 16 different plant families, 8 tissue types and 10 stress types (including *healthy*) represented in the data. The RNAseq file, as the name suggests contains the gene expression for each sample, across a 6335 orthogroups. 

The third file, named `colors.pkl`, is a pickle file (that's a python specific file format!) that contains a mapping between colors (RGB and hex codes) and the different classes of the factors (*family*, *stress* and *tissue* type). Kepler mapper outputs an `html` file visualized in a browser window, other types of plots and figures are created using `seaborn` and `matplotlib`. We have set the color mapping in a file to ensure consistency across these plotting tools used in the project.

In [None]:
projdir = "../.."

datadir = projdir + "/data"
factorfile = datadir + "/clean_metadata.csv"
colorfile = datadir + "/colors.pkl"
rnafile = datadir + "/raw_RNAseq_OutlierRemoved.csv"

resdir = projdir + "/results"
figdir = resdir + "/dimension_reduction"
os.makedirs(figdir, exist_ok=True)

### Loading data

The input data (metadata and gene expression) is stored in two separate files, but samples can be matched using the *SRA*. The function `loaddata` does exactly that. It loads both files into pandas dataframes and merges them using the *SRA* as the key. We perform an inner join, to keep only those *SRAs* for which we have metadata as well as gene expression. The `get_color_dict` is another utility function which looks for an existing `color.pkl` file and loads it. If the file is not present, the function will create it and return a dictionary containing the color mappings.

In [None]:
factors = ["stress", "tissue", "family"]
df, orthos = loaddata(factorfile, rnafile, factors)
color_dict = get_color_dict(df, factors, colorfile)

### Computing projections

The function defined below computes the PCA and t-SNE projections of the gene expression profiles which we then visualize as 2D scatter plots. For t-SNE projections, we use the correlation distance as metric.

In [None]:
def make_projections(x):
    proj = dict()

    print("Computing PCA Projections...")
    proj["PCA"] = PCA(n_components=3).fit_transform(x)

    print("Computing t-SNE Projections...")
    proj["TSNE"] = TSNE(n_components=3, init="pca", learning_rate="auto",
                        square_distances=True,
                        metric="correlation").fit_transform(x)
    
    return proj

projections = make_projections(df[orthos])

### Creating scatter plots

The function defined below creates the scatter plots using the PCA and t-SNE projections and paints them using specified coloring function: either by the class labels (*stress*, *tissue*, *family*), or by the lens function values.

In [None]:
def make_plots(dataframe, projs, to_plot, colorby, savedir, color_dict):
    fig = plt.figure(figsize=(18, 6))
    gs = fig.add_gridspec(nrows=1, ncols=len(to_plot), height_ratios=[1.])
    axs = gs.subplots()
    for idx, mthd in enumerate(to_plot):
        proj = projs[mthd]

        if colorby in ["stress_lens", "root_lens", "leaf_lens"]:
            cols = dataframe[colorby]
            cmap = plt.get_cmap("viridis_r")
            norm = mpl.colors.Normalize(0, 1)
            axs[idx].scatter(proj[:, 0], proj[:, 1], c=cols, s=10.0, cmap=cmap)
        else:
            c_map = color_dict[colorby]
            cols = [c_map[l]["hex"] for l in dataframe[colorby]]
            handles = [mpl.patches.Patch(color=c_map[l]["hex"], label=l)
                                            for l in c_map.keys()]
            axs[idx].scatter(proj[:, 0], proj[:, 1], c=cols, s=10.0)

        axs[idx].set_xlabel(mthd+" 1")
        axs[idx].set_ylabel(mthd+" 2")
        axs[idx].set_title(mthd+" 2D Scatter Plot")

    plt.subplots_adjust(right=0.9)
    if colorby in ["stress_lens", "root_lens", "leaf_lens"]:
        fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap),
                     ax=axs[-1], label=colorby)
    else:
        fig.legend(handles=handles, loc="center right", 
                   borderaxespad=0.2, title=colorby)

    figname = savedir + "/DimReduction_ColorBy_" + colorby + ".png"
    fig.savefig(figname, dpi=300, format="png", bbox_inches="tight",
                pad_inches=0.1, facecolor="white")

    plt.show()

    return None

### Plots colored by stress, tissue, and family labels

In [None]:
to_plot = ["PCA", "TSNE"]
for f in factors:
    make_plots(df, projections, to_plot, f, figdir, color_dict)

### Plots colored by lens function values

Assuming we also have the lens functions values used to construct mapper graphs saved in a `.csv` file. We will use them to color the scatter plot along with different factor (*stress*, *tissue*, or *family*) class labels.

In [None]:
lensfile = datadir + "/saved_lenses.csv"
lenses = pd.read_csv(lensfile)
df = df.assign(stress_lens=lenses["stress_lens"],
               root_lens=lenses["root_lens"],
               leaf_lens=lenses["leaf_lens"])

to_plot = ["PCA", "TSNE"]
for f in ["stress_lens", "root_lens", "leaf_lens"]:
    make_plots(df, projections, to_plot, f, figdir, color_dict)