# Script for `Training`

In [None]:
import os
import sys
import copy
import traceback
import shutil
from typing import List, Dict, Tuple
from datetime import datetime
from glob import glob
import json
import toml
import tomlkit

from tqdm.auto import tqdm
import pandas as pd
import numpy as np
import cv2
import matplotlib.pyplot as plt

import torch
from torch import nn, utils
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
import torchvision

from sklearn.model_selection import train_test_split

import imgaug as ia

sys.path.append("./../modules/") # add path to scan customized module
from logger import init_logger
from fileop import create_new_dir
from dl_utils import set_gpu, ImgDataset, caulculate_metrics, save_model, plot_training_trend, \
                     compose_transform, calculate_class_weight, get_sortedClassMapper_from_dir, \
                     rename_training_dir, save_training_logs
from misc_utils import Timer
from plt_show import plot_in_rgb # server can't use `cv.imshow()`

# print("="*100, "\n")

In [None]:
training_logger = init_logger(r"Training")

Load `vit_b_16.toml` ( train_ver )

In [None]:
with open("vit_b_16.toml", mode="r") as f_reader:
    config = toml.load(f_reader)

# dataset
dataset_root       = os.path.normpath(config["dataset"]["root"])
dataset_name       = config["dataset"]["name"]
dataset_gen_method = config["dataset"]["gen_method"]
dataset_stdev      = config["dataset"]["stdev"]
dataset_param_name = config["dataset"]["param_name"]

# model
save_dir_root    = config["model"]["history_root"]
model_name       = config["model"]["model_name"]
pretrain_weights = config["model"]["pretrain_weights"]

# train_opts
train_ratio  = config["train_opts"]["train_ratio"]
rand_seed    = config["train_opts"]["random_seed"]
epochs       = config["train_opts"]["epochs"]
batch_size   = config["train_opts"]["batch_size"]

# train_opts.optimizer
lr           = config["train_opts"]["optimizer"]["learning_rate"]
weight_decay = config["train_opts"]["optimizer"]["weight_decay"]

# train_opts.lr_schedular
use_lr_schedular   = config["train_opts"]["lr_schedular"]["enable"]
lr_schedular_step  = config["train_opts"]["lr_schedular"]["step"]
lr_schedular_gamma = config["train_opts"]["lr_schedular"]["gamma"]

# train_opts.earlystop
enable_earlystop = config["train_opts"]["earlystop"]["enable"]
max_no_improved  = config["train_opts"]["earlystop"]["max_no_improved"]

# train_opts.data
use_hsv               = config["train_opts"]["data"]["use_hsv"]
aug_on_fly            = config["train_opts"]["data"]["aug_on_fly"]
forcing_balance       = config["train_opts"]["data"]["forcing_balance"]
forcing_sample_amount = config["train_opts"]["data"]["forcing_sample_amount"]
if aug_on_fly and forcing_balance: 
    raise ValueError("'aug_on_fly' and 'forcing_balance' can only set one to True at a time")

# train_opts.debug_mode
debug_mode        = config["train_opts"]["debug_mode"]["enable"]
debug_rand_select = config["train_opts"]["debug_mode"]["rand_select"]

# train_opts.cuda
cuda_idx = config["train_opts"]["cuda"]["index"]
use_amp  = config["train_opts"]["cuda"]["use_amp"]

Generate `path_vars`

In [None]:
model_save_dir = os.path.join(save_dir_root, model_name)
dataset_dir = os.path.join(dataset_root, dataset_name, dataset_gen_method, dataset_stdev, dataset_param_name)
train_selected_dir = os.path.join(dataset_dir, "train", "selected")

In [None]:
# Set GPU
device, device_name = set_gpu(cuda_idx)
training_logger.info(f"Using '{device}', device_name = '{device_name}'")


# Create save model directory
time_stamp = datetime.now().strftime('%Y%m%d_%H_%M_%S')
save_dir = os.path.join(model_save_dir, f"Training_{time_stamp}")
create_new_dir(save_dir)
with open(os.path.normpath(f"{save_dir}/train_config.toml"), mode="w") as f_writer:
    tomlkit.dump(config, f_writer)


# Set 'np.random.seed'
np.random.seed(rand_seed)


# Create transform and Set 'rand_seed' if 'aug_on_fly' is True
if aug_on_fly: 
    ia.seed(rand_seed) # To get consistent augmentations ( 'imgaug' package )
    transform = compose_transform()
else: transform = None


# Scan classes to create 'class_mapper'
num2class_list, class2num_dict = get_sortedClassMapper_from_dir(train_selected_dir)
training_logger.info(f"num2class_list = {num2class_list}, class2num_dict = {class2num_dict}")


