In [None]:
import os
import sys
import random

In [None]:
import numpy as np     
import pandas as pd

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
import cv2

In [None]:
sys.path.append('../utils/')

In [None]:
from dlc_helper import DLC_tracking

In [None]:
from features import *
from features_speed import *
from preprocess_dlc import *

In [None]:
from video_utils import find_square_bounding

In [None]:
from joblib import Parallel, delayed

In [None]:
import ipywidgets.widgets as widgets
from ipywidgets import interact, interact_manual

 # Import the results

In [None]:
df_results_control = pd.read_hdf('../../results/UMAP_HDBSCANclustering_withWV_31072023_1135.h5')

In [None]:
df_results_control.columns

# Plot the UMAP & clustering results

In [None]:
clusters_control = list(df_results_control['hdbscan_wv_scaled'])

In [None]:
embedding = df_results_control.filter(like = 'umap_raw').values
embedding.shape

In [None]:
dict_clusters = {f'cluster_{i}':np.sum(clusters_control==i) for i in list(np.unique(clusters_control))}
dict_clusters

In [None]:
c_pal = sns.color_palette('tab10', 10)
c_dict = {i: c_pal[i+1] for i in np.unique(clusters_control)}
labels_c = [c_dict[lab] for lab in clusters_control]

In [None]:
fig, axes = plt.subplots(1,2, figsize=(15,7))
axes= axes.ravel()
axes[0].scatter(embedding[:, 0],embedding[:, 1], s=0.2)
axes[1].scatter(
    embedding[:, 0],
    embedding[:, 1], c=labels_c, s=1)

markers = [plt.Line2D([0,0],[0,0],color=color, marker='o', linestyle='') for color in c_dict.values()]
plt.legend(markers, c_dict.keys(), numpoints=1)

for ax in axes:
    ax.set_aspect('equal', 'datalim')
    
# fig.savefig('../../results/umap_clustered.png')

In [None]:
df_results_control.groupby('hdbscan_wv_scaled').nunique()

In [None]:
df_results_control.groupby('hdbscan_wv_scaled').count()

# Check feature statistics in each groups

In [None]:
df_results_control.columns

In [None]:
grouped_feats = df_results_control.groupby('hdbscan_wv_scaled')

## Speed_MOUTH

In [None]:
fig, axes = plt.subplots(1,8, figsize = (32,8), sharex=True, sharey=True)
for i, (cluster, group) in enumerate(grouped_feats):
    sns.histplot(data=group, x='speed_MOUTH', ax=axes[i])
    axes[i].set_title(f'Cluster {cluster}')

In [None]:
fig, axes = plt.subplots(1,8, figsize = (32,8), sharex=True, sharey=True)
for i, (cluster, group) in enumerate(grouped_feats):
    sns.boxplot(data=group, y='speed_MOUTH', ax=axes[i])
    axes[i].set_title(f'Cluster {cluster}')

## Speed_V(entral)

In [None]:
fig, axes = plt.subplots(1,8, figsize = (32,8), sharex=True, sharey=True)
for i, (cluster, group) in enumerate(grouped_feats):
    group['mean_speeds_ventral'] = group.filter(like='speed_V').mean(axis=1)
    sns.histplot(data=group, x='mean_speeds_ventral', ax=axes[i])
    axes[i].set_title(f'Cluster {cluster}')

In [None]:
fig, axes = plt.subplots(1,8, figsize = (32,8), sharex=True, sharey=True)
for i, (cluster, group) in enumerate(grouped_feats):
    group['mean_speeds_ventral'] = group.filter(like='speed_V').mean(axis=1)
    sns.boxplot(data=group, y='mean_speeds_ventral', ax=axes[i])
    axes[i].set_title(f'Cluster {cluster}')

## Speed_D(orsal)

In [None]:
fig, axes = plt.subplots(1,8, figsize = (32,8), sharex=True, sharey=True)
for i, (cluster, group) in enumerate(grouped_feats):
    group =  group.fillna(value=-1)
    group['mean_speeds_dorsal'] = group.filter(like='speed_D').mean(axis=1)
    sns.histplot(data=group, x='mean_speeds_dorsal', ax=axes[i])
    axes[i].set_title(f'Cluster {cluster}')

