# Script for `Training`

In [None]:
import os
import sys
import copy
import shutil
import traceback
from pathlib import Path
from typing import List, Dict, Tuple
from datetime import datetime
from collections import Counter
from glob import glob
import json
import toml
import tomlkit

from tqdm.auto import tqdm
import pandas as pd
import numpy as np
import cv2
import matplotlib.pyplot as plt

import torch
from torch import nn, utils
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
import torchvision

from sklearn.model_selection import train_test_split

import imgaug as ia

abs_module_path = Path("./../modules/").resolve()
if (abs_module_path.exists()) and (str(abs_module_path) not in sys.path):
    sys.path.append(str(abs_module_path)) # add path to scan customized module

from logger import init_logger
from fileop import create_new_dir
from dl_utils import caulculate_metrics, save_model, plot_training_trend, \
                     compose_transform, save_training_logs 
from dl.utils import set_gpu, gen_class2num_dict, get_fish_path, get_fish_class, \
                     calculate_class_weight, rename_training_dir
from dl.ImageDataset import ImgDataset_v2
from misc.Timer import Timer
from plt_show import plot_in_rgb # server can't use `cv.imshow()`

config_dir = Path( "./../Config/" ).resolve()

# print("="*100, "\n")

In [None]:
training_logger = init_logger(r"Train 'vit_b_16'")

Load `db_path_plan.toml`

In [None]:
with open(config_dir.joinpath("db_path_plan.toml"), mode="r") as f_reader:
    dbpp_config = toml.load(f_reader)

db_root = Path(dbpp_config["root"])

Load `(TrainModel)_vit_b_16.toml`

In [None]:
config_name = "(TrainModel)_vit_b_16.toml"

with open(config_dir.joinpath(config_name), mode="r") as f_reader:
    config = toml.load(f_reader)

# dataset
dataset_name             = config["dataset"]["name"]
dataset_result_alias     = config["dataset"]["result_alias"]
dataset_gen_method       = config["dataset"]["gen_method"]
dataset_classif_strategy = config["dataset"]["classif_strategy"]
dataset_param_name       = config["dataset"]["param_name"]

# model
model_name         = config["model"]["model_name"]
pretrain_weights   = config["model"]["pretrain_weights"]

# train_opts
train_ratio  = config["train_opts"]["train_ratio"]
rand_seed    = config["train_opts"]["random_seed"]
epochs       = config["train_opts"]["epochs"]
batch_size   = config["train_opts"]["batch_size"]

# train_opts.optimizer
lr           = config["train_opts"]["optimizer"]["learning_rate"]
weight_decay = config["train_opts"]["optimizer"]["weight_decay"]

# train_opts.lr_schedular
use_lr_schedular   = config["train_opts"]["lr_schedular"]["enable"]
lr_schedular_step  = config["train_opts"]["lr_schedular"]["step"]
lr_schedular_gamma = config["train_opts"]["lr_schedular"]["gamma"]

# train_opts.earlystop
enable_earlystop = config["train_opts"]["earlystop"]["enable"]
max_no_improved  = config["train_opts"]["earlystop"]["max_no_improved"]

# train_opts.data
use_hsv               = config["train_opts"]["data"]["use_hsv"]
aug_on_fly            = config["train_opts"]["data"]["aug_on_fly"]
forcing_balance       = config["train_opts"]["data"]["forcing_balance"]
forcing_sample_amount = config["train_opts"]["data"]["forcing_sample_amount"]
if aug_on_fly and forcing_balance:
    raise ValueError("'aug_on_fly' and 'forcing_balance' can only set one to True at a time")

# train_opts.debug_mode
debug_mode        = config["train_opts"]["debug_mode"]["enable"]
debug_rand_select = config["train_opts"]["debug_mode"]["rand_select"]

# train_opts.cuda
cuda_idx = config["train_opts"]["cuda"]["index"]
use_amp  = config["train_opts"]["cuda"]["use_amp"]

# train_opts.cpu.multiworker
num_workers = config["train_opts"]["cpu"]["num_workers"]

Generate `path_vars`

In [None]:
dataset_cropped_root = db_root.joinpath(dbpp_config["dataset_cropped_v2"])
model_cmd_root = db_root.joinpath(dbpp_config["model_cmd"])

dataset_xlsx_path = dataset_cropped_root.joinpath(dataset_name, dataset_result_alias, dataset_gen_method, 
                                                      dataset_classif_strategy, f"{dataset_param_name}.xlsx")
assert dataset_xlsx_path.exists(), f"Can't find `dataset_xlsx`: '{dataset_xlsx_path}'"

