In [None]:
import sys
sys.path.append("D:\\WETrak")

import os
import configparser
import numpy as np
import torch
from torch.utils.data import DataLoader

import import_ipynb
import branch_adapter.classification.classification_model_utils as utils
from branch_adapter.classification.classification_loss import DistanceLoss
from branch_adapter.classification.classification_network import CNN_classify
from branch_adapter.classification.classification_data import classification_data

In [None]:
def train_epoch(model,train_loader,epoch,optimizer,lr,classify_num):
    losses,errors = utils.AverageMeter(),utils.AverageMeter()  
        
    model.train()
    
    for i,(emg,gt) in enumerate(train_loader):
        emg = emg.unsqueeze(1)
        emg = utils.data_gpu(emg, device)
        gt = utils.data_gpu(gt, device)
    
        features, out = model(emg)

        loss_cl = DistanceLoss(num_classes=classify_num)
        loss = loss_cl(features,out,gt)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        losses.update(loss.item())
        
        error = utils.multi_acc(gt, out)
        errors.update(error.item())

In [None]:
def main(user_name,classify_num,config):
    
    data_path = config['main']['data_path']
    output_path = config['main']['output_path']
    
    EPOCHS = int(config['Classification']['epochs'])
    BATCH_SIZE = int(config['Classification']['batch_size'])
    lr = float(config['Classification']['learning_rate'])
    
    finger_type_list = ['Index','Middle','Ring']
    # -------------------------------------Save Dir initialization----------------------------------------------------- #
    runs_dir = os.path.join(output_path, user_name, "classify", str(classify_num))
    save_dir = os.path.join(runs_dir,'checkpoints')
    for t in (runs_dir,save_dir):
        if not(os.path.isdir(t)): os.makedirs(t)
    print("runs_dir",os.path.abspath(runs_dir))

    # -------------------------------------Gpu Device && Tensorboard--------------------------------------------------- #
    global device
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

    # -------------------------------------Model and Optim Initialize-------------------------------------------------- #
    model = CNN_classify(classify_num=classify_num,finger_type=finger_type_list).cuda()
    # weight_decay: L2 penalty
    optimizer = torch.optim.Adam(model.parameters(),lr = lr, weight_decay=0.0005)

    # -------------------------------------Data Initialize--------------------------------------------------- #
    train_set = classification_data(data_path, user_name, "train", finger_type_list,classify_num)
    train_loader = DataLoader(train_set,batch_size= BATCH_SIZE,shuffle = True)

    # -------------------------------------Training--------------------------------------------------- #
    for epoch in range(EPOCHS): 
        train_epoch(model,train_loader,epoch,optimizer,lr,classify_num)

        if epoch%2 == 0:
            # save newest model:
            utils.save_checkpoint({'epoch': epoch, 'model_state_dict':model.state_dict(),
                'optimizer_state_dict':optimizer.state_dict(),'error':best_avg_error},save_dir,flag="newest")

In [None]:
user_name = "user1"
classify_num_list = [15,5,3]

# Read configuration
config = configparser.ConfigParser()
config.read('D:\\WETrak\\config.ini')

for classify_num in classify_num_list:
    main(user_name,classify_num,config)