In [None]:
fig, axes = plt.subplots(1,8, figsize = (32,8), sharex=True, sharey=True)
for i, (cluster, group) in enumerate(grouped_feats):
    group =  group.fillna(value=-1)
    group['mean_speeds_dorsal'] = group.filter(like='speed_D').mean(axis=1)
    sns.boxplot(data=group, y='mean_speeds_dorsal', ax=axes[i])
    axes[i].set_title(f'Cluster {cluster}')

## Speed_NT

In [None]:
fig, axes = plt.subplots(1,8, figsize = (32,8), sharex=True, sharey=True)
for i, (cluster, group) in enumerate(grouped_feats):
    sns.histplot(data=group, x='speed_NT', ax=axes[i])
    axes[i].set_title(f'Cluster {cluster}')

In [None]:
fig, axes = plt.subplots(1,8, figsize = (32,8), sharex=True, sharey=True)
for i, (cluster, group) in enumerate(grouped_feats):
    sns.boxplot(data=group, y='speed_NT', ax=axes[i])
    axes[i].set_title(f'Cluster {cluster}')

## Curvatures

In [None]:
fig, axes = plt.subplots(1,8, figsize = (32,8), sharex=True, sharey=True)
for i, (cluster, group) in enumerate(grouped_feats):
    group =  group.fillna(value=-1)
    group['mean_curv'] = group.filter(like='curv').abs().mean(axis=1)
    sns.histplot(data=group, x='mean_curv', ax=axes[i])
    axes[i].set_title(f'Cluster {cluster}')

In [None]:
fig, axes = plt.subplots(1,8, figsize = (32,8), sharex=True, sharey=True)
for i, (cluster, group) in enumerate(grouped_feats):
    group =  group.fillna(value=-1)
    group['mean_curv'] = group.filter(like='curv').abs().mean(axis=1)
    sns.boxplot(data=group, y='mean_curv', ax=axes[i])
    axes[i].set_title(f'Cluster {cluster}')

## Quirkiness

In [None]:
fig, axes = plt.subplots(1,8, figsize = (32,8), sharex=True, sharey=True)
for i, (cluster, group) in enumerate(grouped_feats):
    sns.histplot(data=group, x='quirkiness', ax=axes[i])
    axes[i].set_title(f'Cluster {cluster}')

In [None]:
fig, axes = plt.subplots(1,8, figsize = (32,8), sharex=True, sharey=True)
for i, (cluster, group) in enumerate(grouped_feats):
    sns.boxplot(data=group, y='quirkiness', ax=axes[i])
    axes[i].set_title(f'Cluster {cluster}')

# Trajectories

In [None]:
filenames = list(df_results_control.filename.unique())

In [None]:
wid_fn = widgets.SelectMultiple(
    options=filenames,
    value=filenames[:2],
    rows=15,
    description='Filename',
    disabled=False
)

In [None]:
@interact_manual
def plot_trajectory(fns=wid_fn):
    
    for fn in fns:
        df_filename = df_results_control[df_results_control['filename']== fn]
        path_to_video = df_filename['path_to_video'].unique()[0]
        print(path_to_video)
    
    n_cols = len(fns)
    fig, axes = plt.subplots(1,n_cols,figsize=(n_cols*8,8), sharex=True, sharey=True)
    
    
    for i, fn in enumerate(fns):
        
        df_result_fn = df_results_control[df_results_control['filename'] == fn]
        
        # data from DLC 
        
        dlc_path = df_result_fn['dlc_result_file'].unique()[0]
        dlc_folder, dlc_filename = os.path.split(dlc_path)
        dlc_obj = DLC_tracking(dlc_filename, dlc_folder)
        
        # data from clustering
        df_cluster = pd.merge(dlc_obj.df_data, df_result_fn, on='frame')
        hue = [c_dict[clus] for clus in df_cluster['hdbscan_wv_scaled']]
        
        xy = df_cluster[['NT_x', 'NT_y']].values
        axes[i].scatter(xy[:,0], xy[:,1], c=hue, s=2)
#         markers = [plt.Line2D([0,0],[0,0],color=color, marker='o', linestyle='') for color in c_dict.values()]
#         axes[i].legend(markers, c_dict.keys(), numpoints=1)
        axes[i].set_aspect('equal')
        axes[i].set_title(fn)
        
    plt.show()
    