DataLoader Reproducibility

In [None]:
import random

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(0)

Load `'dataset_xlsx'` as DataFrame ( pandas )

In [None]:
df_dataset_xlsx: pd.DataFrame = pd.read_excel(dataset_xlsx_path, engine = 'openpyxl')

Test: split `train/valid`

In [None]:
df_training = df_dataset_xlsx[(df_dataset_xlsx['dataset'] == "train") & 
                                (df_dataset_xlsx['state'] == "preserve")]
train_name_list = []
valid_name_list = []
train_class_counts = {}

for cls in ["L", "M", "S"]:

    df = df_training[(df_training['class'] == cls)]
    train = df.sample(frac=train_ratio, replace=False, random_state=2022)
    valid = df[~df.index.isin(train.index)]
    train_name_list.extend(list(train["image_name"]))
    valid_name_list.extend(list(valid["image_name"]))

    train_d = train[(train['cut_section'] == "D")]
    train_u = train[(train['cut_section'] == "U")]
    valid_d = valid[(valid['cut_section'] == "D")]
    valid_u = valid[(valid['cut_section'] == "U")]

    print(f"train_d: {len(train_d)}, train_u: {len(train_u)}")
    print(f"valid_d: {len(valid_d)}, valid_u: {len(valid_u)}")
    
    train_class_counts[cls] = len(train)
    
print(len(train_name_list), len(valid_name_list), "\n")
print(train_class_counts)

# train_d: 347, train_u: 333
# valid_d: 78, valid_u: 92
# train_d: 265, train_u: 263
# valid_d: 65, valid_u: 67
# train_d: 382, train_u: 378
# valid_d: 93, valid_u: 97
# 1968 492 

# {'L': 680, 'M': 528, 'S': 760}

Test: show image

In [None]:
img_path = get_fish_path(train_name_list[-1], df_training)
training_logger.info(f"Read Test: '{img_path}'")
plot_in_rgb(str(img_path), (512, 512))

Test: Train part

In [None]:
# class2num_dict = {'L': 0, 'M': 1, 'S': 2}
# transform = compose_transform()
# train_set = ImgDataset_v2(train_name_list, df_dataset_xlsx, class_mapper=class2num_dict, resize=(224, 224),
#                           use_hsv=use_hsv, transform=transform, logger=training_logger)
# train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True, pin_memory=True)

# for data in train_dataloader:
#     x_train, y_train, crop_name_batch = data
    
#     x_train = x_train[0]
#     y_train = y_train[0]
#     crop_name_batch = crop_name_batch[0]
    
#     x_train = np.moveaxis(x_train.cpu().numpy(), 0, -1)
#     path = get_fish_path(crop_name_batch, df_dataset_xlsx)
#     cls = get_fish_class(crop_name_batch, df_dataset_xlsx)
    
#     print(crop_name_batch, cls, y_train, f'{path}')
#     plot_in_rgb(str(path), (512, 512))
#     plt.imshow(x_train)

Start

In [None]:
# Set GPU
device = set_gpu(cuda_idx, training_logger)


# Create save model directory
time_stamp = datetime.now().strftime('%Y%m%d_%H_%M_%S')
save_dir = model_cmd_root.joinpath(f"Training_{time_stamp}")
create_new_dir(save_dir)
with open(save_dir.joinpath("train_config.toml"), mode="w") as f_writer:
    tomlkit.dump(config, f_writer)


# Set 'np.random.seed', 'torch.manual_seed'
np.random.seed(rand_seed)
torch.manual_seed(rand_seed) # Dataloader shuffle consistency
ia.seed(rand_seed) # To get consistent augmentations ( 'imgaug' package )


# Create transform and Set 'rand_seed' if 'aug_on_fly' is True
if aug_on_fly: transform = compose_transform()
else: transform = None


# Scan classes to create 'class_mapper'
num2class_list = sorted(Counter(df_dataset_xlsx["class"]).keys())
class2num_dict = gen_class2num_dict(num2class_list)
training_logger.info(f"num2class_list = {num2class_list}, class2num_dict = {class2num_dict}")
# CLI >> num2class_list = ['L', 'M', 'S'], class2num_dict = {'L': 0, 'M': 1, 'S': 2}


# Split train, valid set
#  TODO:  forcing_balance

df_training = df_dataset_xlsx[(df_dataset_xlsx['dataset'] == "train") & 
                                (df_dataset_xlsx['state'] == "preserve")]
train_name_list = []
valid_name_list = []
train_class_counts = {}

