In [None]:
import numpy as np
import pandas as pd
import sys
from sklearn.cluster import SpectralClustering
from sklearn.cluster import KMeans
from sklearn.cluster import AgglomerativeClustering
from sklearn.metrics import adjusted_rand_score
from scipy.cluster.hierarchy import dendrogram, linkage, fcluster
import scipy.spatial.distance as ssd
import matplotlib
matplotlib.use('AGG')
from matplotlib import pyplot as plt
import seaborn as sns
sns.set(style='ticks', font_scale=1)

In [None]:
def getLinkage(model):
    # Children of hierarchical clustering
    children = model.children_

    # Distances between each pair of children
    # Since we don't have this information, we can use a uniform one for plotting
    distance = np.arange(children.shape[0])

    # The number of observations contained in each cluster level
    no_of_observations = np.arange(2, children.shape[0]+2)

    # Create linkage matrix and then plot the dendrogram
    linkage_matrix = np.column_stack([children, distance, no_of_observations]).astype(float)

    return linkage_matrix


def plotDendrogram(model, **kwargs):
    linkage_matrix = getLinkage(model)

    # Plot the corresponding dendrogram
    return dendrogram(linkage_matrix, **kwargs)


def preprocess(data, features, normalize, log, direction):
    # Select using actual values or rate of change
    if features == 'rate':
        data = data.diff(axis=1).drop(columns=[col for col in data.columns if 'Uninfected' in col])

    if direction:
        data[data>0] = 1
        data[data<0] = -1

    # Normalize by row
    if normalize:
        data = data.apply(lambda x: (x-x.min())/(x.max()-x.min()), axis=1)

    # Take log value
    if log:
        mask = (data < 0)
        data[data==0] = 1
        data = np.log10(np.abs(data))
        data[mask] *= -1

    return data


In [None]:
    how = 'median'
    features = 'rate'
    normalize = False
    log = True
    direction = False
    metric = 'correlation'
    method = 'complete'
    plot_log = False

    # Read metabolomics data
    raw_data = pd.read_csv('./data/metabolomics-data.csv', index_col=0)

    # Drop unwanted rows
    raw_data = raw_data.drop(index=['Replicate', 'Group', 'Mass (g)', 'Sample #'])
    raw_data = raw_data.astype(float)

    # Remove duplicates (why are there duplicates?)
    raw_data = raw_data.loc[~raw_data.index.duplicated(keep='first')]

    # Define which columns to keep
    cols = ['Uninfected', '1 hour post infection', '12 hours alive',\
            '24 hours alive', '24 hours dead', '2 days', '4 days',\
            '6 days', '8 days', '10 days', '12 days', '16 days']

In [None]:
    ''' Plot the std across trials '''
    stds = []
    for col in cols:
        data = raw_data[[x for x in raw_data.columns if col in x]]
        #stds.append(data.std(axis=1).values)
        stds.append((data.std(axis=1)/(data.mean(axis=1)+1)).values)

    plt.boxplot(stds, whis=[5,95], showfliers=False)
    plt.xticks(np.arange(len(cols))+1, cols, rotation=45, ha='right')
    plt.savefig('./stds.png', bbox_inches='tight')
    plt.clf()


    ''' First, cluster each trial and compare with rand index '''
    clusters = []
    clusters2 = []
    for i in range(4):
        if i != 0:
            keep_cols = ['{}.{}'.format(col, i) for col in cols]
        else:
            keep_cols = cols
        data = raw_data[keep_cols]
        data = preprocess(data, features, normalize, log, direction)
        clusters.append(fcluster(linkage(data.values, method=method, metric=metric), criterion='maxclust', t=20))
        clusters2.append(KMeans(n_clusters=10).fit(data.values).labels_)

    hm_data = pd.DataFrame(index=range(4), columns=range(4))
    hm_data2 = pd.DataFrame(index=range(4), columns=range(4))
    for i in range(4):
        for j in range(4):
            hm_data.loc[i,j] = adjusted_rand_score(clusters[i], clusters[j])
            hm_data2.loc[i,j] = adjusted_rand_score(clusters2[i], clusters2[j])
    print(hm_data, hm_data[hm_data!=1].mean().mean())
    print(hm_data2, hm_data2[hm_data2!=1].mean().mean())
    sns.heatmap(hm_data.astype(float), vmin=0, vmax=1)
    plt.savefig('./compare_trials.png', bbox_inches='tight')
    plt.clf()

