In [None]:
import os
import sys
import random

In [None]:
import numpy as np     
import pandas as pd

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
import cv2

In [None]:
sys.path.append('../utils/')

In [None]:
from dlc_helper import DLC_tracking

In [None]:
from video_utils import find_square_bounding

In [None]:
import ipywidgets.widgets as widgets
from ipywidgets import interact, interact_manual

 # Import the results

In [None]:
df_results_control = pd.read_hdf('../../results/featureset_used_for_UMAPclustering_18072023.h5')

In [None]:
df_results_control.columns

In [None]:
df_results = pd.read_hdf('../../results/UMAP_HDBSCANclustering_29072023_1832.h5', key='features_with_results')

In [None]:
# df_results_control.rename(columns={'hdbscan_clusters':'hdbscan_old', 'umap_raw_0':'umap_old_raw_0', 'umap_raw_1':'umap_old_raw_1'}, inplace=True)

In [None]:
# df_results_control = df_results_control.join(df_results)

In [None]:
df_results_control['umap_raw_0'] = df_results['umap_raw_0']
df_results_control['umap_raw_1'] = df_results['umap_raw_1']
df_results_control['hdbscan_clusters'] = df_results['hdbscan']
# df_results_control['hdbscan_clusters'] = df_results['hdbscan_plus']

# Plot the UMAP & clustering results

In [None]:
clusters_control = list(df_results_control['hdbscan_clusters'])

In [None]:
embedding = df_results_control.filter(like = 'umap_raw').values
embedding.shape

In [None]:
dict_clusters = {f'cluster_{i}':np.sum(clusters_control==i) for i in list(np.unique(clusters_control))}
dict_clusters

In [None]:
c_pal = sns.color_palette('tab10', 10)
c_dict = {i: c_pal[i+1] for i in np.unique(clusters_control)}
labels_c = [c_dict[lab] for lab in clusters_control]

In [None]:
fig, axes = plt.subplots(1,2, figsize=(15,7))
axes= axes.ravel()
axes[0].scatter(embedding[:, 0],embedding[:, 1], s=0.2)
axes[1].scatter(
    embedding[:, 0],
    embedding[:, 1], c=labels_c, s=1)

markers = [plt.Line2D([0,0],[0,0],color=color, marker='o', linestyle='') for color in c_dict.values()]
plt.legend(markers, c_dict.keys(), numpoints=1)

for ax in axes:
    ax.set_aspect('equal', 'datalim')
    
# fig.savefig('../../results/umap_clustered.png')

In [None]:
df_results_control.groupby('hdbscan_clusters').nunique()

In [None]:
df_results_control.groupby('hdbscan_clusters').count()

# Check feature statistics in each groups

In [None]:
df_results_control.columns

In [None]:
# df_feats_with_clusters = pd.merge(df_results_control, df_feats, on=['filename', 'frame'])

In [None]:
grouped_feats = df_results_control.groupby('hdbscan_clusters')

## Speed_MOUTH

In [None]:
fig, axes = plt.subplots(1,4, figsize = (18,4), sharex=True)
for i, (cluster, group) in enumerate(grouped_feats):
    sns.histplot(data=group, x='speed_MOUTH', ax=axes[i])
    axes[i].set_title(f'Cluster {cluster}')

In [None]:
fig, axes = plt.subplots(1,4, figsize = (18,4), sharex=True)
for i, (cluster, group) in enumerate(grouped_feats):
    sns.boxplot(data=group, y='speed_MOUTH', ax=axes[i])
    axes[i].set_title(f'Cluster {cluster}')

## Speed_V(entral)

In [None]:
fig, axes = plt.subplots(1,4, figsize = (18,4), sharex=True)
for i, (cluster, group) in enumerate(grouped_feats):
    group['mean_speeds_ventral'] = group.filter(like='speed_V').mean(axis=1)
    sns.histplot(data=group, x='mean_speeds_ventral', ax=axes[i])
    axes[i].set_title(f'Cluster {cluster}')