for cls in num2class_list:

    df = df_training[(df_training['class'] == cls)]
    train = df.sample(frac=train_ratio, replace=False, random_state=2022)
    valid = df[~df.index.isin(train.index)]
    train_name_list.extend(list(train["image_name"]))
    valid_name_list.extend(list(valid["image_name"]))

    train_d = train[(train['cut_section'] == "D")]
    train_u = train[(train['cut_section'] == "U")]
    valid_d = valid[(valid['cut_section'] == "D")]
    valid_u = valid[(valid['cut_section'] == "U")]

    training_logger.info(f"{cls}: [ train_d: {len(train_d)}, train_u: {len(train_u)} ] "
                                f"[ valid_d: {len(valid_d)}, valid_u: {len(valid_u)} ]")
    
    train_class_counts[cls] = len(train)


## save 'training_amount'
training_amount = f"{{ dataset_{len(df_training)} }}_{{ train_{len(train_name_list)} }}_{{ valid_{len(valid_name_list)} }}"
with open(save_dir.joinpath(training_amount), mode="w") as f_writer: pass


# Create 'train_set', 'train_dataloader'
training_logger.info(f"train_data ({len(train_name_list)})")
[training_logger.info(f"{i} : img_path = {train_name_list[i]}") for i in range(5)]
train_set = ImgDataset_v2(train_name_list, df_dataset_xlsx, class_mapper=class2num_dict, resize=(224, 224),
                          use_hsv=use_hsv, transform=transform, logger=training_logger)
if num_workers > 0:
    train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True, 
                                  pin_memory=True, num_workers=num_workers, 
                                  worker_init_fn=seed_worker, generator=g)
else: train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True, pin_memory=True)
training_logger.info(f"※ : total train batches: {len(train_dataloader)}")


# Create 'valid_set', 'valid_dataloader'
training_logger.info(f"valid_data ({len(valid_name_list)})")
[training_logger.info(f"{i} : img_path = {valid_name_list[i]}") for i in range(5)]
valid_set = ImgDataset_v2(valid_name_list, df_dataset_xlsx, class_mapper=class2num_dict, resize=(224, 224),
                          use_hsv=use_hsv)
if num_workers > 0:
    valid_dataloader = DataLoader(valid_set, batch_size=batch_size, shuffle=False, 
                                  pin_memory=True, num_workers=num_workers)
else: valid_dataloader = DataLoader(valid_set, batch_size=batch_size, shuffle=False, pin_memory=True)
training_logger.info(f"※ : total valid batches: {len(valid_dataloader)}")


# Read test ( debug mode only )
if debug_mode:
    img_path = get_fish_path(train_name_list[-1], df_training)
    training_logger.info(f"Read Test: '{img_path}'")
    plot_in_rgb(str(img_path), (512, 512))


# Create model ( ref: https://github.com/pytorch/vision/issues/7397 )
training_logger.info(f"load model from `torchvision`, model_name: '{model_name}', pretrain_weights: '{pretrain_weights}'")
model = getattr(torchvision.models, model_name)
model = model(weights=pretrain_weights)
## modify model structure
model.heads.head = nn.Linear(in_features=768, out_features=len(class2num_dict), bias=True)
model.to(device)
# print(model)


# Initial 'loss function' and 'optimizer'
if aug_on_fly:
    # `loss_function` with `class_weight`
    loss_fn = nn.CrossEntropyLoss(weight=calculate_class_weight(train_class_counts))
else: 
    loss_fn = nn.CrossEntropyLoss()
loss_fn.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
# optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, weight_decay=weight_decay, momentum=0.9)


# Initial 'lr_scheduler'
if use_lr_schedular: 
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_schedular_step, gamma=lr_schedular_gamma)


# Automatic Mixed Precision (amp package)
if use_amp: scaler = GradScaler()


