In [220]:
import sys
sys.path.insert(0,'C:/Users/vlvdi/Desktop/EMG/alvi')

import os
from pathlib import Path
from functools import partial
import configparser

from typing import TypeVar
import time
# Extern libs (pip install)
import pylsl
import torch
import numpy as np
from einops import rearrange
from pylsl import StreamInfo, StreamOutlet
import pytorch_lightning as pl

# local modules
from utils.quats_and_angles import get_quats, get_angles
from utils.hand_visualize import Hand, save_animation_mp4, visualize_and_save_anim, merge_two_videos, visualize_and_save_anim_gifs #, merge_two_videos_vertically

from models import HVATNet_v2, HVATNet_v3, HVATNet_v3_FineTune, HVATNet_v3_FineTune_N

In [95]:
# parsing init params
config = configparser.ConfigParser()
conf_path = 'C:/Users/vlvdi/Desktop/EMG/alvi/conf.ini'
config.read(conf_path)
PATH_TO_SALUT_ML_DIR = Path(config['global']['salut_ml_dir'])
sys.path.insert(1, str(PATH_TO_SALUT_ML_DIR))

CKPT_PATH = (PATH_TO_SALUT_ML_DIR 
             / Path(config['inference']['init_weights_path']))
REALTIME_WEIGHTS_FOLDER = (Path(config['realtime_training']['work_path']) 
                        / Path(config['realtime_training']['weights_folder']))   
DEVICE = config['inference']['device']
# DEVICE = 'cuda'
REORDER_ELECTORDES = config['realtime_training']['reorder_electrodes']
MODEL_TYPE = config['global']['model_type']

exec_time_list = []

PathLike = TypeVar("PathLike", str, Path)
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(DEVICE)

cpu


In [216]:
CKPT_PATH

WindowsPath('C:/Users/vlvdi/Desktop/EMG/alvi/weights/latest_simple_nast.pt')

In [96]:
def normalize_quats(v):
    """
    [Time, n_bones, 4]
    """
    norm = np.linalg.norm(v, axis=-1, keepdims=True)
    return v / norm


def smooth_ema(data, coef, prev=None):
    """
    This should not be applied on quanternions
    [Time, ...]
    coef in range [0, 1)
    coef = 0 -> no smooth
    """
    
    if prev is None:
        prev = data[0]
        for i in range(1, len(data)): 
            data[i] = prev * coef + data[i] * (1 - coef)
            prev = data[i]
    else: 
        for i in range(0, len(data)): 
            data[i] = prev * coef + data[i] * (1 - coef)
            prev = data[i]        
    return data


# method for quats model in same fps as EMG, NOT used in inference scripts
def myo_to_vr(emg, device=DEVICE):
    """
    Preproc emg window and predict vr points.
    Last points only.

    :return:
        Return list of  Last vr points with step parameters [step, 16*4]
    """
    # global targets_load
    try:
        # emg = (np.array(emg) + np.nanmean(np.array(emg))) / np.nanmax(np.array(emg))
        print('I AM WORKING!!!!!!!')
        # CHANGES
        # emg = (np.array(emg) + 128) / 255.
    except Exception as err:
        raise err
    
    vr_output = model.inference(emg, device=device, first_bone_is_constant=True)
    vr_output = np.reshape(vr_output, (vr_output.shape[0], 64))
    
    # vr_last_points = vr_output[-STRIDE:]  # get last points but with step border
    vr_last_points = vr_output[::MODEL_DS_RATE]  # downsample
    vr_last_points = normalize_quats(vr_last_points)

    return list(vr_last_points)

In [113]:
# for angle model use inference_v2
def myo_to_angles(emg, min=0, max=0.0001, device=DEVICE):
    """
    Preproc emg window and predict vr points.
    Last points only.

    :return:
        Return list of  Last vr points with step parameters [step, 16*4]
    """
    # global targets_load

    try:
        emg_new = emg # (np.array(emg) - min) / (max - min)
        print(np.array(emg).shape)
        # CHANGES
        # emg = (np.array(emg) + 128) / 255
    except Exception as err:
        emg_new = [[1 for j in range(8)] for i in range(256)]
        emg_new = np.array(emg_new)
        #raise err

    #emg_new = emg_new * 2 - 1

    # 0 <= emg <= 2

    if REORDER_ELECTORDES:
        emg_new = emg_new[:, (6, 5, 4, 3, 2, 1, 0, 7)]
        
    if MODEL_TYPE == 'hvatnet_v3':
        # for hvatnetmodel!!
        vr_output = model.inference_v2(emg_new, device=device, first_bone_is_constant=True)
    else: 
        raise ValueError('Wrong model_type !')
    return vr_output


