In [None]:
import preamble

In [None]:
import os
import pickle
import numpy as np
import pandas as pd
from collections import defaultdict
from scipy.stats import ttest_rel, wilcoxon
import matplotlib.pyplot as plt
import seaborn as sns
import warnings

warnings.filterwarnings("ignore", category=DeprecationWarning)
sns.set(style='whitegrid', palette='deep', font_scale=1.2)

###########################
### Utility Functions #####
###########################

def load_all_user_trials(baseline_dir):
    """
    Returns: {user_id: [df1, df2, ...]} from a directory of trial .pkl files
    Assumes filenames like: {user_id}_P=*,V=*.pkl
    """
    user_trials = defaultdict(list)
    for fname in os.listdir(baseline_dir):
        if fname.endswith(".pkl"):
            try:
                user_id = fname.split('_')[0]
                with open(os.path.join(baseline_dir, fname), 'rb') as f:
                    df = pickle.load(f)
                    user_trials[user_id].append(df)
            except Exception as e:
                print(f"Error loading {fname}: {e}")
    return user_trials

def time_weighted_rmse(error, t):
    """
    TWRMSE where `error_magnitude` is a 1D array of scalar error values.
    (e.g., norms of 3D vectors or alignment errors)
    """
    error = np.asarray(error)
    t = np.asarray(t)

    dt = np.diff(t)
    e_sq = error[:-1] ** 2
    weighted_sum = np.sum(e_sq * dt)
    duration = t[-1] - t[0]
    return np.sqrt(weighted_sum / duration)

def velocity_alignment_rmse(x_dot, x_dot_ref, t):
    x_dot = np.stack(x_dot)        # (N, 3)
    x_dot_ref = np.stack(x_dot_ref)

    # Normalize vectors to unit length (add small epsilon to avoid /0)
    eps = 1e-8
    norms = np.linalg.norm(x_dot, axis=1) + eps
    norms_ref = np.linalg.norm(x_dot_ref, axis=1) + eps

    x_dot_unit = x_dot / norms[:, None]
    x_dot_ref_unit = x_dot_ref / norms_ref[:, None]

    cos_sim = np.sum(x_dot_unit * x_dot_ref_unit, axis=1)  # dot product along each row
    alignment_error = 1 - cos_sim

    return time_weighted_rmse(alignment_error, t)

def time_weighted_std(error, t):
    error = np.asarray(error)
    t = np.asarray(t)

    dt = np.diff(t)
    weights = dt / np.sum(dt)

    mu = np.sum(error[:-1] * weights)
    var = np.sum(weights * (error[:-1] - mu)**2)
    return np.sqrt(var)

def compute_metrics(df):
    t = df['t']
    e_norms = np.linalg.norm(np.stack(df['e']), axis=1)
    edot_norms = np.linalg.norm(np.stack(df['e_dot']), axis=1)

    return {
    'completion_time': t.iloc[-1] - t.iloc[0],
    'rmse_position': time_weighted_rmse(e_norms, t),
    'rmse_velocity': time_weighted_rmse(edot_norms, t),
    'rmse_velocity_alignment': velocity_alignment_rmse(df['x_dot'], df['x_dot_ref'], t),
    'std_position': time_weighted_std(e_norms, t),
    'std_velocity': time_weighted_std(edot_norms, t),
    'std_velocity_alignment': time_weighted_std(1 - np.sum(
        (np.stack(df['x_dot']) / (np.linalg.norm(np.stack(df['x_dot']), axis=1, keepdims=True) + 1e-8)) *
        (np.stack(df['x_dot_ref']) / (np.linalg.norm(np.stack(df['x_dot_ref']), axis=1, keepdims=True) + 1e-8)),
        axis=1
    ), t)
}

def trim_to_first_three_turns(df):
    ref = df['ref_going_forward'].astype(int).values
    flipped_to_true = []
    for i in range(1, len(ref)):
        if ref[i-1] == 0 and ref[i] == 1:
            flipped_to_true.append(i)
            if len(flipped_to_true) == 3:
                return df.iloc[:flipped_to_true[-1]]
    return df  # not enough turns → return full trial

def aggregate_user_metrics(user_id, baseline, trial_dfs):
    trimmed = [trim_to_first_three_turns(df) for df in trial_dfs]
    per_trial_metrics = [compute_metrics(df) for df in trimmed]
    # per_trial_metrics = [
    #     m for m in per_trial_metrics
    #     if all(np.isscalar(v) and not np.isnan(v) for v in m.values())
    # ]

    print(f"\n--- Trials for User {user_id} | Baseline: {baseline} ---")
    for idx, m in enumerate(per_trial_metrics):
        print(f"  Trial {idx + 1}:")
        print(f"    Completion Time: {m['completion_time']:.4f} s")
        print(f"    TWRMSE Position: {m['rmse_position']:.4f} m")
        print(f"    TWRMSE Velocity: {m['rmse_velocity']:.4f} m/s")

    if len(per_trial_metrics) == 0:
        return {k: np.nan for k in ['completion_time', 'rmse_position', 'rmse_velocity']}

    agg = {}
    for key in per_trial_metrics[0]:
        agg[key] = np.mean([m[key] for m in per_trial_metrics])
    return agg

