### Imports

https://medium.com/@markstent/dynamic-time-warping-a8c5027defb6

In [None]:
pip install dtaidistance

In [None]:
pip install h5py

In [None]:
import numpy as np
from dtaidistance import dtw, dtw_visualisation as dtwvis
import h5py
import matplotlib.pyplot as plt
from collections import defaultdict
import matplotlib.cm as cm
import matplotlib.colors as mcolors

### Helper Functions

In [None]:
def print_h5_structure(g, indent=0):
    """
    Recursively prints the structure of an HDF5 group.
    
    Parameters:
      - g: an h5py Group or File object.
      - indent: current indentation level (for pretty-printing).
    """
    for key in g.keys():
        item = g[key]
        print("  " * indent + f"{key}: {type(item).__name__}")
        if isinstance(item, h5py.Group):
            print_h5_structure(item, indent+1)

with h5py.File('trajectory_data.h5', 'r') as f:
    print("HDF5 File Structure:")
    print_h5_structure(f)

In [None]:
def resample_trajectory(traj, new_length):
    """
    Resample a 1D trajectory to a specified new_length using linear interpolation.
    """
    current_length = len(traj)
    t_old = np.linspace(0, 1, current_length)
    t_new = np.linspace(0, 1, new_length)
    return np.interp(t_new, t_old, traj)

In [None]:
def sort_key(key):
    return int(key.split('_')[-1])

In [None]:
def normalize_trajectory(x, y):
    """Normalize x and y trajectories to have zero mean and unit variance."""
    x = (x - np.mean(x)) / np.std(x)
    y = (y - np.mean(y)) / np.std(y)
    return x, y

In [None]:
def compare_sorted_trajectories_subplots(actual_mean_trajectories, model_mean_trajectories):
    """
    For each start position (e.g., LEFT, RIGHT, CENTER), create a figure with 9 subplots
    (arranged in a 3x3 grid) comparing the corresponding actual and model trajectories.
    
    The trajectories are paired by the sorted order of their keys (even if the keys differ).
    Each subplot displays the normalized actual (labeled with its key) and model trajectories
    (labeled with its key), along with the computed DTW distance.
    
    Returns:
        A dictionary mapping (start_position, actual_key, model_key) to the DTW distance.
    """
    results = {}
    
    for sp in sorted(actual_mean_trajectories.keys()):
        fig, axs = plt.subplots(6, 3, figsize=(15, 15))
        axs = axs.flatten() 
        
        actual_keys = sorted(actual_mean_trajectories[sp].keys())
        model_keys  = sorted(model_mean_trajectories[sp].keys())
        
        for i, (key_actual, key_model) in enumerate(zip(actual_keys, model_keys)):
            if i >= 18:
                break
            
            traj_actual = actual_mean_trajectories[sp][key_actual]
            traj_model  = model_mean_trajectories[sp][key_model]
            ax = axs[i]
            
            if traj_actual and traj_model:
                x1, y1 = traj_actual
                x2, y2 = traj_model

                x1_norm, y1_norm = normalize_trajectory(np.array(x1), np.array(y1))
                x2_norm, y2_norm = normalize_trajectory(np.array(x2), np.array(y2))
                
                traj1 = np.vstack((x1_norm, y1_norm)).T
                traj2 = np.vstack((x2_norm, y2_norm)).T
                
                distance = dtw.distance(traj1.flatten(), traj2.flatten())
                results[(sp, key_actual, key_model)] = distance
                
                ax.plot(x1_norm, y1_norm, label=f'Actual {key_actual}')
                ax.plot(x2_norm, y2_norm, label=f'Model {key_model}')
                ax.set_title(f'Speed {key_actual}/{key_model}\nDTW: {distance:.2f}')
                ax.legend(fontsize=8)
                ax.tick_params(axis='both', labelsize=8)
            else:
                ax.set_title("Missing trajectory", fontsize=10)
        
        for j in range(i + 1, len(axs)):
            axs[j].axis('off')
        
        plt.suptitle(f'{sp} Speed Comparisons', fontsize=16)
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        plt.show()
    
    return results

