In [None]:
# Import libraries
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os
from scipy.stats import linregress
from matplotlib.ticker import FormatStrFormatter
from matplotlib import colors as mcolors
from matplotlib.lines import Line2D
import matplotlib.patches as mpatches

In [None]:
# Helper functions
def normalize_to_one(df, columns):
    normalized_df = df.copy()
    for column in columns:
        normalized_df[column] = normalized_df[column].astype(float)
    for index, row in df.iterrows():
        row_sum = row[columns].sum()
        for column in columns:
            normalized_df.at[index, column] = row[column] / row_sum if row_sum != 0 else 0
    return normalized_df

def read_and_adjust_sheet(filename, sheet_name):
    total_rows = 35
    rows_to_skip = [1, 2]
    nrows = total_rows - len(rows_to_skip) - 6
    df = pd.read_excel(filename, sheet_name=sheet_name, skiprows=rows_to_skip, nrows=nrows)
    df.index = range(1, len(df) + 1)
    return df

def match_two_parts(condition_x, condition_y):
    s_x, a_x, b_x = int(condition_x[1]), int(condition_x[3]), int(condition_x[5])
    s_y, a_y, b_y = int(condition_y[1]), int(condition_y[3]), int(condition_y[5])
    return ((s_x == s_y) + (a_x == a_y) + (b_x == b_y)) >= 2

def to_meters(series_mm):
    return series_mm / 1000.0

def ticks_include_ends(vmin, vmax, n=6):
    """Evenly spaced ticks including both ends."""
    if np.isclose(vmin, vmax):
        vmax = vmin + 1.0
    return np.linspace(vmin, vmax, n)

def label_for(var):
    """Axis-label mapping (adds degree symbol for Head angle)."""
    if var == 'Head angle':
        return 'Head angle (°)'
    return var

In [None]:
# Setup
base_path = "/Users/jordanfeldman/Desktop/Research/Subject data"

bmh_files = [
    'data_BMH01.xlsx',
    'data_BMH02.xlsx',
    'data_BMH21.xlsx',
    'data_BMH06.xlsx',
    'data_BMH07.xlsx',
    'data_BMH08.xlsx',
    'data_BMH09.xlsx',
    'data_BMH10.xlsx',
    'data_BMH13.xlsx',
    'data_BMH19.xlsx',
    'data_BMH20.xlsx',
    'data_BMH17.xlsx'
]

filepath_survey = os.path.join(base_path, "Subjective_Responses.xlsx")

walking_speed_mapping = {'Slow': 0, 'Medium': 1, 'Fast': 2}
accuracy_mapping = {'Low': 0, 'Medium': 1, 'High': 2}
balance_mapping = {'Low': 0, 'Medium': 1, 'High': 2}

pairplot_columns = [
    'Mean Error Straights', 'Mean Width Straights (mm)', 'Straights Width Variability (mm)',
    'Mean Length Straights (mm)', 'Straights Length Variability (mm)',
    'Average Speed (m/s)', 'EE', 'Head Angle (deg)','Walking Speed', 'Accuracy', 'Balance'
]
targets_columns = ['Balance', 'Foot Placement', 'Walking Speed']


actual_points = []
predicted_points = []
s_vals, a_vals, b_vals = [], [], []
combined_data = pd.DataFrame()
combined_targets = pd.DataFrame()

for bmh_file in bmh_files:
    data_path = os.path.join(base_path, bmh_file)
    df_data = pd.read_excel(data_path)
    df_data['Walking Speed'] = df_data['Walking Speed'].map(walking_speed_mapping).astype(int)
    df_data['Accuracy'] = df_data['Accuracy'].map(accuracy_mapping).astype(int)
    df_data['Balance'] = df_data['Balance'].map(balance_mapping).astype(int)
    df_data['Condition'] = df_data.apply(lambda row: f"s{row['Walking Speed']}a{row['Accuracy']}b{row['Balance']}", axis=1)
    combined_data = pd.concat([combined_data, df_data], ignore_index=True)

