In [1]:
import sys
sys.path.append("D:\\WETrak")

import os
import configparser
import torch
from torch.utils.data import DataLoader

import import_ipynb
import tracker_network.utils.model_utils as utils
from tracker_network.graph.network import Autoencoder
from tracker_network.data_sources.finger_data import finger_data_15,finger_data_234

os.environ['KMP_DUPLICATE_LIB_OK']='True'

importing Jupyter notebook from D:\WETrak\tracker_network\utils\model_utils.ipynb
importing Jupyter notebook from D:\WETrak\tracker_network\graph\network.ipynb
importing Jupyter notebook from D:\WETrak\tracker_network\data_sources\finger_data.ipynb
importing Jupyter notebook from D:\WETrak\tracker_network\data_sources\data_utils.ipynb
importing Jupyter notebook from D:\WETrak\branch_adapter\regression\regression_inference.ipynb
importing Jupyter notebook from D:\WETrak\branch_adapter\regression\regression_model_utils.ipynb
importing Jupyter notebook from D:\WETrak\branch_adapter\regression\regression_network.ipynb
importing Jupyter notebook from D:\WETrak\branch_adapter\regression\regression_data.ipynb
importing Jupyter notebook from D:\WETrak\branch_adapter\regression\regression_data_utils.ipynb
importing Jupyter notebook from D:\WETrak\branch_adapter\classification\classification_cascade_inference.ipynb
importing Jupyter notebook from D:\WETrak\branch_adapter\classification\classific

In [2]:
def inference_15(model,test_loader,finger_type_list):
    with torch.no_grad():
        model.eval()
        
        for length,finger_type in enumerate(finger_type_list): 
            maes = utils.AverageMeter()
            
            for i,(emg,gt) in enumerate(test_loader):
                emg = utils.data_gpu(emg, device) # [1,1,1000,2]
                gt = utils.data_gpu(gt, device) # [1,1,1000,4]
                
                out = model(emg) # [1,1,1000,4]
                out_temp = out[:,:,:,length*4:(length+1)*4]
                gt_temp = gt[:,:,:,length*4:(length+1)*4]
                
                mae = utils.mean_absolute_error(gt_temp, out_temp)
                maes.update(mae.item())

            print("Tracking error of {} is: {} degree".format(finger_type,maes.median()))

In [3]:
def inference_234(model,test_loader,finger_type_list):
    with torch.no_grad():
        model.eval()
        
        for length,finger_type in enumerate(finger_type_list): 
            maes = utils.AverageMeter()
            
            for i,(emg,gt,min_val,max_val) in enumerate(test_loader):
#                 print("i",i+1)
                emg = utils.data_gpu(emg, device)
                gt = utils.data_gpu(gt, device) 
                
                out = model(emg) 
                # re-normalize
                out = out.cpu() * (max_val-min_val)+min_val
                out = utils.data_gpu(out, device)
                
                # log errors values
                out_temp = out[:,:,:,length*4:(length+1)*4]
                gt_temp = gt[:,:,:,length*4:(length+1)*4]
                
                mae = utils.mean_absolute_error(gt_temp, out_temp)
                maes.update(mae.item())

            print("Tracking error of {} is: {} degree".format(finger_type,maes.median()))

In [4]:
def main(user_name,finger_type_list,config,classify_num=None):    
    data_path = config['main']['data_path']
    output_path = config['main']['output_path']
    
    if finger_type_list == ["Thumb"]:
        is_norm = False
        file_name = "Thumb"
    elif finger_type_list == ["Pinky"]:
        is_norm = False
        file_name = "Pinky"
    elif finger_type_list == ['Index','Middle','Ring']:
        is_norm = True
        inference_mode = True
        file_name = "Others"

    # -------------------------------------Save Dir initialization----------------------------------------------------- #
    runs_dir = os.path.join(output_path, user_name, file_name)
    checkpoint_dir = os.path.join(runs_dir, 'checkpoints')
    print("load {} checkpoingts from {}".format(file_name,os.path.abspath(checkpoint_dir)))

    # -------------------------------------Gpu Device && Tensorboard--------------------------------------------------- #
    global device
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

    # -------------------------------------Model and Optim Initialize-------------------------------------------------- #
    model = Autoencoder(finger_type_list,is_norm).cuda()

    # -------------------------------------Data Initialize--------------------------------------------------- #
    if len(finger_type_list) == 1:
        test_set = finger_data_15(data_path, user_name, "test", finger_type_list)
    else:
        test_set = finger_data_234(data_path, user_name, "test", finger_type_list,classify_num,inference_mode,output_path)
    
    test_loader = DataLoader(test_set,batch_size = 1, shuffle = False)
    
    # -------------------------------------Inference--------------------------------------------------- #
    checkpoint = utils.load_checkpoint(os.path.join(checkpoint_dir),"median")
    model.load_state_dict(checkpoint['model_state_dict'])

    if len(finger_type_list) == 1:
        inference_15(model,test_loader,finger_type_list)
    else:
        inference_234(model,test_loader,finger_type_list)

In [5]:
# Read configuration
config = configparser.ConfigParser()
config.read('D:\\WETrak\\config.ini')

user_name = "user1"
finger_type_list = ['Thumb']
main(user_name,finger_type_list,config)

load Thumb checkpoingts from D:\WETrak\Outputs\user1\Thumb\checkpoints
Tracking error of Thumb is: 3.922192692756653 degree


In [6]:
user_name = "user1"
finger_type_list = ['Pinky']
main(user_name,finger_type_list,config)

load Pinky checkpoingts from D:\WETrak\Outputs\user1\Pinky\checkpoints
Tracking error of Pinky is: 3.9536795616149902 degree


In [7]:
user_name = "user1"
classify_num = 15
finger_type_list = ['Index','Middle','Ring']

main(user_name,finger_type_list,config,classify_num)

load Others checkpoingts from D:\WETrak\Outputs\user1\Others\checkpoints
load regression checkpoingts from D:\WETrak\Outputs\user1\regression\15\checkpoints
load classification checkpoingts from D:\WETrak\Outputs\user1\classify\15\checkpoints
load classification checkpoingts from D:\WETrak\Outputs\user1\classify\5\checkpoints
load classification checkpoingts from D:\WETrak\Outputs\user1\classify\3\checkpoints
Tracking error of Index is: 5.552586078643799 degree
Tracking error of Middle is: 5.982937574386597 degree
Tracking error of Ring is: 5.38001012802124 degree
