In [None]:
'''
This notebook takes a trained BAMS (https://multiscale-behavior.github.io/) model and pose data (DeepLabCut '.csvs') and returns features for behavorial classification/ studies of dynamic movement. 
The embeddings are shape (n samples, n frames per sample, n bams features) where samples = DLC csvs. Frame level embeddings, where each frame from the pose 
'.csvs' is its own datapoint, are shape (n samples * n frames per sample, n bams features). These frame level data points are embedded into a 3D space using UMAP and clustered in the UMAP 
space using DBSCAN. Sequence level embeddings can also be computed by averaging each feature over the frames of the sample and are shape (n samples, n bams features).

Created by:
Mary Beth Cassity @ mary.beth.cassity@cornell.edu
Sarvestani Lab, Cornell University

Last updated: 
9/13/2024
'''

# Get BAMS feature embeddings from trained BAMS model + pose (DLC csvs)
#### Outline:
#### 1. Load the data and trained model and get an embedding
##### You can stop after this step if you just want to get the bams feature embeddings in 64 (short and long)/128 (all) dimensional space. The next steps reduce the dimensions (UMAP) and cluster the datapoints in the 3d UMAP space (DBSCAN)
#### 2. Use UMAP to reduce dimensions and DBSCAN to cluster in UMAP space for frame level embeddings
#### 3. Save gifs for data points in UMAP space organized by cluster (view the movement latents)

##### Import required libraries and functions

In [1]:
import torch

from bams.data import KeypointsDataset
from bams.models import BAMS
from bams import compute_representations
from custom_dataset_w_labels import load_data, load_annotations

import numpy as np
import os 
import seaborn as sns
import umap
from sklearn.cluster import DBSCAN
import pandas as pd

from extra_fun import plot_3d_umap, save_gifs

## 1. Load the data and trained model and get an embedding


<span style="color:Red; font-size:24px;">User input</span>

In [2]:
'''
input the path to the bams model folder and the dlc data folder 

the bams model folder should have a bams model (ending in '.pt') in it
the code recognizes the model by it beginning with 'bams-custom' and ending with '.pt' therefore, it is important to make sure that only one exists per folder

the data folder should be organized with subfolders named with the species (label)
each subfolder contains dlc csvs. each csv should contain the same number of datapoints (frames) 
'''

feature_processing = 'subtract_centroid' 

### input path to model folder here ###
### change the model to match what features you want to extract-- see the ppt Documentation for model:feature match ###
model_folder = r"X:\MaryBeth\BAMS\Visuomotor-Latents\models\bams-custom-2024-08-29-15-48-48_0.8" 

### input path to data folder here ###
### you don't need to change this unless you want to change the input dataset ###
dlc_data_folder = r"X:\MaryBeth\BAMS\Visuomotor-Latents\data\threshold_0.8\movement" 

<span style="color:Red; font-size:24px;">User input</span>
##### load an embedding (optional)

In [None]:
### If you want to load a previously computed embedding, uncomment the line below ###

# embeddings = torch.load(os.path.join(model_folder,'embeddings.pth'))

In [None]:
with os.scandir(model_folder) as entries:
    for entry in entries:
        if entry.is_file() and entry.name.startswith('bams-custom') and entry.name.endswith('.pt'):
            model_name = entry
            print("Loading model", model_name)
            print()
            
model_path = os.path.join(model_folder, model_name)
annotations_path = os.path.join(model_folder,"video_labels.csv")

hoa_bins = 32
model_input = load_data(dlc_data_folder, model_folder, feature_processing = feature_processing, create_csv = False) ### set to True if you want to re create the mapping from bams to the original data csvs-- you only need to do this if you want to change the input dataset ### 
annotations, eval_utils = load_annotations(annotations_path)

dataset = KeypointsDataset(
        keypoints=model_input,
        cache=False,
        hoa_bins=hoa_bins,
        annotations=annotations,
        eval_utils=eval_utils
    )

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### initiate the model ###

model = BAMS(
        input_size=dataset.input_size,
        short_term=dict(num_channels=(64, 64, 64, 64), kernel_size=3),
        long_term=dict(num_channels=(64, 64, 64, 64, 64), kernel_size=3, dilation=4),
        predictor=dict(
            hidden_layers=(-1, 256, 512, 512, dataset.target_size * hoa_bins)
        ),
    ).to(device)
model.load_state_dict(torch.load(model_path))
model.eval()

### compute the embeddings ###
embeddings = compute_representations(model, dataset, device)

In [None]:
''' 
each embedding (short, long, all) should be shape (n samples (total csvs), n frames per sample, n bams features)

frame level embeddings (n samples * n frames per sample, n bams features) are used to plot in 3d UMAP space and save datapoints (frames) as gifs
'''

### retrieve embeddings ###

