In [None]:
import os
import sys

In [None]:
import numpy as np
import pandas as pd

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
import umap
import hdbscan

In [None]:
import tqdm

In [None]:
from sklearn.preprocessing import MinMaxScaler, Normalizer

In [None]:
import joblib
from joblib import Parallel, delayed

In [None]:
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

In [None]:
from scipy import signal

In [None]:
import tqdm

In [None]:
from itertools import groupby
from collections import Counter
from operator import itemgetter

In [None]:
from sklearn.preprocessing import scale, StandardScaler

In [None]:
sys.path.append('../utils/')

In [None]:
from path_complexity import obtain_M, get_H
from dlc_helper import DLC_tracking
from preprocess_dlc import *

# Load the umap models fit on control dataset

In [None]:
fn_umap_speeds = f'../../results/umap_model_29072023_v1.joblib'
fn_umap_wavelets = f'../../results/umap_wavelets_model_31072023.joblib'

In [None]:
model_umap_speeds = joblib.load(fn_umap_speeds)
model_umap_wavelets = joblib.load(fn_umap_wavelets)

# Load the drugs data

## Load metadata

In [None]:
df_meta = pd.read_pickle('../../data/amphioxus_metadata_final500.pickle')

In [None]:
# create a filename column to match with the filename column in the dataset
df_meta['filename'] = df_meta['filename_video'].apply(lambda x: x.split('.avi')[0])

In [None]:
df_meta.columns

In [None]:
df_meta.light.unique()

## Load feature data

In [None]:
df = pd.read_hdf('../../results/featureset_v4_29072023.h5', key='features')

In [None]:
# threshold speed mouth 
df = df[(df['speed_MOUTH'].isna())|(df['speed_MOUTH'] < 20)]

In [None]:
df_merged = df.merge(df_meta, how='left', on='filename')

In [None]:
df_drugs = df_merged[(df_merged['age'] > 50)&(df_merged['drugs']!='none')&(df_merged['light']=='None')]

In [None]:
df_light = df_merged[(df_merged['drugs']=='none')&((df_merged['light']=='Light'))]

In [None]:
df_young = df_merged[(df_merged['age'] < 50)&(df_merged['drugs']=='none')&(df_merged['light']=='None')]

In [None]:
len(df_merged.filename.unique())

In [None]:
len(df_drugs.filename.unique())

In [None]:
len(df_light.filename.unique())

In [None]:
len(df_young.filename.unique())

In [None]:
cols_speed = list(df_merged.filter(like='speed').columns)
feats_to_use = cols_speed

### Drugs

In [None]:
df_drugs.drugs.unique()

In [None]:
df_drugs.groupby('drugs')['filename'].nunique()

#### using speeds

In [None]:
df_drugs_in_speeds = df_drugs[feats_to_use]

In [None]:
df_drugs_in_speeds = df_drugs_in_speeds.fillna(-1)

#### using curvatures

In [None]:
df_drugs_curv = df_drugs.filter(like='curv')
df_drugs_curv

In [None]:
pca = PCA()
pca_drugs_curv = pca.fit_transform(df_drugs_curv)

In [None]:
plt.plot(np.cumsum(pca.explained_variance_ratio_))

In [None]:
# For wavelet transforms

fs = 30
omg0 = 6
fs_ny = fs/2

# 30 scales linearly spaced, max being sampling frequency  + 1

f_channels = np.arange(1, fs +1, 1)
# f_channels = np.arange(1, fs +1, 1)
widths = omg0*fs / (2*f_channels*np.pi) 

print(f_channels, len(widths))

In [None]:
def wv_transform(df, widths, omg0):
    
    df_wv = pd.DataFrame()
    df_wv['filename'] = df['filename']
    df_wv['frames'] = df['frame']
    df_wv['drugs'] = df['drugs']
    
    df.drop(['filename', 'frame', 'drugs'], axis=1, inplace = True)


    for feat in list(df.columns):


        ts = df[feat].values
        cwtm = signal.cwt(ts, signal.morlet2, widths, w = omg0)

        wv_feat = np.abs(cwtm)
        
        for f in range(len(widths)):
            df_wv[f'{feat}_wv{f}'] = wv_feat[f-1,:]
    
    return df_wv

