In [2]:
import pprint
import json
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import to_hex
from scipy.cluster.hierarchy import dendrogram, linkage
from scipy.spatial.distance import pdist, squareform
# from sklearn import datasets

from scipy.cluster import hierarchy
from matplotlib import cm
import umap

## Plot dendrogram

In [2]:
def plot_dendrogram(dir_name,threshold):
    
    with open(dir_name + 'all_result.pckl','rb') as f:
        all_result = pickle.load(f)
    Z = all_result[0]['Z']  
    
    # Clusters
    clusters = hierarchy.fcluster(Z,threshold,'distance')
    colors = cm.rainbow(np.linspace(0,1,max(clusters)))
    colors = np.flip(colors,axis = 0)
    
    # Dendrogram
    R = hierarchy.dendrogram(Z, no_plot = True)
    icoord = np.array(R['icoord'])
    # print(icoord.shape)
    dcoord = np.array(R['dcoord'])
    leaves = np.array(R['leaves'])
    # print(leaves.shape)

    # Get clusters for leaves
    clusters_leaves = np.array([clusters[x] for x in leaves])
    lower_bounds = []
    for cluster in range(np.min(clusters),np.max(clusters)+1):
        temp = np.argwhere(clusters_leaves == cluster)
        lower_bounds.append(temp[0][0])

    # Get colors for leaves
    leaves_colors = np.zeros((len(leaves),4))
    for cluster,lower_bound in enumerate(lower_bounds):
        leaves_colors[lower_bound:,:] = colors[cluster]
        
    for i in range(len(icoord)):

        color_left = leaves_colors[int((icoord[i,0]-5)/10)]
        color_right = leaves_colors[int((icoord[i,2]-5)/10)]
        
        if np.all(color_left == color_right):
            plt.plot(icoord[i,0:2],dcoord[i,0:2],c = color_left)
            plt.plot(icoord[i,1:3],dcoord[i,1:3],c = color_left)
            plt.plot(icoord[i,2:4],dcoord[i,2:4],c = color_left)
        else:
            plt.plot(icoord[i,0:2],dcoord[i,0:2],c = 'gray')
            plt.plot(icoord[i,1:3],dcoord[i,1:3],c = 'gray')
            plt.plot(icoord[i,2:4],dcoord[i,2:4],c = 'gray')
    plt.xlabel(dir_name + " ")
    plt.title('Dendrogram')
#     plt.show()
    plt.savefig(fname = dir_name + 'dendrogram.svg', dpi = 300)
    plt.close()
    
    return leaves_colors

## Plot UMAP

In [3]:
def plot_umap(dir_name,threshold):
    
    with open(dir_name +'feat_norm.pckl','rb') as f:
         feat_norm = pickle.load(f)

    with open(dir_name + 'all_result.pckl','rb') as f:
        all_result = pickle.load(f)

    Z = all_result[0]['Z']
    
    # Clusters
    clusters = hierarchy.fcluster(Z,threshold,'distance')
    colors = cm.rainbow(np.linspace(0,1,max(clusters)))
    colors = np.flip(colors,axis = 0)
    
    clip_colors = np.zeros((len(clusters),4))
    for clip in range(len(clusters)):
        clip_colors[clip,:] = colors[clusters[clip]-1,:]
        
    um = umap.UMAP(n_neighbors=5,
               min_dist=0.3,
               metric='correlation')
    Y = um.fit_transform(feat_norm)
    colors = cm.rainbow(np.linspace(0,1,len(Y)))
    plt.scatter(Y[:,0], Y[:,1], c = clip_colors)
    plt.title('UMAP')
    plt.savefig(fname = dir_name + 'UMAP.svg', dpi = 300)
    plt.close()

## Plot timeline

In [4]:
def plot_timeline(dir_name):

    with open(dir_name + 'all_result.pckl','rb') as f:
        all_result = pickle.load(f)
    Z = all_result[0]['Z']
    
    # Clusters
    clusters = hierarchy.fcluster(Z,threshold,'distance')
    colors = cm.rainbow(np.linspace(0,1,max(clusters)))
    colors = np.flip(colors,axis = 0)
    
    clip_colors = np.zeros((len(clusters),4))
    for clip in range(len(clusters)):
        clip_colors[clip,:] = colors[clusters[clip]-1,:]
  

    with open(dir_name + 'all_info_selected.pckl','rb') as f:
        all_info_selected = pickle.load(f)
    info_clips = all_info_selected['info_clips']
    
    tmp = [info_clips[i]['video_path'] for i in range(len(clip_colors))]
    tmp = list(set(tmp))
    ytick_str = [x.split('/')[-3] for x in tmp]
    
    for k in range(len(ytick_str)):
        j = 0
        for i in range(len(clip_colors)):
            if ytick_str[k] in info_clips[i]['video_path']:
                plt.plot(15*j, k, '.', c = clip_colors[i] )
                j += 1
            
    ytick_value = range(len(ytick_str) ) 
    plt.title('Timeline')
    plt.yticks(ytick_value,ytick_str)        
    plt.xlabel('time(frame)')
    plt.tight_layout()
    plt.savefig(fname = dir_name + 'timeline.svg', dpi = 300, format ='svg')
    plt.close()