In [None]:
# Pairplots for all trials
pairplot = combined_data[pairplot_columns]
predictors_columns = pairplot_columns[:-3]  
hue_cols = ['Walking Speed', 'Accuracy', 'Balance']
palettes = ['Set1', 'Set2', 'Set3']
margin_pct = .05
limits = {}
for col in predictors_columns:
    min_val = pairplot[col].min()
    max_val = pairplot[col].max()
    range_val = max_val - min_val
    margin = margin_pct * range_val
    limits[col] = (min_val - margin, max_val + margin)

for cat, palette in zip(hue_cols, palettes):
    g = sns.pairplot(pairplot, vars=predictors_columns, hue=cat, palette=palette,corner=True)
        # Set axis limits

    for ax in g.axes.flat:
        if ax is not None:
            xlabel = ax.get_xlabel()
            ylabel = ax.get_ylabel()
            if xlabel in limits:
                ax.set_xlim(limits[xlabel])
            if ylabel in limits:
                ax.set_ylim(limits[ylabel])

    plt.suptitle(f'All trials, hue = {cat}', y=1.02)
    plt.tight_layout()
    plt.show()

In [None]:
# Pairplots for averaged conditions
pairplot = combined_data[pairplot_columns]
predictors_columns = pairplot_columns[:-3] 
margin_pct = .05

avg_combined_data = combined_data.groupby('Condition')[pairplot_columns].mean()

avg_pairplot = avg_combined_data[pairplot_columns]
hue_cols = ['Walking Speed', 'Accuracy', 'Balance']
palettes = ['Set1', 'Set2', 'Set3']
avg_limits = {}
for col in predictors_columns:
    avg_min_val = avg_pairplot[col].min()
    avg_max_val = avg_pairplot[col].max()
    avg_range_val = avg_max_val - avg_min_val
    avg_margin = margin_pct * avg_range_val
    avg_limits[col] = (avg_min_val - avg_margin, avg_max_val + avg_margin)

for cat, palette in zip(hue_cols, palettes):
    g = sns.pairplot(avg_pairplot, vars=predictors_columns, hue=cat, palette=palette, corner=True, plot_kws={'s': 40})

    for i, row_var in enumerate(predictors_columns):
        for j, col_var in enumerate(predictors_columns[:i]):  # Only lower triangle
            ax = g.axes[i, j]
            if ax is not None:
                x = avg_pairplot[col_var]
                y = avg_pairplot[row_var]
                sns.regplot(x=x, y=y, ax=ax, scatter=False, line_kws={'color': 'black', 'linewidth': 1.5})
                slope, intercept, r_value, p_value, std_err = linregress(x, y)
                r_squared = r_value**2
                ax.text(0.05, 0.9, f"$R^2$ = {r_squared:.2f}", transform=ax.transAxes, fontsize=9, verticalalignment='top')
            # Set limits
            if col_var in avg_limits:
                ax.set_xlim(avg_limits[col_var])
            if row_var in avg_limits:
                ax.set_ylim(avg_limits[row_var])

    plt.suptitle(f'27 trials, hue = {cat}', y=1.02)
    plt.tight_layout()
    plt.show()

In [None]:
# Pairplot for supplemental figure, color/shape/opacity scheme
sns.set_style("ticks")

color_map = {0: '#F3776E', 1: '#F7C21F', 2: '#20BDC3'}
marker_map = {0: 'o',       1: '^',       2: 's'}
alpha_map  = {0: 0.3,       1: 0.6,       2: 1.0}

predictors_columns = pairplot_columns#[:-3]  

def _mode_or_first(x):
    x = pd.Series(x)
    m = x.mode()
    return m.iloc[0] if not m.empty else x.iloc[0]

agg = {c: 'mean' for c in predictors_columns}
for k in ['s','a','b']:
    if k in combined_data.columns:
        agg[k] = _mode_or_first

avg = combined_data.groupby('Condition', as_index=False).agg(agg)
print("Rows in pairplot (conditions):", len(avg))

avg['s'] = avg['Walking Speed']
avg['a'] = avg['Accuracy']
avg['b'] = avg['Balance']