In [None]:
def plot_model_trajectories(filename, deep=False):
    start_positions = ['LEFT', 'RIGHT', 'CENTER']
    data_speeds = [-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8]
    trajectories = {sp: {s: [] for s in data_speeds} for sp in start_positions}
    mean_trajectories = {sp: {s: [] for s in data_speeds} for sp in start_positions}

    cmap = plt.get_cmap("viridis")  
    norm = mcolors.Normalize(vmin=min(data_speeds), vmax=max(data_speeds))

    with h5py.File(filename, 'r') as f:
        for sp in start_positions:
            x_group = f[sp]['x']
            y_group = f[sp]['y']
            keys = sorted(x_group.keys(), key=lambda k: int(k))
            
            for key in keys:
                speed = int(key)                   
                x_arr = x_group[key][:]
                y_arr = y_group[key][:]
                trajectories[sp][speed].append((x_arr, y_arr))
                
    for sp in start_positions:
        for s in data_speeds:
            traj_list = trajectories[sp][s]
            max_length = max(len(x) for x, _ in traj_list)
            avg_x = np.mean(np.array([resample_trajectory(x, max_length) for x, _ in traj_list]), axis=0)
            avg_y = np.mean(np.array([resample_trajectory(y, max_length) for _, y in traj_list]), axis=0)
                
            grouped_x = defaultdict(list)
            for x_val, y_val in zip(avg_x, avg_y):

                rounded_y = round(y_val/2)*2 if deep else round(y_val)
                grouped_x[rounded_y].append(x_val)
                
            rounded_ys = sorted(grouped_x.keys())
            averaged_xs = [np.mean(grouped_x[ry]) for ry in rounded_ys]
            mean_trajectories[sp][s] = (np.array(averaged_xs), np.array(rounded_ys))

    fig, axs = plt.subplots(2, len(start_positions), figsize=(18, 10), sharex=True, sharey=True)
    for ax, sp in zip(axs[0], start_positions):
        for s in data_speeds[8:]:
            if mean_trajectories[sp][s]:
                avg_x, avg_y = mean_trajectories[sp][s]
                ax.plot(avg_x, avg_y, color=cmap(norm(s)), label=f'Speed {s}')
        ax.set_title(f'Start Position: {sp}')
        ax.set_xlabel('X Coordinate')
        ax.set_ylabel('Y Coordinate')
        ax.legend()
    for ax, sp in zip(axs[1], start_positions):
        for s in data_speeds[:8]:
            if mean_trajectories[sp][s]:
                avg_x, avg_y = mean_trajectories[sp][s]
                ax.plot(avg_x, avg_y, color=cmap(norm(s)), label=f'Speed {s}')
        ax.set_title(f'Start Position: {sp}')
        ax.set_xlabel('X Coordinate')
        ax.set_ylabel('Y Coordinate')
        ax.legend()

    fig.suptitle('Rulebased Mean Trajectories for Different Start Positions and Speeds')
    plt.tight_layout()
    plt.show()

    return mean_trajectories

In [None]:
def plot_actual_trajectory(filename_acc_data, group_names):
    start_positions = ['LEFT', 'RIGHT', 'CENTER']
    actual_speeds = [-20, -17.5, -15, -12.5, -10, -7.5, -5, -2.5, 0, 2.5, 5, 7.5, 10, 12.5, 15, 17.5, 20]
    trajectories = {sp: {s: [] for s in actual_speeds} for sp in start_positions}  
    actual_mean_trajectories = {sp: {s: [] for s in actual_speeds} for sp in start_positions}  

    cmap = plt.get_cmap("viridis")  
    norm = mcolors.Normalize(vmin=min(actual_speeds), vmax=max(actual_speeds))

    with h5py.File(filename_acc_data, 'r') as f:
        for group in group_names:
            print(f"Processing group: {group}")
            x_group = f[f'{group}/neck_x_mirror_smooth']
            y_group = f[f'{group}/neck_y_mirror_smooth']
            speed_dataset = f[f'{group}/trial_settings/speed_direction'][:]
            direction_dataset = f[f'{group}/trial_settings/initial_position_X_SIDE'][:]

            keys = sorted(x_group.keys(), key=sort_key)

            data_speeds = np.array([float(s.decode('utf-8')) if isinstance(s, bytes) else float(s) for s in speed_dataset])
            directions = [d.decode('utf-8') if isinstance(d, bytes) else d for d in direction_dataset]

            for i, key in enumerate(keys):
                x_arr = x_group[key][:]
                y_arr = y_group[key][:]
                trial_speed = data_speeds[i]
                rounded_speed = round(trial_speed / 2.5) * 2.5    
                dir_val = directions[i]
                trajectories[dir_val][rounded_speed].append((x_arr, y_arr))

    for sp, speed_dict in trajectories.items():
        for speed, traj_list in speed_dict.items():
            max_length = max(len(x) for x, _ in traj_list)
            avg_x = np.mean(np.array([resample_trajectory(x, max_length) for x, _ in traj_list]), axis=0)
            avg_y = np.mean(np.array([resample_trajectory(y, max_length) for _, y in traj_list]), axis=0)

            grouped_x = defaultdict(list)
            for x_val, y_val in zip(avg_x, avg_y):
                rounded_y = round(y_val)
                grouped_x[rounded_y].append(x_val)

            rounded_ys = sorted(grouped_x.keys())
            averaged_xs = [np.mean(grouped_x[ry]) for ry in rounded_ys]
            actual_mean_trajectories[sp][speed] = (np.array(averaged_xs), np.array(rounded_ys))

    fig, axs = plt.subplots(2, 3, figsize=(18, 10), sharex=True, sharey=True)

    for ax, sp in zip(axs[0], start_positions):  
        for speed in actual_speeds[8:]:
            avg_x, avg_y = actual_mean_trajectories[sp][speed]
            ax.plot(avg_x, avg_y, color=cmap(norm(speed)), label=f'Speed {speed}')
        ax.set_title(f'Start Position: {sp}')
        ax.set_xlabel('X Coordinate')
        ax.set_ylabel('Y Coordinate')
        ax.legend()
    for ax, sp in zip(axs[1], start_positions):  
        for speed in actual_speeds[:8]:
            avg_x, avg_y = actual_mean_trajectories[sp][speed]
            ax.plot(avg_x, avg_y, color=cmap(norm(speed)), label=f'Speed {speed}')
        ax.set_title(f'Start Position: {sp}')
        ax.set_xlabel('X Coordinate')
        ax.set_ylabel('Y Coordinate')
        ax.legend()

    fig.suptitle('Average Trajectories for Different Start Positions (Combined Groups)') 
    plt.tight_layout()
    plt.show()

    return actual_mean_trajectories

