In [12]:
# import opensim as osim
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
from OA_utils.data_utils import *
import random
import pickle

Load data from file

In [15]:
OA_data_dir = "C:\\Users\\bakel\\Desktop\\GRFMuscleModel\\Old_Young_Walking_Data\\"
with open(OA_data_dir + 'resampled_compiled_segments', 'rb') as f:
    OA_segs = pickle.load(f)
YA_data_dir = "C:\\Users\\bakel\\Desktop\\GRFMuscleModel\\data\\"
with open(YA_data_dir + 'resampled_segments.pkl', 'rb') as f:
    YA_segs = pickle.load(f)

Visualize amount of data present for each subject

In [18]:
overall_OA_segs = 0
for subj, data in OA_segs.items():
    if subj == "time_resampled" or subj =='OA11':
        continue 
    
    # pick any signal key that exists in every subject
    num_segments = len(data["grf_y"])
    overall_OA_segs += num_segments
    print(f"{subj}: {num_segments} segments")
print(overall_OA_segs, 'segments overall')

OA1: 42 segments
OA2: 43 segments
OA4: 41 segments
OA5: 16 segments
OA7: 45 segments
OA8: 44 segments
OA9: 42 segments
OA10: 39 segments
OA12: 32 segments
OA13: 44 segments
OA14: 43 segments
OA17: 21 segments
OA18: 39 segments
OA19: 24 segments
OA20: 22 segments
OA22: 44 segments
OA23: 44 segments
OA24: 38 segments
OA25: 44 segments
707 segments overall


In [19]:
overall_YA_segs = 0
for subj, data in YA_segs.items():
    if subj == "time_resampled":
        continue 
    num_segments = len(data["grf_y"])
    overall_YA_segs += num_segments
    print(f"{subj}: {num_segments} segments")
print(overall_YA_segs, 'segments overall')

Subject1: 1357 segments
Subject2: 1517 segments
Subject3: 1526 segments
Subject4: 1338 segments
Subject5: 1482 segments
Subject6: 1544 segments
Subject7: 1539 segments
Subject8: 1581 segments
Subject9: 1450 segments
13334 segments overall


Split data by shuffling subjects

In [39]:
OA_subjects = [s for s in OA_segs.keys() if s != "time_resampled"]
OA_subjects.remove('OA11')
random.seed(42)
OA_subjects_shuffled = OA_subjects.copy()
random.shuffle(OA_subjects_shuffled)
N = len(OA_subjects_shuffled)
train_end =  int(N * 0.8)
val_end = int(N * (0.9))
OA_train_subjs = OA_subjects_shuffled[:train_end]
OA_val_subjs = OA_subjects_shuffled[train_end:val_end]
OA_test_subjs = OA_subjects_shuffled[val_end:]
OA_train_data = {s: OA_segs[s] for s in OA_train_subjs}
OA_val_data   = {s: OA_segs[s] for s in OA_val_subjs}
OA_test_data  = {s: OA_segs[s] for s in OA_test_subjs}

YA_subjects = [s for s in YA_segs.keys() if s != 'time_resampled']
random.seed(42)
YA_subjects_shuffled = YA_subjects.copy()
random.shuffle(YA_subjects_shuffled)
N = len(YA_subjects_shuffled)
train_end =  int(N * 0.8)
val_end = int(N * (0.9))
YA_train_subjs = YA_subjects_shuffled[:train_end]
YA_val_subjs = YA_subjects_shuffled[train_end:val_end]
YA_test_subjs = YA_subjects_shuffled[val_end:]
YA_train_data = {s: YA_segs[s] for s in YA_train_subjs}
YA_val_data   = {s: YA_segs[s] for s in YA_val_subjs}
YA_test_data  = {s: YA_segs[s] for s in YA_test_subjs}

time_resampled = OA_segs["time_resampled"]

In [40]:
print(f'train: {OA_train_data.keys()}')
print(f'validation: {OA_val_data.keys()}')
print(f'test: {OA_test_data.keys()}')