# Scan tiff
dataset_img_dict = { "all_classes" : [] }
# random sampling from each class with a constant value (forcing balance)
if aug_on_fly: dataset_img_dict['all_classes'] = glob(os.path.normpath(f"{train_selected_dir}/*/*selected*.tiff"))
elif forcing_balance:
    for key, value in class2num_dict.items(): # key: class, value: class_idx
        dataset_img_dict[key] = glob(os.path.normpath(f"{train_selected_dir}/{key}/*.tiff"))
        dataset_img_dict[key] = np.random.choice(dataset_img_dict[key], size=forcing_sample_amount, replace=False)
        dataset_img_dict['all_classes'].extend(dataset_img_dict[key])
else: dataset_img_dict['all_classes'] = glob(os.path.normpath(f"{train_selected_dir}/*/*.tiff"))
training_logger.info(f"total = {len(dataset_img_dict['all_classes'])}")
## debug mode: random select [debug_rand_select] images
if debug_mode:
    dataset_img_dict['all_classes'] = np.random.choice(dataset_img_dict['all_classes'], size=debug_rand_select, replace=False)
    training_logger.info(f"Debug mode, only select first {len(dataset_img_dict['all_classes'])}")


# Split train, valid dataset
train_img_list, valid_img_list = train_test_split(dataset_img_dict['all_classes'], random_state=rand_seed, train_size=train_ratio)
## save 'training_amount'
training_amount = f"{{ dataset_{len(dataset_img_dict['all_classes'])} }}_{{ train_{len(train_img_list)} }}_{{ valid_{len(valid_img_list)} }}"
with open(os.path.normpath(f"{save_dir}/{training_amount}"), mode="w") as f_writer: pass


# Create 'train_set', 'train_dataloader'
training_logger.info(f"train_data ({len(train_img_list)})")
[training_logger.info(f"{i} : img_path = {train_img_list[i]}") for i in range(5)]
train_set = ImgDataset(train_img_list, class_mapper=class2num_dict, resize=(224, 224), 
                       use_hsv=use_hsv, transform=transform, logger=training_logger)
train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True, pin_memory=True) # TODO:  Dataloader shuffle consistency
training_logger.info(f"※ : total train batches: {len(train_dataloader)}")


# Create 'valid_set', 'valid_dataloader'
training_logger.info(f"valid_data ({len(valid_img_list)})")
[training_logger.info(f"{i} : img_path = {valid_img_list[i]}") for i in range(5)]
valid_set = ImgDataset(valid_img_list, class_mapper=class2num_dict, resize=(224, 224), 
                       use_hsv=use_hsv)
valid_dataloader = DataLoader(valid_set, batch_size=batch_size, shuffle=False, pin_memory=True)
training_logger.info(f"※ : total valid batches: {len(valid_dataloader)}")


# Read test ( debug mode only )
if debug_mode:
    training_logger.info(f"Read Test: {train_img_list[-1]}")
    plot_in_rgb(train_img_list[-1], (512, 512))


# Create model ( ref: https://github.com/pytorch/vision/issues/7397 )
training_logger.info(f"load model from `torchvision`, model_name: '{model_name}', pretrain_weights: '{pretrain_weights}'")
model = getattr(torchvision.models, model_name)
model = model(weights=pretrain_weights)
## modify model structure
model.heads.head = nn.Linear(in_features=768, out_features=len(class2num_dict), bias=True)
model.to(device)
# print(model)


# Initial 'loss function' and 'optimizer'
if aug_on_fly:
    logs_path = os.path.join(dataset_dir, r"{Logs}_train_selected_summary.log")
    with open(logs_path, 'r') as f_writer: class_counts: Dict[str, int] = json.load(f_writer)
    loss_fn = nn.CrossEntropyLoss(weight=calculate_class_weight(class_counts)) # apply 'class_weight'
else: 
    loss_fn = nn.CrossEntropyLoss()
loss_fn.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
# optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, weight_decay=weight_decay, momentum=0.9)


# Initial 'lr_scheduler'
if use_lr_schedular: lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_schedular_step, gamma=lr_schedular_gamma)


# Automatic Mixed Precision (amp package)
if use_amp: scaler = GradScaler()


