## Community State Type Analysis - Genera Level

This notebook contains the steps to perform hierarchical clustering and derive community state types.

In [None]:
import pandas as pds
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import seaborn as sns
from scipy.spatial import distance
from sklearn.metrics import silhouette_samples
import matplotlib.ticker as ticker
from CST_Clustering import cluster16SMatrix, validateClusters
import os

os.makedirs('../Analysis Figures/CSTAnalysis/', exist_ok=True)

The hierarchical clustering will be performed on the data matrix with counts aggregated per species (sum of all OTUs for a given species). 

In [None]:
# Read the 16S count data matrix
counts_genera = pds.read_csv('../../Data/Genera_Counts_Table.csv')

clinicalData = pds.read_csv('../../Data/ClinicalData.csv')

clinicalData = clinicalData.loc[:, ['Sample Name', 'Classification']]

counts_genera = clinicalData.merge(counts_genera, on='Sample Name')
# Sort columns by species count - to facilitate plots downstream
countsOrder = np.argsort(counts_genera.iloc[:, 4:].sum())[::-1]
counts_genera = pds.concat([counts_genera.iloc[:, 0:4], counts_genera.iloc[:, 4:].iloc[:, countsOrder]], axis=1)

In [None]:
taxonomyDataset = pds.read_csv('../../Data/Genera_Taxonomy.csv')

In [None]:
# Select only the columns containing counts. The index starts at 27
counts_matrix = counts_genera.iloc[:, 4::]
counts_matrix.index = counts_genera['Sample Name'].values
counts_matrix.index = counts_matrix.index.set_names('SampleID')

counts_matrix.columns = taxonomyDataset['FullName'].values[countsOrder]

# Selecting the number of clusters - CST's

In the next cell, the "optimal" number of clusters is estimated using the silhouette score. 
The validateClusters function is defined separately in the *CST_Clustering.py* file.

In [None]:
cluster_validation = validateClusters(counts_matrix, distanceMetric='jensenshannon', clusterMethod='ward', method='silhouette')
fig, ax = plt.subplots(dpi=150, figsize=(6, 3))
ax.plot(cluster_validation[0], cluster_validation[1], '-o')
ax.xaxis.set_major_locator(ticker.MultipleLocator(2))
ax.set_xlabel("Number of Clusters")
ax.set_ylabel("Average {0} score".format(cluster_validation[2]))


fig.savefig('../Analysis Figures/CSTAnalysis/Semen_16S_SilhouetteValidation_GeneraLevel.png')
fig.savefig('../Analysis Figures/CSTAnalysis/Semen_16S_SilhouetteValidation_GeneraLevel.svg')
plt.show()

print("The suggested number of clusters is {0}".format(cluster_validation[0][np.argmax(cluster_validation[1])]))

## Perform the Clustering of the CLR transformed 16S data matrix

Vaginal Community State types are obtained as described by Ravel et al<sup>1</sup> performing Ward hierarchical clustering and using Jensen-Shannon divergence as distance metric.

Based on the silhouette score suggestion we 3 clusters.

1 - doi: 10.1073/pnas.1002611107

In [None]:
CST_Clustering = cluster16SMatrix(counts_matrix, nClusters=3, distanceMetric='jensenshannon', clusterMethod='ward')

### Heatmap plot with the clustering dendrogram

In [None]:
from matplotlib.colors import rgb2hex
from matplotlib.pyplot import gcf

nCST = len(CST_Clustering['clusterID'].unique())

cmapCST = ListedColormap(sns.color_palette("Set1", nCST))
cmapClassification = ListedColormap(sns.color_palette("Set2", 5))

ClassificationLevels = pds.Categorical(counts_genera['Classification']).categories
ClassificationRowColor = pds.Series(pds.Categorical(counts_genera['Classification']).codes)

row_colors = np.c_[CST_Clustering['clusterID'].map(cmapCST), ClassificationRowColor.map(cmapClassification)]
row_colors = np.vectorize(rgb2hex)(row_colors)

cstMap = sns.clustermap(np.log(counts_matrix.values[:, 0:25].T + 1), col_linkage=CST_Clustering['LinkageMatrix'], 
                        row_cluster=False, col_colors=row_colors.T, yticklabels=counts_matrix.columns[0:25], cmap='viridis', xticklabels=False)

# Add legend with cluster assignments to match the relative abundance plots for interpretation (Genera + Species)
for labelIdx in range(nCST):
    cstMap.ax_col_dendrogram.bar(0, 20, color=cmapCST(labelIdx), label='Cluster ' + str(labelIdx + 1), linewidth=0)

# cstMap.ax_col_dendrogram.bar(0, 20, color='white', label=' ', linewidth=0)