In [None]:
@interact_manual
def plot_trajectory_line(fns=wid_fn):
    
    from matplotlib.collections import LineCollection
    
    for fn in fns:
        df_filename = df_results_control[df_results_control['filename']== fn]
        path_to_video = df_filename['path_to_video'].unique()[0]
        print(path_to_video)
    
    n_cols = len(fns)
    fig, axes = plt.subplots(1,n_cols,figsize=(n_cols*8,8), sharex=True, sharey=True)
    
    
    for i, fn in enumerate(fns):
        
        df_result_fn = df_results_control[df_results_control['filename'] == fn]
        
        # data from DLC 
        
        dlc_path = df_result_fn['dlc_result_file'].unique()[0]
        dlc_folder, dlc_filename = os.path.split(dlc_path)
        dlc_obj = DLC_tracking(dlc_filename, dlc_folder)
        
        # data from clustering
        df_cluster = pd.merge(dlc_obj.df_data, df_result_fn, on='frame')
        hue = [c_dict[clus] for clus in df_cluster['hdbscan_wv_scaled']]
        
        xy = df_cluster[['NT_x', 'NT_y']].values
        xy = xy.reshape(-1, 1, 2)
        segments = np.hstack([xy[:-1], xy[1:]])

        coll = LineCollection(segments, colors=hue)
#         coll.set_array(np.random.random(xy.shape[0]))

        axes[i].add_collection(coll)
        axes[i].autoscale_view()
        axes[i].set_title(fn)
        
    plt.show()

# Temporal properties : transitions

In [None]:
from itertools import groupby

In [None]:
from collections import Counter

## Time in each cluster

In [None]:
df_files_grouped = df_results_control.groupby('filename')

In [None]:
cluster_usage = []
for name, df_file in df_files_grouped:
    clusters_file = df_file['hdbscan_wv_scaled'].values
    test_count = Counter(clusters_file)
    dict_cluster_usage = {}
    dict_cluster_usage['filename'] = name
    for k in sorted(test_count.keys()):
        dict_cluster_usage[f'cluster_{k}_frames'] = test_count[k]
    cluster_usage.append(dict_cluster_usage)

df_cluster_usage  = pd.DataFrame(cluster_usage)      
df_cluster_usage.fillna(0, inplace=True)
df_cluster_usage

In [None]:
df_cluster_usage['acclimitization'] = df_cluster_usage['filename'].apply(lambda x: 1 if ((x.split('_')[3]=='15m0s')|(x.split('_')[3]=='15m3s')) else 0)

In [None]:
df_cluster_usage.groupby('acclimitization').sum()

In [None]:
df = df_cluster_usage.groupby('acclimitization').sum()
res = df.div(df.sum(axis=1), axis=0)
res.mul(100)

In [None]:
# res.mul(100).sum(axis=1)

## Transition counts

### tests

In [None]:
s = "1110002223344555551111"

from itertools import groupby

groups = groupby(s)
result = [(label, sum(1 for _ in group)) for label, group in groups]
result

In [None]:
clusters_file = df_file['hdbscan_wv_scaled'].values
start_list = [x for x in clusters_file[:-1]]
stop_list = [x for x in clusters_file[1:]]

trans_dict = {'start': start_list, 'stop': stop_list}
trans_df = pd.DataFrame(trans_dict)
trans_df

In [None]:
trans_df.groupby(['start', 'stop']).size().reset_index(name='counts')

In [None]:
transition_counts_control = trans_df.groupby(['start', 'stop']).size().reset_index(name='counts')
trans_mat_counts = pd.pivot_table(transition_counts_control, values='counts', index=['start'],
                columns=['stop'])
trans_mat_counts = trans_mat_counts.fillna(0)
trans_mat_probs = trans_mat_counts.div(trans_mat_counts.sum(axis=1))

In [None]:
trans_mat_probs

In [None]:
sns.heatmap(trans_mat_probs)

### transition matrix - all, acclimitization, non-acclimitization

In [None]:
cluster_usage_acc = []
cluster_usage_non_acc = []
for name, df_file in df_files_grouped:
    clusters_file = df_file['hdbscan_wv_scaled'].values
    start_list = [x for x in clusters_file[:-1]]
    stop_list = [x for x in clusters_file[1:]]

    trans_dict = {'start': start_list, 'stop': stop_list}
    trans_df = pd.DataFrame(trans_dict)
    
    exp_duration = name.split('_')[3]
    if  (exp_duration == '15m0s')|(exp_duration == '15m3s'):
        cluster_usage_acc.append(trans_df)
    elif  (exp_duration == '5m0s')|(exp_duration == '5m3s'):
        cluster_usage_non_acc.append(trans_df)