##########################
### Load All Data ########
##########################

baseline_dirs = {
    'Language': '../data/experiments_no_videos/language_only',
    'Force': '../data/experiments_no_videos/physical_only',
    'Force+Language': '../data/experiments_no_videos/language_and_force'
}

metrics = [
    'completion_time',
    'rmse_position', 'rmse_velocity', 'rmse_velocity_alignment',
    'std_position', 'std_velocity', 'std_velocity_alignment'
]
results = {m: {b: [] for b in baseline_dirs} for m in metrics}

# Find common users across all baselines
all_users = None
for baseline, path in baseline_dirs.items():
    user_dfs = load_all_user_trials(path)
    user_ids = sorted(user_dfs.keys())
    if all_users is None:
        all_users = set(user_ids)
    else:
        all_users &= set(user_ids)
all_users = sorted(all_users)

# Compute aggregated metrics
for baseline, path in baseline_dirs.items():
    user_dfs = load_all_user_trials(path)
    for user_id in all_users:
        trial_dfs = user_dfs[user_id]
        agg_metrics = aggregate_user_metrics(user_id, baseline, trial_dfs)
        for m in metrics:
            results[m][baseline].append(agg_metrics[m])

###############################
### Debug Print Per User #####
###############################

print("\n=== AGGREGATED METRICS PER USER ===")
for m in metrics:
    print(f"\n--- {m.upper()} ---")
    print("User\tLanguage\tForce\tLanguage+Force")
    for idx, user_id in enumerate(all_users):
        lang = results[m]['Language'][idx]
        force = results[m]['Force'][idx]
        combined = results[m]['Force+Language'][idx]
        print(f"{user_id}\t{lang:.4f}\t\t{force:.4f}\t\t{combined:.4f}")

#################################
### Hypothesis Tests ############
#################################

def run_tests(metric_name):
    print(f"\n=== {metric_name.upper()} ===")
    F = np.array(results[metric_name]['Force'])
    L = np.array(results[metric_name]['Language'])
    FL = np.array(results[metric_name]['Force+Language'])

    def report(test_name, stat, p):
        print(f"{test_name}: stat = {stat:.4f}, p = {p:.4f}")

    t_stat_FL_F, p_FL_F = ttest_rel(FL, F)
    t_stat_FL_L, p_FL_L = ttest_rel(FL, L)
    p_FL_F_1sided = p_FL_F / 2 if t_stat_FL_F < 0 else 1 - p_FL_F / 2
    p_FL_L_1sided = p_FL_L / 2 if t_stat_FL_L < 0 else 1 - p_FL_L / 2

    report("Paired t-test: F+L < Force", t_stat_FL_F, p_FL_F_1sided)
    report("Paired t-test: F+L < Language", t_stat_FL_L, p_FL_L_1sided)

    w_stat_FL_F, w_p_FL_F = wilcoxon(FL, F, alternative='less')
    w_stat_FL_L, w_p_FL_L = wilcoxon(FL, L, alternative='less')

    report("Wilcoxon: F+L < Force", w_stat_FL_F, w_p_FL_F)
    report("Wilcoxon: F+L < Language", w_stat_FL_L, w_p_FL_L)

for m in metrics:
    run_tests(m)

########################################
### Publication-Quality Bar Plots #####
########################################