## need to import HVATNet_v2 for this to work
def load_HVATNetv2(CKPT_PATH = 'HVATNet_v2_matvey_fixed_quats.pt'):
    from models import HVATNet_v2
     # init model and load weights
    hvatnet_v2_params =dict(n_electrodes=8, n_channels_out=64,
                        n_res_blocks=3, n_blocks_per_layer=2,
                        n_filters=128, kernel_size=3,
                        strides=(2, 2, 2, 4),
                        dilation=2)
    
    # device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = HVATNet_v2.HVATNetv2(**hvatnet_v2_params)
    model.load_state_dict(torch.load(CKPT_PATH, map_location=torch.device(DEVICE)))
    model.to(DEVICE)
    model.eval()
    return model

def upload_weights_pl(model, path, pt):
    class Lit_Wrapper(pl.LightningModule):
        def __init__(self, model):
            super().__init__()
            self.model = model
        def forward(self, x):
            x = self.model(x)
            return x
    
    if pt == 'ckpt':
        ckpt = torch.load(path, map_location=torch.device('cpu'))
        model_pl = Lit_Wrapper(model)
        model_pl.load_state_dict(ckpt['state_dict'])
    else:
        try:
            ckpt = torch.load(path, map_location=torch.device('cpu'))
            model_pl = Lit_Wrapper(model)
            model_pl.load_state_dict(ckpt)
        except:
            ckpt = torch.load(path, map_location=torch.device('cpu'))
            model.load_state_dict(ckpt)
            return model
    
    return model_pl.model

def load_HVATNetv3(CKPT_PATH):
    from models import HVATNet_v3_FineTune

    # device = 'cuda' if torch.cuda.is_available() else 'cpu'
    hvatnet_v3_params =dict(n_electrodes=8, n_channels_out=20,
                        n_res_blocks=3, n_blocks_per_layer=3,
                        n_filters=128, kernel_size=3,
                        strides=(2, 2, 2), dilation=2,
                        use_angles=True)
    
    model = HVATNet_v3_FineTune.HVATNetv3(**hvatnet_v3_params)
    model = upload_weights_pl(model, CKPT_PATH, CKPT_PATH[-4:])
    
    model.eval()
    return model