In [None]:
def compute_auc(x, y):
    valid = ~np.isnan(y)
    if not valid.any():
        return np.nan
    auc = 0.0
    splits = np.where(~valid)[0]
    segments = []
    start = 0
    for idx in splits:
        if idx > start:
            segments.append((start, idx))
        start = idx + 1
    if start < len(y):
        segments.append((start, len(y)))
    for s, e in segments:
        auc += np.trapz(y[s:e], x[s:e])
    return auc

In [None]:
def plot_dtw_scores(actual_mean_trajectories, list_of_model_trajectories, model_names=None):
    """
    For each start position (LEFT, CENTER, RIGHT), pairs the actual trajectories with the
    model trajectories (sorted by speed key, even if the keys differ) and computes the DTW
    distance for each pair. All points are placed on a continuous x-axis (ordered by group)
    so that an area under the curve (AUC) can be computed for each model.
    
    In addition to an overall DTW score comparison plot, this function produces a second figure 
    containing a subplot for each individual model's DTW score curve. Both plots include:
        - Vertical dashed lines at the group midpoints (labeled with the start position)
        - Faint vertical dashed lines at the boundaries between start positions.
    
    Args:
        actual_mean_trajectories (dict): {start_pos: {speed_key: (x, y), ...}, ...}
        list_of_model_trajectories (list): List of dictionaries structured like actual_mean_trajectories.
        model_names (list, optional): Names for each model. Defaults to ["Model 1", "Model 2", ...].
    
    Returns:
        dict: auc_dict mapping model_name -> AUC value (using np.trapz over valid segments).
    """
    
    start_positions = ['LEFT', 'CENTER', 'RIGHT']
    
    group_counts = [len(sorted(actual_mean_trajectories[sp].keys(), key=lambda s: float(s))) if sp in actual_mean_trajectories else 0 for sp in start_positions]
    total_points = sum(group_counts)
    global_x = np.arange(total_points)
    
    midpoints = {}
    start_idx = 0
    for sp, count in zip(start_positions, group_counts):
        if count > 0:
            midpoints[sp] = (start_idx + start_idx + count - 1) / 2.0
        start_idx += count

    # Compute boundaries between groups.
    cumulative_counts = np.cumsum(group_counts)
    group_boundaries = []
    for boundary in cumulative_counts[:-1]:
        group_boundaries.append(boundary - 0.5)
    
    model_curves = {name: [] for name in model_names}
    results = {}
    
    for sp in start_positions:
        actual_keys = sorted(actual_mean_trajectories.get(sp, {}).keys(), key=lambda s: float(s))
        n_points = len(actual_keys)
        
        for m_idx, model in enumerate(list_of_model_trajectories):
            model_name = model_names[m_idx]
            if sp not in model:
                model_curves[model_name].extend([np.nan] * n_points)
                continue

            model_keys = sorted(model[sp].keys(), key=lambda s: float(s))
            pairs = list(zip(actual_keys, model_keys))
            for a_key, m_key in pairs:
                traj_actual = actual_mean_trajectories[sp][a_key]
                traj_model  = model[sp][m_key]
                
                if traj_actual and traj_model:
                    x1, y1 = np.array(traj_actual[0]), np.array(traj_actual[1])
                    x2, y2 = np.array(traj_model[0]), np.array(traj_model[1])
                    x1_norm, y1_norm = normalize_trajectory(x1, y1)
                    x2_norm, y2_norm = normalize_trajectory(x2, y2)
                    
                    traj1 = np.vstack((x1_norm, y1_norm)).T
                    traj2 = np.vstack((x2_norm, y2_norm)).T
                    
                    distance = dtw.distance(traj1.flatten(), traj2.flatten())
                    results[(model_name, sp, a_key, m_key)] = distance
                    model_curves[model_name].append(distance)
                else:
                    model_curves[model_name].append(np.nan)

    # ---------------- Overall Plot ----------------
    plt.figure(figsize=(10, 4))
    cmap = plt.get_cmap('tab10', len(list_of_model_trajectories))
    
    for i, model_name in enumerate(model_names):
        scores = np.array(model_curves[model_name])
        plt.plot(global_x, scores, marker='o', linestyle='-', color=cmap(i), label=model_name)
    
    for sp, pos in midpoints.items():
        plt.axvline(pos, color='grey', linestyle='--', alpha=0.5)
    
    for boundary in group_boundaries:
        plt.axvline(boundary, color='grey', linestyle='--', alpha=0.3)
    
    plt.xticks(list(midpoints.values()), list(midpoints.keys()))
    plt.xlabel("Start Position")
    plt.ylabel("DTW Score")
    plt.title("DTW Score Comparisons")
    plt.legend(title="Model")
    plt.tight_layout()
    plt.show()
    
    # ---------------- Compute AUC ----------------
    auc_dict = {model_name: compute_auc(global_x, np.array(scores))
                for model_name, scores in model_curves.items()}
    
    # ---------------- Individual Subplots ----------------
    n_models = len(model_names)
    fig, axs = plt.subplots(n_models, 1, figsize=(10, 3 * n_models), sharex=True)
    
    for i, model_name in enumerate(model_names):
        ax = axs[i]
        scores = np.array(model_curves[model_name])
        ax.plot(global_x, scores, marker='o', linestyle='-', color=cmap(i))
        ax.set_title(f"{model_name} DTW Score Curve (AUC = {auc_dict[model_name]:.3f})")
        ax.set_ylabel("DTW Score")
    
        for sp, pos in midpoints.items():
            ax.axvline(pos, color='grey', linestyle='--', alpha=0.5)
            ax.text(pos, ax.get_ylim()[1]*0.95, sp, ha='center', va='top', fontsize=8, color='grey')
    
        for boundary in group_boundaries:
            ax.axvline(boundary, color='grey', linestyle='--', alpha=0.3)
    
    axs[-1].set_xlabel("Global x (Continuous Index)")
    plt.tight_layout()
    plt.show()
    
    return auc_dict