In [None]:
transition_df_acc = pd.concat(cluster_usage_acc)
transition_df_non_acc = pd.concat(cluster_usage_non_acc)
transition_df = pd.concat(cluster_usage_acc + cluster_usage_non_acc)

In [None]:
transition_counts_acc = transition_df_acc.groupby(['start', 'stop']).size().reset_index(name='counts')
transition_counts_non_acc = transition_df_non_acc.groupby(['start', 'stop']).size().reset_index(name='counts')
transition_counts = transition_df.groupby(['start', 'stop']).size().reset_index(name='counts')

In [None]:
trans_mat_counts_acc = pd.pivot_table(transition_counts_acc, values='counts', index=['start'],
                columns=['stop'])
trans_mat_counts_non_acc = pd.pivot_table(transition_counts_non_acc, values='counts', index=['start'],
                columns=['stop'])
trans_mat_counts = pd.pivot_table(transition_counts, values='counts', index=['start'],
                columns=['stop'])

In [None]:
trans_mat_counts_acc = trans_mat_counts_acc.fillna(0)
trans_mat_counts_non_acc = trans_mat_counts_non_acc.fillna(0)
trans_mat_counts = trans_mat_counts.fillna(0)

In [None]:
trans_mat_probs_acc = trans_mat_counts_acc.div(trans_mat_counts_acc.sum(axis=1))
trans_mat_probs_non_acc = trans_mat_counts_non_acc.div(trans_mat_counts_non_acc.sum(axis=1))
trans_mat_probs = trans_mat_counts.div(trans_mat_counts.sum(axis=1))

In [None]:
trans_mat_probs

In [None]:
sns.heatmap(trans_mat_probs_acc)

In [None]:
sns.heatmap(trans_mat_probs_non_acc)

In [None]:
sns.heatmap(trans_mat_probs)

## Lengths of cluster stretches

In [None]:
from operator import itemgetter

In [None]:
df_file.columns

In [None]:
def make_cluster_motifs_df(fn, df_file):   
    
    clusters_file = df_file['hdbscan_wv_scaled'].values
    frames = df_file['frame'].values
    
    df_motif = []
    
    
    for state in np.unique(clusters_file):

        clus = {}
        clus_inds = [ind for ind, val in zip(frames, clusters_file) if val == state]
        clus_inds_nested = [list(map(itemgetter(1), g)) for k, g in groupby(enumerate(clus_inds), lambda x: x[0]-x[1])]
        clus['start'] = [x[0] for x in clus_inds_nested]
        clus['stop'] = [x[-1] for x in clus_inds_nested]
        clus['duration'] = [x[-1]-x[0] for x in clus_inds_nested]
        clus['cluster'] = [state for x in clus_inds_nested]
        clus['filename'] = [fn for x in clus_inds_nested]
        df_clus = pd.DataFrame(clus)

        df_motif.append(df_clus)
    df_motif = pd.concat(df_motif) 
    return df_motif

    

In [None]:
df_motifs_all = Parallel(n_jobs=40, verbose = 5)(delayed(make_cluster_motifs_df)(fn, df_fn) 
                                                for fn, df_fn in df_files_grouped)
df_motifs_combined = pd.concat(df_motifs_all)

In [None]:
motif_groups  = df_motifs_combined.groupby('cluster')

In [None]:
fig, axes = plt.subplots(1,8, figsize=(20, 8), sharey=True)
for i, (clus, motif_g) in enumerate(motif_groups):
    sns.boxplot(data = motif_g, x='cluster', y='duration', ax =axes[clus+1])

In [None]:
motif_groups.agg({'duration':[min, max, np.mean]})

# Path complexity

In [None]:
from sklearn.preprocessing import scale, StandardScaler

In [None]:
def obtain_M(X, Y, window):
    """returns normalized embedding matrix M for columns X and Y with the specified window size.
    This matrix can be passed to the get_H function to compute the complexity value"""
    Mx = np.array(X[:window]) #initialize first row of Mx
    My = np.array(Y[:window]) #initialize first row of My
    for ii in range(1, len(X)-window): #skip first entry since we already have that in M
        Mx = np.vstack([Mx, X[ii:ii+window]]) #add new vector to Mx
        My = np.vstack([My, Y[ii:ii+window]]) #add new vector to My
    
    
    Mx = StandardScaler().fit_transform(Mx)
    My = StandardScaler().fit_transform(My)
    
