In [1]:
# import opensim as osim
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
from OA_utils.data_utils import *
import random
import pickle


Load data from file

In [2]:
OA_data_dir = "C:\\Users\\bakel\\Desktop\\GRFMuscleModel\\Old_Young_Walking_Data\\"
with open(OA_data_dir + 'OA_resampled_compiled_segments', 'rb') as f:
    OA_segs = pickle.load(f)
YA_data_dir = "C:\\Users\\bakel\\Desktop\\GRFMuscleModel\\Old_Young_Walking_Data\\"
with open(YA_data_dir + 'Y_resampled_compiled_segments', 'rb') as f:
    YA_segs = pickle.load(f)

Visualize amount of data present for each subject

In [3]:
overall_OA_segs = 0
for subj, data in OA_segs.items():
    if subj == "time_resampled" or subj =='OA11':
        continue 
    
    # pick any signal key that exists in every subject
    num_segments = len(data["grf_y"])
    overall_OA_segs += num_segments
    print(f"{subj}: {num_segments} segments")
print(overall_OA_segs, 'OA segments overall')

OA1: 42 segments
OA2: 43 segments
OA4: 41 segments
OA5: 16 segments
OA7: 45 segments
OA8: 44 segments
OA9: 42 segments
OA10: 39 segments
OA12: 32 segments
OA13: 44 segments
OA14: 43 segments
OA17: 21 segments
OA18: 39 segments
OA19: 24 segments
OA20: 22 segments
OA22: 44 segments
OA23: 44 segments
OA24: 38 segments
OA25: 44 segments
707 OA segments overall


In [5]:
overall_YA_segs = 0
for subj, data in YA_segs.items():
    if subj == "time_resampled" or subj =='OA11':
        continue 
    
    # pick any signal key that exists in every subject
    num_segments = len(data["grf_y"])
    overall_YA_segs += num_segments
    print(f"{subj}: {num_segments} segments")
print(overall_YA_segs, 'YA segments overall')

Y1: 39 segments
Y2: 9 segments
Y4: 38 segments
Y5: 37 segments
Y6: 30 segments
Y7: 37 segments
Y8: 38 segments
Y9: 32 segments
Y10: 28 segments
Y11: 39 segments
Y12: 37 segments
Y13: 45 segments
Y14: 41 segments
Y15: 26 segments
Y16: 35 segments
Y17: 37 segments
Y18: 39 segments
Y19: 43 segments
Y20: 22 segments
Y21: 28 segments
Y22: 45 segments
725 YA segments overall


In [16]:
total_segs = overall_YA_segs + overall_OA_segs
print(total_segs, 'segments total')

1432 segments total


Split data by shuffling subjects

In [17]:
def count_total_segments(split_dict, key='grf_y'):
    return sum(len(split_dict[subj][key]) for subj in split_dict)

best_seed = None
best_err = float("inf")
best_counts = []
best_split = None
target_train, target_val, target_test = 0.8, 0.1, 0.1

for seed in range(40): 
    OA_subjects = [s for s in OA_segs.keys() if s != "time_resampled"]
    OA_subjects.remove('OA11')
    random.seed(seed)
    OA_shuf = OA_subjects.copy()
    random.shuffle(OA_shuf)

    N = len(OA_shuf)
    train_end =  int(N * 0.8)
    val_end = int(N * (0.9))

    OA_train = {s: OA_segs[s] for s in OA_shuf[:train_end]}
    OA_val   = {s: OA_segs[s] for s in OA_shuf[train_end:val_end]}
    OA_test  = {s: OA_segs[s] for s in OA_shuf[val_end:]}

    YA_subjects = [s for s in YA_segs.keys() if s != "time_resampled"]
    random.seed(seed)
    YA_shuf = YA_subjects.copy()
    random.shuffle(YA_shuf)

    N = len(YA_shuf)
    train_end = int(N * 0.8)
    val_end   = int(N * 0.9)

    YA_train = {s: YA_segs[s] for s in YA_shuf[:train_end]}
    YA_val   = {s: YA_segs[s] for s in YA_shuf[train_end:val_end]}
    YA_test  = {s: YA_segs[s] for s in YA_shuf[val_end:]}

    
    train_segs = count_total_segments(OA_train) + count_total_segments(YA_train)
    val_segs   = count_total_segments(OA_val)   + count_total_segments(YA_val)
    test_segs  = count_total_segments(OA_test)  + count_total_segments(YA_test)
    
    train_p = train_segs / total_segs
    val_p   = val_segs / total_segs
    test_p  = test_segs / total_segs
        # error score: L1 distance to target ratios (simple + robust)
    err = abs(train_p - target_train) + abs(val_p - target_val) + abs(test_p - target_test)

    if err < best_err:
        best_err = err
        best_seed = seed
        best_counts = (train_segs, val_segs, test_segs)
        best_split = {
            "OA_train_subjs": list(OA_train.keys()),
            "OA_val_subjs": list(OA_val.keys()),
            "OA_test_subjs": list(OA_test.keys()),
            "YA_train_subjs": list(YA_train.keys()),
            "YA_val_subjs": list(YA_val.keys()),
            "YA_test_subjs": list(YA_test.keys()),
        }

    print("Best seed:", best_seed)
    print("Counts (train/val/test):", best_counts)
    print("Ratios:",
        best_counts[0]/total_segs,
        best_counts[1]/total_segs,
        best_counts[2]/total_segs)
    print("Error:", best_err)





