# Inference on validation data

This notebook provides how to make prediction and visualize movement prediction to gif files.

Also you can stack prediction horizontally for comparing difference between prediction and real movements 

Default parameters: 

- FPS = 200
- NEW_FPS = 25
- WINDOW_SIZE = 512
- STRIDE = 512


In [1]:
import wandb, sys, os  
import numpy as np
import einops
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

from pathlib import Path
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint


sys.path.insert(1, os.path.realpath(os.path.pardir))

from utils import data_utils, losses
from utils.hand_visualize import Hand, save_animation_mp4, visualize_and_save_anim, merge_two_videos, visualize_and_save_anim_gifs #, merge_two_videos_vertically
from utils.inference_utils import make_inference, calculcate_latency, get_angle_degree
from utils.quats_and_angles import get_quats, get_angles

from scipy import signal

from models import HVATNet_v2, HVATNet_v3, HVATNet_v3_FineTune, HVATNet_v3_FineTune_N

%load_ext autoreload
%autoreload 2

In [2]:
%matplotlib qt

In [3]:
class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self
        
def upload_weights_pl(model, path, pt):
    class Lit_Wrapper(pl.LightningModule):
        def __init__(self, model):
            super().__init__()
            self.model = model
        def forward(self, x):
            x = self.model(x)
            return x
    
    if pt == 'ckpt':
        ckpt = torch.load(path, map_location=torch.device('cpu'))
        model_pl = Lit_Wrapper(model)
        model_pl.load_state_dict(ckpt['state_dict'])
    else:
        try:
            ckpt = torch.load(path, map_location=torch.device('cpu'))
            model_pl = Lit_Wrapper(model)
            model_pl.load_state_dict(ckpt)
        except:
            ckpt = torch.load(path, map_location=torch.device('cpu'))
            model.load_state_dict(ckpt)
            return model
    
    return model_pl.model

In [9]:
class TrainConfig:
    WANDB_NOTES = 'HVATNet v3 FT train on all new data + no augs'

    datasets = ['../../data/general_set']
    # datasets = ['../../data/processed/dataset_v2_blocks']
    
    # hand_type = ['right']
    # human_type = ['health']
    hand_type = ['right']
    human_type = ['amputant']

    test_dataset_list = ['5']

    use_preproc_data = True # use preproc data (faster preparation  
    use_angles = True # use angels as target.
    
    original_fps = 250  # TODO describtion
    delay_ms = 0  # Shift vr vs EMG parameter. Do not work with preproc data. Fix it!!
    start_crop_ms = 0  # bad values in the beginning of recordign in ms to delete.
    window_size = 256
    down_sample_target=8 # None

    max_epochs = 3000
    samples_per_epoch = 5000*256
    train_bs = 2048
    val_bs = 2048
    device = [2]
    optimizer_params = dict(lr=1e-4,
                            wd=0)
config = TrainConfig()


In [14]:
rootdir = Path('C:/Users/vlvdi/Desktop/EMG/EMG_TRAINING/Nastya/GeneralTraining/Train')
files = list(rootdir.glob('*'))

In [15]:
train_paths = files

In [19]:
train_paths