## Plot mutual information for each feature

In [228]:
from sklearn.feature_selection import mutual_info_classif

def plot_mutual_information(dir_name,single_flag):
    
    with open(dir_name + 'feature_clips_dict.pckl','rb') as f:
        feature_clips_dict = pickle.load(f)

    with open(dir_name + 'all_result.pckl','rb') as f:
        all_result = pickle.load(f)
    Z = all_result[0]['Z']
    
    # Clusters
    clusters = hierarchy.fcluster(Z,threshold,'distance')
    feature_keys = feature_clips_dict.keys()

    summary = []
    for feature_key in feature_keys:
        temp = np.asarray(feature_clips_dict[feature_key])
        mi = mutual_info_classif(np.array(temp), clusters)
        summary = summary + list(mi)
    summary = np.asarray(summary)
    
    mark = []
    for feature_key in feature_keys:
        temp = feature_clips_dict[feature_key][0]
        num = temp.shape
        mark.append(num[0])    
    
    cum_mark = np.cumsum(mark)
    cum_mark = np.concatenate((np.array([0]),cum_mark))
    
    short = []
    for i in range(len(cum_mark)-1):
        temp = summary[cum_mark[i]:cum_mark[i+1]]
        short.append(np.mean(temp))
    short = np.asarray(short)

    if single_flag:
        feature_keys = list(feature_keys)
        feature_keys = feature_keys[:16]
        short = short[:16]
        
    font_size = 12
    if single_flag:
        fig = plt.figure(figsize=(cm2inch(18), cm2inch(6)))
    else:
        fig = plt.figure(figsize=(cm2inch(36), cm2inch(6)))
    plt.plot(short,'*')
    plt.xticks(range(len(feature_keys)),labels,fontsize = font_size,ha = 'right',rotation = 70)

    plt.ylabel('Mutual information')
    if single_flag:
        plt.title('Individual Behavior',fontsize = font_size)
    else:
        plt.title('Social Behavior',fontsize = font_size)
    plt.tight_layout()
#     plt.show()
    plt.savefig(fname = dir_name + 'mutual_info.svg', dpi = 300, format ='svg')
    plt.close()
    

## Plot feature heatmap

In [229]:
def cm2inch(value):
    return value/2.54


In [230]:
def plot_feature_heatmap(dir_name,threshold,single_flag):
    
    with open(dir_name + 'all_result.pckl','rb') as f:
        all_result = pickle.load(f)
    Z = all_result[0]['Z']
       
    clusters = hierarchy.fcluster(Z,threshold,'distance')
    
    with open(dir_name + 'feature_clips_dict.pckl','rb') as f:
        feature_clips_dict = pickle.load(f)
    feature_keys = feature_clips_dict.keys()
    feature_keys = list(feature_keys)
#     print(feature_keys)
#     feature_keys = labels
    
    total = []
    for i in range(len(clusters)):
        entry = [] 
        for feature_key in feature_keys:
            entry = entry + list(feature_clips_dict[feature_key][i])
        total.append(entry)
    total = np.asarray(total)   
    
    mark = []
    for feature_key in feature_keys:
        temp = feature_clips_dict[feature_key][0]
        num = temp.shape
        mark.append(num[0])    
    
    cum_mark = np.cumsum(mark)
    cum_mark = np.concatenate((np.array([0]),cum_mark))