In [None]:
for i in range(5):
    df_drugs[f'pca_{i}'] = pca_drugs_curv[:,i]

In [None]:
df_drugs_files = df_drugs.groupby(by = 'filename')

In [None]:
df_drug_exps = []
for name, group in df_drugs_files:
    df_drug_exps.append(group)

In [None]:
feats_list = ['pca_0', 'pca_1', 'pca_2', 'pca_3', 'pca_4', 'pca_5']
df_feats_selected = []
for df in df_drug_exps:
    df_feats_selected.append(df.filter(items = feats_list + ['filename', 'frame', 'drugs']))

In [None]:
df_feats_selected[0]

In [None]:
df_drugs_wv = Parallel(n_jobs=40, verbose = 5)(delayed(wv_transform)(df, widths, omg0) 
                                                for df in df_feats_selected)

In [None]:
df_drugs_wv_comb = pd.concat(df_drugs_wv)

In [None]:
df_drugs_wv_comb

# Find the UMAP projections

# control

In [None]:
# embedding =  df_results.filter(like = 'umap_raw').values
# embedding.shape

## Drugs

### UMAP x 2

In [None]:
embedding_drugs_speeds = {}
embedding_drugs_wavelets = {}

In [None]:
for i, (name, group) in enumerate(df_drugs.groupby('drugs')):
    
    group_in = group[feats_to_use]
    group_in = group_in.fillna(-1)
    embedding_drugs_speeds[name] = model_umap_speeds.transform(group_in.values)

In [None]:
for i, (name, group) in enumerate(df_drugs_wv_comb.groupby('drugs')):
    
    group_in = group.filter(like='pca')
    embedding_drugs_wavelets[name] = model_umap_wavelets.transform(group_in.values)

In [None]:
embedding_drugs_speeds.keys()

In [None]:
embedding_drugs_wavelets.keys()

In [None]:
fig, axes = plt.subplots(2,3, figsize=(23,14))

for i, key in enumerate(embedding_drugs_speeds.keys()):
    
    axes[0][i].scatter(embedding_drugs_speeds[key][:, 0],embedding_drugs_speeds[key][:, 1], s=0.2)
    axes[0][i].set_title(key)
    axes[0][i].set_aspect('equal')
    
    axes[1][i].hist2d(embedding_drugs_speeds[key][:, 0],embedding_drugs_speeds[key][:, 1], bins=(150,150), density=True)
    axes[1][i].set_title(key)
    axes[1][i].set_aspect('equal')

In [None]:
fig, axes = plt.subplots(2,3, figsize=(23,14))


for i, key in enumerate(embedding_drugs_wavelets.keys()):
    
    axes[0][i].scatter(embedding_drugs_wavelets[key][:, 0],embedding_drugs_wavelets[key][:,1], s=0.2)
    axes[0][i].set_title(key)
    axes[0][i].set_aspect('equal')
    
    axes[1][i].hist2d(embedding_drugs_wavelets[key][:, 0],embedding_drugs_wavelets[key][:,1], bins=(150,150), density=True)
    axes[1][i].set_title(key)
    axes[1][i].set_aspect('equal')

### Cluster 4D 

In [None]:
fn_hdbscan = f'../../results/hdbscan_4Dumap_model_31072023.joblib'

In [None]:
model_hdbscan = joblib.load(fn_hdbscan)

In [None]:
labels_drugs_dict = {}

In [None]:
for drug in tqdm.tqdm(embedding_drugs_wavelets.keys()):
    print(drug)
    embedding_drugs_combo = np.hstack([embedding_drugs_speeds[drug], embedding_drugs_wavelets[drug]])
    embedding_drugs_combo_scaled = MinMaxScaler().fit_transform(embedding_drugs_combo)
    test_labels, strengths = hdbscan.approximate_predict(model_hdbscan, embedding_drugs_combo_scaled)
    labels_drugs_dict[drug] = test_labels

In [None]:
c_pal = sns.color_palette('tab10', 10)

In [None]:
fig, axes = plt.subplots(2,3, figsize=(23,14))