[WindowsPath('C:/Users/vlvdi/Desktop/EMG/EMG_TRAINING/Nastya/GeneralTraining/Train/1_1'),
 WindowsPath('C:/Users/vlvdi/Desktop/EMG/EMG_TRAINING/Nastya/GeneralTraining/Train/1_2'),
 WindowsPath('C:/Users/vlvdi/Desktop/EMG/EMG_TRAINING/Nastya/GeneralTraining/Train/1_3'),
 WindowsPath('C:/Users/vlvdi/Desktop/EMG/EMG_TRAINING/Nastya/GeneralTraining/Train/1_4'),
 WindowsPath('C:/Users/vlvdi/Desktop/EMG/EMG_TRAINING/Nastya/GeneralTraining/Train/2_1'),
 WindowsPath('C:/Users/vlvdi/Desktop/EMG/EMG_TRAINING/Nastya/GeneralTraining/Train/2_2'),
 WindowsPath('C:/Users/vlvdi/Desktop/EMG/EMG_TRAINING/Nastya/GeneralTraining/Train/2_3'),
 WindowsPath('C:/Users/vlvdi/Desktop/EMG/EMG_TRAINING/Nastya/GeneralTraining/Train/2_4'),
 WindowsPath('C:/Users/vlvdi/Desktop/EMG/EMG_TRAINING/Nastya/GeneralTraining/Train/3_1'),
 WindowsPath('C:/Users/vlvdi/Desktop/EMG/EMG_TRAINING/Nastya/GeneralTraining/Train/3_2'),
 WindowsPath('C:/Users/vlvdi/Desktop/EMG/EMG_TRAINING/Nastya/GeneralTraining/Train/3_3'),
 WindowsPa

In [16]:
rootdir = Path('C:/Users/vlvdi/Desktop/EMG/EMG_TRAINING/Nastya/GeneralTraining/Validation')
files = list(rootdir.glob('*'))

In [17]:
val_paths = files

In [18]:
val_paths

[WindowsPath('C:/Users/vlvdi/Desktop/EMG/EMG_TRAINING/Nastya/GeneralTraining/Validation/1_1')]

## Download data

In [20]:
# Init train and val dataset and dataloaders
val_datasets = []
for val_folder in val_paths:
    val_dataset = data_utils.create_dataset(data_folder=val_folder,
                                            original_fps=config.original_fps,
                                            delay_ms=config.delay_ms,
                                            start_crop_ms=config.start_crop_ms,
                                            window_size=config.window_size,

                                            down_sample_target = config.down_sample_target, 
                                            use_preproc_data=config.use_preproc_data, 
                                            use_angles=config.use_angles, 

                                            random_sampling = False,
                                            samples_per_epoch = None, 
                                            transform = None)
    
    val_datasets.append(val_dataset)

val_dataset = torch.utils.data.ConcatDataset(val_datasets)

Number of moves: 8 | Dataset: GeneralTraining


In [21]:
val_dataset = torch.utils.data.ConcatDataset(val_datasets)

## Init model 

In [24]:
root = 'C:/Users/vlvdi/Desktop/EMG/MainScripts/weights/'

CKPT_PATH = root + 'latest_simple_nast.pt'
#CKPT_PATH = artifact_dir / 'epoch=23_val_angle_degree=8.682.ckpt'

params =dict(n_electrodes=8, n_channels_out=20,
                        n_res_blocks=3, n_blocks_per_layer=3,
                        n_filters=128, kernel_size=3,
                        strides=(2, 2, 2),
                        dilation=2, use_angles = False)

model = HVATNet_v3_FineTune_N.HVATNetv3(**params)
model.use_angles = True

model = upload_weights_pl(model, CKPT_PATH, CKPT_PATH[-4:])

#model.load_state_dict(torch.load(CKPT_PATH, map_location=torch.device('cpu')))
model.eval()
print()

# TO SAVE MODEL:
#torch.save(model.state_dict(), 'C:/Users/vlvdi/Desktop/EMG/model_nast_simple.pt')

Number of parameters:  4210788



In [25]:
model.state_dict()['tune_module.spatial_weights']

tensor([[ 0.9349,  0.0073,  0.2019,  0.1432,  0.0932,  0.1879,  0.0873,  0.0327],
        [ 0.0045,  0.9120,  0.0723,  0.1233,  0.1115,  0.1561,  0.1423,  0.1124],
        [ 0.0797, -0.0847,  0.7601, -0.0369,  0.1082,  0.1474,  0.1955,  0.1603],
        [ 0.0771,  0.0754,  0.0791,  0.8439,  0.0322,  0.1315,  0.1551,  0.1556],
        [ 0.0760,  0.1414,  0.2265,  0.0222,  0.8946, -0.0027,  0.0666,  0.1134],
        [ 0.0690,  0.1321,  0.2582,  0.1089, -0.0251,  0.8681,  0.0077,  0.0501],
        [ 0.0108,  0.0942,  0.2647,  0.1532,  0.0895,  0.0313,  0.8688,  0.0154],
        [-0.0199,  0.0586,  0.2243,  0.1930,  0.1227,  0.1052,  0.0147,  0.8787]])

## Apply inference for each moves.  

In [124]:
val_dataset = torch.utils.data.ConcatDataset(val_datasets)

In [125]:
model.__dict__['training']

False

In [126]:
model.use_angles

True

In [127]:
model = model.to('cpu')

In [128]:
FPS = 25
NEW_FPS = 25
DRAW_EVERY = 200 // NEW_FPS

WINDOW_SIZE = 256
STRIDE = 256

corr_list = []
angle_degree_list = []
preds_per_person = []
targets_per_person = []
abs_diff_per_person = []
angle_diff_per_person = []

for val_dataset_sample in val_dataset.datasets:
    
    all_move_targets = []
    all_move_preds = []
    all_myo = []
    
    for move_data in tqdm(val_dataset_sample.exps_data): 
        data_myo, data_vr = move_data['data_myo'], move_data['data_angles']

        preds, targets = make_inference(data_myo, data_vr, model,
                                        window_size=WINDOW_SIZE, 
                                        stride=STRIDE, 
                                        device='cpu', 
                                        return_angles=True)
        preds = get_angles(preds)
        if config.down_sample_target is not None:
            targets = targets[::config.down_sample_target]

        all_move_targets.append(targets)
        all_move_preds.append(preds)
        all_myo.append(data_myo)

    targets = np.concatenate(all_move_targets, axis=0)
    preds = np.concatenate(all_move_preds, axis=0)
    all_myo = np.concatenate(all_myo, axis=0)

    preds_per_person.append(all_move_preds)
    targets_per_person.append(all_move_targets)
    diff = np.abs(targets - preds)
    abs_diff_per_person.append(diff)
    mean_diff_angle_per_joint = np.rad2deg(np.mean(diff, axis = 1))
    angle_diff_per_person.append(mean_diff_angle_per_joint)
    
    # our metrics:
    dif = np.mean(diff)
    angle_degree = np.round(np.rad2deg(dif), 3)

    corr_coef = torch.mean(F.cosine_similarity(torch.from_numpy(targets), torch.from_numpy(preds), dim=0, eps=1e-8))
    corr_coef = np.round(corr_coef.item(), 3)
    
    corr_list.append(corr_coef)
    angle_degree_list.append(angle_degree)
    

# angle_degree = np.round(F.preds, targets), 3)

print(f"Size of targets {targets.shape} || Size of preds {preds.shape} ") 
print(f"Window size {WINDOW_SIZE} || Stride {STRIDE}")
print('Angle degree: ', np.mean(angle_degree_list), np.std(angle_degree_list))
print('Cosine similarity: ', np.mean(corr_list), np.std(corr_list))

100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:14<00:00, 14.45s/it]

Size of targets (2752, 20) || Size of preds (2752, 20) 
Window size 256 || Stride 256
Angle degree:  3.702 0.0
Cosine similarity:  0.893 0.0





In [129]:
pred_quats = get_quats(preds)
tar_quats = get_quats(targets)

In [130]:
NEW_FPS = 25
DRAW_EVERY = 4
visualize_and_save_anim_gifs(data=pred_quats[::DRAW_EVERY],
                        path=Path('C:/Users/vlvdi/Desktop/EMG/test_predict_offline.gif'), 
                        fps=NEW_FPS)

Video test_predict_offline completed


In [None]:
visualize_and_save_anim_gifs(data=tar_quats[::DRAW_EVERY],
                        path=Path('C:/Users/vlvdi/Desktop/EMG/test_predict_offline.gif'), 
                        fps=NEW_FPS)