#     for i in range(len(cum_mark)-1):
#         temp = total[:,cum_mark[i]:cum_mark[i+1]]
#         temp = (temp-np.mean(temp))/np.std(temp) 
#         total[:,cum_mark[i]:cum_mark[i+1]] = temp

    total = (total-np.mean(total,axis = 0))/np.std(total,axis = 0)
        
    summary = []
    for i in range(1,max(clusters)+1):
        flag = [clusters == i]
        temp = total[flag]
        temp = np.mean(temp,axis = 0)
        summary.append(temp)
    summary = np.asarray(summary)
    

        
    short = []
    for i in range(len(cum_mark)-1):
        temp = summary[:,cum_mark[i]:cum_mark[i+1]]
        temp = np.nanmean(temp,axis = 1)
        short.append(temp)
    short = np.asarray(short)
    short = np.swapaxes(short,0,1)
    
    feature_keys = labels
    if single_flag:
        short = short[:,:16]
        temp = []
        for ii in range(len(feature_keys)):
            if (ii in range(16)) :
                temp.append(feature_keys[ii])
        feature_keys = temp
    font_size = 12
    import seaborn as sns
    if single_flag:
        fig = plt.figure(figsize=(cm2inch(18), cm2inch(18)))
    else:
        fig = plt.figure(figsize=(cm2inch(36), cm2inch(18)))   
    ax = sns.heatmap(short,cbar = True)
    xtick_pos = [0.5 + i for i in range(len(feature_keys))]
    plt.xticks(xtick_pos,feature_keys,rotation = 70,fontsize = font_size,ha='right')
    ytick_text = ['cluster {}'.format(i+1) for i in range(np.max(clusters))]
    ytick_pos = [0.5 + i for i in range(np.max(clusters))]
    plt.yticks(ytick_pos,ytick_text,rotation = 0,fontsize = font_size)
    cbar = ax.collections[0].colorbar
    # here set the labelsize by 20
    cbar.ax.tick_params(labelsize=12)
#     sns.heatmap(summary,cbar = True)
#     plt.xticks(cum_mark[:-1],feature_keys)
    if single_flag:
        plt.title('Individual Behavior',fontsize = font_size)
    else:
        plt.title('Social Behavior',fontsize = font_size)
        
    plt.tight_layout()
#     plt.show()
    plt.savefig(fname = dir_name + 'feature_heatmap.svg', dpi = 300, format ='svg')
    plt.close()
    return summary

## Plot similarity between features

In [231]:
def seriation(Z,N,cur_index):
    '''
        input:
            - Z is a hierarchical tree (dendrogram)
            - N is the number of points given to the clustering process
            - cur_index is the position in the tree for the recursive traversal
        output:
            - order implied by the hierarchical tree Z
            
        seriation computes the order implied by a hierarchical tree (dendrogram)
    '''
    if cur_index < N:
        return [cur_index]
    else:
        left = int(Z[cur_index-N,0])
        right = int(Z[cur_index-N,1])
        return (seriation(Z,N,left) + seriation(Z,N,right))
    
def compute_serial_matrix(dist_mat,method="ward"):
    
    from fastcluster import linkage
    '''
        input:
            - dist_mat is a distance matrix
            - method = ["ward","single","average","complete"]
        output:
            - seriated_dist is the input dist_mat,
              but with re-ordered rows and columns
              according to the seriation, i.e. the
              order implied by the hierarchical tree
            - res_order is the order implied by
              the hierarhical tree
            - res_linkage is the hierarhical tree (dendrogram)
        
        compute_serial_matrix transforms a distance matrix into 
        a sorted distance matrix according to the order implied 
        by the hierarchical tree (dendrogram)
    '''
    
    N = len(dist_mat)
    flat_dist_mat = squareform(dist_mat)
    res_linkage = linkage(flat_dist_mat, method=method,preserve_input=True)
    res_order = seriation(res_linkage, N, N + N-2)
    seriated_dist = np.zeros((N,N))
    a,b = np.triu_indices(N,k=1)
    seriated_dist[a,b] = dist_mat[ [res_order[i] for i in a], [res_order[j] for j in b]]
    seriated_dist[b,a] = seriated_dist[a,b]
    
    return seriated_dist, res_order, res_linkage

In [232]:
def plot_similarity(dir_name):
    
    with open(dir_name + 'all_result.pckl','rb') as f:
        all_result = pickle.load(f)
    Z = all_result[0]['Z']
    
    clusters = hierarchy.fcluster(Z,threshold,'distance')
    
    with open(dir_name + 'feature_clips_dict.pckl','rb') as f:
        feature_clips_dict = pickle.load(f)
    feature_keys = feature_clips_dict.keys()

    total = []
    for i in range(len(clusters)):
        entry = [] 
        for feature_key in feature_keys:
            temp = np.mean(feature_clips_dict[feature_key][i])
            entry.append(temp)
        total.append(entry)
    total = np.asarray(total) 
    
    summary = []
    for i in range(1,max(clusters)+1):
        flag = [clusters == i]
        temp = total[flag]
        temp = np.mean(temp,axis = 0)
        summary.append(temp)
    summary = np.asarray(summary)

    dist_mat = squareform(pdist(summary))
    
    N = len(summary)
