In [1]:
import pandas as pd
from matplotlib import pyplot as plt
import numpy as np
import seaborn as sns
from tqdm import tqdm
from collections import OrderedDict

In [2]:
sns.set(rc={'figure.figsize':(11.7,8.27)})

In [3]:
meta = pd.read_csv('meta_all_disease_jsd.tsv.gz', sep='\t', index_col=0)
mat = pd.read_csv('jsd_matrix_all_disease.tsv.gz', sep='\t', index_col=0).astype(np.float32)

In [4]:
meta.loc[meta['phenotype'].isin({'Health', 'Healthy', 'healthy'}), 'phenotype'] = 'Health'
sorted_idx = meta.sort_values(by='phenotype').index.to_series()
SampleID_Env = meta['phenotype'].to_dict()

In [5]:
mat.columns = mat.columns.to_series().apply(lambda x: SampleID_Env[x]+'_'+x)
mat.index = mat.index.to_series().apply(lambda x: SampleID_Env[x]+'_'+x)

In [6]:
mat = mat.loc[sorted_idx.apply(lambda x: SampleID_Env[x]+'_'+x), sorted_idx.apply(lambda x: SampleID_Env[x]+'_'+x)]
meta = meta.loc[sorted_idx, :]

In [None]:
fig = plt.figure(figsize=(16, 8), dpi=120)
labels = mat.columns.to_series().apply(lambda x: x.split('_')[0])
labels_unique = labels.unique()
mat_pal = sns.color_palette('Set2', labels_unique.size)
cmapper = OrderedDict(zip(labels_unique, mat_pal))
colors = labels.map(cmapper)
g = sns.clustermap(mat,#.iloc[range(100), range(100)], 
                   row_cluster=False, col_cluster=False, 
               row_colors=colors, col_colors=colors, linewidths=0, 
               xticklabels=False, yticklabels=False, cmap='RdYlGn')
for label in labels_unique:
    g.ax_col_dendrogram.bar(0, 0, color=cmapper[label], label=label, linewidth=0)
g.ax_col_dendrogram.legend(loc="upper left", ncol=2, fontsize=8, bbox_to_anchor=(1.04, 0))
g.cax.set_position([.08, .2, .02, .45])
'''labels_txt = list(cmapper.keys())
counts = labels.value_counts(sort=False)
counts = pd.Series([counts[i] for i in labels_txt])
pos = counts.cumsum() - counts * 0.5
for i in range(len(labels_txt)):
    plt.text(2.2, -pos[i]/150, labels_txt[i], va="top", ha="right", fontsize=10, 
             color=cmapper[labels_txt[i]], fontweight='bold')'''
plt.savefig('disease_clustermap.svg', dpi=200, bbox_inches='tight')
plt.savefig('disease_clustermap.png', dpi=200, bbox_inches='tight')
plt.savefig('disease_clustermap.pdf', dpi=200, bbox_inches='tight')