In [None]:
import os
import pandas as pd
import numpy as np
import plotly.express as px
import torch
import timeit

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader
from datetime import datetime
from umap import UMAP

from tqdm.notebook import tqdm

from src.Models import LSTMAutoencoder, LSTMVAE, LSTMVAE_t
from src.ModelUtils import (EarlyStopping, 
                            segment_steps_by_phase, 
                            steps_to_tensor, 
                            masked_mse_loss, 
                            masked_vae_loss, 
                            masked_vae_t_loss)
from src.DataUtils import (load_data_from_directory)

In [None]:
## Directories and constants
FIGURES_DIR = './figures/'
MODELS_DIR = './src/models'
DATA_DIR = './csv/Pain_Plot_Features'
DATASETS = ['A', 'B', 'C', 'D', 'E']
GROUPS = ['pre', 'post']
DIRECTION = ['left', 'right']

## Hyperparameters and early stopping
INPUT_DIM = None
HIDDEN_DIM = 64
LATENT_DIM = 16
BATCH_SIZE = 32
NUM_EPOCHS = 500
LR = 1e-3
PATIENCE = 50 # number of epochs to wait for improvement before stopping
MIN_DELTA = 1e-4 # minimum change to qualify as an improvement
BEST_MODEL_PATH = None#os.path.join(MODELS_DIR, 'lstm_VAE_no_first_last_20250609_121841.pt')

## Plot constants
SCATTER_SIZE = 6
SCATTER_LINE_WIDTH = 1
SCATTER_SYMBOL = 'circle'
LEGEND_FONT_SIZE = 18
TITLE_FONT_SIZE = 24
AXIS_FONT_SIZE = 16
AXIS_TITLE_FONT_SIZE = 20

# Load the data
data = {}
directory = os.listdir(DATA_DIR)
for file in directory:
    if file.endswith('.csv'):
        components = file.split('_')

        ## Assuming the file naming convention is:
        # dataset_group_mouse_direction_run.csv
        dataset = components[0]
        group = components[1]
        mouse = components[2]
        direction = components[3]
        run = components[4]

        datagroup = dataset + '_' + group
        if datagroup not in data:
            data[datagroup] = {}

        mouse_direction = mouse + '_' + direction
        if mouse_direction not in data[datagroup]:
            data[datagroup][mouse_direction] = {}

        data[datagroup][mouse_direction][run] = pd.read_csv(
            os.path.join(DATA_DIR, file), index_col=0)