cstMap.ax_col_dendrogram.legend(title='Cluster',bbox_to_anchor=(1.3, 1), loc='upper center', ncol=3)

labelsClassification = list()
for labelIdx in range(5):
    currentLabel = cstMap.ax_col_dendrogram.bar(0, 0, color=cmapClassification(labelIdx), label=ClassificationLevels[labelIdx], linewidth=0)
    labelsClassification.append(currentLabel)

legendClassification = plt.legend(labelsClassification, ClassificationLevels, loc="upper center", title='Classification', ncol=2, bbox_to_anchor=(1.01, 0.92), bbox_transform=gcf().transFigure)
plt.gca().add_artist(legendClassification)

# Increase fontsize
cstMap.cax.set_position([0.05, .3, 0.05, .2])
cstMap.cax.set_ylabel('Log(Counts + 1)')

cstMap.savefig('../Analysis Figures/CSTAnalysis/Semen_16S_CSTClustering_Genera.png', dpi=300)
cstMap.savefig('../Analysis Figures/CSTAnalysis/Semen_16S_CSTClustering_Genera.svg', dpi=300)

plt.show()

In [None]:
# Extract the palette for R plots 
from matplotlib.colors import rgb2hex
[rgb2hex(x) for x in pds.Series([0,1,2]).map(cmapCST)]

To inspect the characteristic microbial composition of each cluster, we plot the mean relative abundances of the top 6 species per CST cluster.

In [None]:
uniqueValues = np.unique(np.array([x[1].index[0:6] for x in CST_Clustering['ClusterAbundances']]))

cmapTaxa = ListedColormap(sns.color_palette("tab20", len(uniqueValues)))
# cmapTaxa = sns.color_palette("Set1", len(uniqueValues))

fig, ax = plt.subplots(1, 3, dpi=300, figsize=(12,6), sharey=True)

for idx, cst in enumerate(CST_Clustering['ClusterAbundances']):
    
    currentColorId = cst[1][0:6].index
    currentColorId = np.searchsorted(uniqueValues, currentColorId)
    
    ax[idx].bar(x=np.arange(0, 6), height=cst[1][0:6]*100, yerr=cst[2][0:6]*100,
                label='Cluster ' + str(idx + 1) , color=cmapTaxa(currentColorId), **{'error_kw':{'lolims':True}})
    
    ax[idx].xaxis.set_ticks(np.arange(0, 6), labels=cst[1][0:6].index.values)
    ax[idx].xaxis.set_tick_params(rotation=90, labelsize=15)
    ax[idx].tick_params('y', labelsize=15)
    ax[idx].yaxis.set_ticks(np.linspace(0, 100, 5))
    ax[idx].set_title("Cluster Number: {0}".format(str(cst[0] + 1)), fontsize=20)
    
ax[0].set_ylabel('Relative abundance (%)', fontsize=15)

fig.tight_layout()
fig.savefig('../Analysis Figures/CSTAnalysis/Semen_16S_CSTClustering_GeneraDescription.png')
fig.savefig('../Analysis Figures/CSTAnalysis/Semen_16S_CSTClustering_GeneraDescription.svg')

### Assess the silhouette values for samples accross the CST clusters

In [None]:
silhFrame = pds.DataFrame({'Cluster':CST_Clustering['clusterID'], 'Silhouette':CST_Clustering['SilhouetteSamples']})
silhFrame['Cluster'] = silhFrame['Cluster'] + 1 
fig, ax = plt.subplots(dpi=300)
sns.stripplot(data=silhFrame, x='Cluster', y='Silhouette', ax=ax, palette='Set1')
ax.set_ylabel('Silhouette score', fontsize = 15)
ax.set_xlabel('', fontsize = 15)
ax.tick_params('y', labelsize=15)
ax.tick_params('x', labelsize=15)
ax.set_xticklabels(['Cluster 1\n(n=113)', 'Cluster 2\n(n=83)', 'Cluster 3\n(n=27)'])

fig.savefig('../Analysis Figures/CSTAnalysis/Semen_16S_CSTClustering_SilhouetteSamples_Genera.png')
fig.savefig('../Analysis Figures/CSTAnalysis/Semen_16S_CSTClustering_SilhouetteSamples_Genera.svg')

## Export the CST Assignments


After obtaining the Community State Type information with hierarchical clustering, we export a dataframe containing the results 

In [None]:
CSTAssignments = pds.DataFrame({'CST':CST_Clustering['clusterID'] + 1,
                                'SilhouetteValues':CST_Clustering['SilhouetteSamples']})

CSTAssignments.index = counts_matrix.index

In [None]:
os.makedirs('../Results/CST_Analysis', exist_ok=True)
CSTAssignments.to_csv('../Results/CST_Analysis/Semen_CST_GeneraAssignments.csv')