In [None]:
import preamble
import numpy as np
from numpy.typing import NDArray
from sklearn.model_selection import KFold
import matplotlib.pyplot as plt
from dataset import Dataset
from model import DualAutoencoderModel, MLPModel
from embedder import BinaryEmbedder, GloveEmbedder
from globals import FORCE_SAMPLE_COUNT, FORCE_CURVE_DURATION
from scipy.signal import savgol_filter
import os

In [None]:
os.makedirs('output', exist_ok=True)

dataset = Dataset('../data/trimmed_trials')

force_data, phrase_data = dataset.load()
merged_phrase_data = dataset.merge_directions(phrase_data)

trial_id, user_id = np.meshgrid(np.arange(0, 42), np.arange(0, 10))
id = np.stack((user_id, trial_id), axis=-1).reshape(-1, 2)

binary_embedder = BinaryEmbedder(np.unique(merged_phrase_data), merged_phrase_data.shape[-1])
glove_embedder = GloveEmbedder('../data/limited_vocab_embeddings_with_special_no_word_token_50d.pt', phrase_data.shape[-1])

In [None]:
def derivative(curve: NDArray) -> NDArray:
    return np.gradient(curve, FORCE_CURVE_DURATION / FORCE_SAMPLE_COUNT, axis=0)

def derivative_smooth(curve: NDArray) -> NDArray:
    return np.gradient(savgol_filter(curve, window_length=50, polyorder=3, axis=0, mode='nearest'), FORCE_CURVE_DURATION / FORCE_SAMPLE_COUNT)

kfold = KFold(n_splits=10, shuffle=True, random_state=0)

for train_indices, test_indices in kfold.split(force_data):
    force_train, phrase_train, merged_phrase_train = force_data[train_indices], phrase_data[train_indices], merged_phrase_data[train_indices]
    force_test, phrase_test, merged_phrase_test = force_data[test_indices], phrase_data[test_indices], merged_phrase_data[test_indices]
    id_test = id[test_indices]

    dae_b = MLPModel(glove_embedder, phrase_mse_loss=True)
    dae_g = DualAutoencoderModel(glove_embedder, phrase_mse_loss=True)

    dae_b.train(force_train, phrase_train, epochs=1024)
    dae_g.train(force_train, phrase_train, epochs=1024)

    force_pred_b = dae_b.phrase_to_force(phrase_test)
    force_pred_g = dae_g.phrase_to_force(phrase_test)
    phrase_pred_b = dae_b.force_to_phrase(force_test)
    phrase_pred_g = dae_g.force_to_phrase(force_test)

    t = np.linspace(0.0, FORCE_CURVE_DURATION, FORCE_SAMPLE_COUNT)

    for i in range(force_test.shape[0]):
        print(i)
        if 'backward up' not in ' '.join(merged_phrase_test[i]).strip():
            continue

        fig, ax = plt.subplots(figsize=(9, 6))
        ax.plot(t, derivative(force_test[i, :, 0]), linewidth=1.0, color='red', linestyle='-', label='Reference X')
        ax.plot(t, derivative(force_test[i, :, 1]), linewidth=1.0, color='green', linestyle='-', label='Reference Y')
        ax.plot(t, derivative(force_test[i, :, 2]), linewidth=1.0, color='blue', linestyle='-', label='Reference Z')
        ax.plot(t, derivative_smooth(force_pred_b[i, :, 0]), linewidth=1.0, color='red', linestyle='--', label='$\\text{DMLP}_S$ X')
        ax.plot(t, derivative_smooth(force_pred_b[i, :, 1]), linewidth=1.0, color='green', linestyle='--', label='$\\text{DMLP}_S$ Y')
        ax.plot(t, derivative_smooth(force_pred_b[i, :, 2]), linewidth=1.0, color='blue', linestyle='--', label='$\\text{DMLP}_S$ Z')
        ax.plot(t, derivative_smooth(force_pred_g[i, :, 0]), linewidth=1.0, color='red', linestyle=':', label='$\\text{DAE}_S$ X')
        ax.plot(t, derivative_smooth(force_pred_g[i, :, 1]), linewidth=1.0, color='green', linestyle=':', label='$\\text{DAE}_S$ Y')
        ax.plot(t, derivative_smooth(force_pred_g[i, :, 2]), linewidth=1.0, color='blue', linestyle=':', label='$\\text{DAE}_S$ Z')
        ax.set_xlabel('Time (s)')
        ax.set_xlim(0.0, 2.0)
        ax.set_ylabel('Force (N)')
        ax.set_title(f"Predicted Force Profile for 'Ease back and up a bit more.'")
        ax.grid(alpha=0.5)
        ax.legend()

        table_data = [
            ["Reference", ' '.join(merged_phrase_test[i]).strip()],
            ["$\\text{DAE}_B$", ' '.join(phrase_pred_b[i]).strip()],
            ["$\\text{DAE}_G$", ' '.join(phrase_pred_g[i]).strip()]
        ]
        table = plt.table(cellText=table_data,
                          colLabels=["$\\bf{Model}$", "$\\bf{Phrase}$"],
                          cellLoc='center',
                          loc='bottom',
                          bbox=[0, -0.5, 1, 0.3])
        plt.figtext(0.5, -0.02, "Predicted Phrases for Reference Force Profile", ha="center", fontsize=12)
        plt.savefig(f'output/{id_test[i, 0]}_{id_test[i, 1]}.png', bbox_inches='tight')
        plt.show()