In [1]:
import numpy as np
from pathlib import Path
import plotly.graph_objects as go


from cgpdm_dynamics import CGPDM
from cgpdm_dynamics_pf import CGPDM_PF
from dataset_utils.mocap_labels import WALK_TRIALS_TRAIN, WALK_TRIALS_TEST, RUN_TRIALS_TRAIN, RUN_TRIALS_TEST
import dataset_utils.select_joints as select_joints



pygame 2.6.1 (SDL 2.30.7, Python 3.12.6)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
config_path = Path.cwd() / 'models' / 'cgpdm' / 'cgpdm_config.pth'
state_path = Path.cwd() / 'models' / 'cgpdm' / 'cgpdm_state.pth'

In [3]:
reduced_columns = WALK_TRIALS_TEST[0].get_columns_for_joints(select_joints.WALKING_SIMPLIFIED_JOINTS)
DOFs = len(reduced_columns)
print(f"Number of DOFs: {DOFs}")


Number of DOFs: 35


In [4]:
TEST_TRIALS = RUN_TRIALS_TEST + WALK_TRIALS_TEST
TRUTH_LABELS = [1] * len(RUN_TRIALS_TEST) + [0] * len(WALK_TRIALS_TEST)

reduced_columns = TEST_TRIALS[0].get_columns_for_joints(select_joints.WALKING_SIMPLIFIED_JOINTS)
DOFs = len(reduced_columns)
print(f"Number of DOFs: {DOFs}")

Number of DOFs: 35


In [5]:
d = 3 # latent space dimension
DOFs = len(reduced_columns)  # number of degrees of freedom

dyn_back_step = 1 # Number of time steps to look back in the dynamics GP

# Initial values for hyperparameters
y_lambdas_init = np.ones(DOFs)  # Signal standard deviation for observation GP
y_lengthscales_init = np.ones(d)  # Lengthscales for observation GP
y_sigma_n_init = 1e-2  # Noise standard deviation for observation GP

x_lambdas_init = np.ones(d)  # Signal standard deviation for latent dynamics GP
x_lengthscales_init = np.ones(dyn_back_step*d)  # Lengthscales for latent dynamics GP
x_sigma_n_init = 1e-2  # Noise standard deviation for latent dynamics GP
x_lin_coeff_init = np.ones(dyn_back_step*d + 1)  # Linear coefficients for latent dynamics GP


In [6]:
gpdm = CGPDM(
        D=DOFs,
        d=d,
        n_classes=2,
        dyn_target='full',
        dyn_back_step=dyn_back_step,
        y_lambdas_init=y_lambdas_init,
        y_lengthscales_init=y_lengthscales_init,
        y_sigma_n_init=y_sigma_n_init,
        x_lambdas_init=x_lambdas_init,
        x_lengthscales_init=x_lengthscales_init,
        x_sigma_n_init=x_sigma_n_init,
        x_lin_coeff_init=x_lin_coeff_init
    )

In [7]:
gpdm.load(config_path, state_path)

[32m
GPDM correctly loaded[0m


In [8]:
import torch 

# check if the gpdm has X_class attribute
if not hasattr(gpdm, 'X_class'):
    gpdm.M = gpdm.get_M()
    gpdm.M_class = []
    gpdm.X_class = []
    gpdm.Kx_inv_class = []
    for i in range(gpdm.n_classes):
        gpdm.M_class.append(gpdm.get_M_for_class(i))
        gpdm.X_class.append(gpdm.get_X_for_class(i))
        Xin, Xout, _ = gpdm.get_Xin_Xout_matrices(gpdm.X_class[i])
        Kx = gpdm.get_x_kernel(Xin, Xin)
        U, info = torch.linalg.cholesky_ex(Kx, upper=True)
        U_inv = torch.inverse(U)
        gpdm.Kx_inv_class.append(torch.matmul(U_inv, U_inv.t()))

# Build a gpdm_pf model
gpdm_pf = CGPDM_PF(gpdm, num_particles=100)

   

In [9]:
# Iterate over the test trials and classify them



for i, trial in enumerate(TEST_TRIALS):
    label = TRUTH_LABELS[i]

    # Get the data in an array format
    data_arr = trial.as_dataframe()[reduced_columns].to_numpy(dtype=np.float32)[::5, :]

    gpdm_pf.reset()
    correct = 0
    total = 0

    # Run the filter bank on the trial
    for data in data_arr:
        gpdm_pf.update(data)
        predicted_label = gpdm_pf.get_most_likely_class()
        latest_lls = gpdm_pf.log_likelihood_classwise()
        walk_ll = latest_lls[0]
        run_ll = latest_lls[1]

        if predicted_label == label:
            correct += 1
        
        total += 1

    accuracy = correct / total

    print(f"Trial {i}, Accuracy: {accuracy}")
    print(f"Trial {i} Final, True label: {label}, Predicted label: {predicted_label}, Walk LL: {walk_ll}, Run LL: {run_ll}")


Trial 0, Accuracy: 0.0
Trial 0 Final, True label: 1, Predicted label: 0, Walk LL: nan, Run LL: -1517.1725217431808


IndexError: index 97 is out of bounds for dimension 0 with size 97