df_plot = avg[predictors_columns + ['s','a','b']].copy()

# ----- unit conversion (mm -> m) -----
def is_mm_like(name):
    n = name.lower()
    return any(s in n for s in ('(mm', ' variability (mm', 'mean error straights'))

for col in predictors_columns:
    if is_mm_like(col):
        df_plot[col] = df_plot[col] / 1000.0


def pretty_label(name: str) -> str:
    n = name.lower()
    if 'straights length variability' in n:
        return 'Step length\nvariability (m)'
    if 'mean length straights' in n:
        return 'Step length (m)'
    if 'mean width straights' in n:
        return 'Step width (m)'
    if 'straights width variability' in n:
        return 'Step width\nvariability (m)'
    if 'mean error straights' in n:
        return 'Foot placement\nerror (m)'
    if 'average speed' in n:
        return 'Walking\nspeed (m/s)'
    if name == 'EE':
        return 'Energy\nexpenditure (W)'
    # generic mm→m rename
    return name.replace('(mm)', '(m)')

def tick_fmt(name: str) -> str:
    n = name.lower()
    if 'average speed' in n or '(m/s' in n:
        return '%.1f'
    if name == 'EE' or n.strip() == 'ee':
        return '%.0f'
    return '%.2f'

# Precompute limits per variable so all subplots using that var share limits
limits = {}
for v in predictors_columns:
    vals = df_plot[v].to_numpy()
    vals = vals[np.isfinite(vals)]
    if len(vals):
        pad = 0.05*(np.max(vals)-np.min(vals)) or 0.05*np.maximum(1, np.abs(np.max(vals)))
        limits[v] = (np.min(vals)-pad, np.max(vals)+pad)

# ----- plotting helpers -----
def scatter_with_scheme(frame):
    def _fn(x, y, **kws):
        ax = plt.gca()
        rows = frame.loc[x.index]
        x = np.asarray(x); y = np.asarray(y)
        mask = np.isfinite(x) & np.isfinite(y)
        x = x[mask]; y = y[mask]; rows = rows[mask]

        for a_val in sorted(marker_map.keys()):
            sel = rows['a'].astype(int).values == a_val
            if not np.any(sel): 
                continue
            rgba = [mcolors.to_rgba(color_map[int(sv)], alpha=alpha_map[int(bv)])
                    for sv, bv in zip(rows.loc[sel,'s'].astype(int), rows.loc[sel,'b'].astype(int))]
            ax.scatter(x[sel], y[sel], s=80, marker=marker_map[a_val], c=rgba, edgecolors='none')

        if len(x) >= 2:
            slope, intercept, r, p, _ = linregress(x, y)
            xf = np.linspace(np.nanmin(x), np.nanmax(x), 200)
            yf = slope*xf + intercept
            ax.plot(xf, yf, color='black', linestyle='--', linewidth=1.2)
            ax.text(0.05, 0.92, f"$R^2$ = {r**2:.2f}\n$r$ = {r:.2f}", transform=ax.transAxes,
                    fontsize=16, va='top', ha='left', color='black')
    return _fn

def style_axis(ax):
    ax.spines[['right','top']].set_visible(False)
    ax.tick_params(axis='both', which='both', direction='out', length=5, width=1, labelsize=16)
predictors_columns = pairplot_columns[:-3]
g = sns.PairGrid(df_plot, vars=predictors_columns, corner=True)
g.map_lower(scatter_with_scheme(df_plot))

# format axes: ticks everywhere, labels only on outer axes
nvars = len(predictors_columns)
for i, row_var in enumerate(predictors_columns):
    # hide diagonal axes completely (no histograms)
    axd = g.axes[i, i]
    if axd is not None:
        axd.set_visible(False)

    for j, col_var in enumerate(predictors_columns[:i]):
        ax = g.axes[i, j]
        if ax is None: 
            continue
        style_axis(ax)

        # shared limits per variable
        if col_var in limits: ax.set_xlim(*limits[col_var])
        if row_var in limits: ax.set_ylim(*limits[row_var])

        # tick formatters
        ax.xaxis.set_major_formatter(FormatStrFormatter(tick_fmt(col_var)))
        ax.yaxis.set_major_formatter(FormatStrFormatter(tick_fmt(row_var)))

        # labels only on bottom row / left column
        ax.set_xlabel(pretty_label(col_var) if i == nvars-1 else '', fontsize=20)
        ax.set_ylabel(pretty_label(row_var) if j == 0 else '', fontsize=20)