def plot_bar(metric_name, use_median=False, show="std"):
    """
    show = 'std', 'iqr', or 'ci'
    """
    data = results[metric_name]

    # Display-friendly labels
    baseline_labels = {
        'Language': 'Language',
        'Force': 'Force',
        'Force+Language': 'Language+Force'
    }

    ylabels = {
        'completion_time': 'Completion Time (s)',
        'rmse_position': 'Position TWRMSE (m)',
        'rmse_velocity': 'Velocity TWRMSE (m/s)',
        'rmse_velocity_alignment': 'Velocity Alignment TWRMSE',
        'std_position': 'Position TWSTD (m)',
        'std_velocity': 'Velocity TWSTD (m/s)',
        'std_velocity_alignment': 'Velocity Alignment TWSTD'
    }

    titles = {
        'completion_time': 'Mean Completion Time Across Users',
        'rmse_position': 'Mean Position TWRMSE Across Users',
        'rmse_velocity': 'Mean Velocity TWRMSE Across Users',
        'rmse_velocity_alignment': 'Mean Velocity Alignment TWRMSE Across Users',
        'std_position': 'Mean Position TWSTD Across Users',
        'std_velocity': 'Mean Velocity TWSTD Across Users',
        'std_velocity_alignment': 'Mean Velocity Alignment TWSTD Across Users'
    }

    labels = [baseline_labels[k] for k in data.keys()]
    values = [np.median(v) if use_median else np.mean(v) for v in data.values()]

    if show == 'std':
        errors = [np.std(v) for v in data.values()]
    elif show == 'iqr':
        errors = [np.percentile(v, 75) - np.percentile(v, 25) for v in data.values()]
    elif show == 'ci':
        from scipy.stats import sem, t
        errors = []
        for v in data.values():
            s = sem(v)
            h = s * t.ppf((1 + 0.95) / 2., len(v) - 1)
            errors.append(h)

    fig, ax = plt.subplots(figsize=(6, 5))
    bars = ax.bar(labels, values, yerr=errors, capsize=8, width=0.5, color='steelblue')

    for bar in bars:
        bar.set_edgecolor('black')

    ax.set_ylabel(ylabels[metric_name], fontsize=13)
    ax.set_title(titles[metric_name], fontsize=14)
    ax.tick_params(axis='x', labelsize=12)
    ax.tick_params(axis='y', labelsize=12)
    ax.spines[['top', 'right']].set_visible(False)

    # Optionally zoom in on the y-axis for better visual comparison
    min_val = min(values)
    max_val = max(values)
    range_val = max_val - min_val
    margin = 0.1 * range_val if range_val > 0 else 0.01  # avoid flat line

    # Set limits to zoom in around the range
    ax.set_ylim(min_val - margin, max_val + margin)

    plt.tight_layout()
    plt.show()

##################################
### Generate All Plots ###########
##################################

plot_bar('completion_time', use_median=False, show='std')
plot_bar('rmse_position', use_median=False, show='std')
plot_bar('std_position', use_median=False, show='std')
plot_bar('rmse_velocity', use_median=False, show='std')
plot_bar('std_velocity', use_median=False, show='std')
plot_bar('rmse_velocity_alignment', use_median=False, show='std')
plot_bar('std_velocity_alignment', use_median=False, show='std')

In [None]:
# import os
# import pickle
# import numpy as np
# import pandas as pd
# import matplotlib.pyplot as plt
# import seaborn as sns
# from collections import defaultdict
# import warnings

# warnings.filterwarnings("ignore", category=DeprecationWarning)

# sns.set(style='whitegrid', palette='deep', font_scale=1.2)

# BASELINE_DIRS = {
#     'Language': '../data/experiments_no_videos/language_only',
#     'Force': '../data/experiments_no_videos/physical_only',
#     'Force+Language': '../data/experiments_no_videos/language_and_force'
# }

# BEHAVIORAL_CONDITIONS = {
#     'Language': ['V=0', 'V=1'],
#     'Force': ['P=0', 'P=1'],
#     'Force+Language': ['P=0,V=0', 'P=0,V=1', 'P=1,V=0', 'P=1,V=1']
# }

# METRICS = [
#     'completion_time',
#     'rmse_position', 'rmse_velocity', 'rmse_velocity_alignment',
#     'std_position', 'std_velocity', 'std_velocity_alignment'
# ]

# def trim_to_first_three_turns(df):
#     ref = df['ref_going_forward'].astype(int).values
#     flipped = [i for i in range(1, len(ref)) if ref[i-1] == 0 and ref[i] == 1]
#     return df.iloc[:flipped[2]] if len(flipped) >= 3 else df

# def time_weighted_rmse(error, t):
#     error = np.asarray(error)
#     t = np.asarray(t)
#     dt = np.diff(t)
#     e_sq = error[:-1] ** 2
#     weighted_sum = np.sum(e_sq * dt)
#     return np.sqrt(weighted_sum / (t[-1] - t[0]))

# def time_weighted_std(error, t):
#     error = np.asarray(error)
#     t = np.asarray(t)
#     dt = np.diff(t)
#     weights = dt / np.sum(dt)
#     mu = np.sum(error[:-1] * weights)
#     return np.sqrt(np.sum(weights * (error[:-1] - mu)**2))

# def velocity_alignment_error(df):
#     x_dot = np.stack(df['x_dot'])
#     x_dot_ref = np.stack(df['x_dot_ref'])
#     eps = 1e-8
#     x_unit = x_dot / (np.linalg.norm(x_dot, axis=1, keepdims=True) + eps)
#     x_ref_unit = x_dot_ref / (np.linalg.norm(x_dot_ref, axis=1, keepdims=True) + eps)
#     return 1 - np.sum(x_unit * x_ref_unit, axis=1)