# Training
## training variables
training_timer = Timer()
train_logs = []
valid_logs = []
score_key = "maweavg_f1"; training_logger.info("※ : maweavg_f1 = ( macro_f1 + weighted_f1 ) / 2")
## progress bar
pbar_n_epoch = tqdm(total=epochs, desc=f"Epoch ")
pbar_n_train = tqdm(total=len(train_dataloader), desc="Train ")
pbar_n_valid = tqdm(total=len(valid_dataloader), desc="Valid ")
## best validation condition
best_val_log = { "Best": time_stamp, "epoch": 0 }
best_val_avg_loss = np.inf
best_val_f1 = 0.0
best_model_state_dict = copy.deepcopy(model.state_dict())
best_optimizer_state_dict = copy.deepcopy(optimizer.state_dict())
## early stop
accum_no_improved = 0
## exception
training_state = None
try:
    
    ## TODO:  recover training
    training_timer.start()
    for epoch in range(1, epochs+1):

        
        # Update progress bar description
        pbar_n_epoch.desc = f"Epoch {epoch} "
        pbar_n_epoch.refresh()
        pbar_n_train.n = 0
        pbar_n_train.refresh()
        pbar_n_valid.n = 0
        pbar_n_valid.refresh()
        
        
        
        # Start training
        ## reset variables
        epoch_train_log = { "Train": "", "epoch": epoch }
        pred_list = []
        gt_list = []
        accum_batch_loss = 0.0
        ## set to training mode
        model.train()
        for data in train_dataloader:
            x_train, y_train, crop_name_batch = data
            x_train, y_train = x_train.to(device), y_train.to(device) # move to GPU
            
            
            optimizer.zero_grad() # clean gradients before each iteration
            if use_amp:
                with autocast():
                    preds = model(x_train)
                    loss_value = loss_fn(preds, y_train)
                    
                scaler.scale(loss_value).backward() # 計算並縮放損失的梯度
                scaler.step(optimizer) # 更新模型參數
                scaler.update() # 更新縮放因子
                
            else:
                preds = model(x_train)
                loss_value = loss_fn(preds, y_train)
                
                ## update mode_parameters by back_propagation
                loss_value.backward()
                optimizer.step()
            
            
            ## extend 'pred_list', 'gt_list'
            pred_prob = torch.nn.functional.softmax(preds, dim=1)
            _, pred_train = torch.max(pred_prob, 1) # get the highest probability class
            pred_list.extend(pred_train.cpu().numpy().tolist()) # conversion flow: Tensor --> ndarray --> list
            gt_list.extend(y_train.cpu().numpy().tolist())
            ## add current batch loss
            accum_batch_loss += loss_value.item() # get value of Tensor
            
            ## update 'pbar_n_train'
            pbar_n_train.update(1)
            pbar_n_train.refresh()
        
        if use_lr_schedular: lr_scheduler.step() # update 'lr' for each epoch
        
        caulculate_metrics(epoch_train_log, (accum_batch_loss/len(train_dataloader)),
                           gt_list, pred_list, class2num_dict)
        # print(json.dumps(epoch_train_log, indent=4))
        train_logs.append(epoch_train_log)
        ## update postfix of 'pbar_n_train'
        if use_lr_schedular: pbar_n_train.postfix = (f" {'{'} Loss: {epoch_train_log['average_loss']}, "
                                                     f"{score_key}: {epoch_train_log[f'{score_key}']}, "
                                                     f"lr: {lr_scheduler.get_last_lr()[0]:.0e} {'}'} ")
        else: pbar_n_train.postfix = (f" {'{'} Loss: {epoch_train_log['average_loss']}, "
                                      f"{score_key}: {epoch_train_log[f'{score_key}']} {'}'} ")
        pbar_n_train.refresh()
        # End training
        
        
        
        # Start validating
        ## reset variables
        epoch_valid_log = { "Valid": "", "epoch": epoch }
        pred_list = []
        gt_list = []
        accum_batch_loss = 0.0
        ## set to evaluation mode
        model.eval() 
        with torch.no_grad(): 
            for data in valid_dataloader:
                x_valid, y_valid, crop_name_batch = data
                x_valid, y_valid = x_valid.to(device), y_valid.to(device) # move to GPU
                preds = model(x_valid)
                loss_value = loss_fn(preds, y_valid)
                
                ## extend 'pred_list', 'gt_list'
                pred_prob = torch.nn.functional.softmax(preds, dim=1)
                _, pred_valid = torch.max(pred_prob, 1)
                pred_list.extend(pred_valid.cpu().numpy().tolist())
                gt_list.extend(y_valid.cpu().numpy().tolist())
                ## add current batch loss
                accum_batch_loss += loss_value.item()
                
                ## update 'pbar_n_valid'
                pbar_n_valid.update(1)
                pbar_n_valid.refresh()

        caulculate_metrics(epoch_valid_log, (accum_batch_loss/len(valid_dataloader)),
                           gt_list, pred_list, class2num_dict)
        # print(json.dumps(epoch_valid_log, indent=4))
        valid_logs.append(epoch_valid_log)
        ## update postfix of 'pbar_n_valid'
        pbar_n_valid.postfix = (f" {'{'} Loss: {epoch_valid_log['average_loss']}, "
                                f"{score_key}: {epoch_valid_log[f'{score_key}']} {'}'} ")
        pbar_n_valid.refresh()
        # End validating
        
        
        
        # Check best condition
        epoch_valid_log["valid_state"] = ""
        if epoch_valid_log[f"{score_key}"] > best_val_f1:
            best_val_f1 = epoch_valid_log[f"{score_key}"]
            ## update 'best_val_log'
            caulculate_metrics(best_val_log, (accum_batch_loss/len(valid_dataloader)),
                               gt_list, pred_list, class2num_dict)
            epoch_valid_log["valid_state"] = "☆★☆ BEST_VALIDATION ☆★☆"
            tqdm.write((f"Epoch: {epoch:0{len(str(epochs))}}, "
                        f"☆★☆ BEST_VALIDATION ☆★☆, "
                        f"best_val_avg_loss = {best_val_log['average_loss']}, "
                        f"best_val_{score_key} = {best_val_log[f'{score_key}']}"))
            best_model_state_dict = copy.deepcopy(model.state_dict())
            best_optimizer_state_dict = copy.deepcopy(optimizer.state_dict())
            best_val_log["epoch"] = epoch # put here to make sure all updates of best state has been done.
        
        # Update figure 
        plot_training_trend_kwargs = {
            "plt"        : plt,
            "save_dir"   : save_dir,
            "loss_key"   : "average_loss",
            "score_key"  : score_key,
            "train_logs" : pd.DataFrame(train_logs),
            "valid_logs" : pd.DataFrame(valid_logs),
        }
        plot_training_trend(**plot_training_trend_kwargs)
        
        # Update 'pbar_n_epoch'
        pbar_n_epoch.update(1)
        pbar_n_epoch.refresh()
        
        
        # Check 'EarlyStop'
        epoch_valid_log["valid_improve"] = ""
        if enable_earlystop:
            if epoch_valid_log["average_loss"] < best_val_avg_loss:
                best_val_avg_loss = epoch_valid_log["average_loss"]
                accum_no_improved = 0
            else:
                epoch_valid_log["valid_improve"] = "◎㊣◎ NO_IMPROVED ◎㊣◎"
                accum_no_improved += 1
                tqdm.write(f"Epoch: {epoch:0{len(str(epochs))}}, ◎㊣◎ NO_IMPROVED ◎㊣◎, accum_no_improved = {accum_no_improved}")
                if accum_no_improved == max_no_improved: sys.exit()


