In [None]:
import sys
sys.path.append("D:\\ASGaze")

import os
import warnings
import time
import datetime

import numpy as np
import json
import torch
import torch.nn.functional as f
from torch.utils.tensorboard import SummaryWriter
# filter future warning of tensorboard summary
warnings.filterwarnings('ignore',category=FutureWarning)

import import_ipynb
from iris_boundary_detector.data_sources.ASGaze_data import ASGaze_data
from iris_boundary_detector.graph.vgg_unet import get_model
from iris_boundary_detector.graph.losses import adjust_learning_rate,MissingLoss,DistanceMapLoss,GeneralizedDiceLoss
from iris_boundary_detector.utils.metrics import multi_acc,ComputeIoU
from iris_boundary_detector.utils.load_model import data_gpu,save_checkpoint,AverageMeter

In [None]:
def train_epoch(train_loader,model,optimizer,epoch,lr,criterion_missing_loss,criterion_distance_map_loss,criterion_dice_loss,alpha):
    
    losses,batch_time,accuracy,end = AverageMeter(),AverageMeter(),AverageMeter(),time.time()
    
    model.train()
    compute_iou,length = ComputeIoU(3),len(train_loader)
    
    for i, (data_id, img, gt, one_hot, rotated_rect, rect_trans, iris_missing_weights, distMap) in enumerate(train_loader):
        print(i)
        optimizer.zero_grad() 
        out = model(data_gpu(img, device))   
        
        missing_loss = criterion_missing_loss(out,data_gpu(gt, device))    
        iris_missing_loss = missing_loss*(iris_missing_weights).to(torch.float32).to(device)
        iris_missing_loss = torch.mean(iris_missing_loss).to(torch.float32).to(device)
        
        distance_map_loss = criterion_distance_map_loss(out,(distMap).to(device))
        
        dice_loss = criterion_dice_loss(out,data_gpu(gt, device))
        # Overall loss
        loss = iris_missing_loss + (1-alpha[epoch])*distance_map_loss + alpha[epoch]*(dice_loss)
        
        loss.backward() 
        optimizer.step()

        acc = multi_acc(f.softmax(out,dim=1),data_gpu(gt, device))
        compute_iou(torch.argmax(f.softmax(out, dim=1), dim=1),data_gpu(gt, device))
        
        losses.update(loss.item())
        accuracy.update(acc.item())
        batch_time.update(time.time() - end)
        end = time.time()
        
        lr = optimizer.param_groups[0]['lr']
        if i%(length//3) ==0:
            print("Train epoch {} ({}/{}): [Loss: {} Learning rate: {} batch_time: {}]".
                format(epoch,i, length, losses.avg, lr, batch_time.avg))
            
    ious = compute_iou.get_ious()
    miou = compute_iou.get_miou()
    
    # Train log
    for item in ious:
        writer.add_scalar('Train/IoU:{}'.format(item),ious[item],epoch)
    writer.add_scalar('Train/LR', lr, epoch)
    writer.add_scalar('Train/Loss', losses.avg, epoch)
    writer.add_scalar("Train/BatchTime", batch_time.avg, epoch)
    writer.add_scalar("Train/Pixel Accuracy", accuracy.avg, epoch)    
    writer.add_scalar('Train/mIoU',miou,epoch)
    
    return losses.avg

In [None]:
def main(cfile):
    
    # Load config file
    config = json.load(open(cfile))
    data_name = config['data']
    # -------------------------------------Save Dir initialization----------------------------------------------------- #
    runs_dir = os.path.join(config['trainer']['runs_dir'], data_name+"-"+datetime.datetime.now().strftime('%m%d-%H%M'))
    (save_dir,log_dir) = (os.path.join(runs_dir,'checkpoints'),os.path.join(runs_dir,'log'))
    for t in (runs_dir,save_dir,log_dir):
        if not(os.path.isdir(t)): os.makedirs(t)
            
    # -------------------------------------Gpu Device && Tensorboard--------------------------------------------------- #
    global device, writer
    writer = SummaryWriter(log_dir)
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

    # -------------------------------------Model and Optimizer Initialize-------------------------------------------------- #
    model = get_model("A") # Initial model
    model.cuda()
    # Initial optimizer from config file
    op_type, op = config['optimizer']['type'], config['optimizer']['args']
    optimizer = getattr(torch.optim,op_type)(model.parameters(), lr=op['lr'], weight_decay=op['weight_decay'])
    
    criterion_missing_loss = MissingLoss(gamma=2)
    criterion_distance_map_loss = DistanceMapLoss()
    criterion_dice_loss = GeneralizedDiceLoss(softmax=True, reduction=True)
    
    # -------------------------------------Model and Data Initialize--------------------------------------------------- #
    data_flag = 1
    td = config['train_data'] # Prepare train dataset
    train_set = ASGaze_data(datapath=td['dir'], name=data_name,split='train',flag=data_flag)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=td['batch_size'], 
                                    shuffle=td['shuffle'], num_workers=td['num_workers'], drop_last=td['drop_last'])

    # Train initialization
    t_c, best_loss, start_epoch = config['trainer'], np.inf, 0
    print("outputs:",os.path.abspath(t_c['runs_dir']))
    
    alpha=np.zeros(t_c['epochs'])
    alpha[:]=1 - np.arange(1,t_c['epochs']+1)/t_c['epochs']
    alpha[alpha<0.5]=0.5
        
    for epoch in range(start_epoch, t_c['epochs']):
        print("epoch",epoch)
        
        # Learning rate adjustment
        lr_c = config['lr_scheduler']['args']
        lr = adjust_learning_rate(optimizer, op['lr'], lr_c['gamma'], lr_c['step_size'], epoch)
        
        diff = 0.2 # Learning difficulty setting
        train_set.set_difficulty(diff)
        train_loss = train_epoch(train_loader,model,optimizer,epoch,lr,criterion_missing_loss,criterion_distance_map_loss,criterion_dice_loss,alpha) # epoch train iteration
            
        # Save newest model
        save_checkpoint({'epoch': epoch, 'model_state_dict':model.state_dict(),
            'optimizer_state_dict':optimizer.state_dict(),'loss':best_loss},save_dir,flag=False)

### Trigger

In [None]:
if __name__ == "__main__":
    
    seg_train_config = "./configs/segmentation_train.json"
    main(seg_train_config)
    print('DONE')