# Training
## training variables
training_timer = Timer()
train_logs = []
valid_logs = []
## progress bar
pbar_n_epoch = tqdm(total=epochs, desc=f"Epoch ")
pbar_n_train = tqdm(total=len(train_dataloader), desc="Train ")
pbar_n_valid = tqdm(total=len(valid_dataloader), desc="Valid ")
## best validation condition
best_val_log = { "Best": time_stamp, "epoch": 0 }
best_val_avg_loss = np.inf
best_val_weighted_f1 = 0.0
best_model_state_dict = copy.deepcopy(model.state_dict())
best_optimizer_state_dict = copy.deepcopy(optimizer.state_dict())
## early stop
accum_no_improved = 0
## exception
training_state = None
try:
    
    ## TODO:  recover training
    training_timer.start()
    for epoch in range(1, epochs+1):

        
        # Update progress bar description
        pbar_n_epoch.desc = f"Epoch {epoch} "
        pbar_n_epoch.refresh()
        pbar_n_train.n = 0
        pbar_n_train.refresh()
        pbar_n_valid.n = 0
        pbar_n_valid.refresh()
        
        
        
        # Start training
        ## reset variables
        epoch_train_log = { "Train": "", "epoch": epoch }
        pred_list = []
        gt_list = []
        accum_batch_loss = 0.0
        ## set to training mode
        model.train()
        for data in train_dataloader:
            x_train, y_train, crop_name_batch = data
            x_train, y_train = x_train.to(device), y_train.to(device) # move to GPU
            
            
            optimizer.zero_grad() # clean gradients before each iteration
            if use_amp:
                with autocast():
                    preds = model(x_train)
                    loss_value = loss_fn(preds, y_train)
                    
                scaler.scale(loss_value).backward() # 計算並縮放損失的梯度
                scaler.step(optimizer) # 更新模型參數
                scaler.update() # 更新縮放因子
                
            else:
                preds = model(x_train)
                loss_value = loss_fn(preds, y_train)
                
                ## update mode_parameters by back_propagation
                loss_value.backward()
                optimizer.step()
            
            
            ## extend 'pred_list', 'gt_list'
            pred_prob = torch.nn.functional.softmax(preds, dim=1)
            _, pred_train = torch.max(pred_prob, 1) # get the highest probability class
            pred_list.extend(pred_train.cpu().numpy().tolist()) # conversion flow: Tensor --> ndarray --> list
            gt_list.extend(y_train.cpu().numpy().tolist())
            ## add current batch loss
            accum_batch_loss += loss_value.item() # get value of Tensor
            
            ## update 'pbar_n_train'
            pbar_n_train.update(1)
            pbar_n_train.refresh()
        
        if use_lr_schedular: lr_scheduler.step() # update 'lr' for each epoch
        
        caulculate_metrics(epoch_train_log, (accum_batch_loss/len(train_dataloader)),
                           gt_list, pred_list, class2num_dict)
        # print(json.dumps(epoch_train_log, indent=4))
        train_logs.append(epoch_train_log)
        ## update postfix of 'pbar_n_train'
        if use_lr_schedular: pbar_n_train.postfix = f" {'{'} Loss: {epoch_train_log['average_loss']}, Weighted_f1: {epoch_train_log['weighted_f1']}, lr: {lr_scheduler.get_last_lr()[0]:.0e} {'}'} "
        else: pbar_n_train.postfix = f" {'{'} Loss: {epoch_train_log['average_loss']}, Weighted_f1: {epoch_train_log['weighted_f1']} {'}'} "
        pbar_n_train.refresh()
        # End training
        
        
        
        # Start validating
        ## reset variables
        epoch_valid_log = { "Valid": "", "epoch": epoch }
        pred_list = []
        gt_list = []
        accum_batch_loss = 0.0
        ## set to evaluation mode
        model.eval() 
        with torch.no_grad(): 
            for data in valid_dataloader:
                x_valid, y_valid, crop_name_batch = data
                x_valid, y_valid = x_valid.to(device), y_valid.to(device) # move to GPU
                preds = model(x_valid)
                loss_value = loss_fn(preds, y_valid)
                
                ## extend 'pred_list', 'gt_list'
                pred_prob = torch.nn.functional.softmax(preds, dim=1)
                _, pred_valid = torch.max(pred_prob, 1)
                pred_list.extend(pred_valid.cpu().numpy().tolist())
                gt_list.extend(y_valid.cpu().numpy().tolist())
                ## add current batch loss
                accum_batch_loss += loss_value.item()
                
                ## update 'pbar_n_valid'
                pbar_n_valid.update(1)
                pbar_n_valid.refresh()

        caulculate_metrics(epoch_valid_log, (accum_batch_loss/len(valid_dataloader)),
                           gt_list, pred_list, class2num_dict)
        # print(json.dumps(epoch_valid_log, indent=4))
        valid_logs.append(epoch_valid_log)
        ## update postfix of 'pbar_n_valid'
        pbar_n_valid.postfix = f" {'{'} Loss: {epoch_valid_log['average_loss']}, Weighted_f1: {epoch_valid_log['weighted_f1']} {'}'} "
        pbar_n_valid.refresh()
        # End validating
        
        
        
        # Check best condition
        epoch_valid_log["valid_state"] = ""
        if epoch_valid_log["weighted_f1"] > best_val_weighted_f1:
            best_val_weighted_f1 = epoch_valid_log["weighted_f1"]
            ## update 'best_val_log'
            caulculate_metrics(best_val_log, (accum_batch_loss/len(valid_dataloader)),
                               gt_list, pred_list, class2num_dict)
            epoch_valid_log["valid_state"] = "☆★☆ BEST_VALIDATION ☆★☆"
            tqdm.write((f"Epoch: {epoch:0{len(str(epochs))}}, "
                        f"☆★☆ BEST_VALIDATION ☆★☆, "
                        f"best_val_avg_loss = {best_val_log['average_loss']}, "
                        f"best_val_weighted_f1 = {best_val_log['weighted_f1']}"))
            best_model_state_dict = copy.deepcopy(model.state_dict())
            best_optimizer_state_dict = copy.deepcopy(optimizer.state_dict())
            best_val_log["epoch"] = epoch # put here to make sure all updates of best state has been done.
        
        # Update figure 
        plot_training_trend_kwargs = {
            "plt"        : plt,
            "save_dir"   : save_dir,
            "loss_key"   : "average_loss",
            "score_key"  : "weighted_f1",
            "train_logs" : pd.DataFrame(train_logs),
            "valid_logs" : pd.DataFrame(valid_logs),
        }
        plot_training_trend(**plot_training_trend_kwargs)
        
        # Update 'pbar_n_epoch'
        pbar_n_epoch.update(1)
        pbar_n_epoch.refresh()
        
        
        # Check 'EarlyStop'
        epoch_valid_log["valid_improve"] = ""
        if enable_earlystop:
            if epoch_valid_log["average_loss"] < best_val_avg_loss:
                best_val_avg_loss = epoch_valid_log["average_loss"]
                accum_no_improved = 0
            else:
                epoch_valid_log["valid_improve"] = "◎㊣◎ NO_IMPROVED ◎㊣◎"
                accum_no_improved += 1
                if accum_no_improved == max_no_improved: sys.exit()


