# Introduction
## The Data Set
In today's workshop, we will revisit the data set you worked with in the Machine Learning workshop. As a refresher:  this data set is from the GSE53987 dataset on Bipolar disorder (BD) and major depressive disorder (MDD) and schizophrenia:

Lanz TA, Joshi JJ, Reinhart V, Johnson K et al. STEP levels are unchanged in pre-frontal cortex and associative striatum in post-mortem human brain samples from subjects with schizophrenia, bipolar disorder and major depressive disorder. PLoS One 2015;10(3):e0121744. PMID: 25786133

This is a microarray data on platform GPL570 (HG-U133_Plus_2, Affymetrix Human Genome U133 Plus 2.0 Array) consisting of 54675 probes.

The raw CEL files of the GEO series were downloaded, frozen-RMA normalized, and the probes have been converted to HUGO gene symbols using the annotate package averaging on genes. The sample clinical data (meta-data) was parsed from the series matrix file. You can download it **here**.  

In total there are 205 rows consisting of 19 individuals diagnosed with BPD, 19 with MDD, 19 schizophrenia and 19 controls. Each sample has gene expression from 3 tissues (post-mortem brain). There are a total of 13768 genes (numeric features) and 10 meta features and 1 ID (GEO sample accession):

- Age
- Race (W for white and B for black)
- Gender (F for female and M for male)
- Ph: pH of the brain tissue
- Pmi: post mortal interval
- Rin: RNA integrity number
- Patient: Unique ID for each patient. Each patient has up to 3 tissue samples. The patient ID is written as disease followed by a number from 1 to 19
- Tissue: tissue the expression was obtained from.
- Disease.state: class of disease the patient belongs to: bipolar, schizophrenia, depression or control.
- source.name: combination of the tissue and disease.state

## Workshop Goals
This workshop will walk you through an analysis of the GSE53987 microarray data set. This workshop has the following three tasks:  
    1. Visualize the demographics of the data set  
    2. Cluster gene expression data and appropriately visualize the cluster results
    3. Compute differential gene expression and visualize the differential expression

Each task has a __required__ section and a __bonus__ section. Focus on completing the three __required__ sections first, then if you have time at the end, revisit the __bonus__ sections. 

Finally, as this is your final workshop, we hope that you will this as an opportunity to integrate the different concepts that you have learned in previous workshops.

## Workshop Logistics
As mentioned in the pre-workshop documentation, you can do this workshop either in a Jupyter Notebook, or in a python script. Please make sure you have set-up the appropriate environment for youself. This workshop will be completed using "paired-programming" and the "driver" will switch every 15 minutes. Also, we will be using the python plotting libraries matplotlib and seaborn. 

## TASK 0: Import Libraries and Data
- Download the data set (above) as a .csv file
- Initialize your script by loading the following libraries.

In [2]:
# Import Necessary Libraries
import pandas as pd
import numpy as np
import seaborn as sns
from sklearn import cluster, metrics, decomposition
from matplotlib import pyplot as plt
import itertools
data = pd.read_csv('GSE53987_combined.csv', index_col=0)
genes = data.columns[10:]

## TASK 1: Visualize Dataset Demographics  
### Required Workshop Task:  
##### Use the skeleton code to write 3 plotting functions:  
    1. plot_distribution()  
        - Returns a distribution plot object given a dataframe and one observation
    2. plot_relational()
        - Returns a distribution plot object given a dataframe and (x,y) observations  
    3. plot_categorical()
        - Returns a categorical plot object given a dataframe and (x,y) observations
##### Use these functions to produce the following plots:
    1. Histogram of patient ages
    2. Histogram of gene expression for 1 gene
    3. Scatter plot of gene expression for 1 gene by ages 
    4. Scatter plot of gene expression for 1 gene by disease state 
Your plots should satisfy the following critical components:  
    - Axis titles
    - Figure title
    - Legend (if applicable)
    - Be readable
    