for i, drug in enumerate(labels_drugs_dict.keys()):
    dict_clusters_scaled = {f'cluster_{i}':np.sum(labels_drugs_dict[drug]==i) for i in list(np.unique(labels_drugs_dict[drug]))}
    print(dict_clusters_scaled)
    c_dict_scaled = {i: c_pal[i+1] for i in np.unique(labels_drugs_dict[drug])}
    labels_c_scaled = [c_dict_scaled[lab] for lab in labels_drugs_dict[drug]]
    
    axes[0][i].scatter(
    embedding_drugs_speeds[drug][:, 0],
    embedding_drugs_speeds[drug][:, 1], c=labels_c_scaled, s=1)
    
    axes[1][i].scatter(
    embedding_drugs_wavelets[drug][:, 0],
    embedding_drugs_wavelets[drug][:, 1], c=labels_c_scaled, s=1)
    
    markers = [plt.Line2D([0,0],[0,0],color=color, marker='o', linestyle='') for color in c_dict_scaled.values()]
    plt.legend(markers, c_dict_scaled.keys(), numpoints=1)
    
for ax in axes.ravel():
    ax.set_aspect('equal', 'datalim')
    

# Results : all in one place

In [None]:
df_control = pd.read_hdf('../../results/UMAP_HDBSCANclustering_withWV_31072023_1135.h5')
df_control.columns

In [None]:
df_control['drug'] = 'control'

In [None]:
df_control.rename(columns={'hdbscan_wv_scaled':'hdbscan_clusters'}, inplace=True)

In [None]:
df_control.drop(['hdbscan', 'hdbscan_plus', 'hdbscan_plus_'], axis=1, inplace=True)

In [None]:
(set(df_control.columns) | set(df_results_drugs_combo)) - (set(df_control.columns) & set(df_results_drugs_combo))

In [None]:
df_drugs

In [None]:
drugs_results = []
for i, drug in enumerate(labels_drugs_dict.keys()):
    df_drug_res = pd.DataFrame()
    df_drug_res['umap_raw_0'] = embedding_drugs_speeds[drug][:,0]
    df_drug_res['umap_raw_1'] = embedding_drugs_speeds[drug][:,1]
    df_drug_res['umap_wv_0'] = embedding_drugs_wavelets[drug][:,0]
    df_drug_res['umap_wv_1'] = embedding_drugs_wavelets[drug][:,1]
    df_drug_res['hdbscan_clusters'] = labels_drugs_dict[drug]
    df_drug_res['drug'] = [drug] * len(labels_drugs_dict[drug]) 
    drugs_results.append(df_drug_res)

In [None]:
df_results_drugs = pd.concat(drugs_results).reset_index(drop=True)
df_results_drugs

In [None]:
df_results_drugs_combo = pd.concat([df_drugs.reset_indebx(drop=True), df_results_drugs], axis=1)
df_results_drugs_combo

In [None]:
df_results_all = pd.concat([df_control, df_results_drugs_combo])

In [None]:
df_results_drugs_combo.to_hdf('../../results/drugs_UMAP_HDBSCANclustering_withWV_08082023_1951.h5', key='results')

# Thigmotaxis 

# Temporal properties : transitions

## time in each cluster

In [None]:
df_files_grouped = df_results_all.groupby('filename')

In [None]:
cluster_usage = []
for name, df_file in df_files_grouped:
    drug = df_file.drug.unique()[0]
    clusters_file = df_file['hdbscan_clusters'].values
    test_count = Counter(clusters_file)
    dict_cluster_usage = {}
    dict_cluster_usage['filename'] = name
    dict_cluster_usage['drug'] = drug
    for k in sorted(test_count.keys()):
        dict_cluster_usage[f'cluster_{k}_frames'] = test_count[k]
    cluster_usage.append(dict_cluster_usage)

df_cluster_usage  = pd.DataFrame(cluster_usage)      
df_cluster_usage.fillna(0, inplace=True)
df_cluster_usage

In [None]:
df = df_cluster_usage.groupby('drug').sum()
res = df.div(df.sum(axis=1), axis=0)
res.mul(100)

In [None]:
res.sum(axis=1)

In [None]:
cluster_usage

## lengths of cluster stretches 