In [None]:
fig, axes = plt.subplots(1,4, figsize = (18,4), sharey=True)
for i, (cluster, group) in enumerate(grouped_feats):
    group['mean_speeds_ventral'] = group.filter(like='speed_V').mean(axis=1)
    sns.boxplot(data=group, y='mean_speeds_ventral', ax=axes[i])
    axes[i].set_title(f'Cluster {cluster}')

## Speed_D(orsal)

In [None]:
fig, axes = plt.subplots(1,4, figsize = (18,4), sharex=True, sharey=True)
for i, (cluster, group) in enumerate(grouped_feats):
    group =  group.fillna(value=-1)
    group['mean_speeds_dorsal'] = group.filter(like='speed_D').mean(axis=1)
    sns.histplot(data=group, x='mean_speeds_dorsal', ax=axes[i])
    axes[i].set_title(f'Cluster {cluster}')

In [None]:
fig, axes = plt.subplots(1,4, figsize = (18,4), sharey=True)
for i, (cluster, group) in enumerate(grouped_feats):
    group =  group.fillna(value=-1)
    group['mean_speeds_dorsal'] = group.filter(like='speed_D').mean(axis=1)
    sns.boxplot(data=group, y='mean_speeds_dorsal', ax=axes[i])
    axes[i].set_title(f'Cluster {cluster}')

## Speed_NT

In [None]:
fig, axes = plt.subplots(1,4, figsize = (18,4), sharey=True)
for i, (cluster, group) in enumerate(grouped_feats):
    sns.histplot(data=group, x='speed_NT', ax=axes[i])
    axes[i].set_title(f'Cluster {cluster}')

In [None]:
fig, axes = plt.subplots(1,4, figsize = (18,4), sharey=True)
for i, (cluster, group) in enumerate(grouped_feats):
    sns.boxplot(data=group, y='speed_NT', ax=axes[i])
    axes[i].set_title(f'Cluster {cluster}')

## Curvatures

In [None]:
fig, axes = plt.subplots(1,4, figsize = (18,4), sharex=True, sharey=True)
for i, (cluster, group) in enumerate(grouped_feats):
    group =  group.fillna(value=-1)
    group['mean_curv'] = group.filter(like='curv').abs().mean(axis=1)
    sns.histplot(data=group, x='mean_curv', ax=axes[i])
    axes[i].set_title(f'Cluster {cluster}')

In [None]:
fig, axes = plt.subplots(1,4, figsize = (18,4), sharey=True)
for i, (cluster, group) in enumerate(grouped_feats):
    group =  group.fillna(value=-1)
    group['mean_curv'] = group.filter(like='curv').abs().mean(axis=1)
    sns.boxplot(data=group, y='mean_curv', ax=axes[i])
    axes[i].set_title(f'Cluster {cluster}')

## Quirkiness

In [None]:
fig, axes = plt.subplots(1,4, figsize = (18,4), sharex=True, sharey=True)
for i, (cluster, group) in enumerate(grouped_feats):
    sns.histplot(data=group, x='quirkiness', ax=axes[i])
    axes[i].set_title(f'Cluster {cluster}')

In [None]:
fig, axes = plt.subplots(1,4, figsize = (18,4),sharey=True)
for i, (cluster, group) in enumerate(grouped_feats):
    sns.boxplot(data=group, y='quirkiness', ax=axes[i])
    axes[i].set_title(f'Cluster {cluster}')

# Trajectories

In [None]:
filenames = list(df_results_control.filename.unique())

In [None]:
wid_fn = widgets.SelectMultiple(
    options=filenames,
    value=filenames[:2],
    rows=15,
    description='Filename',
    disabled=False
)