### Bonus Task: 
1. Return to these functions and include functionality to customize color palettes, axis legends, etc. You can choose to define your own plotting "style" and keep that consistent for all of your plotting functions.  
2. Faceting your plots. Modify your functions to take in a "facet" argument that when facet is an observation, the function will create a facet grid and facet on that observation. Read more about faceting here:
Faceting generates multi-plot grids by __mapping a dataset onto multiple axes arrayed in a grid of rows and columns that correspond to levels of variables in the dataset.__  
    - In order to use facteting, your data __must be__ in a Pandas DataFrame and it must take the form of what Hadley Whickam calls “tidy” data. 
    - In brief, that means your dataframe should be structured such that each column is a variable and each row is an observation. There are figure-level functions (e.g. relplot() or catplot()) that will create facet grids automatically and can be used in place of things like distplot() or scatterplot(). 

In [63]:
# Import the data (.csv file) as a data frame
data = pd.read_csv("/Users/ebriars/Desktop/Bioinformatics/BRITE REU Workshops/Data Visualization/GSE53987_combined.csv", index_col=0)

# Function to Plot a Distribtion
def plot_distribution(df, obs1, obs2=''):
    """
    Create a distribution plot for at least one observation
    
    Arguments:
        df (pandas data frame): data frame containing at least 1 column of numerical values
        obs1 (string): observation to plot distribution on
        obs2 (string, optional)
    Returns:
        axes object
    """
    if obs2 == '':
        ax = sns.distplot(df[obs1])
    else:
        ax = sns.FacetGrid(df, hue=obs2)
        ax = (g.map(sns.distplot, obs1, hist=False))
    return ax

# Function to Plot Relational (x,y) Plots 
def plot_relational(df, x, y, hue=None, kind=None):
    """
    Create a plot for an x,y relationship (default = scatter plot)
    Optional functionality for additional observations.
    
    Arguments:
        df (pandas data frame): data frame containing at least 2 columns of numerical values
        x (string): observation for the independent variable
        y (string): observation for the dependent variable
        hue (string, optional): additional observation to color the plot on
        kind (string, optional): type of plot to create [scatter, line]
    Returns:
        axes object
    """
    if kind == None or kind == "scatter":
        ax = sns.scatterplot(data=df, x=x, y=y, hue=hue)
    else:
        ax = sns.lineplot(data=df, x=x, y=y, hue=hue)
    return ax

def plot_categorical(df, x, y, hue=None, kind=None):
    """
    Create a plot for an x,y relationship where x is categorical (not numerical)
    
    Arguments:
        df (pandas data frame): data frame containing at least 2 columns of numerical values
        x (string): observation for the independent variable (categorical)
        y (string): observation for the dependent variable
        hue (string, optional): additional observation to color the plot on
        kind (string, optional): type of plot to create. Options should include at least: 
        strip (default), box, and violin
    """
    if kind == None or kind == "strip":
        ax = sns.stripplot(data=df, x=x, y=y, hue=hue)
    elif kind == "violin":
        ax = sns.violinplot(data=df, x=x, y=y, hue=hue)
    elif kind == "box":
        ax = sns.boxplot(data=df, x=x, y=y, hue=hue)
    return ax

def main():
    """
    Generate the following plots:
    1. Histogram of patient ages
    2. Histogram of gene expression for 1 gene
    3. Scatter plot of gene expression for 1 gene by ages 
    4. Scatter plot of gene expression for 1 gene by disease state 
    """
    

## TASK 2: Differential Expression Analysis

Differential expression analysis is a fancy way of saying, "We want to find which genes exhibit increased or decreased expression compared to a control group". Neat. Because the dataset we're working with is MicroArray data -- which is mostly normally distributed -- we'll be using a simple One-Way ANOVA. If, however, you were working with sequence data -- which follows a Negative Binomial distribution -- you would need more specialized tools. A helper function is provided below.

