In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.cluster import hierarchy
from scipy.spatial.distance import pdist
import os
import shutil
import csv
import sys

### Functions

In [2]:
def plot_curves(sample_ls, df1):
    colnames = np.array(df1.columns)
    plt.figure(figsize= [15,15])
    for item in sample_ls:
        npar = np.array(df1.loc[item])
        plt.plot(colnames, npar, label=item)
    plt.legend()
    plt.xticks(rotation=90)
    plt.show()

### Setting up directory

In [3]:
input_type = 'neutralization' # either competition or neutralization
selection = False
if input_type == 'competition':
    infile = '../output/competition_fitted_sigmoidal_readouts.csv'
    outdir = '../output/competition_clustering'
elif input_type == 'neutralization':
    infile = '../output/neutralization_fitted_sigmoidal_readouts.csv'
    outdir = '../output/neutralization_clustering'    
else:
    sys.exit()
if os.path.exists(outdir):
    shutil.rmtree(outdir)
os.mkdir(outdir)

### Loading and Processing input 

In [4]:
everyxcol = 20
df = pd.read_csv(infile)
df = df.replace('LONG_','', regex=True)
df = df.set_index('sample_name')
df.shape
df_short = df[df.columns[::everyxcol]]
ymax = df_short.max().max() + 10
outpngfile = outdir + '/clustering_every' + str(everyxcol) + '.png'

In [5]:
if selection == True:
    # Making a selection of rows (mutations of interest)
    sele_ls = ['F83L', 'F83A', 'D85A', 'D85R', 'V86A' , 'F83A-D85A-V86A',
               'D102A', 'D102R', 'K104A', 'K104E', 'T105A', 'D102A-K104A-T105A',
               'I140A', 'T141A', 'D142A', 'D142R', 'L143A', 'V144A', 'Q145A', 
               'I140A-T141A-D142A', 'L143A-V144A-Q145A', 'WT']
    df_temp = pd.DataFrame()
    for item in sele_ls:
        df2 = df_short.loc[[item]]
        df_temp = df_temp.append(df2, ignore_index=False)
    df_short.iloc[0,0]
    df_short = df_temp.copy()
display(df_short)

Unnamed: 0_level_0,-3.12,-2.73608,-2.35216,-1.96824,-1.58432,-1.2004,-0.81648,-0.43256,-0.04864,0.33528
sample_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
A78R,97.54865,96.37153,94.31818,90.83837,85.22022,76.82325,65.61642,52.72241,40.19352,29.86654
A79S,97.08836,94.29936,89.35711,81.20176,69.20685,54.24731,39.00034,26.34509,17.51218,12.06627
D102A,99.60371,95.79926,89.75141,80.73877,68.52285,53.93199,38.92133,25.6732,15.472,8.41172
D102R,99.59523,96.93745,92.45833,85.29844,74.76717,61.03039,45.66975,31.18912,19.57621,11.41215
D107R,104.0052,103.4243,102.28683,100.1001,96.0408,88.97285,77.93708,63.33198,47.71838,34.41159
D142A,101.40903,100.45906,98.63479,95.22148,89.13516,79.15975,64.86969,47.902,31.72155,19.21544
D142R,102.15373,100.71113,97.98674,93.03499,84.63131,71.90248,55.59721,38.66719,24.52111,14.69372
D80A,100.823,100.38369,99.73194,98.765,97.33045,95.20218,92.04474,87.3605,80.41128,70.10212
D85A,97.95932,97.33996,96.12371,93.79301,89.52866,82.34881,71.7997,59.03695,46.77666,37.34862
D85R,91.42603,91.42603,91.42603,91.42599,91.42555,91.42037,91.36019,90.68808,85.53651,76.76899


### Clustering

In [6]:
%%capture
sns.set(font_scale=1.4)
g = sns.clustermap(df_short, col_cluster=False, figsize=(10,35), cmap='vlag', \
                   cbar_pos=(0.06, 0.8, 0.05, 0.05), dendrogram_ratio=0.4,
                  method='average')
for a in g.ax_row_dendrogram.collections:
    a.set_linewidth(3)
g.savefig(outpngfile, dpi=300)
# Getting tree
Z = g.dendrogram_row.linkage
hierarchy.dendrogram(Z)

### Cut tree at multiples points, save clustered curves

In [7]:
# Create directories
for maxclusct in range(2,11):
    clusterdir = outdir + '/clusters_' + str(maxclusct)
    if not os.path.exists(clusterdir):
        os.mkdir(clusterdir)

In [8]:
# ymax = df_short.max().max() + 0
for maxclusct in range(2,11):
    print('Cutting tree into' , maxclusct, 'clusters')
    clusterdir = outdir + '/clusters_' + str(maxclusct)
    clusters = hierarchy.fcluster(Z, maxclusct, criterion='maxclust')
    cluster_st = set(clusters)
    label_dc = {}
    for clusterid in cluster_st:
    #     print(clusterid)
        target = []
        for indx, clsid in enumerate(clusters):
            if clsid == clusterid:
                target.append(indx)
        df_x = df_short.iloc[target]
        x_val = np.array(df_x.columns).astype(np.float)
        plt.figure(figsize=[10,5])
        for index, row in df_x.iterrows():
            y_val = np.array(df_short.loc[index]).astype(np.float)
            labelnm = df.loc[index].name
            label_dc[labelnm] = clusterid
            plt.plot(x_val, y_val, label=labelnm)
        plt.legend(fontsize=3.)
        plt.yticks(fontsize=10.)
        plt.ylim(0,ymax)
        plt.xticks(rotation=90, fontsize=10.)
        plt.xticks(np.arange(np.min(x_val), np.max(x_val)+0.6, 0.2))
        plt.savefig(clusterdir + '/clusterid_'+ str(clusterid) + '.png',dpi=300)
        plt.close()

    # Save dictionary
    f = open(clusterdir + '/cluster_assignment.csv', 'w')
    writer = csv.writer(f)
    for key, val in label_dc.items():
        writer.writerow([key, val])
    f.close()

Cutting tree into 2 clusters
Cutting tree into 3 clusters
Cutting tree into 4 clusters
Cutting tree into 5 clusters
Cutting tree into 6 clusters
Cutting tree into 7 clusters
Cutting tree into 8 clusters
Cutting tree into 9 clusters
Cutting tree into 10 clusters