except KeyboardInterrupt:
    training_state = "KeyboardInterrupt"
    tqdm.write("KeyboardInterrupt")
    
except SystemExit:
    training_state = "EarlyStop"
    tqdm.write("EarlyStop, exit training")
    
except Exception as e:
    training_state = "ExceptionError"
    tqdm.write(traceback.format_exc())
    with open(os.path.normpath(f"{save_dir}/{{Logs}}_ExceptionError.log"), mode="w") as f_writer:
        f_writer.write(traceback.format_exc())

else:
    training_state = "Completed"
    tqdm.write("Training Completed")
    
finally:

    # Close 'progress_bar'
    pbar_n_epoch.close()
    pbar_n_train.close()
    pbar_n_valid.close()
    
    # Save training consume time
    training_timer.stop()
    training_timer.calculate_consume_time()
    training_timer.save_consume_time(save_dir, desc="training time")
    
    
    if best_val_log["epoch"] > 0: # If `best_val_log["epoch"]` > 0, all of `logs` and `state_dict` are not empty.
        
        # Save logs (convert to Dataframe)
        save_training_logs(save_dir, train_logs, valid_logs, best_val_log)
        
        # Save model
        save_model("best", save_dir, best_model_state_dict, best_optimizer_state_dict, best_val_log)
        save_model("final", save_dir, model.state_dict(), optimizer.state_dict(), {"train": train_logs, "valid": valid_logs})

        # Rename 'save_dir'
        ## new_name_format = {time_stamp}_{state}_{target_epochs_with_ImgLoadOptions}_{test_f1}
        ## state = {EarlyStop, Interrupt, Completed, Tested, etc.}
        rename_training_dir(state=training_state, orig_dir_path=save_dir, 
                            epochs=valid_logs[-1]["epoch"], aug_on_fly=aug_on_fly, use_hsv=use_hsv, time_stamp=time_stamp)
        
    else: # Delete folder if less than one epoch has been completed.
        
        print(f"Less than One epoch has been completed, remove directory '{save_dir}' ")
        shutil.rmtree(save_dir)