short_term = embeddings['short_term']
long_term = embeddings['long_term']
all_embeddings = torch.cat([short_term, long_term], dim=2)

print("short_term: ", np.shape(short_term))
print("long_term: ", np.shape(long_term))
print("all_embeddings: ", np.shape(all_embeddings))
print("")

### compute sequence level embeddings of the bams features by averaging over the frames for each sample ###

short_term_seq = torch.mean(short_term, dim=1, keepdim=False)
long_term_seq = torch.mean(long_term, dim=1, keepdim=False)
all_embeddings_seq = torch.cat([short_term_seq, long_term_seq], dim=1)

print("short_term seq: ", np.shape(short_term_seq))
print("long_term seq: ", np.shape(long_term_seq))
print("all_embeddings: ", np.shape(all_embeddings_seq))
print("")

### reshape frame array to get frame level embeddings ###

short_term_frame = short_term.view(short_term.size(0) * short_term.size(1), short_term.size(2))
long_term_frame = long_term.view(long_term.size(0) * long_term.size(1), long_term.size(2))
all_embeddings_frame = all_embeddings.view(all_embeddings.size(0) * all_embeddings.size(1), all_embeddings.size(2))

print("short_term frame: ", np.shape(short_term_frame))
print("long_term frame: ", np.shape(long_term_frame))
print("all_embeddings frame: ", np.shape(all_embeddings_frame))


<span style="color:Red; font-size:24px;">User input</span>
##### save the embedding (optional)

In [None]:
### If you want to save a newly computed embedding, uncomment the line below ###

# torch.save({'short_term': short_term, 'long_term': long_term}, os.path.join(model_folder,'embeddings.pth'))

## 2. Use UMAP to reduce dimensions and DBSCAN to cluster in UMAP space for frame level embeddings

<span style="color:Red; font-size:24px;">User input</span>
##### choose which embedding you want to investigate

In [5]:
'''
Frame level embeddings: each frame from the pose csvs is its own datapoint: shape (n samples * n frames per sample, n bams features) 
'''

### choose which embedding you want to investigate (short, long, or all) ###

embedding = short_term # long_term, all_embeddings
embedding_frame = short_term_frame # long_term_frame, all_embeddings_frame
embedding_frame_name = "short_term_frame" # "long_term_frame", "all_embeddings_frame"


In [None]:
### create mapping from bams embeddings to the video/ csv space ###

### read the order in which the csvs were passed to bams ###
df = pd.read_csv(annotations_path)
video = df['video_name']

### for each csv, repeat its name n frames per sample times ###
video = np.repeat(video, embedding.shape[1])
print(np.shape(video))

### for each csv, number the frames from 0 to n frames per sample ### 
repeated_array = np.tile(np.arange(embedding.shape[1]), embedding.shape[0])
print(np.shape(repeated_array))

### make mapping that contains [video name, frame number] for each sample ###
video_frames = np.column_stack((video, repeated_array))
print(np.shape(video_frames))


<span style="color:Red; font-size:24px;">User input</span>
##### choose how many datapoins to randomly sample

In [None]:
'''
due to the large size of the frame level dataset, a random sample of the data is selected to give to umap and dbscan
'''

### choose how many samples you want to plot ###
sample_size = 4000

os.makedirs(os.path.join(model_folder,f"{sample_size}"), exist_ok=True)

### set the random state ###
np.random.seed(42)
sample_indices = np.random.choice(len(embedding_frame), size=sample_size, replace=False)

sample = embedding_frame[sample_indices]
video_frames_sample = video_frames[sample_indices]

umap_model = umap.UMAP(n_components=3, random_state=42) 
behavior_umap = umap_model.fit_transform(sample)

<span style="color:Red; font-size:24px;">User input</span>
##### choose eps and min_samples for DBSCAN

In [8]:
'''
the value of eps can be computed, however, I have found this doesn't work well for our purpose/ dataset composition
see section 4.1 of https://dl.acm.org/doi/pdf/10.1145/3068335 for a description of the knee method for choosing eps
and strategies for choosing min_samples
'''

### set eps and min_samples to a computed or  experimentally determined value ###
eps = 0.15
min_samples = 5

dbscan_model = DBSCAN(eps=eps, min_samples=min_samples) 
dbscan_labels = dbscan_model.fit_predict(behavior_umap)

In [None]:
### plot the 3d umap using plotly ###
plot_3d_umap(dbscan_labels, behavior_umap, video_frames_sample, model_folder, sample_size, eps, min_samples, embedding_frame_name)

## 3. Save gifs for data points in UMAP space organized by cluster (view the movement latents)

In [10]:
mp4_folder = r"X:\MaryBeth\BAMS\Visuomotor-Latents\data\threshold_0.8\0.8_mp4s_combined"
save_gifs(model_folder, eps, min_samples, video_frames_sample, dbscan_labels, mp4_folder, sample_size, embedding_frame_name)