# buffer of EMG inlet shold be >= WINDOW_SIZE (maybe in seconds)
def pull_predict(emg_inlet,
                 quat_buf_outlet,
                    window_size,
                    stride,
                        emg_buffer,
                            last_angle=None,
                            smooth_coef = None, min_emg=0, max_emg=0.0001):
    counter = 0 
    start_time = pylsl.local_clock()
    
    # dropping stride of old EMG
    emg_buffer = emg_buffer[stride:]
    # pulling stride new emg
    emgpull_start_time = pylsl.local_clock()
    while counter < stride:
        emg, timestamp = emg_inlet.pull_sample()
        emg_buffer.append(emg)
        counter += 1
    emgpull_end_time = pylsl.local_clock()

    # angle_buffer = myo_to_vr(emg_buffer)
    modelpred_start_time = pylsl.local_clock()
    angle_buffer = myo_to_angles(emg_buffer, min_emg, max_emg)

    modelpred_end_time = pylsl.local_clock()

    angle_buffer = angle_buffer[-(stride//MODEL_DS_RATE):]

    # smoothing only to angles
    smooth_start_time = pylsl.local_clock()
    if not smooth_coef is None:
        angle_buffer = smooth_ema(data=angle_buffer,
                                    prev=last_angle,
                                    coef=smooth_coef)
    smooth_end_time = pylsl.local_clock()    
    # angles to quats conversion
    anglequat_start_time = pylsl.local_clock()
    quat_to_push = get_quats(angle_buffer)
    quat_to_push = rearrange(quat_to_push, 't b q -> t (b q)')
    anglequat_end_time = pylsl.local_clock()
    # this be buffered again and resend with regular sample rate
    for sample in quat_to_push:
        quat_buf_outlet.push_sample(sample)
        
    # counting exec time and wait for new stride of EMG samples to come 
    end_time = pylsl.local_clock()
    exec_time = end_time  - start_time   
    wait_time = stride/VR_OUTPUT_FPS - exec_time # waiting for stride of EMG to come
    exec_time_list.append(float(exec_time))
    # logging execution time
    time_log_pull.append(emgpull_end_time - emgpull_start_time) 
    time_log_modelpred.append(modelpred_end_time - modelpred_start_time) 
    time_log_smooth.append(smooth_end_time - smooth_start_time)
    time_log_exec.append(end_time - start_time)
    print(f'''
    emg pull time = {time_log_pull[-1]}
    model pred time = {time_log_modelpred[-1]}
    smoothing time = {time_log_smooth[-1]}
    EXEC time = {time_log_exec[-1]}
    ''')
    # time.sleep(0 if wait_time < 0 else wait_time)
    # return emg_buffer and last angle for smoothing
    return emg_buffer, angle_buffer[-1]


def event_loop(model, 
               emg_inlet,
               quat_buf_outlet,
               window_size,
               stride,
               smooth_coef=None, 
               update_weights_time = 15):

    emg_buffer = [[0 for j in range(8)] for i in range(window_size)]
    prev_angle = None
    print(len(emg_buffer))
    counter = 0
    global_counter = 0
    last_weights_path = None

    min_emg = 0
    max_emg = 0.00001

    while True: 
        emg_buffer, prev_angle = pull_predict(emg_inlet=emg_inlet,
                                              quat_buf_outlet=quat_buf_outlet,
                                                window_size=window_size,
                                                stride=stride,
                                                emg_buffer=emg_buffer,
                                                last_angle=prev_angle,
                                                smooth_coef=smooth_coef, min_emg=min_emg, max_emg=max_emg)

        try:
            if global_counter < 200:
                min_new = np.nanmin(emg_buffer)
                max_new = np.nanmax(emg_buffer)

                if min_new < min_emg:
                    min_emg = min_new
                if max_new > max_emg:
                    max_emg = max_new
        except:
            None

        print('Current time', stride / VR_OUTPUT_FPS * counter)
        if stride / VR_OUTPUT_FPS * counter > update_weights_time:
            counter = 0 
            start_update_time = time.time()
            model, last_weights_path, updated = reload_weights(model, REALTIME_WEIGHTS_FOLDER, last_weights_path)
            update_time = time.time() - start_update_time
            print(f'weights {updated=} {last_weights_path.name if last_weights_path is not None else last_weights_path}')
            if updated:
                print(f'WEIGHTS UPDATE TIME = {update_time}')
        counter += 1
        global_counter += 1


def reload_weights(model, folder, last_weights_path):
    paths = sorted(folder.iterdir())
    if len(paths) > 2:
        new_path = paths[-1]
    else:
        new_path = last_weights_path
    # print(sorted(folder.iterdir()))
    print(f'{new_path=}')
    print(f'{last_weights_path=}')
    updated = (last_weights_path != new_path) 
    if updated:
        model.load_state_dict(torch.load(new_path, map_location=torch.device(DEVICE)))
        model.to(DEVICE)
        model.eval()
    return model, new_path, updated     

## Runtime

In [176]:
WINDOW_SIZE

256

In [160]:
MYO_INPUT_FPS = int(config['global']['myo_input_fps'])
VR_OUTPUT_FPS = int(config['global']['vr_output_fps'])  # fps which we want to have in vr.
MODEL_OUTPUT_FPS = int(config['global']['model_output_fps'])
WINDOW_SIZE = int(config['inference']['window_size'])
STRIDE = 256 #int(config['inference']['stride'])
SMOOTH_COEF = float(config['inference']['smooth_coef'])
UPDATE_WEIGHTS_TIME = int(config['inference']['update_weights_time'])

TOTAL_DS_RATE = MYO_INPUT_FPS // VR_OUTPUT_FPS
MODEL_DS_RATE = MYO_INPUT_FPS // MODEL_OUTPUT_FPS

VR_BUFFER = []
EMG_BUFFER = []
counter_emg = 0

In [161]:
str(CKPT_PATH)

'C:\\Users\\vlvdi\\Desktop\\EMG\\alvi\\weights\\latest_simple_nast.pt'

In [162]:
model = load_HVATNetv3(str(CKPT_PATH))

Number of parameters:  4210788


In [163]:
# buffer of EMG inlet shold be >= WINDOW_SIZE (maybe in seconds)
def pull_predict_offline(sample, window_size,
                    stride, emg_buffer, last_angle=None, smooth_coef=None, min_emg=0, max_emg=0.0001):
    counter = 0
    MODEL_DS_RATE = 10
    
    # dropping stride of old EMG
    emg_buffer = emg_buffer[stride:]
    # pulling stride new emg
    
    # buffer of EMG inlet shold be >= WINDOW_SIZE (maybe in seconds)
    emg_buffer = sample
    counter += 1

    # angle_buffer = myo_to_vr(emg_buffer)
    angle_buffer = myo_to_angles(emg_buffer, min_emg, max_emg)

    angle_buffer = angle_buffer[-(stride//MODEL_DS_RATE):]

    # smoothing only to angles
    if not smooth_coef is None:
        angle_buffer = smooth_ema(data=angle_buffer,
                                    prev=last_angle,
                                    coef=smooth_coef)
        
    # angles to quats conversion
    quat_to_push = get_quats(angle_buffer)
    quat_to_push = rearrange(quat_to_push, 't b q -> t (b q)')

    return emg_buffer, angle_buffer[-1]

In [164]:
path = f"C:/Users/vlvdi/Desktop/EMG/EMG_TRAINING/Nastya/GeneralTraining/Validation/1_1/0000.npz"
data = np.load(path)
myo = data['data_myo']
angles = data['data_vr']

In [165]:
myo.shape

(22233, 8)

In [166]:
#myo = myo[0:myo.shape[0] - sum(np.isnan(myo[:, 0])), :]

In [188]:
myo.shape

(22233, 8)

In [221]:
data = []

for plus in range(0, myo.shape[0], 256):
    data.append(myo[plus:256+plus, :])

In [222]:
(myo.shape[0] - 256)//256 + 1

86

In [223]:
len(data)

87

In [237]:
emg_buffer = [[0 for j in range(8)] for i in range(WINDOW_SIZE)]
prev_angle = None
print(len(emg_buffer))
counter = 0
global_counter = 0
last_weights_path = None

min_emg = 0
max_emg = 0.00001

angless = []

for sample in range(0, len(data)-1): 
#     emg_buffer, prev_angle = pull_predict_offline(sample=data[sample],
#                                             window_size=WINDOW_SIZE,
#                                             stride=STRIDE,
#                                             emg_buffer=emg_buffer,
#                                             last_angle=prev_angle,
#                                             smooth_coef=SMOOTH_COEF, min_emg=min_emg, max_emg=max_emg)
    pred = model.inference(data[sample], device = 'cpu')
    angless.append(pred)
#     try:
#         if global_counter < 200:
#             min_new = np.nanmin(emg_buffer)
#             max_new = np.nanmax(emg_buffer)

#             if min_new < min_emg:
#                 min_emg = min_new
#             if max_new > max_emg:
#                 max_emg = max_new
#     except:
#         None

    counter += 1
    global_counter += 1

256


In [238]:
angls = np.concatenate(angless, axis=0)

In [241]:
angls = np.array(angls)

In [248]:
angls.shape

(2752, 16, 4)

In [243]:
preds = get_angles(angls)

In [246]:
preds_quat = get_quats(preds)

In [249]:
preds_quat.shape

(2752, 16, 4)

In [251]:
NEW_FPS = 25
DRAW_EVERY = 4

visualize_and_save_anim_gifs(data=angls[::DRAW_EVERY],
                        path=Path('C:/Users/vlvdi/Desktop/EMG/test_online.gif'), 
                        fps=NEW_FPS)

Video test_online completed