time_resampled = OA_segs["time_resampled"]

Best seed: 0
Counts (train/val/test): (1151, 142, 139)
Ratios: 0.8037709497206704 0.09916201117318436 0.09706703910614525
Error: 0.007541899441340774
Best seed: 0
Counts (train/val/test): (1151, 142, 139)
Ratios: 0.8037709497206704 0.09916201117318436 0.09706703910614525
Error: 0.007541899441340774
Best seed: 0
Counts (train/val/test): (1151, 142, 139)
Ratios: 0.8037709497206704 0.09916201117318436 0.09706703910614525
Error: 0.007541899441340774
Best seed: 0
Counts (train/val/test): (1151, 142, 139)
Ratios: 0.8037709497206704 0.09916201117318436 0.09706703910614525
Error: 0.007541899441340774
Best seed: 0
Counts (train/val/test): (1151, 142, 139)
Ratios: 0.8037709497206704 0.09916201117318436 0.09706703910614525
Error: 0.007541899441340774
Best seed: 0
Counts (train/val/test): (1151, 142, 139)
Ratios: 0.8037709497206704 0.09916201117318436 0.09706703910614525
Error: 0.007541899441340774
Best seed: 0
Counts (train/val/test): (1151, 142, 139)
Ratios: 0.8037709497206704 0.0991620111731843

In [18]:
def count_total_segments(split_dict, key='grf_y'):
    return sum(len(split_dict[subj][key]) for subj in split_dict)

OA_train_total = count_total_segments(OA_train)
YA_train_total = count_total_segments(YA_train)

OA_val_total = count_total_segments(OA_val)
YA_val_total = count_total_segments(OA_val)

OA_test_total = count_total_segments(OA_test)
YA_test_total = count_total_segments(OA_test)

print("Train (OA,YA):", OA_train_total, ',', YA_train_total)
print("Val (OA,YA):", OA_val_total,',', YA_val_total)
print("Test (OA,YA):", OA_test_total, ',',  YA_test_total)

Train (OA,YA): 552 , 534
Val (OA,YA): 81 , 81
Test (OA,YA): 74 , 74


In [19]:
def dict_to_array(split_dict):
    packed_segments = []
    for subj, data in split_dict.items():
        num_segs = len(data['grf_x']) 
        for i in range(num_segs):
            sample = np.column_stack([
                data['grf_x'][i],
                data['grf_y'][i],
                data['grf_z'][i],
                data['tibpost'][i],
                data['tibant'][i],
                data['edl'][i],
                data['ehl'][i],
                data['fdl'][i],
                data['fhl'][i],
                data['perbrev'][i],
                data['perlong'][i],
                data['achilles'][i],
            ])
            packed_segments.append(sample)
    return np.array(packed_segments)
    

In [20]:
def get_subject_ranges(split_dict):
    """
    Given the dict used to build a split (e.g., OA_train_data),
    return a dict mapping subject â†’ (start_idx, end_idx).
    """
    ranges = {}
    start_idx = 0
    
    for subj, data in split_dict.items():
        n = len(data['grf_x'])   # number of segments for this subject
        end_idx = start_idx + n
        ranges[subj] = (start_idx, end_idx)
        start_idx = end_idx
    
    return ranges

In [21]:
OA_test_ranges = get_subject_ranges(OA_test)
print(OA_test_ranges)

{'OA12': (0, 32), 'OA9': (32, 74)}


In [22]:
OA_train_arr = dict_to_array(OA_train)
OA_val_arr = dict_to_array(OA_val)
OA_test_arr = dict_to_array(OA_test)

YA_train_arr = dict_to_array(YA_train)
YA_val_arr = dict_to_array(YA_val)
YA_test_arr = dict_to_array(YA_test)

train_arr = np.concatenate([OA_train_arr, YA_train_arr], axis = 0)
val_arr = np.concatenate([OA_val_arr, YA_val_arr], axis = 0)
test_arr = np.concatenate([OA_test_arr, YA_test_arr], axis = 0)

print("Train:", train_arr.shape)
print("Val:", val_arr.shape)
print("Test:", test_arr.shape)

Train: (1086, 100, 12)
Val: (165, 100, 12)
Test: (181, 100, 12)


In [23]:
X_train, y_train = train_arr[:, :, :3], train_arr[:, :, 3:]
X_val, y_val = val_arr[:, :, :3], val_arr[:, :, 3:]
X_test, y_test = test_arr[:, :, :3], test_arr[:, :, 3:]

print(f"X_train shape: {X_train.shape}")
print(f"y_train shape: {y_train.shape}")
print(f"X_val shape: {X_val.shape}")
print(f"y_val shape: {y_val.shape}")
print(f"X_test shape: {X_test.shape}")
print(f"y_test shape: {y_test.shape}")

X_train shape: (1086, 100, 3)
y_train shape: (1086, 100, 9)
X_val shape: (165, 100, 3)
y_val shape: (165, 100, 9)
X_test shape: (181, 100, 3)
y_test shape: (181, 100, 9)


In [24]:
np.savez(OA_data_dir + 'Silder_mixed_train_data.npz', X_train=X_train, y_train=y_train)
np.savez(OA_data_dir + 'Silder_mixed_val_data.npz', X_val=X_val, y_val=y_val)
np.savez(OA_data_dir + 'Silder_mixed_test_data.npz', X_test=X_test, y_test=y_test)