plt.tight_layout()
plt.savefig('pairplot_sup_fig_sizing_redo.svg', bbox_inches='tight')
plt.show()

In [None]:
# Density plots and pairwise scatter plots with regression lines for Fig. 2

# ---------- Aggregate, convert, and bin ----------
avg_combined_data = combined_data.groupby('Condition')[
    ['Head Angle (deg)', 'Straights Length Variability (mm)', 'Mean Error Straights', 'Accuracy', 'Walking Speed', 'Balance']
].mean().reset_index()

# Create meter-based columns (used everywhere in plots)
avg_combined_data['Step length variability (m)'] = to_meters(avg_combined_data['Straights Length Variability (mm)'])
avg_combined_data['Foot placement error (m)']   = to_meters(avg_combined_data['Mean Error Straights'])
avg_combined_data['Head angle'] = avg_combined_data['Head Angle (deg)']
# Bins
avg_combined_data['AccuracyBin'] = pd.qcut(avg_combined_data['Accuracy'], q=3, labels=['Low', 'Medium', 'High'])
avg_combined_data['BalanceBin']  = pd.qcut(avg_combined_data['Balance'],  q=3, labels=['Low', 'Medium', 'High'])
avg_combined_data['SpeedCat']    = pd.qcut(avg_combined_data['Walking Speed'], q=3, labels=[0, 1, 2])  # drives color

# ---------- Encodings ----------
# Scatter colors (Foot placement prompt categories in this order)
color_map   = {0: '#20BDC3', 1: '#F7C21F', 2: '#F3776E'}  
shape_map   = {'Low': 'o', 'Medium': '^', 'High': 's'}
opacity_map = {'Low': 0.3,   'Medium': 0.6, 'High': 1.0}

# KDE colors by AccuracyBin
kde_palette = {'Low': '#92278f', 'Medium': '#cc4899', 'High': '#234192'}

# ---------- Legends ----------
# Scatter legend (Foot placement prompt)
prompt_labels  = ['Ignore', 'Near', 'Accurate']
prompt_colors  = [color_map[k] for k in [0, 1, 2]]
prompt_handles = [
    Line2D([0], [0], marker='o', linestyle='None',
           markerfacecolor=c, markeredgecolor=c, markersize=10, label=lab)
    for c, lab in zip(prompt_colors, prompt_labels)
]

# KDE legend (rectangles reflecting KDE colors)
kde_labels  = ['Low', 'Medium', 'High']
kde_handles = [
    mpatches.Patch(facecolor=kde_palette[lab], edgecolor='none', label=lab)
    for lab in kde_labels
]

# ---------- Figure layout ----------
# KDE row (left→right): Head angle, Foot placement error (m), Step length variability (m)
kde_vars = ['Head angle', 'Foot placement error (m)', 'Step length variability (m)']

# KDE x-axis limits
kde_xlim = {
    'Head angle': (10, 60),
    'Foot placement error (m)': (0.05, 0.225),
    'Step length variability (m)': (0.0, 0.3)
}
# KDE y-axis max values
kde_ymax = {
    'Head angle': 0.14,
    'Foot placement error (m)': 50,
    'Step length variability (m)': 60
}

# Scatter row order 
scatter_pairs = [
    ('Head angle', 'Foot placement error (m)'),
    ('Head angle', 'Step length variability (m)'),
    ('Step length variability (m)', 'Foot placement error (m)')
]
# Scatter x-axis limits
scatter_xlim = [
    (14, 54),
    (14, 54),
    (0.025, 0.25)
]
# Scatter y-axis limits
scatter_ylim = [
    (0, 0.25),
    (0, 0.25),
    (0, 0.25)
]