In [None]:
def make_cluster_motifs_df(fn, df_file):   
    
    clusters_file = df_file['hdbscan_clusters'].values
    frames = df_file['frame'].values
    drug = df_file.drug.unique()[0]
    
    df_motif = []
    
    
    for state in np.unique(clusters_file):

        clus = {}
        clus_inds = [ind for ind, val in zip(frames, clusters_file) if val == state]
        clus_inds_nested = [list(map(itemgetter(1), g)) for k, g in groupby(enumerate(clus_inds), lambda x: x[0]-x[1])]
        clus['start'] = [x[0] for x in clus_inds_nested]
        clus['stop'] = [x[-1] for x in clus_inds_nested]
        clus['duration'] = [x[-1]-x[0] for x in clus_inds_nested]
        clus['cluster'] = [state for x in clus_inds_nested]
        clus['filename'] = [fn for x in clus_inds_nested]
        clus['drug'] = [drug for x in clus_inds_nested]
        df_clus = pd.DataFrame(clus)

        df_motif.append(df_clus)
    df_motif = pd.concat(df_motif) 
    return df_motif

In [None]:
df_motifs_drugs_all = Parallel(n_jobs=40, verbose = 5)(delayed(make_cluster_motifs_df)(fn, df_fn) 
                                                for fn, df_fn in df_files_grouped)
df_motifs_drugs_combined = pd.concat(df_motifs_drugs_all)

In [None]:
motif_groups_drugs  = df_motifs_drugs_combined.groupby(['drug','cluster'])

In [None]:
motif_groups_drugs.agg({'duration':[min, max, np.mean]})

