In [1]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

from gpdm import GPDM
from gpdm_filter import GPDM_Filter
from gpdm_ukf import GPDM_UKF

from dataset_utils.mocap_labels import RUN_TRIALS_TEST, WALK_TRIALS_TEST
import dataset_utils.select_joints as select_joints



pygame 2.6.1 (SDL 2.30.7, Python 3.12.6)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
TEST_TRIALS = RUN_TRIALS_TEST + WALK_TRIALS_TEST
TRUTH_LABELS = [1] * len(RUN_TRIALS_TEST) + [0] * len(WALK_TRIALS_TEST)

reduced_columns = TEST_TRIALS[0].get_columns_for_joints(select_joints.WALKING_SIMPLIFIED_JOINTS)
DOFs = len(reduced_columns)
print(f"Number of DOFs: {DOFs}")


Number of DOFs: 35


In [3]:
d = 3 # latent space dimension
DOFs = len(reduced_columns)  # number of degrees of freedom

dyn_back_step = 1 # Number of time steps to look back in the dynamics GP

# Initial values for hyperparameters
y_lambdas_init = np.ones(DOFs)  # Signal standard deviation for observation GP
y_lengthscales_init = np.ones(d)  # Lengthscales for observation GP
y_sigma_n_init = 1e-2  # Noise standard deviation for observation GP

x_lambdas_init = np.ones(d)  # Signal standard deviation for latent dynamics GP
x_lengthscales_init = np.ones(dyn_back_step*d)  # Lengthscales for latent dynamics GP
x_sigma_n_init = 1e-2  # Noise standard deviation for latent dynamics GP
x_lin_coeff_init = np.ones(dyn_back_step*d + 1)  # Linear coefficients for latent dynamics GP

In [4]:
model_dir = Path().cwd() / 'models' / 'epochs_1k'
walk_config = model_dir / 'walk_gpdm_config.pth'
walk_state = model_dir / 'walk_gpdm_state.pth'

walk_gpdm =  GPDM(
        D=DOFs,
        d=d,
        dyn_target='full',
        dyn_back_step=dyn_back_step,
        y_lambdas_init=y_lambdas_init,
        y_lengthscales_init=y_lengthscales_init,
        y_sigma_n_init=y_sigma_n_init,
        x_lambdas_init=x_lambdas_init,
        x_lengthscales_init=x_lengthscales_init,
        x_sigma_n_init=x_sigma_n_init,
        x_lin_coeff_init=x_lin_coeff_init
    )
walk_gpdm.load(walk_config, walk_state, False)
walk_gpdm_filter = GPDM_UKF(walk_gpdm)
#walk_gpdm_filter = GPDM_Filter(walk_gpdm)


[32m
GPDM correctly loaded[0m


In [5]:
run_config = model_dir / 'run_gpdm_config.pth'
run_state = model_dir / 'run_gpdm_state.pth'

run_gpdm =  GPDM(
        D=DOFs,
        d=d,
        dyn_target='full',
        dyn_back_step=dyn_back_step,
        y_lambdas_init=y_lambdas_init,
        y_lengthscales_init=y_lengthscales_init,
        y_sigma_n_init=y_sigma_n_init,
        x_lambdas_init=x_lambdas_init,
        x_lengthscales_init=x_lengthscales_init,
        x_sigma_n_init=x_sigma_n_init,
        x_lin_coeff_init=x_lin_coeff_init
    )

run_gpdm.load(run_config, run_state, False)
run_gpdm_filter = GPDM_UKF(run_gpdm)
#run_gpdm_filter = GPDM_Filter(run_gpdm)


[32m
GPDM correctly loaded[0m


In [6]:
# Iterate over the test trials and classify them

for i, trial in enumerate(TEST_TRIALS):
    label = TRUTH_LABELS[i]

    # Get the data in an array format
    data_arr = trial.as_dataframe()[reduced_columns].to_numpy(dtype=np.float32)[::5, :]

    walk_gpdm_filter.reset()
    run_gpdm_filter.reset()

    # Run each data point through the GPDM filter
    for timestep in range(data_arr.shape[0]):
        data_point = data_arr[timestep, :]
        walk_ll = walk_gpdm_filter.update(data_point)
        if walk_ll is None:
            walk_ll = walk_gpdm_filter.log_likelihood()
        run_ll = run_gpdm_filter.update(data_point)
        if run_ll is None:
            run_ll = run_gpdm_filter.log_likelihood()

    # Classify the trial using the final log likelihoods
    if walk_ll > run_ll:
        predicted_label = 0
    else:
        predicted_label = 1

    print(f"Trial {i}: True label: {label}, Predicted label: {predicted_label}, Walk LL: {walk_ll}, Run LL: {run_ll}")



  ll = -0.5 * (torch.log(torch.det(self.S)) + (self.residual).T @ torch.inverse(self.S) @ (self.residual))


Trial 0: True label: 1, Predicted label: 1, Walk LL: -90.40590557228516, Run LL: -26.409789045360085
Trial 1: True label: 1, Predicted label: 1, Walk LL: -60.008893712372, Run LL: -7.32161604676115
Trial 2: True label: 1, Predicted label: 1, Walk LL: -72.84636996842492, Run LL: -8.836298474439445
Trial 3: True label: 1, Predicted label: 1, Walk LL: -99.65919787996988, Run LL: -30.57035853880285
Trial 4: True label: 1, Predicted label: 1, Walk LL: -88.7570576179837, Run LL: -32.54148520921389
Trial 5: True label: 1, Predicted label: 1, Walk LL: -104.10049794080878, Run LL: -34.3168216374586
Trial 6: True label: 1, Predicted label: 1, Walk LL: -68.56665570835794, Run LL: -21.483580717436894
Trial 7: True label: 1, Predicted label: 1, Walk LL: -39.276945196516124, Run LL: 1.424409897516398
Trial 8: True label: 1, Predicted label: 1, Walk LL: -69.88533603666129, Run LL: -48.70731907707932
Trial 9: True label: 1, Predicted label: 1, Walk LL: -78.51274796283131, Run LL: -62.07387627748811
Tr