#     cols = Mx.shape[1] #get number of columns from array object
#     for ii in range(cols): #normalize per column:
#         Mx[:,ii] = Mx[:,ii] - np.nanmean(Mx[:,ii])
#         My[:,ii] = My[:,ii] - np.nanmean(My[:,ii])

    
    M = np.dstack([Mx,My]) #stack the arrays Mx and My   
    return M #return M

In [None]:
def get_H(M):
    """Performs singular value decomposition on M, and uses the diagonal matrix S
    to calculate complexity value H as the entropy in the distribution of components of S
    I advise you to read Herbert-Read (2017) on escape path complexity"""
    U,S,V = np.linalg.svd(M) # do singular value decomposition
    hats_array = [s/np.sum(s) for s in S] #make hats array
    local_H = [-np.sum(s*np.log2(s)) for s in hats_array]
#     H = -np.sum([s*np.log2(s) for s in hats_array]) #calculate H
    H = -np.sum(hats_array * np.log2(hats_array))
    return local_H,H


In [None]:
@interact_manual
def plot_path_complexity(fns=wid_fn):
    
    from matplotlib.collections import LineCollection
    
    for fn in fns:
        df_filename = df_results_control[df_results_control['filename']== fn]
        path_to_video = df_filename['path_to_video'].unique()[0]
        
    
    n_cols = len(fns)
    fig, axes = plt.subplots(1,n_cols,figsize=(n_cols*12,12), sharex=True, sharey=True)
    
    complexity = []
    for i, fn in enumerate(fns):
        
        print(fn)
        
        df_result_fn = df_results_control[df_results_control['filename'] == fn]
        
        # data from DLC 
        
        dlc_path = df_result_fn['dlc_result_file'].unique()[0]
        dlc_folder, dlc_filename = os.path.split(dlc_path)
        dlc_obj = DLC_tracking(dlc_filename, dlc_folder)
        
        # data from clustering
        df_cluster = pd.merge(dlc_obj.df_data, df_result_fn, on='frame')
        hue = [c_dict[clus] for clus in df_cluster['hdbscan_wv_scaled']]
        
        framerate = 30
        window = framerate * 3
        
        df_cluster['NT_x'] = df_cluster['NT_x'] 
        df_cluster['NT_y'] = df_cluster['NT_y'] 
        
        
        df_xy = df_cluster[['NT_x', 'NT_y']] 
        df_xy = df_xy.dropna()
        
        try:
            
            M = obtain_M(df_xy['NT_x'], df_xy['NT_y'], window = window)

            lH,H = get_H(M)

            df_xy['lH'] = np.hstack((lH, np.array([np.nan]*window)))
            
            xy = df_xy.values
            
            axes[i].scatter(xy[:,0], xy[:,1], c=df_xy['lH'], s=2, cmap='jet')
            
            print(df_xy.lH.mean(), H)

            complexity.append(df_xy.lH.median())
            
        except Exception as e:
            print(e, fn)
    
        
    return complexity

In [None]:
def calc_path_complexity(filename):
    
    df_result_fn = df_results_control[df_results_control['filename'] == filename]
        
    # data from DLC 

    dlc_path = df_result_fn['dlc_result_file'].unique()[0]
    dlc_folder, dlc_filename = os.path.split(dlc_path)
    dlc_obj = DLC_tracking(dlc_filename, dlc_folder)
  
    
    # Interpolate missing datapoints (dorsal)
    df_dorsal = dlc_obj.df_data.filter(regex='^(NT_|TT_|D).*(x|y)$')
    df_dorsal_filt = df_dorsal[df_dorsal.isna().sum(axis=1) < 5]
    df_dorsal_x = df_dorsal_filt.filter(like='_x')
    df_dorsal_y = df_dorsal_filt.filter(like='_y')
    df_dorsal_interp_x = interpol_spatial(df_dorsal_x)
    df_dorsal_interp_y = interpol_spatial(df_dorsal_y)
    df_dorsal_x_fin = interpol_temporal(df_dorsal_interp_x)
    df_dorsal_y_fin = interpol_temporal(df_dorsal_interp_y)
    dlc_obj.df_data.loc[df_dorsal_filt.index,'NT_x_interp'] = df_dorsal_x_fin['NT_x']
    dlc_obj.df_data.loc[df_dorsal_filt.index,'NT_y_interp'] = df_dorsal_y_fin['NT_y']
    
    # data from clustering  # need not do this !
    df_cluster = pd.merge(dlc_obj.df_data, df_result_fn, on='frame')
    hue = [c_dict[clus] for clus in df_cluster['hdbscan_wv_scaled']]


    framerate = 30
    window = framerate 

    df_xy = df_cluster[['filename','frame','NT_x_interp', 'NT_y_interp', 'NT_x', 'NT_y']] 