train: dict_keys(['OA20', 'OA19', 'OA7', 'OA13', 'OA8', 'OA22', 'OA24', 'OA9', 'OA18', 'OA23', 'OA14', 'OA2', 'OA17', 'OA4', 'OA25'])
validation: dict_keys(['OA10', 'OA12'])
test: dict_keys(['OA1', 'OA5'])


In [41]:
def count_total_segments(split_dict, key='grf_y'):
    return sum(len(split_dict[subj][key]) for subj in split_dict)

OA_train_total = count_total_segments(OA_train_data)
YA_train_total = count_total_segments(YA_train_data)

OA_val_total = count_total_segments(OA_val_data)
YA_val_total = count_total_segments(YA_val_data)

OA_test_total = count_total_segments(OA_test_data)
YA_test_total = count_total_segments(YA_test_data)

print("Train (OA,YA):", OA_train_total, ',', YA_train_total)
print("Val (OA,YA):", OA_val_total,',', YA_val_total)
print("Test (OA,YA):", OA_test_total, ',',  YA_test_total)

Train (OA,YA): 578 , 10460
Val (OA,YA): 71 , 1357
Test (OA,YA): 58 , 1517


In [42]:
random.seed(42)

def downsample_split(split_dict, target_ratio):
    downsampled = {}

    for subj, subj_data in split_dict.items():
        downsampled[subj] = {}

        for key, seg_list in subj_data.items():
            n = len(seg_list)
            k = max(1, int(n * target_ratio))   # keep at least 1 segment
            
            if n == 0:
                downsampled[subj][key] = []
                continue

            # randomly choose k segments
            keep_indices = random.sample(range(n), k)
            downsampled[subj][key] = [seg_list[i] for i in keep_indices]

    return downsampled

In [43]:
train_ratio = OA_train_total / YA_train_total
val_ratio   = OA_val_total / YA_val_total
test_ratio  = OA_test_total / YA_test_total
YA_train_down = downsample_split(YA_train_data, train_ratio)
YA_val_down   = downsample_split(YA_val_data,   val_ratio)
YA_test_down  = downsample_split(YA_test_data,  test_ratio)
print("Downsampled YA train:", count_total_segments(YA_train_down))
print("OA train:", count_total_segments(OA_train_data))

print("Downsampled YA val:", count_total_segments(YA_val_down))
print("OA val:", count_total_segments(OA_val_data))

print("Downsampled YA test:", count_total_segments(YA_test_down))
print("OA test:", count_total_segments(OA_test_data))


Downsampled YA train: 575
OA train: 578
Downsampled YA val: 71
OA val: 71
Downsampled YA test: 58
OA test: 58


Display segment counts per split

In [44]:
train_out = {}
OA_example_subj = next(iter(OA_train_data))
YA_example_subj = next(iter(YA_train_down))
for key in OA_train_data[OA_example_subj].keys():
    train_out[key] = sum(len(OA_train_data[subj][key]) for subj in OA_train_data)
for key in YA_train_down[YA_example_subj].keys():
    train_out[key] += sum(len(YA_train_down[subj][key]) for subj in YA_train_down)
val_out = {}
OA_example_subj = next(iter(OA_val_data))
YA_example_subj = next(iter(YA_val_down))
for key in OA_val_data[OA_example_subj].keys():
    val_out[key] = sum(len(OA_val_data[subj][key]) for subj in OA_val_data)
for key in YA_val_down[YA_example_subj].keys():
    val_out[key] += sum(len(YA_val_down[subj][key]) for subj in YA_val_down)
test_out = {}
OA_example_subj = next(iter(OA_test_data))
YA_example_subj = next(iter(YA_test_down))
for key in OA_test_data[OA_example_subj].keys():
    test_out[key] = sum(len(OA_test_data[subj][key]) for subj in OA_test_data)