In [7]:
def differential_expression(data, group_col, features, reference=None):
    """
    Perform a one-way ANOVA across all provided features for a given grouping.
    
    Arguments
    ---------
    
        data : (pandas.DataFrame)
            DataFrame containing group information and feature values.
        group_col : (str)
            Column in `data` containing sample group labels.
        features : (list, numpy.ndarray):
            Columns in `data` to test for differential expression. Having them
            be gene names would make sense. :thinking:
        reference : (str, optional)
            Value in `group_col` to use as the reference group. Default is None,
            and the value will be chosen.
            
    Returns
    -------
        pandas.DataFrame
            A DataFrame of differential expression results with columns for
            fold changes between groups, maximum fold change from reference,
            f values, p values, and adjusted p-values by Bonferroni correction.
    """
    if group_col not in data.columns:
        raise ValueError("`group_col` {} not found in data".format(group_col))
    if any([x not in data.columns for x in features]):
        raise ValueError("Not all provided features found in data.")
    if reference is None:
        reference = data[group_col].unique()[0]
        print("No reference group provided. Using {}".format(reference))
    elif reference not in data[group_col].unique():
        raise ValueError("Reference value {} not found in column {}.".format(
                         reference, group_col))
    by_group = data.groupby(group_col)
    reference_avg = by_group.get_group(reference).loc[:,features].mean()
    values = []
    results = {}
    for each, index in by_group.groups.items():
        values.append(data.loc[index, features])
        if each !=  reference:
            key = "{}.FoldChange".format(each)
            results[key] = data.loc[index, features].mean()\
                         / reference_avg
    fold_change_cols = list(results.keys())
    fvalues, pvalues = stats.f_oneway(*values)
    results['f.value'] = fvalues
    results['p.value'] = pvalues
    results['p.value.adj'] = pvalues * len(features)
    results_df = pd.DataFrame(results)
    def largest_deviation(x):
        i = np.where(abs(x) == max(abs(x)))[0][0]
        return x[i]
    results_df['Max.FoldChange'] = results_df[fold_change_cols].apply(
                                       lambda x: largest_deviation(x.values), axis=1)

    return results_df  

In [15]:
# Here's some pre-subsetted data
hippocampus = data[data["Tissue"] == "hippocampus"]
pf_cortex = data[data["Tissue"] == "Pre-frontal cortex (BA46)"]
as_striatum = data[data["Tissue"] == "Associative striatum"]
# Here's how we can subset a dataset by two conditions.
# You might find it useful :thinking:
data[(data["Tissue"] == 'hippocampus') & (data['Disease.state'] == 'control')]

Unnamed: 0,Patient,Source.name,Age,Gender,Race,Pmi,Ph,Rin,Tissue,Disease.state,...,ZSWIM8.AS1,ZW10,ZWILCH,ZWINT,ZXDA,ZXDB,ZXDC,ZYX,ZZEF1,ZZZ3
GSM1304870,control_10,"hippocampus, control",51,M,W,24.2,6.6,7.8,hippocampus,control,...,5.901962,6.924233,4.883176,7.037155,4.221055,4.071727,6.101329,6.840671,5.802425,7.951339
GSM1304871,control_11,"hippocampus, control",51,F,W,7.8,6.6,7.2,hippocampus,control,...,5.891827,6.686104,4.919824,7.019892,4.449879,4.110077,6.372282,7.096796,5.972458,7.995525
GSM1304872,control_12,"hippocampus, control",36,F,W,14.5,6.4,8.0,hippocampus,control,...,6.162939,7.245197,4.95024,7.558985,4.500566,4.214102,6.515646,6.374378,5.741482,7.969294
GSM1304873,control_13,"hippocampus, control",65,F,W,18.5,6.5,7.0,hippocampus,control,...,5.926336,6.451683,4.3526,6.130494,4.395073,4.175186,6.513069,7.194028,5.851711,7.977704
GSM1304874,control_14,"hippocampus, control",55,M,W,28.0,6.1,6.8,hippocampus,control,...,5.973318,6.760423,5.264255,7.481936,4.387548,4.218758,6.709437,6.740813,5.929435,7.820045
GSM1304875,control_15,"hippocampus, control",22,M,W,20.1,6.8,7.1,hippocampus,control,...,5.8603,6.748603,4.718314,7.282907,4.234804,4.009726,6.335518,6.874673,6.012384,7.792208
GSM1304876,control_16,"hippocampus, control",52,F,W,22.6,7.1,7.0,hippocampus,control,...,6.233145,6.715545,4.307412,6.938351,4.361942,4.094315,6.42252,6.849899,5.829819,8.066799
GSM1304877,control_17,"hippocampus, control",58,F,W,22.7,6.4,6.3,hippocampus,control,...,6.029479,6.41822,4.516272,6.932229,4.317337,4.087772,6.366974,7.096214,6.021788,7.811885
GSM1304878,control_18,"hippocampus, control",40,F,B,16.6,6.8,7.9,hippocampus,control,...,5.925501,6.953431,4.801928,7.410219,4.367789,4.151118,6.216399,6.73773,5.838513,7.889143
GSM1304879,control_19,"hippocampus, control",41,F,W,15.4,6.6,8.5,hippocampus,control,...,5.946981,6.942262,5.14493,7.33129,4.142431,4.094786,6.351937,6.892228,5.911674,7.77813