### Actual Data Preprocessing

In [None]:
group_names = [
    'BRAC68537c/221109.0',
    'BRAC68537c/221111.0',
    'BRAC68537c/221114.0',
    'BRAC68537c/221115.0',
    'BRAC68537c/221117.0',
    'BRAC68537d/221109.0',
    'BRAC68537d/221111.0',
    'BRAC68537d/221114.0',
    'BRAC68537d/221115.0',
    'BRAC68537d/221117.0'
]

actual_mean_trajectories = plot_actual_trajectory('trajectory_data.h5', group_names)

### Deep RL Data Preprocessing

In [None]:
filename_deep_rl_data = '../../scripts/deep_rl_mean_coords.h5'
deep_mean_trajectories = plot_model_trajectories(filename_deep_rl_data, True)

### Deep RL Comparison

In [None]:

results = compare_sorted_trajectories_subplots(actual_mean_trajectories, deep_mean_trajectories)
print(results)

### Rulebased Data Preprocessing

In [None]:
filename_linear_data = '../../scripts/linear_mean_coords.h5'
linear_mean_trajectories = plot_model_trajectories(filename_linear_data)

In [None]:
filename_random_data = '../../scripts/random_mean_coords.h5'
random_mean_trajectories = plot_model_trajectories(filename_random_data)

### Rulebased Comparison

In [None]:

results = compare_sorted_trajectories_subplots(actual_mean_trajectories, linear_mean_trajectories)
print(results)

In [None]:

results = compare_sorted_trajectories_subplots(actual_mean_trajectories, random_mean_trajectories)
print(results)

### RL Aproximation Data Preprocessing

In [None]:
filename_rl_data = '../../scripts/rl_mean_coords.h5'
rl_mean_trajectories = plot_model_trajectories(filename_rl_data)

### RL Approximation Comparison

In [None]:
results = compare_sorted_trajectories_subplots(actual_mean_trajectories, rl_mean_trajectories)
print(results)

### Overall Plot Comparison

In [None]:
model_names = ['Deep RL', 'Linear', 'Random', 'RL']
model_trajectories = [deep_mean_trajectories, linear_mean_trajectories, random_mean_trajectories, rl_mean_trajectories]
plot_dtw_scores(actual_mean_trajectories, model_trajectories, model_names)