# def compute_metrics(df):
#     t = df['t']
#     e_norms = np.linalg.norm(np.stack(df['e']), axis=1)
#     edot_norms = np.linalg.norm(np.stack(df['e_dot']), axis=1)
#     align_err = velocity_alignment_error(df)
#     return {
#         'completion_time': t.iloc[-1] - t.iloc[0],
#         'rmse_position': time_weighted_rmse(e_norms, t),
#         'rmse_velocity': time_weighted_rmse(edot_norms, t),
#         'rmse_velocity_alignment': time_weighted_rmse(align_err, t),
#         'std_position': time_weighted_std(e_norms, t),
#         'std_velocity': time_weighted_std(edot_norms, t),
#         'std_velocity_alignment': time_weighted_std(align_err, t),
#     }

# def parse_key(fname):
#     parts = fname.split(',')
#     p, v = None, None
#     for part in parts:
#         if 'P=' in part:
#             p = int(part.split('=')[1].split('.')[0])
#         if 'V=' in part:
#             v = int(part.split('=')[1].split('.')[0])
#     return p, v

# def load_all_trials():
#     data = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
#     for baseline, path in BASELINE_DIRS.items():
#         for fname in os.listdir(path):
#             if '*' in fname or not fname.endswith('.pkl'):
#                 continue
#             user = fname.split('_')[0]
#             p, v = parse_key(fname)
#             with open(os.path.join(path, fname), 'rb') as f:
#                 df = pickle.load(f)
#                 key = []
#                 if p is not None:
#                     key.append(f'P={p}')
#                 if v is not None:
#                     key.append(f'V={v}')
#                 behavior_key = ','.join(key)
#                 data[baseline][behavior_key][user].append(df)
#     return data

# def aggregate_metrics(data):
#     results = defaultdict(lambda: defaultdict(list))
#     for baseline in BASELINE_DIRS:
#         for behavior in BEHAVIORAL_CONDITIONS[baseline]:
#             user_metrics = defaultdict(list)
#             for label in data[baseline]:
#                 if behavior in label:
#                     for user, dfs in data[baseline][label].items():
#                         for df in dfs:
#                             trimmed = trim_to_first_three_turns(df)
#                             m = compute_metrics(trimmed)
#                             for k, v in m.items():
#                                 user_metrics[k].append(v)
#             for k in METRICS:
#                 if user_metrics[k]:
#                     results[k][f'{baseline}\n{behavior}'] = user_metrics[k]
#     return results

# def get_bar_style(label):
#     if 'P=1' in label and 'V=1' in label:
#         return {'color': 'green', 'hatch': None}
#     elif 'P=0' in label and 'V=0' in label:
#         return {'color': 'red', 'hatch': None}
#     elif ('P=1' in label and 'V=0' in label) or ('P=0' in label and 'V=1' in label):
#         return {'color': 'white', 'edgecolor': 'black', 'hatch': 'xx'}
#     elif 'P=1' in label or 'V=1' in label:
#         return {'color': 'green', 'hatch': None}
#     elif 'P=0' in label or 'V=0' in label:
#         return {'color': 'red', 'hatch': None}
#     else:
#         return {'color': 'gray', 'hatch': None}

# def plot_metrics(results):
#     ylabels = {
#         'completion_time': 'Completion Time (s)',
#         'rmse_position': 'Position TWRMSE (m)',
#         'rmse_velocity': 'Velocity TWRMSE (m/s)',
#         'rmse_velocity_alignment': 'Velocity Alignment TWRMSE',
#         'std_position': 'Position TWSTD (m)',
#         'std_velocity': 'Velocity TWSTD (m/s)',
#         'std_velocity_alignment': 'Velocity Alignment TWSTD'
#     }

#     for metric in METRICS:
#         data = results[metric]
#         labels = list(data.keys())
#         means = [np.mean(data[label]) for label in labels]
#         stds = [np.std(data[label]) for label in labels]

#         fig, ax = plt.subplots(figsize=(10, 6))

#         for i, (label, mean, std) in enumerate(zip(labels, means, stds)):
#             style = get_bar_style(label)
#             ax.bar(
#                 label, mean, yerr=std, capsize=6, width=0.5,
#                 color=style.get('color', 'white'),
#                 edgecolor=style.get('edgecolor', 'black'),
#                 hatch=style.get('hatch', None)
#             )

#         ax.set_ylabel(ylabels[metric])
#         ax.set_title(f'Mean {ylabels[metric]} Across Users')
#         ax.tick_params(axis='x', labelsize=11, rotation=45)
#         ax.tick_params(axis='y', labelsize=12)
#         ax.spines[['top', 'right']].set_visible(False)
#         plt.tight_layout()
#         plt.show()

# # To use:
# trial_data = load_all_trials()
# aggregated = aggregate_metrics(trial_data)
# plot_metrics(aggregated)