# PAMM Clustering - Whole Dataset

Sample Notebook to use PAMM clustering algorithm (orignal [paper](https://pubs.acs.org/doi/abs/10.1021/acs.jctc.7b00993)) with the GMPLabTools implementation.

The keyword **WHOLE** dataset refers to the tratments of the dataset towards the kernel density estimation (KDE), which are "summed" togheter.

In [None]:
import time
import warnings
import random
import seaborn as sns
import numpy as np

from scipy.cluster.hierarchy import dendrogram

from gmplabtools.analysis import DataSampler
from gmplabtools.pamm import PammGMM
from gmplabtools.pamm import Pamm
from gmplabtools.analysis import calculate_adjacency, adjancency_dendrogram
from gmplabtools.analysis import ClusterRates

import matplotlib.pyplot as plt
%matplotlib inline

## Utilities Functions

In [None]:
def make_colors(clust,mode='tab20'):
    if np.min(clust) == -1:
        N = np.unique(clust).shape[0] - 1
        colors = sns.color_palette(mode, N) + [(0,0,0)]
    else:
        N = np.unique(clust).max()
        colors = sns.color_palette(mode, N) 
    return colors


def get_axes(L, max_col=3, fig_frame=(5,4), res=100):
    cols = L if L <= max_col else max_col
    rows = int(L / max_col) + int(L % max_col != 0)
    fig, ax = plt.subplots(rows, cols, figsize=(cols * fig_frame[0], rows * fig_frame[1]), dpi=res)
    ax =  ax.flatten()
    return fig, ax


def shuffle(X, Y=None, n=None):
    l = np.arange(X.shape[0])
    random.shuffle(l)
    if Y is None:
        return X[l[:n],:]
    elif Y is None and n is None:
        return X[l,:]
    elif n is None:
        return X[l,:], Y[l]
    else: 
        return X[l[:n],:], Y[l[:n]]



## Dataset definition and loading

The data that one wants to process needs to be load and initialized as follows

`SYSX1 = np.loadtxt(my_dir/my_fileX1)`

and then put in a well named dictionary

`SYST = {
    'name_X1' : SYSX1,
    'name_X2' : SYSX2,
        ...   : ...  ,
}`

As stated before in this workflow one need to define a _wholesystemData_ and store it accordingly

`ALL = np.loadtxt(my_dir/my_wholedata)`

In [None]:
PCA_DIR='../pca_files/rcut65'

In [None]:
SYS1 = np.loadtxt(PCA_DIR+'/PCA_300-4np-005_CIT_rcut65_trj0-20000-50.pca')
ALL = np.loadtxt(PCA_DIR+'/wholesystem.pca')

In [None]:
SYST = {
    '300_005' : SYS1,
}

In [None]:
DIM = ALL.shape[1]
print(f"Data dimensions considered: {DIM}")

In [None]:
ALL.shape

## General variables

In [None]:
CHUNK = 5000
LABEL_SIZE = 18
L = len(SYST)
SAVE_PLOT = True

In [None]:
shffull = shuffle(ALL)

fig, ax = plt.subplots()
ax.scatter(shffull[:CHUNK,0], shffull[:CHUNK,1])
ax.set_title('whole data visualization')
gx = ax.get_xlim()
gy = ax.get_ylim()   

## Algorithm inputs

The paramters for the calculation needs to be stored as follows.

The meaning of these parameters can be found in the orignal [paper](https://pubs.acs.org/doi/abs/10.1021/acs.jctc.7b00993).

The `nm_frame` refers to how many components a frame of the trajectory is composed (es. fiber having 40 monomers `nm_frame : 40`).

In [None]:
default_inputs = dict(
    # cluster
    distance = "minkowski",
    size = 2250,
    p = 2,
    generate_grid = True,
    savegrid = "grid_data",
    # cluster inputs
    d = DIM,
    fspread = 0.20,
    ngrid = 2250,
    qs = 1,
    o = "pamm",
    trajectory = PCA_DIR+"/wholesystem.pca",
    merger = 0.005,
    bootstrap = 73
)

In [None]:
datasets_cluster = [
    (ALL, {}),
]

datasets_predict = [
    (SYS5, {'sys' : '300_005', 'nm_frame' : 1072})
]

## Original dataset plot

In [None]:
colors=sns.color_palette('tab10', L)
fig, ax = get_axes(L, max_col=L)
for i,s in enumerate(SYST):
    ax[i].scatter(SYST[s][:CHUNK,0], SYST[s][:CHUNK,1], s=10, linewidth=1, marker="o", alpha=0.5)
    ax[i].set_xlim(gx)
    ax[i].set_ylim(gy)
    ax[i].set_title(f"{s}", weight='bold',size=LABEL_SIZE)
    ax[i].tick_params(labelsize=LABEL_SIZE,width=3,size=7)
    
    for side in ['bottom','right','top','left']:
        ax[i].spines[side].set_linewidth(3)
    
    if i == 0:
        ax[i].set_xlabel('PCA 1', weight='bold',size=LABEL_SIZE)
        ax[i].set_ylabel('PCA 2', weight='bold',size=LABEL_SIZE)
        for side in ['right','top']:
            ax[i].spines[side].set_visible(False)          
    else:
        ax[i].set_xlabel('PCA 1', weight='bold',size=LABEL_SIZE)
        ax[i].tick_params(labelleft=None)
        for side in ['right','top']:
            ax[i].spines[side].set_visible(False)

fig.tight_layout()

if SAVE_PLOT:
    plt.savefig("data_set_soap_pca.png")

## PAMM - Clustering part

In [None]:
for i_dataset, (dataset, algo_params) in enumerate(datasets_cluster):
    # update parameters with dataset-specific values
    params = default_inputs.copy()
    params.update(algo_params)

    # Clustering
    p = Pamm(params)
    print('\n#-----------------------------------------------')
    print(p.command_parser)
    
    print('\nRUNNING Clustering')
    t0 = time.time()
    p.run()
    t1 = time.time()
    print('TIME= '+str(np.round(t1-t0, 2))+' s \n')

## PAMM - Prediction on data part

In [None]:
gmm = PammGMM.read_clusters('pamm.pamm', 
                                grid_file='pamm.grid', 
                                bootstrap_file='pamm.bs')
NUM_CLUST=np.unique(gmm.pk).shape[0]
print(f"There are {NUM_CLUST} clusters")

In [None]:
cluster_output = {}
grid_cluster = {}
prob_output = {}
bootstr_output = {}
systnames = []
for i_dataset,dataset in enumerate(datasets_predict):
    run_syst = str(datasets_predict[i_dataset][1]['sys'])
    # Predict
    print('\nRUNNING Predict '+run_syst)
    t0 = time.time()
    
    x = datasets_predict[i_dataset][0]
    x_ = gmm.predict_proba(x)
    labels = np.argmax(x_, axis=1) #.reshape((-1, 1))

    t1 = time.time()
    print('TIME= '+str(np.round(t1-t0, 2))+' s \n')

    # Storing data
    cluster_output[run_syst] = labels
    grid_cluster[run_syst] = gmm.cluster
    prob_output[run_syst] = gmm.p
    bootstr_output[run_syst] = gmm.bs
    systnames.append(run_syst)

    # output for initial clustering
    np.savetxt(run_syst + "_clusters.dat", labels.reshape((-1, 1)))
    
    rates = ClusterRates(datasets_predict[i_dataset][1]['nm_frame'], 'label').calculate_matrix(labels.reshape((-1, 1)))
    np.savetxt(run_syst + "_rates.dat", rates)

## Output post-processing

In [None]:
#def PLTmatrixrates(ax,data,labels,s=18):
    
#    sns.heatmap(data, annot=True, fmt=".2f", cbar=False, ax=ax, annot_kws={"fontsize":s})
#    ax.xaxis.tick_top()
#    ax.set_xticklabels(labels, size='18', weight='bold')
#    ax.set_yticklabels(labels, size='18', weight='bold')
#    return ax

def PLTmatrixrates(ax,data,s=18):
    
    sns.heatmap(data, annot=True, fmt=".2f", cbar=False, ax=ax, annot_kws={"fontsize":s})
    ax.xaxis.tick_top()
    #ax.set_xticklabels(labels, size='18', weight='bold')
    #ax.set_yticklabels(labels, size='18', weight='bold')
    return ax

In [None]:
colors=make_colors(NUM_CLUST,mode='tab20')
fig, ax = get_axes(L, max_col=L)
for i,sys in enumerate(systnames):
    labels = cluster_output[sys]
    ax[i].scatter(datasets_predict[i][0][:CHUNK,0], datasets_predict[i][0][:CHUNK,1], c=np.array(colors)[labels[:CHUNK]], s=10)
    ax[i].set_xlim(gx)
    ax[i].set_ylim(gy)
    ax[i].set_title(f"{sys}", weight='bold',size=LABEL_SIZE)
    ax[i].tick_params(labelsize=LABEL_SIZE,width=3,size=7)
    
    for side in ['bottom','right','top','left']:
        ax[i].spines[side].set_linewidth(3)
    
    if i == 0:
        ax[i].set_xlabel('PCA 1', weight='bold',size=LABEL_SIZE)
        ax[i].set_ylabel('PCA 2', weight='bold',size=LABEL_SIZE)
        for side in ['right','top']:
            ax[i].spines[side].set_visible(False)          
    else:
        ax[i].set_xlabel('PCA 1', weight='bold',size=LABEL_SIZE)
        ax[i].tick_params(labelleft=None)
        for side in ['right','top']:
            ax[i].spines[side].set_visible(False)

fig.tight_layout()

if SAVE_PLOT:
    fig.savefig('clusters_pamm.png')

### Cluster iterconversion matrices

These matrices $N_{clusters} \times N_{clusters}$ represent all the transfomation that clusters undergo during the trajectories analyzed.
A sample row $n_i$ gives you the probability, in terms of frequency, of that cluster to becomes the $n_j$ column cluster.

In [None]:
SYS1rates = np.loadtxt("./300_005_rates.dat")

RATES = {
    '300_005' : SYS1rates
}

In [None]:
labels_ = ['0','1', '2', '3', '4','5']
fig, ax = get_axes(L, max_col=3, fig_frame=(5,4), res=100)
for i,sys in enumerate(RATES):
    PLTmatrixrates(ax[i], RATES[sys], s=14)
    
fig.tight_layout()

if SAVE_PLOT:
    fig.savefig('micro_clusters_pamm_matrix.png')

### Clusters hierarchy

The dendrograms shoul be all identical, since the prediction on the single sets came from a merged big set.

In [None]:
prob_output.keys()

In [None]:
fig, ax = get_axes(L, max_col=L)
for i,den in enumerate(SYST):
    adjacency, mapping = calculate_adjacency(
    prob=prob_output[den],
    clusters=grid_cluster[den],
    bootstrap=bootstr_output[den]
    )

    z = adjancency_dendrogram(adjacency)
    _ = dendrogram(z, ax=ax[i], count_sort=True)['leaves']
    
    ax[i].set_title(den)
    ax[i].set_yticks([])
    ax[i].yaxis.set_ticks_position('none')
    
    for side in ['bottom','right','top','left']:
        ax[i].spines[side].set_visible(False)

fig.tight_layout()    
    
if SAVE_PLOT:
    fig.savefig('clusters_pamm_dendrogram.png')

### Clusters mearging (Macroclusters processing)

Macrocluster syntax definition:

`mapping = [
    ('SYSX1', {MacroCl1: [microClx,...], 
               MacroCl2: [microCly,...]})
]`

where the mearging comes from the dendrogram.

In [None]:
mapping = [
    ('300_005', {0: [1,4],
              1: [3,2,5],
              2: [0,6]})
]

In [None]:
# it does not matter if one put np.argmax(y__, axis=1).reshape((-1,1)) \w or \wout the reshape part
macro_cluster_output = {}
rates_macro_clusters = {}

for s,macro_cl in enumerate(systnames):
    # Macro Cluster
    run_syst = macro_cl
    print("MACRO CLUSTERS - "+run_syst)
    
    y = datasets_predict[s][0]
    y_ = gmm.predict_proba(y)
    y__ = np.zeros((y.shape[0], len(mapping[s][1])))
    for k, v in mapping[s][1].items():
        y__[:, k] = y_[:,v].sum(1)

    macro_cluster_output[macro_cl] = np.argmax(y__, axis=1)
    np.savetxt(run_syst+'_macro_cluster.dat', np.argmax(y__, axis=1).reshape((-1,1)) )
    
    rates = ClusterRates(datasets_predict[s][1]['nm_frame'], 'label').calculate_matrix(np.argmax(y__, axis=1).reshape((-1,1)) )
    rates_macro_clusters[macro_cl] = rates
    np.savetxt(run_syst+'_macro_rates.dat', rates)

In [None]:
rates_macro_clusters

In [None]:
Mcolors = ["#ff0000","#0b0bff","#00ffff"]

fig, ax = get_axes(L, max_col=L)
for i,sys in enumerate(SYST):
    colors=Mcolors
    labels = macro_cluster_output[sys]
    print(np.bincount(labels)/len(labels)*100)
    ax[i].scatter(datasets_predict[i][0][:CHUNK,0], datasets_predict[i][0][:CHUNK,1], c=np.array(colors)[labels[:CHUNK]], s=10)
    ax[i].set_title(f"{sys}", weight='bold',size=3)
    ax[i].tick_params(labelsize=LABEL_SIZE,width=3,size=7)

    ax[i].set_xlim(gx)
    ax[i].set_ylim(gy)
    
    for side in ['bottom','right','top','left']:
        ax[i].spines[side].set_linewidth(3)
    
    if i == 0:
        ax[i].set_xlabel('PCA 1', weight='bold',size=LABEL_SIZE)
        ax[i].set_ylabel('PCA 2', weight='bold',size=LABEL_SIZE)
        for side in ['right','top']:
            ax[i].spines[side].set_visible(False)          
    else:
        ax[i].set_xlabel('PCA 1', weight='bold',size=LABEL_SIZE)
        ax[i].tick_params(labelleft=None)
        for side in ['right','top']:
            ax[i].spines[side].set_visible(False)

#fig.tight_layout()    
    
if SAVE_PLOT:
    fig.savefig('macro_clusters_pamm.png')

In [None]:
SYS1rates = np.loadtxt("./300_005_macro_rates.dat")

RATES = {
    '300_005' : SYS1rates
}

In [None]:
def plot_matrix_nolabels(dataMatrix, axes, palette_name="viridis", s=16,lw=1):
    sns.heatmap(dataMatrix, annot=True, fmt=".2f", 
    			vmin=0.0, vmax=1.0, 
                cbar=False, ax=axes,
                annot_kws={"fontsize":s, "fontweight":'bold'}, linewidths=lw, linecolor='w')
    axes.xaxis.tick_top()
    axes.tick_params(labeltop=False,labelleft=False)
    
    return axes


In [None]:
labels_ = ['0','1', '2','3']
fig, ax = get_axes(L, max_col=3, fig_frame=(5,4), res=100)
for i,sys in enumerate(RATES):
    plot_matrix_nolabels(RATES[sys], ax[i], s=25)
    

fig.tight_layout()    
    
if SAVE_PLOT:
    fig.savefig('macro_clusters_pamm_matrix.png')

In [None]:
# funzioni

def histo_axe(L, fig_frame=(3,3), res=100, s=10):
    fig, axes = get_axes(L,L,fig_frame=fig_frame,res=res)
    for i in range(L):
        for side in ['bottom','left']:
            axes[i].spines[side].set_linewidth(3)
        for side in ['right','top']:
            axes[i].spines[side].set_visible(False)
    return fig, axes

def plot_bars(percentages, width, colors, axes, label_size=16, annotate=True):
    xdummy = np.arange(len(percentages))
    labels = [str(i+1) for i in range(len(percentages))]
    bar = axes.bar(xdummy-width, percentages, label=labels, color=colors)
    axes.set_ylim(0,110)
    axes.set_xticks([])
    axes.set_xticklabels([])
    if annotate:
        autolabel(bar, label_size=label_size, axes=axes)
    axes.tick_params(labelsize=label_size,width=3,size=7)
    return axes

def autolabel(rects, label_size, axes):
    """Attach a text label above each bar in *rects*, displaying its height."""
    for rect in rects:
        height = rect.get_height()
        if height < 1 and height != 0:
            axes.annotate('<0.5%',
                        xy=(rect.get_x() + rect.get_width() / 2, height),
                        xytext=(0, 3),  # 3 points vertical offset
                        textcoords="offset points",
                        ha='center', va='bottom', fontsize=label_size)
        else:
            axes.annotate(f'{height}',
                        xy=(rect.get_x() + rect.get_width() / 2, height),
                        xytext=(0, 3),  # 3 points vertical offset
                        textcoords="offset points",
                        ha='center', va='bottom', fontsize=label_size)

# get the cluster fractions
def get_clustersFraction(labels):
    percent = np.bincount(labels) * 100. / np.sum(np.bincount(labels))
    return np.round(percent)

In [None]:
#labels = np.loadtxt('XXX_cluster.dat').astype(int)
L = len(macro_cluster_output)

fig, ax = histo_axe(16, fig_frame=(2.5,2))
for i,sys in enumerate(macro_cluster_output):
    clust_tmp = macro_cluster_output[sys]
    
    plot_bars(get_clustersFraction(clust_tmp), 
              width=.15,
              colors=Mcolors, # sono i colori dei cluster, la stessa dei plottini
              axes=ax[i], annotate=True)
fig.tight_layout()

if SAVE_PLOT:
    fig.savefig('macro_clusters_pamm_population.png')