In [15]:
import os

import torch

import model.classifiers as classifiers
import model.transforms as transforms

import visualizer.utils as utils
from config import CONFIG, TRAIN_CONFIG

In [16]:
exp_id = 1

device = 'cpu'
length = 115

checkpoint_path = os.path.join(
    TRAIN_CONFIG.train_params.output_data,
    f'experiment_{str(exp_id).zfill(3)}',
    'checkpoint.pth',
)

samples_folder = CONFIG.mediapipe.points_pose_world_windowed_filtered_labeled

with_rejection = TRAIN_CONFIG.gesture_set.with_rejection
label_map = {gesture: i for i, gesture in enumerate(TRAIN_CONFIG.gesture_set.gestures, start=1)}
if with_rejection:
    # label_map['_rejection'] = len(label_map)
    label_map['_rejection'] = 0

inv_label_map = {value: key for key, value in label_map.items()}

In [17]:
model = classifiers.LSTMClassifier(len(label_map))
model.to(device)
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model.eval()

test_transforms = transforms.TrainTransforms(device=device)

In [18]:
subject = 120
gesture = 'start'
hand = 'right'
trial = 1

file_path = os.path.join(
    samples_folder,
    f'G{subject}_{gesture}_{hand}_trial{trial}.npy'
)

data = utils.get_mediapipe_points(file_path)[:length]
points = data[:, :-1]
labels = torch.tensor(data[:, -1]).to(torch.int64) * label_map[gesture]

In [19]:
with torch.no_grad():
    prediction, *_ = model(test_transforms(points))

In [20]:
prediction_probs, prediction_labels = prediction.max(dim=-1)
prediction_labels

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [21]:
labels

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [22]:
accuracy = (prediction_labels == labels).sum() / length
f'{accuracy.item():.2%}'

'98.26%'