In [None]:
@interact_manual
def plot_trajectory(fns=wid_fn):
    
    for fn in fns:
        df_filename = df_results_control[df_results_control['filename']== fn]
        path_to_video = df_filename['path_to_video'].unique()[0]
        print(path_to_video)
    
    n_cols = len(fns)
    fig, axes = plt.subplots(1,n_cols,figsize=(n_cols*8,8), sharex=True, sharey=True)
    
    
    for i, fn in enumerate(fns):
        
        df_result_fn = df_results_control[df_results_control['filename'] == fn]
        
        # data from DLC 
        
        dlc_path = df_result_fn['dlc_result_file'].unique()[0]
        dlc_folder, dlc_filename = os.path.split(dlc_path)
        dlc_obj = DLC_tracking(dlc_filename, dlc_folder)
        
        # data from clustering
        df_cluster = pd.merge(dlc_obj.df_data, df_result_fn, on='frame')
        hue = [clus+1 for clus in df_cluster['hdbscan_clusters']]
        
        sns.scatterplot(data = df_cluster, x='NT_x',y='NT_y', s=2, hue=hue, ax=axes[i], palette='tab10')
        markers = [plt.Line2D([0,0],[0,0],color=color, marker='o', linestyle='') for color in c_dict.values()]
        axes[i].legend(markers, c_dict.keys(), numpoints=1)
        axes[i].set_aspect('equal')
        axes[i].set_title(fn)
    

# try annotating video

In [None]:
dict_swims = {-1:'', 0:'right', 1:'swims', 2:'left', 3:'left'}

In [None]:
cap = cv2.VideoCapture("")
ret, frame = cap.read()
fps = cap.get('CAP_PROP_FPS')

fourcc = cv2.VideoWriter_fourcc(*'XVID')
out = cv2.VideoWriter('../../results/output.avi',fourcc, 30.0, (1024, 1280))

# Obejction detection from stable camera

frame_count = 0
image_count = 0
while True:
    ret, frame = cap.read()
    frame_count += 1

    try:
        if (frame_count > 0) & (frame_count < len(state_seq)):
            height, width, _ = frame.shape
            x, y, w, h = dict_bbox[frame_count]
            cv2.putText(frame, str(state_seq[frame]) , (int(x), int(y) - 15), cv2.FONT_HERSHEY_PLAIN, 1, (255, 0, 0), 2)
            cv2.rectangle(frame, (int(x)-5,int(y)-5), (int(x) + int(w)+5, int(y) + int(h)+5), (0, 255, 0), 3)
            out.write(frame)
    except KeyError:
        print(f'missing key {frame_count}')
cap.release()
out.release()
cv2.destroyAllWindows()

# Checking frames

In [None]:
test_path_to_video = '/media/athira/Amphioxus1/20180724/Exp_20180724_173426_1_15m0s_None_None_None/20180724_173426_1_15m0s_None_None_None_INVERTED.avi'
df_test = df_results_control[df_results_control['path_to_video'] == test_path_to_video]
path_to_dlc_coords = df_test['dlc_result_file'].unique()[0]
dlc_folder, dlc_filename = os.path.split(path_to_dlc_coords)
dlc_obj = DLC_tracking(dlc_filename, dlc_folder)
dict_bbox = dlc_obj.find_bbox_dlc()

In [None]:
dict_rois = {}
for key in dict_bbox.keys():
    x, y, w, h = dict_bbox[key]
    mid_x = x + (w/2)
    mid_y = y + (h/2)
    try:
        x = int(x)
        y = int(y)
        w = int(w)
        h = int(h)
        x_new, y_new, w_new, h_new = find_square_bounding(x,y,w,h, height_max = 150, width_max = 150)
        dict_rois[key] = [x_new, y_new, w_new, h_new]

    except ValueError:
        print(f'Value Error encountered for frame {key}')
    except Exception as e:
        print(f'Error encountered')
        print(e)

In [None]:
df_test.groupby('hdbscan_clusters').nunique()

In [None]:
df_grouped = df_test.groupby('hdbscan_clusters')