except KeyboardInterrupt:
    training_state = "KeyboardInterrupt"
    tqdm.write("KeyboardInterrupt")
    
except SystemExit:
    training_state = "EarlyStop"
    tqdm.write("EarlyStop, exit training")
    
except Exception as e:
    training_state = "ExceptionError"
    tqdm.write(traceback.format_exc())
    with open(save_dir.joinpath(f"{{Logs}}_ExceptionError.log"), mode="w") as f_writer:
        f_writer.write(traceback.format_exc())

else:
    training_state = "Completed"
    tqdm.write("Training Completed")
    
finally:

    # Close 'progress_bar'
    pbar_n_epoch.close()
    pbar_n_train.close()
    pbar_n_valid.close()
    
    # Save training consume time
    training_timer.stop()
    training_timer.calculate_consume_time()
    training_timer.save_consume_time(save_dir, desc="training time")
    
    
    if best_val_log["epoch"] > 0: # If `best_val_log["epoch"]` > 0, all of `logs` and `state_dict` are not empty.
        
        # Save logs (convert to Dataframe)
        save_training_logs(save_dir, train_logs, valid_logs, best_val_log)
        
        # Save model
        save_model("best", save_dir, best_model_state_dict, best_optimizer_state_dict, best_val_log)
        save_model("final", save_dir, model.state_dict(), optimizer.state_dict(), {"train": train_logs, "valid": valid_logs})

        # Rename 'save_dir'
        ## new_name_format = {time_stamp}_{state}_{target_epochs_with_ImgLoadOptions}_{test_f1}
        ## state = {EarlyStop, Interrupt, Completed, Tested, etc.}
        rename_training_dir(save_dir, time_stamp=time_stamp, state=training_state,
                            epochs=valid_logs[-1]["epoch"], aug_on_fly=aug_on_fly, use_hsv=use_hsv)
        
    else: # Delete folder if less than one epoch has been completed.
        
        training_logger.info(f"Less than One epoch has been completed, remove directory '{save_dir}' ")
        shutil.rmtree(save_dir)