#     df_xy = df_xy.dropna(how='any')
        
    try:

        M = obtain_M(df_xy['NT_x_interp'], df_xy['NT_y_interp'], window = window)

        lH,H = get_H(M)
        
        df_xy['lH'] = np.hstack((np.array([np.nan]*(window//2)), lH, np.array([np.nan]*(window - (window//2)))))
        return df_xy
        
    except Exception as e:
        return None

In [None]:
test_lH = calc_path_complexity(filenames[-8])

In [None]:
test_lH['lH'].median()

In [None]:
filenames[0]

In [None]:
df_lH_all[-8].lH.median()

In [None]:
df_lH_all = Parallel(n_jobs=40, verbose = 5)(delayed(calc_path_complexity)(fn) 
                                                for fn in filenames)
df_lH_combined = pd.concat(df_lH_all)
df_lH_combined

In [None]:
df_results_merged_complexity = pd.merge(df_lH_combined, df_results_control, on=['filename','frame'])

In [None]:
df_results_merged_complexity.groupby('hdbscan_wv_scaled').agg({'lH':np.nanmean})

### stats

In [None]:
import itertools

In [None]:
from scipy.stats import kruskal, mannwhitneyu

In [None]:
dict_clus_lH = {}
for name, group in df_results_merged_complexity.groupby('hdbscan_wv_scaled'):
    dict_clus_lH[name] = group['lH'].dropna().values 

In [None]:
kruskal(arr_clus_lH[1],arr_clus_lH[2],arr_clus_lH[3],arr_clus_lH[4])

In [None]:
df_mwu_stat = pd.DataFrame(index=[f'clus_{i}' for i in dict_clus_lH.keys()], columns=[f'clus_{i}' for i in dict_clus_lH.keys()])
df_mwu_pval = pd.DataFrame(index=[f'clus_{i}' for i in dict_clus_lH.keys()], columns=[f'clus_{i}' for i in dict_clus_lH.keys()])
for clus1, clus2 in itertools.product(dict_clus_lH.keys(),dict_clus_lH.keys()):
    mwu_results = mannwhitneyu(dict_clus_lH[clus1],dict_clus_lH[clus2])
    df_mwu_pval.loc[f'clus_{clus1}',f'clus_{clus2}'] = mwu_results[1]
    df_mwu_stat.loc[f'clus_{clus1}',f'clus_{clus2}'] = mwu_results[0]

In [None]:
df_mwu

In [None]:
df_results_merged_complexity['acclimitization'] = df_results_merged_complexity['filename'].apply(lambda x: 1 if ((x.split('_')[3]=='15m0s')|(x.split('_')[3]=='15m3s')) else 0)

In [None]:
df_results_merged_complexity.groupby(['acclimitization']).agg({'lH':np.mean})

In [None]:
df_results_merged_complexity.groupby(['acclimitization','hdbscan_wv_scaled']).agg({'lH':np.nanmean})

## what's cluster 3?

In [None]:
df_results_merged_complexity.groupby('hdbscan_wv_scaled').get_group(3).groupby(['filename']).nunique('frame')

In [None]:
fig, axes = plt.subplots(3,3, figsize=(9,9))
axes = axes.ravel()
for i, (name, data) in enumerate(df_results_merged_complexity.groupby('hdbscan_wv_scaled').get_group(6).groupby(['filename'])):
    if i >= 9:
        break
    else:
        print(name)
        print(data['frame'])
#         data.dropna(inplace=True)
        
        res, last = [[]], None
        for x in list(data.index):
            if last is None or abs(last - x) < 2:
                res[-1].append(x)
            else:
                res.append([x])
            last = x
#         print([len(r) for r in res])
        for r in res:
            if (len(r) > 2):
                axes[i].plot(data.loc[r[0]:r[-1]]['NT_x'], data.loc[r[0]:r[-1]]['NT_y'])

## Test path complexity calculation

### Straight line

In [None]:
window = 30

x = np.linspace(0, 10, 1000)
y = x * 5

M = obtain_M(x, y, window = window)

lH,H = get_H(M)

lH_padded = np.hstack((lH, np.array([np.nan]*window)))

np.mean(lH_padded)

In [None]:
plt.plot(x, y)

### Random walk path 

In [None]:
#setting up steps for simulating 2D
dims = 2
step_n = 1000
step_set = [-1, 0, 1]
origin = np.zeros((1,dims))
#Simulate steps in 2D
step_shape = (step_n,dims)
steps = np.random.choice(a=step_set, size=step_shape)
path = np.concatenate([origin, steps]).cumsum(0)

path.shape

In [None]:
plt.plot(path[:,0], path[:,1])

In [None]:
M = obtain_M(path[:,0], path[:,1], window = window)

lH,H = get_H(M)

lH_padded = np.hstack((lH, np.array([np.nan]*window)))

np.mean(lH_padded)

In [None]:
import math
pi = math.pi

def PointsInCircum(r,n=1000):
    x = [math.cos(2*pi/n*x)*r for x in range(0,n+1)]
    y = [math.sin(2*pi/n*x)*r  for x in range(0,n+1)]
    return x,y

In [None]:
path_cir = PointsInCircum(5)

In [None]:
plt.plot(path_cir[0], path_cir[1])

In [None]:
M = obtain_M(path_cir[0], path_cir[1], window = window)

lH,H = get_H(M)

lH_padded = np.hstack((lH, np.array([np.nan]*window)))

np.mean(lH_padded)

# Discrete HMM

In [None]:
from hmmlearn.hmm import MultinomialHMM # Change to CategoricalHMM if using version 0.28 or above

In [None]:
df_files_grouped = df_results_control.groupby('filename')

X = []
lengths = []

for name, group in df_files_grouped:
    group['hdbscan_wv_scaled_plus'] = group['hdbscan_wv_scaled'].apply(lambda x: x+1)
    x_cluster = group['hdbscan_wv_scaled_plus'].values
    X.append(x_cluster)
    lengths.append(len(group.index))
X = np.concatenate(X).reshape(-1,1)

In [None]:
X.shape

In [None]:
hmm_discrete = MultinomialHMM(n_components=8)
hmm_discrete.fit(X, lengths)

In [None]:
if hmm_discrete.monitor_.converged:
    print("Model converged")

In [None]:
sns.heatmap(hmm_discrete.emissionprob_)

In [None]:
sns.heatmap(hmm_discrete.transmat_)

In [None]:
X_predicted = hmm_discrete.predict(X, lengths)

In [None]:
df_results_control['hmm_discrete'] = X_predicted

In [None]:
@interact_manual
def plot_hmm_line(fns=wid_fn):
    
    from matplotlib.collections import LineCollection
    
    for fn in fns:
        df_filename = df_results_control[df_results_control['filename']== fn]
        path_to_video = df_filename['path_to_video'].unique()[0]
        print(path_to_video)
    
    n_cols = len(fns)
    fig, axes = plt.subplots(1,n_cols,figsize=(n_cols*8,8), sharex=True, sharey=True)
    
    
    for i, fn in enumerate(fns):
        
        df_result_fn = df_results_control[df_results_control['filename'] == fn]
        
        # data from DLC 
        
        dlc_path = df_result_fn['dlc_result_file'].unique()[0]
        dlc_folder, dlc_filename = os.path.split(dlc_path)
        dlc_obj = DLC_tracking(dlc_filename, dlc_folder)
        
        # data from clustering
        df_cluster = pd.merge(dlc_obj.df_data, df_result_fn, on='frame')
        hue = [c_dict[clus-1] for clus in df_cluster['hmm_discrete']]
        
        xy = df_cluster[['NT_x', 'NT_y']].values
        xy = xy.reshape(-1, 1, 2)
        segments = np.hstack([xy[:-1], xy[1:]])

        coll = LineCollection(segments, colors=hue)
#         coll.set_array(np.random.random(xy.shape[0]))

        axes[i].add_collection(coll)
        axes[i].autoscale_view()
        axes[i].set_title(fn)
        
    plt.show()