In [1]:
import data_processing as dp
import model_evaluation as me
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from torch import tensor

import importlib
importlib.reload(dp)

In [2]:
####################
# LOADING THE DATA #
####################
DATA_DIR = "../segmented_data/"
SUBJECTS = ['AT']
SCENES = ['FlatWalkStraight', 'FlatWalkCircular', 'FlatWalkStatic']
TRIALS = ('all')

gait_cycles = dp.read_gait_cycles(DATA_DIR, SUBJECTS, SCENES, TRIALS, drop_emgs=True)

In [3]:
#############
# FILTERING #
#############
df_filtered = dp.filter_together(gait_cycles)

In [4]:
######################
# FEATURE EXTRACTION #
######################
X_test = dp.extract_features(df_filtered)
Y_test = df_filtered[['Fx_l', 'Fy_l', 'Fz_l', 'Tz_l',
                      'Fx_r', 'Fy_r', 'Fz_r', 'Tz_r']]

Y_test = tensor(Y_test.to_numpy().reshape((-1, 8)), dtype=torch.float32 )

In [5]:
DIR = 'results/excl_emg/all/20240601-191218'

In [6]:
from models.stm_regressor import STMRegressor

full_grf_estimator = STMRegressor(DIR)
Y_pred = full_grf_estimator(X_test)

In [7]:
# importlib.reload(me)
me.plot_correlations(Y_test, Y_pred)

In [8]:
####################
# LOADING THE DATA #
####################
DATA_DIR = "../segmented_data/"
SUBJECTS = ['AT']
SCENES = ['FlatWalkStraight', 'FlatWalkCircular', 'FlatWalkStatic']
TRIALS = ('all')

perturbations = dp.read_gait_cycles(DATA_DIR, SUBJECTS, SCENES, TRIALS, drop_emgs=True)

In [18]:
import joblib
from models.mlp import MLP

X_l, X_r = dp.homogenize(X_test)


scaler = joblib.load(f'{DIR}/scaler.pkl')
pca    = joblib.load(f'{DIR}/PCA.pkl')

# Normalize features
X_l = scaler.transform(X_l.values)
X_r = scaler.transform(X_r.values)

# Perform PCA
X_l_pc = pca.transform(X_l)
X_r_pc = pca.transform(X_r)

# Convert to tensors
X_pc_l_tensor = tensor(X_l_pc, dtype=torch.float32)
X_pc_r_tensor = tensor(X_r_pc, dtype=torch.float32)

model = MLP.load(DIR, 'Tz')
y_pred = model(X_pc_l_tensor)

for key, value in model.state_dict().items():
    if 'weight' in key:
        if 'input' in key:  print(np.shape(value)[1])
        if 'hidden' in key: print(np.shape(value)[1])
        if 'output' in key:
            print(np.shape(value)[1])
            print(np.shape(value)[0])

me.print_metrics(Y_test[:,3].reshape(-1,1), y_pred)