In [None]:
fig, axes = plt.subplots(3,8, figsize=(20, 24), sharey=True)
for i, (drug_clus, motif_g) in enumerate(motif_groups_drugs):
    print(drug_clus)
    sns.boxplot(data = motif_g, x='cluster', y='duration', ax =axes[i//8][i%8])
    axes[i//8][i%8].set_title(f'cluster {drug_clus[1]}')
#     axes[i//8][i%8].set_ylabel(f'{drug_clus[0]}')

## transitions

In [None]:
drugs = df_results_all.drug.unique()
drugs

In [None]:
cluster_usage_dfs = {drug:[] for drug in drugs}

for name, df_file in df_files_grouped:
    drug = df_file.drug.unique()[0]
    clusters_file = df_file['hdbscan_clusters'].values
    start_list = [x for x in clusters_file[:-1]]
    stop_list = [x for x in clusters_file[1:]]

    trans_dict = {'start': start_list, 'stop': stop_list}
    trans_df = pd.DataFrame(trans_dict)
    cluster_usage_dfs[drug].append(trans_df)
    

In [None]:
trans_mat_probs = {drug:[] for drug in drugs}
fig, axes = plt.subplots(1,4, figsize=(28,6))

for i, drug in enumerate(cluster_usage_dfs.keys()):
    transition_df = pd.concat(cluster_usage_dfs[drug])
    transition_counts = transition_df.groupby(['start', 'stop']).size().reset_index(name='counts')
    trans_mat_counts = pd.pivot_table(transition_counts, values='counts', index=['start'],
                columns=['stop'])
    trans_mat_counts = trans_mat_counts.fillna(0)
    trans_mat_probs[drug] = trans_mat_counts.div(trans_mat_counts.sum(axis=1))
    
    sns.heatmap(trans_mat_probs[drug], ax= axes[i])
    axes[i].set_title(drug)
    axes[i].set_aspect('equal')
    

## Path complexity

In [None]:
def calc_path_complexity(filename):3.084309	
    
    df_result_fn = df_results_all[df_results_all['filename'] == filename]
        
    # data from DLC 

    dlc_path = df_result_fn['dlc_result_file'].unique()[0]
    dlc_folder, dlc_filename = os.path.split(dlc_path)
    dlc_obj = DLC_tracking(dlc_filename, dlc_folder)
  
    
    # Interpolate missing datapoints (dorsal)
    df_dorsal = dlc_obj.df_data.filter(regex='^(NT_|TT_|D).*(x|y)$')
    df_dorsal_filt = df_dorsal[df_dorsal.isna().sum(axis=1) < 5]
    df_dorsal_x = df_dorsal_filt.filter(like='_x')
    df_dorsal_y = df_dorsal_filt.filter(like='_y')
    df_dorsal_interp_x = interpol_spatial(df_dorsal_x)
    df_dorsal_interp_y = interpol_spatial(df_dorsal_y)
    df_dorsal_x_fin = interpol_temporal(df_dorsal_interp_x)
    df_dorsal_y_fin = interpol_temporal(df_dorsal_interp_y)
    dlc_obj.df_data.loc[df_dorsal_filt.index,'NT_x_interp'] = df_dorsal_x_fin['NT_x']
    dlc_obj.df_data.loc[df_dorsal_filt.index,'NT_y_interp'] = df_dorsal_y_fin['NT_y']
    
    # data from clustering  # need not do this !
    df_cluster = pd.merge(dlc_obj.df_data, df_result_fn, on='frame')


    framerate = 30
    window = framerate 

    df_xy = df_cluster[['filename','frame','NT_x_interp', 'NT_y_interp', 'NT_x', 'NT_y']] 
#     df_xy = df_xy.dropna(how='any')
        
    try:

        M = obtain_M(df_xy['NT_x_interp'], df_xy['NT_y_interp'], window = window)

        lH,H = get_H(M)
        
        df_xy['lH'] = np.hstack((np.array([np.nan]*(window//2)), lH, np.array([np.nan]*(window - (window//2)))))
        return df_xy
        
    except Exception as e:
        return None

In [None]:
filenames = list(df_results_all.filename.unique())

In [None]:
calc_path_complexity(filenames[8])

In [None]:
df_lH_all = Parallel(n_jobs=40, verbose = 5)(delayed(calc_path_complexity)(fn) 
                                                for fn in filenames)
df_lH_combined = pd.concat(df_lH_all)
df_lH_combined

In [None]:
df_results_complexity = pd.merge(df_lH_combined, df_results_all, on=['filename','frame'])

# Light

In [None]:
embedding_light = {}

In [None]:
for i, (name, group) in tqdm.tqdm(enumerate(df_light.groupby('stim_RGB'))):
    
    group_in = group[feats_to_use]
    group_in = group_in.fillna(-1)
    embedding_light[name] = loaded_reducer.transform(group_in.values)

In [None]:
embedding_light.keys()

In [None]:
fig, axes = plt.subplots(1,4, figsize=(32,7))
axes= axes.ravel()

axes[0].scatter(embedding[:, 0],embedding[:, 1], s=0.2)
axes[0].set_title('Control')

for i, key in enumerate(embedding_light.keys()):
    
    if key != 'v0310000':
    
        axes[i+1].scatter(embedding_light[key][:, 0],embedding_light[key][:, 1], s=0.2)
        axes[i+1].set_title(key)

In [None]:
fig, axes = plt.subplots(1,4, figsize=(32,7), sharex=True, sharey=True)
axes= axes.ravel()

axes[0].hist2d(embedding[:, 0],embedding[:, 1], bins=(150,150), density=True)
axes[0].set_title('Control')

for i, key in enumerate(embedding_light.keys()):
    if key != 'v0310000':
    
        axes[i+1].hist2d(embedding_light[key][:, 0],embedding_light[key][:, 1], bins=(150,150), density=True)
        axes[i+1].set_title(key)

# Age

In [None]:
df_young_in = df_young[feats_to_use]
df_young_in = df_young_in.fillna(-1)

In [None]:
embedding_age = loaded_reducer.transform(df_young_in.values)

In [None]:
fig, axes = plt.subplots(1,2, figsize=(15,7))
axes[0].scatter(embedding[:, 0],embedding[:, 1], s=0.2)
axes[0].set_title('Control')
axes[1].scatter(embedding_age[:, 0],embedding_age[:, 1], s=0.2)
axes[1].set_title('Young larvae')

In [None]:
fig, axes = plt.subplots(1,2, figsize=(15,7))
axes= axes.ravel()

axes[0].hist2d(embedding[:, 0],embedding[:, 1], bins=(150,150), density=True)
axes[0].set_title('Control')

    
axes[1].hist2d(embedding_age[:, 0],embedding_age[:, 1], bins=(150,150), density=True)
axes[1].set_title('Young larvae')