fig, axes = plt.subplots(10, 5, figsize = (25, 50))
for i, (clus, group) in enumerate(df_grouped):
    try:
        samples = group.sample(10)
        sample_frames =  samples['frame']
        for j, sample_f in enumerate(sample_frames):
            image_path = f'/data/temp/athira/amphi_frames_19072023/Exp_20180724_173426/frame_{sample_f}.png'
            if os.path.isfile(image_path):
                image = plt.imread(image_path)
                x, y, w, h = dict_rois[sample_f]
                x = int(x)
                y = int(y)
                w = int(w)
                h = int(h)
                cropped_image = image[y:y+h, x:x+w]
                axes[j][i].imshow(cropped_image)
                if j == 0:
                    axes[j][i].set_title(f'Cluster {clus}')
    except Exception as e:
        print(clus, e)


In [None]:
/media/athira/Amphioxus1/20180720/Exp_20180720_131830_1_5m0s_None_None_None/20180720_131830_1_5m0s_None_None_None_INVERTED.avi
/media/athira/Amphioxus1/20180720/Exp_20180720_151943_1_5m0s_None_None_None/20180720_151943_1_5m0s_None_None_None_INVERTED.avi
/media/athira/Amphioxus1/20180720/Exp_20180720_152448_1_5m0s_None_None_None/20180720_152448_1_5m0s_None_None_None_INVERTED.avi
/media/athira/Amphioxus1/20180724/Exp_20180724_161122_1_15m0s_None_None_None/20180724_161122_1_15m0s_None_None_None_INVERTED.avi
/media/athira/Amphioxus1/20180724/Exp_20180724_163657_1_5m0s_None_None_None/20180724_163657_1_5m0s_None_None_None_INVERTED.avi
/media/athira/Amphioxus1/20180724/Exp_20180724_142913_1_15m0s_None_None_None/20180724_142913_1_15m0s_None_None_None_INVERTED.avi


# Picking frames

In [None]:
def get_rois_per_video(path_to_video, path_to_dlc_coords, crop=True):
    
    

    dlc_folder, dlc_filename = os.path.split(path_to_dlc_coords)
    dlc_obj = DLC_tracking(dlc_filename, dlc_folder)
    
    dict_bbox = dlc_obj.find_bbox_dlc()
    dict_rois = {}
    
    cap = cv2.VideoCapture(path_to_video)
    total_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
    count = 0
    while count < total_frames:
        _, image = cap.read()

        count += 1
        if crop:
            # Find coords of the roi corners
            x, y, w, h = dict_bbox[count-1]
            mid_x = x + (w/2)
            mid_y = y + (h/2)
            try:
                x = int(x)
                y = int(y)
                w = int(w)
                h = int(h)
                x_new, y_new, w_new, h_new = find_square_bounding(x,y,w,h, height_max = 150, width_max = 150)
                roi_box_padded = image[y_new:y_new+h_new, x_new:x_new+w_new]
                dict_rois[count-1] = roi_box_padded

            except ValueError:
                print(f'Value Error encountered for frame {count-1}')
            except Exception as e:
                print(f'Error encountered')
                print(e)

        else:
            dict_rois[count-1] = image
            
#         cv2.waitKey(30)
            
    cap.release()
#     cv2.destroyAllWindows()
        
    return dict_rois

In [None]:
@interact_manual
def plot_trajectory(filename=filenames):
    
    df_filename = df_results_control[df_results_control['filename']== filename]
    clusters_fn = sorted(df_filename['hdbscan_clusters'].unique())
    path_to_video = df_filename['path_to_video'].unique()[0]
    path_to_dlc_coords = df_filename['dlc_result_file'].unique()[0]
    test_dict = get_rois_per_video(path_to_video, path_to_dlc_coords)
    
    df_grouped = df_filename.groupby('hdbscan_clusters')
    
    for i, clus in (enumerate(clusters_fn)):

        fig, axes = plt.subplots(1, 5, figsize=(5*8, 8))
        axes = axes.ravel()
        df = df_grouped.get_group(clus)
        if len(df.index) > 5:
            df_samples = df.sample(5)
        else:
            df_samples = df

        frames = list(df_samples['frame'])

        for j, f in enumerate(frames):
            axes[j].imshow(test_dict[f])
            axes[j].set_title(f)

            if j == 0:
                axes[j].set_ylabel(f'cluster: {clus}')
    plt.show()