#     methods = ["ward","single","average","complete"]
    methods = ["ward"]
    
    for method in methods:
        ordered_dist_mat, res_order, res_linkage = compute_serial_matrix(dist_mat,method)
        plt.pcolormesh(ordered_dist_mat)
        xtick_pos = [0.5 + i for i in range(N)]
        plt.xticks(xtick_pos,range(N))
        plt.yticks(xtick_pos,range(N))
        plt.xlim([0,N])
        plt.ylim([0,N])
        plt.xlabel('clusters')
        plt.ylabel('clusters')
        plt.colorbar()
        plt.title('Distance between clusters')
        plt.tight_layout()
#         plt.show()
    plt.savefig(fname = dir_name + 'similarity.svg', dpi = 300, format ='svg')   
    plt.close()

# Generate plots

### Single mouse behavior

In [233]:
labels = [
    'displace-x', 
    'displace-y', 
    'displace-rho', 
    'displace-phi-cos',
    'displace-phi-sin', 
    'body-length', 
    'head-length', 
    'head-body-angle',
    'left ear-snout', 
    'right ear-snout',
    'left ear-snout-phi',
    'right ear-snout-phi', 
    'snout-fft-amplitude',
    'snout-fft-angle',
    'contourPCA-fft-amplitude',
    'contourPCA-fft-angle', 
    'displace-x-other', 
    'displace-y-other',
    'displace-rho-other',
    'displace-phi-cos-other', 
    'displace-phi-sin-other', 
    'body-length-other', 
    'head-length-other', 
    'head-body-angles-other', 
    'left ear-snout-other', 
    'right ear-snout-other', 
    'left ear-snout-phi-other', 
    'right ear-snout-phi-other', 
    'snout-fft-amplitude-other', 
    'snout-fft-angle-other', 
    'contourPCA-fft-amplitude-other',
    'contourPCA-fft-angle-other',
    'two body-angle', 
    'two head-angle',
    'body change-angle',
    'body change-angle_other', 
    'mouse2 snout-mouse1 tail-phi',
    'mouse2 snout-mouse1 tail-rho', 
    'mouse1 snout-mouse2 tail-phi',
    'mouse1 snout-mouse2 tail-rho',
    'two snout-rho',
    'two snout-phi']

In [234]:
dir_name = 'results_single/'
threshold = 200
leaves_colors = plot_dendrogram(dir_name,threshold)        
plot_umap(dir_name,threshold)
plot_timeline(dir_name)
plot_mutual_information(dir_name,single_flag = True)
plot_feature_heatmap(dir_name,threshold,single_flag = True)
plot_similarity(dir_name)

  plt.tight_layout()
  total = (total-np.mean(total,axis = 0))/np.std(total,axis = 0)
  temp = total[flag]


array([[ 1.09440098,  1.2464975 ,  1.24554667, ...,  0.19503605,
         0.19394046,  0.20478062],
       [-0.06697895, -0.13759203, -0.12290143, ...,  0.31506747,
         0.31831846,  0.27960334],
       [-0.00501798, -0.04179956, -0.02507584, ...,  0.25841355,
         0.27444205,  0.25267039],
       ...,
       [-0.02052821, -0.00626706, -0.00231935, ...,  0.07095441,
         0.0767933 ,  0.05782671],
       [-0.41010557, -0.45077453, -0.43748227, ..., -0.47572813,
        -0.49077112, -0.47563487],
       [ 0.02916275,  0.03996225, -0.01601425, ..., -0.00450209,
        -0.0197827 , -0.00483104]])

### Social behavior

In [235]:
dir_name = 'results_social/'
threshold = 500
leaves_colors = plot_dendrogram(dir_name,threshold)     

plot_umap(dir_name,threshold)

plot_timeline(dir_name)

plot_mutual_information(dir_name,single_flag = False)

plot_feature_heatmap(dir_name,threshold,single_flag = False)

plot_similarity(dir_name)
 

  plt.tight_layout()
  total = (total-np.mean(total,axis = 0))/np.std(total,axis = 0)
  temp = total[flag]


array([[ 0.43583628,  0.49213409,  0.40176278, ...,  0.88609765,
         0.85340579,  0.79926413],
       [-0.02099014,  0.00611057, -0.01863085, ..., -0.41538252,
        -0.42495963, -0.43332807],
       [-0.09922292, -0.05876616, -0.03946179, ..., -1.14506216,
        -1.08630719, -1.09355268],
       ...,
       [ 0.27453086,  0.29192261,  0.25315706, ...,  0.55642555,
         0.57798275,  0.51249544],
       [-0.18199526, -0.14293654, -0.12031435, ..., -0.38730982,
        -0.37571432, -0.35455908],
       [-0.46679626, -0.44512369, -0.45576902, ..., -0.42153315,
        -0.4433791 , -0.41066023]])