fig, axs = plt.subplots(2, 3, figsize=(15, 10))

# ---------- Row 1: KDEs ----------
for i, var in enumerate(kde_vars):
    ax = axs[0, i]
    for acc in ['Low', 'Medium', 'High']:
        sns.kdeplot(
            data=avg_combined_data[avg_combined_data['AccuracyBin'] == acc],
            x=var, ax=ax,
            color=kde_palette[acc],
            linewidth=2, fill=True, alpha=0.3,
            label=None
        )
    xmin, xmax = kde_xlim[var]
    ax.set_xlim(xmin, xmax)
    ax.set_xticks(ticks_include_ends(xmin, xmax, n=5))

    ymax = kde_ymax[var]
    ax.set_ylim(0.0, ymax)
    ax.set_yticks([0, ymax])

    if ymax >= 10:
        ax.set_yticklabels(['0', f'{int(ymax)}'])
    else:
        ax.set_yticklabels(['0', f"{ymax:.3f}".rstrip('0').rstrip('.')])


    ax.set_xlabel(label_for(var), fontsize=20)
    ax.set_ylabel("Kernel density estimate", fontsize=20)

    if var == 'Head angle':
        ax.xaxis.set_major_formatter(FormatStrFormatter('%.0f'))
    else:
        ax.xaxis.set_major_formatter(FormatStrFormatter('%.3f'))
    ax.tick_params(labelsize=16)
    ax.spines[['top', 'right']].set_visible(False)

# KDE legend on the rightmost KDE
axs[0, 2].legend(handles=kde_handles, title="Foot placement prompt",
                 frameon=False, fontsize=16, title_fontsize=16, loc='upper right')

# ---------- Row 2: Scatter plots ----------
for i, (x_var, y_var) in enumerate(scatter_pairs):
    ax = axs[1, i]

    # Scatter points (color = SpeedCat; marker = AccuracyBin; alpha = BalanceBin)
    for _, row in avg_combined_data.iterrows():
        c = color_map[row['SpeedCat']]
        ax.scatter(
            row[x_var], row[y_var],
            color=c, marker=shape_map[row['AccuracyBin']],
            alpha=opacity_map[row['BalanceBin']],
            edgecolor=c, linewidth=0.8, s=120
        )

    # Regression line
    slope, intercept, r, p, stderr = linregress(avg_combined_data[x_var], avg_combined_data[y_var])
    sns.regplot(data=avg_combined_data, x=x_var, y=y_var, scatter=False, ax=ax, color='black')
    r2 = r ** 2

    x_min, x_max = scatter_xlim[i]
    y_min, y_max = scatter_ylim[i]
    ax.set_xlim(x_min, x_max)
    ax.set_ylim(y_min, y_max)
    ax.set_xticks(ticks_include_ends(x_min, x_max, n=5))
    ax.set_yticks(ticks_include_ends(y_min, y_max, n=6))

    if x_var == 'Head angle':
        ax.xaxis.set_major_formatter(FormatStrFormatter('%.0f'))
    else:
        ax.xaxis.set_major_formatter(FormatStrFormatter('%.3f'))
    if '(m)' in y_var:
        ax.yaxis.set_major_formatter(FormatStrFormatter('%.3f'))
    else:
        ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))

    # Labels
    ax.set_xlabel(label_for(x_var), fontsize=20)
    ax.set_ylabel(label_for(y_var), fontsize=20)

    x_text = x_min + 0.50 * (x_max - x_min)
    y_text = slope * x_text + intercept
    ax.annotate(f"$R^2$ = {r2:.2f} \n $r$ = {r:.2f}", xy=(x_text, y_text), xytext=(5, 5),
                textcoords='offset points', fontsize=16)

    ax.tick_params(labelsize=16)
    ax.spines[['top', 'right']].set_visible(False)

plt.tight_layout()
plt.savefig('density_and_pair_redo.svg')
plt.show()