### Task 2a: Volcano Plots

Volcano plots are ways to showcase the number of differentially expressed genes found during high throughput sequencing analysis. Log fold changes are plotted along the x-axis, while p-values are plotted along the y-axis. Genes are marked significant if they exceed some absolute Log fold change theshold **as well** some p-value level for significance. This can be seen in the plot below.

![](https://galaxyproject.github.io/training-material/topics/transcriptomics/images/rna-seq-viz-with-volcanoplot/volcanoplot.png)

Your first task will be to generate some Volcano plots:

**Requirments**
1. Use the provided function to perform an ANOVA (analysis of variance) between control and experimental groups in each tissue.
    - Perform a separate analysis for each tissue.
2. Implement the skeleton function to create a volcano plot to visualize both the log fold change in expression values and the adjusted p-values from the ANOVA
3. Highlight significant genes with distinct colors

In [None]:
def volcano_plot(data, sig_col, fc_col, sig_thresh, fc_thresh):
    """
    Generate a volcano plot to showcasing differentially expressed genes.
    
    Parameters
    ----------
        data : (pandas.DataFrame)
            A data frame containing differential expression results
        sig_col : str
            Column in `data` with adjusted p-values.
        fc_col : str
            Column in `data` with fold changes.
        sig_thresh : str
            Threshold for statistical significance.
        fc_thresh
    """
    data['significant'] = False
    data[fc_col] = np.log2(data[fc_col])
    de_genes = (data[sig_col] < sig_thesh) & (data[fc_col].abs() > fc_thresh)
    data.loc[de_genes, 'significant'] = True
    ax = sns.scatterplot(x=fc_col, y=sig_col, hue='significant', data=data,
                         palette=['black', 'red'], alpha=0.75)
    linewidth = plt.rcParams['lines.linewidth'] - 1
    plt.axvline(x=fc_thresh, linestyle='--', linewidth=linewidth,
                color='#4D4E4F')
    plt.axvline(x=-fc_thresh, linestyle='--', linewidth=linewidth,
                color='#4D4E4F')
    plt.axhline(y=sig_thresh, linestyle='--', linewidth=linewidth,
                color='#4D4E4F')
    ax.legend().set_visible(False)
    ylabel = sig_col
    if sig_col.lower() == 'fdr':
        ylabel = 'False Discovery Rate'
    plt.xlabel(r"$log_2$ Fold Change")
    plt.ylabel(ylabel)
    for spine in ['right', 'top']:
        ax.spines[spine].set_visible(False)
    plt.tight_layout()
    return ax

### Task 2b: Plot the Top 1000 Differentially Expressed Genes

Clustered heatmaps are hugely popular for displaying differences in gene expression values. To reference such a plot, look back at the introductory material. Here we will be plotting the 1000 most differentially expressed genes for each of the analysis performed before.

**Requirements**
- Implement the skeleton function below
- Z normalize gene values
- Use a diverging and perceptually uniform colormap
- Generate plots for each of the DE results above

**Hint**: Look over all the options for [sns.clustermap()](https://seaborn.pydata.org/generated/seaborn.clustermap.html). It might make things easier.

In [None]:
def heatmap(data, genes, group_col):
    """[summary]
    
    Parameters
    ----------
    data : pd.DataFrame
        A (sample x gene) data matrix containing gene expression values for each sample.
    genes : list, str
        List of genes to plot   
    """

    plot_data = anno_df[:, genes]
    ax = sns.clustermap(data, cmap='RdBu_r', z_score=1)
    return ax

**Bonus** There's nothing denoting which samples belong to which experimental group. Fix it.

*Bonus hint*: Look real close at the documentation.

## TASK 3: Clustering Analysis

You've seen clustering in the previous machine learning workshop. Some basic plots were generated for you, including plotting the clusters on the principle componets. While we can certainly do more of that, we will also be introducing two new plots: elbow plots and silhouette plots.

### Elbow Plots

Elbow plots are plots that are used to help diagnose the perennial question of K-means clustering: how do I chose K? To create the graph, you plot the number of clusters on the x-axis and some evaluation of "cluster goodness" on the y-axis. Looking at the name of the plot, you might guess that we're looking for an "elbow". This is the point in the graph when we start getting diminished returns in performance, and specifying more clusters may lead to over-clustering the data. An example plot is shown below.

![](https://upload.wikimedia.org/wikipedia/commons/c/cd/DataClustering_ElbowCriterion.JPG)

You can see the K selected (K = 3), is right before diminishing returns start to kick in. Mathematically, this point is defined as the point in which curvature is maximized. However, the inflection point is also a decent -- though more conservative -- estimate. However, we'll just stick to eye-balling it for this workshop. If you would like to know how to automatically find the elbow point, more information can be found [here](https://raghavan.usc.edu/papers/kneedle-simplex11.pdf)

### Task 2a: Implement a function that creates an elbow plot

Skeleton code is provided below. The function expects a list of k-values and their associated scores. An optional "ax" parameter is also provided. This parameter should be an axes object and can be created by issueing the following command:

```ax = plt.subplot()```

While we won't need the parameter right now, we'll likely use it in the future.

**Function Requirements**
- Generate plot data by clustering the entire dataset on the first 50 principle components. Vary K values from 2 - 10.
    - While you've been supplied a helper function for clustering, you'll need to supply the principle components yourself. Refer to your machine learning workshop along with the scikit-learn [documentation](https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.PCA.html)
- Plots each k and it's associated value.
- Plots lines connecting each data point.
- Produces a plot with correctly labelled axes.

**Hint:** Working with an axis object is similar to base matplotlib, except `plt.scatter()` might become something like `ax.scatter()`.

#### Helper Function

In [6]:
def cluster_data(X, k):
    """
    Cluster data using K-Means.
    
    Parameters
    ----------
        X : (numpy.ndarray)
            Data matrix to cluster samples on. Should be (samples x features).
        k : int
            Number of clusters to find.
    Returns
    -------
        tuple (numpy.ndarray, float)
            A tuple where the first value is the assigned cluster labels for
            each sample, and the second value is the score associated with
            the particular clustering.
    """
    model = cluster.KMeans(n_clusters=k).fit(X)
    score = model.score(X)
    return (model.labels_, score)

#### Task 2a Implementation

In [3]:
def elbow_plot(ks, scores, best=None, ax=None):
    """
    Create a scatter plot to aid in choosing the number of clusters using K-means.
    
    
    Arguments
    ---------
        ks : (numpy.ndarray)
            Tested values for the number of clusters.
        scores: (numpy.ndarray)
            Cluster scores associated with each number K.
        ax: plt.Axes Object, optional
    """
    if ax is None:
        fig, ax = plt.subplots()
    ax.scatter(ks, scores)
    ax.plot(ks, scores)
    ax.set_xlabel("Number of Clusters")
    ax.set_ylabel("Negative Distance From Mean")
    return ax

Once you've created the base plotting function, you'll probably realize we have no indivation of where the elbow point is. Fix this by adding another optional parameter (`best`) to your function. The parameter `best` should be the K value that produces the elbow point.

**Function Requirements**

- Add an optional parameter `best` that if supplied denotes the elbow point with a vertical, dashed line.
- If `best` is not supplied, the plot should still be produced but without denoting the elbow point.

**Hint**: `plt.axvline` and `plt.axhline` can be used to produce vertical and horizontal lines, respectively. More information [here](https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.pyplot.axvline.html)

**Note**: You are not required to have the line end at the associated score value.

In [2]:
def elbow_plot(ks, scores, best=None, ax=None):
    """
    Create a scatter plot to aid in choosing the number of clusters using K-means.
    
    
    Arguments
    ---------
        ks : (numpy.ndarray)
            Tested values for the number of clusters.
        scores: (numpy.ndarray)
            Cluster scores associated with each number K.
        best: int, optional
            The best value for K. Determined by the K that falls at the elbow. If
            passed, a black dashed line will be plotted to indicate the best.
            Default is no line.
        ax: plt.Axes Object, optional
    """
    if ax is None:
        fig, ax = plt.subplots()
    ax.scatter(ks, scores)
    ax.plot(ks, scores)
    ax.set_xlabel("Number of Clusters")
    ax.set_ylabel("Negative Distance From Mean")
    if best is not None:
        if best not in ks:
            raise ValueError("{} not included in provided number "
                             "of clusters.")
        idx = np.where(np.array(ks) == best)[0][0]
        ymin, ymax = ax.get_ylim()
        point = (scores[idx] - ymin) / (ymax - ymin)
        print(point)
        ax.axvline(x=best, ymax=point, linestyle="--", c='black')
        ax.scatter([3], scores[idx], edgecolor='black', facecolor="none",
                    s=200)
        ax.set_title("Elbow at K={}".format(best), loc='left')
    return ax

### Silhouette Plots

Silhouette plots are another way to visually diagnose cluster performance. They are created by finding the [silhouette coefficient](https://en.wikipedia.org/wiki/Silhouette_(clustering)) for each sample in the data, and plotting an area graph for each cluster. The silhouette coefficient measures how well-separated clusters are from each other. The value ranges from $[-1 , 1]$, where 1 indicates good separation, 0 indicates randomness, and -1 indicates mixing of clusters.  An example is posted below.

![](https://scikit-plot.readthedocs.io/en/stable/_images/plot_silhouette.png)

As you can see, each sample in each cluster has the area filled from some minimal point (usually 0 or the minimum score in the dataset) and clusters are separated to produce distinct [silhouettes](https://www.youtube.com/watch?v=-TcUvXzgwMY).

### Task 3b: Implement a function to plot silhouette coefficients

Because the code for create a silhouette plot can be a little bit involved, we've created both a skeleton function with documentation, and provided the following pseudo-code:

```
- Calculate scores for each sample.
- Get a set of unique sample labels.
- Set a score minimum
- Initialize variables y_lower, and y_step
    - y_lower is the lower bound on the x-axis for the first cluster's silhouette
    - y_step is the distance between cluster silhouettes
- Initialize variable, breaks
    - breaks are the middle point of each cluster silhouette and will be used to
      position the axis label
- Interate through each cluster label, for each cluster:
    - Calcaluate the variable y_upper by adding the number of samples
    - Fill the area between y_lower and y_upper using the silhoutte scores for
      each sample
    - Calculate middle point of y distance. Append the variable break.
    - Calculate new y_lower value
- Label axes with appropriate names and tick marks
- Create dashed line at the average silhouette score over all samples
```

**Hint**: you might find [ax.fill_betweenx()](https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.axes.Axes.fill_betweenx.html)
and [ax.set_yticks()](https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.axes.Axes.set_yticks.html?highlight=set_yticks#matplotlib.axes.Axes.set_yticks)/
[ax.set_yticklabels()](https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.axes.Axes.set_yticklabels.html?highlight=set_yticklabels#matplotlib.axes.Axes.set_yticklabels) useful.

In [4]:
def silhouette_plot(X, y, ax=None):
    """
    Plot silhouette scores for all samples across clusters. 
    
    Parameters
    ----------
    X : numpy.ndarray
        Numerical data used to cluster the data.
    y : numpy.ndarray
        Cluster labels assigned to each sample.
    ax : matplotlib.Axes
        Axis object to plot scores onto. Default is None, and a new axis will
        be created.
    
    Returns
    -------
    matplotlib.Axes
    """
    if ax is None:
        ax = plt.subplot()
    scores = metrics.silhouette_samples(X, y)
    clusters = sorted(np.unique(y))
    score_min = 0
    y_lower, y_step = 5, 5
    props = plt.rcParams['axes.prop_cycle']
    colors = itertools.cycle(props.by_key()['color'])
    breaks = []
    for each, color in zip(clusters, colors):
        # Aggregate the silhouette scores for samples, sort scores for
        # area filling
        cluster_scores = scores[y == each]
        cluster_scores.sort()
        y_upper = y_lower + len(cluster_scores)
        ax.fill_betweenx(np.arange(y_lower, y_upper),
                        score_min, cluster_scores,
                        facecolor=color, edgecolor=color, alpha=0.7,
                        label=each)
        breaks.append((y_upper + y_lower) / 2)
        # Compute the new y_lower for next plot
        y_lower = y_upper + y_step
    plt.legend()
    ax.set_xlabel("Silhouette Coefficient")
    ax.set_ylabel("Cluster")

    # Vertical line for threshold 
    ax.set_yticks(breaks)
    ax.set_yticklabels(clusters)
    ax.axvline(x=0, linestyle="-", linewidth=2, c='black')
    ax.axvline(x=np.mean(scores), linestyle='--', c='black')
    plt.tight_layout()
    return ax

### Task 3C: Put it all together!

**Requirements**
- Create a function `cluster_and_plot` that will cluster a provided dataset for a range of k-values
- The function should return a single figure with two subplots:
    - An elbow plot with the "best" K value distinguished
    - A silhouette plot associated with clustering determined by the provided K value.
- Appropriate axes labels

**Hint**: You will likely find [plt.subplots()](https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.pyplot.subplots.html?highlight=subplots#matplotlib.pyplot.subplots) useful.

In [5]:
def cluster_and_plot(X, best=3, kmax=10):
    """
    Cluster samples using KMeans and display the results.
    
    Results are displayed in a (1 x 2) figure, where the
    first subplot is an elbow plot and the second subplot
    is a silhouette plot.
    
    Parameters
    ----------
        X : (numpy.ndarray)
            A (sample x features) data matrix used to cluster
            samples.
        best : int, optional
            Final value of K to use for K-Means clustering.
            Default is 3.
        kmax : int, optional
            Maximum number of clusters to plot in the elbow
            plot. Default is 10.
    Returns
    -------
        matplotlib.Figure
            Clustering results.
    """
    fig, axes = plt.subplots(nrows=1, ncols=2)
    scores = np.array([cluster_data(X, k)[1] for k in np.arange(1, kmax+1)])
    y, score = cluster_data(X, best)
    elbow_plot(np.arange(1, kmax + 1), scores, best, axes[0])
    silhouette_plot(X, y, axes[1])
    return fig