In [None]:
    # Select how to combine the different trials
    if how == 'first':
        data = raw_data[cols]
    if how == 'mean':
        data = pd.concat([raw_data[[c for c in raw_data.columns if col in c]]\
                .mean(axis=1) for col in cols], axis=1, keys=cols)
    if how == 'median':
        data = pd.concat([raw_data[[c for c in raw_data.columns if col in c]]\
                .median(axis=1) for col in cols], axis=1, keys=cols)

    # Preprocess data
    data = preprocess(data, features, normalize, log, direction)

    # Cluster data points (heirarchical)
    if metric == 'correlation':
        corrs = 1 - data.T.corr().abs()
        #corrs = ssd.squareform(corrs)
        link = linkage(corrs, method=method)
    else:
        link = linkage(data, method=method, metric=metric)
    clusters = fcluster(link, criterion='maxclust', t=9)

    # Plot a dendrogram of the heirarchy
    plt.figure(figsize=(20,20))
    R = dendrogram(link, orientation='left', labels=data.index, leaf_font_size=8, count_sort='descending')
    plt.tight_layout()
    plt.savefig('./dendrogram.png', bbox_inches='tight')
    plt.clf()

    # Plot a heatmap with clusters
    print(len(clusters), len(data))
    data['cluster'] = clusters
    label_order = list(reversed(R['ivl']))
    hm_data = data.reindex(index=label_order)
    # Take log value
    if plot_log:
        mask = (hm_data < 0)
        hm_data[hm_data==0] = 1
        hm_data = np.log10(np.abs(hm_data))
        hm_data[mask] *= -1
    fig, ax = plt.subplots(figsize=(20,20))
    ax = sns.heatmap(hm_data.drop(columns='cluster'), ax=ax, cmap='coolwarm', yticklabels=True)
    sep = np.where(hm_data['cluster'].diff()!=0)[0]
    ax.hlines(sep, *ax.get_xlim(), linewidth=3)
    plt.tight_layout()
    plt.savefig('./heatmap.png', bbox_inches='tight')
    plt.clf()

In [None]:
    # Plot each cluster feature vectors
    for val in np.unique(clusters):
        cluster_idx = np.where(clusters == val)[0]
        for idx in cluster_idx:
            plot_data = data.drop(columns='cluster').iloc[idx]
            plt.plot(range(len(plot_data)), plot_data, label=data.index[idx])

        plt.legend()
        if plot_log:
            plt.yscale('symlog')
        plt.ylabel(features)
        plt.xlabel('Time')
        plt.xticks(range(len(data.columns)-1), data.drop(columns='cluster').columns)
        plt.tight_layout()
        plt.savefig('./{}.png'.format(val), bbox_inches='tight')
        plt.clf()

    # Plot a cluster map
    dummy_link = linkage(np.array([(x**3,0) for x in range(len(data.columns)-1, 0, -1)]))
    plot_data = data.drop(columns='cluster')
    # Take log value
    if plot_log:
        mask = (plot_data < 0)
        plot_data[plot_data==0] = 1
        plot_data = np.log10(np.abs(plot_data))
        plot_data[mask] *= -1
    cm = sns.clustermap(plot_data, row_linkage=link, col_linkage=dummy_link, figsize=(25,25), col_cluster=False, cmap='coolwarm', yticklabels=True)
    cm.ax_col_dendrogram.set_visible(False)
    plt.savefig('./clustermap.png')
    plt.clf()

    # Plot a cluster distance map
    cm = sns.clustermap(data.drop(columns='cluster').T.corr(), row_linkage=link, col_linkage=link, yticklabels=True, xticklabels=True, figsize=(25,25))
    plt.savefig('./clustermap-dist.png')
    plt.clf()