In [None]:
import sys
sys.path.append("D:\\WETrak")

import os
import configparser
import numpy as np
import torch
from torch.utils.data import DataLoader

import import_ipynb
import tracker_network.utils.model_utils as utils
from tracker_network.graph.loss import mse_loss
from tracker_network.graph.network import Autoencoder
from tracker_network.data_sources.finger_data import finger_data_15,finger_data_234

In [None]:
def train_epoch(model,train_loader,epoch,optimizer,lr,finger_type):
    ####### Metrics #######
    losses,errors = utils.AverageMeter(),utils.AverageMeter()    
        
    ####### Train #######
    model.train()
    
    ####### Thumb or Pinky #######
    if len(finger_type) == 1:
        for i,(emg,gt) in enumerate(train_loader):
            emg = utils.data_gpu(emg, device) 
            gt = utils.data_gpu(gt, device)

            out = model(emg)        

            loss = mse_loss(out,gt)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            error = utils.mean_absolute_error(gt, out)
            errors.update(error.item())
            losses.update(loss.item())
    
    ####### Index, Middle, Ring #######
    elif len(finger_type) > 1:
        for i,(emg,gt,min_val,max_val) in enumerate(train_loader):
            emg = utils.data_gpu(emg, device) 
            gt = gt * (max_val-min_val)+min_val 
            gt = utils.data_gpu(gt, device)

            out = model(emg)        
            weight = (max_val-min_val).to(device)
            bias = min_val.to(device)
            out = out * weight + bias

            loss = mse_loss(out,gt)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            error = utils.mean_absolute_error(gt, out)
            errors.update(error.item())
            losses.update(loss.item())

    #         if i%(len(train_loader)//3) ==0:
    #             print("Train epoch {} ({}/{}): [Loss: {} Learning rate: {} Error: {}]".
    #                 format(epoch,i, len(train_loader), losses.avg, lr, errors.avg))

In [None]:
def main(user_name,finger_type_list,config):    
    data_path = config['main']['data_path']
    output_path = config['main']['output_path']
    
    EPOCHS = int(config['Tracker']['epochs'])
    BATCH_SIZE = int(config['Tracker']['batch_size'])
    lr = float(config['Tracker']['learning_rate'])
    
    if finger_type_list == ["Thumb"]:
        is_norm = False
        file_name = "Thumb"
    elif finger_type_list == ["Pinky"]:
        is_norm = False
        file_name = "Pinky"
    elif finger_type_list == ['Index','Middle','Ring']:
        is_norm = True
        file_name = "Others"
    
    # -------------------------------------Save Dir initialization----------------------------------------------------- #
    runs_dir = os.path.join(output_path, user_name,file_name)
    save_dir = os.path.join(runs_dir,'checkpoints')
    for t in (runs_dir,save_dir):
        if not(os.path.isdir(t)): os.makedirs(t)
    print("runs_dir",runs_dir)

    # -------------------------------------Gpu Device && Tensorboard--------------------------------------------------- #
    global device
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

    # -------------------------------------Model and Optim Initialize-------------------------------------------------- #
    model = Autoencoder(finger_type_list,is_norm).cuda()
    # weight_decay: L2 penalty
    optimizer = torch.optim.Adam([
                 {'params': model.encoder.parameters()},
                 {'params': model.max_pool.parameters()},
                 {'params': model.resnet.parameters()},
                 {'params': model.decoder.parameters()},                
                 {'params': model.attention.parameters(), 'lr': 0.001}
                ], lr = lr, weight_decay=0.0005)

    #-------------------------------------Data Initialize--------------------------------------------------- #
    if len(finger_type_list) == 1:
        train_set = finger_data_15(data_path, user_name, "train", finger_type_list)
    
    elif len(finger_type_list) > 1:
        train_set = finger_data_234(data_path, user_name, "train", finger_type_list)
    
    train_loader = DataLoader(train_set,batch_size= BATCH_SIZE,shuffle = True)

    # -------------------------------------Training--------------------------------------------------- #
    for epoch in range(EPOCHS): 
        train_epoch(model,train_loader,epoch,optimizer,lr,finger_type_list)

        if epoch%10 == 0:
            # save newest model:
            utils.save_checkpoint({'epoch': epoch, 'model_state_dict':model.state_dict(),
                'optimizer_state_dict':optimizer.state_dict(),'error':best_avg_error},save_dir,flag="newest")

In [None]:
user_name = "user1"
finger_type_list = ['Thumb'] 

# Read configuration
config = configparser.ConfigParser()
config.read('D:\\WETrak\\config.ini')

main(user_name,finger_type_list,config)

In [None]:
user_name = "user1"
finger_type_list = ['Pinky'] 

main(user_name,finger_type_list,config)

In [None]:
user_name = "user1"
finger_type_list = ['Index','Middle','Ring'] 

main(user_name,finger_type_list,config)