for key in YA_test_down[YA_example_subj].keys():
    test_out[key] += sum(len(YA_test_down[subj][key]) for subj in YA_test_down)

print("Train:", train_out)
print("Val:",   val_out)
print("Test:",  test_out)

Train: {'grf_x': 1153, 'grf_y': 1153, 'grf_z': 1153, 'tibpost': 1153, 'tibant': 1153, 'edl': 1153, 'ehl': 1153, 'fdl': 1153, 'fhl': 1153, 'gaslat': 1153, 'gasmed': 1153, 'soleus': 1153, 'perbrev': 1153, 'perlong': 1153, 'achilles': 1153}
Val: {'grf_x': 142, 'grf_y': 142, 'grf_z': 142, 'tibpost': 142, 'tibant': 142, 'edl': 142, 'ehl': 142, 'fdl': 142, 'fhl': 142, 'gaslat': 142, 'gasmed': 142, 'soleus': 142, 'perbrev': 142, 'perlong': 142, 'achilles': 142}
Test: {'grf_x': 116, 'grf_y': 116, 'grf_z': 116, 'tibpost': 116, 'tibant': 116, 'edl': 116, 'ehl': 116, 'fdl': 116, 'fhl': 116, 'gaslat': 116, 'gasmed': 116, 'soleus': 116, 'perbrev': 116, 'perlong': 116, 'achilles': 116}


In [45]:
def dict_to_array(split_dict):
    packed_segments = []
    for subj, data in split_dict.items():
        num_segs = len(data['grf_x']) 
        for i in range(num_segs):
            sample = np.column_stack([
                data['grf_x'][i],
                data['grf_y'][i],
                data['grf_z'][i],
                data['tibpost'][i],
                data['tibant'][i],
                data['edl'][i],
                data['ehl'][i],
                data['fdl'][i],
                data['fhl'][i],
                data['perbrev'][i],
                data['perlong'][i],
                data['achilles'][i],
            ])
            packed_segments.append(sample)
    return np.array(packed_segments)
    

In [48]:
OA_train_arr = dict_to_array(OA_train_data)
OA_val_arr = dict_to_array(OA_val_data)
OA_test_arr = dict_to_array(OA_test_data)

YA_train_arr = dict_to_array(YA_train_down)
YA_val_arr = dict_to_array(YA_val_down)
YA_test_arr = dict_to_array(YA_test_down)

train_arr = np.concatenate([OA_train_arr, YA_train_arr], axis = 0)
val_arr = np.concatenate([OA_val_arr, YA_val_arr], axis = 0)
test_arr = np.concatenate([OA_test_arr, YA_test_arr], axis = 0)

print("Train:", train_arr.shape)
print("Val:", val_arr.shape)
print("Test:", test_arr.shape)

Train: (1153, 100, 12)
Val: (142, 100, 12)
Test: (116, 100, 12)


In [49]:
X_train, y_train = train_arr[:, :, :3], train_arr[:, :, 3:]
X_val, y_val = val_arr[:, :, :3], val_arr[:, :, 3:]
X_test, y_test = test_arr[:, :, :3], test_arr[:, :, 3:]

print(f"X_train shape: {X_train.shape}")
print(f"y_train shape: {y_train.shape}")
print(f"X_val shape: {X_val.shape}")
print(f"y_val shape: {y_val.shape}")
print(f"X_test shape: {X_test.shape}")
print(f"y_test shape: {y_test.shape}")

X_train shape: (1153, 100, 3)
y_train shape: (1153, 100, 9)
X_val shape: (142, 100, 3)
y_val shape: (142, 100, 9)
X_test shape: (116, 100, 3)
y_test shape: (116, 100, 9)


In [50]:
np.savez(OA_data_dir + 'mixed_train_data.npz', X_train=X_train, y_train=y_train)
np.savez(OA_data_dir + 'mixed_val_data.npz', X_val=X_val, y_val=y_val)
np.savez(OA_data_dir + 'mixed_test_data.npz', X_test=X_test, y_test=y_test)