# Script for ```Training```

In [1]:
import os
import sys
import copy
import traceback
import shutil
from typing import List, Dict, Tuple
from datetime import datetime
from glob import glob
import json
import argparse

from tqdm.auto import tqdm
import pandas as pd
import numpy as np
import cv2
import matplotlib.pyplot as plt

import torch
from torch import nn, utils
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import train_test_split

import imgaug as ia

sys.path.append(r"C:\Users\confocal_microscope\Desktop\ZebraFish_AP_POS\modules") # add path to scan customized module
from logger import init_logger
from fileop import create_new_dir
from dl_utils import set_gpu, ImgDataset, caulculate_metrics, save_model, plot_training_trend, \
                     compose_transform, calculate_class_weight, get_sortedClassMapper_from_dir, \
                     rename_training_dir, save_training_logs
from misc_utils import Timer

# print("="*100, "\n")

In [2]:
training_logger = init_logger(r"Training")

constant path

In [3]:
ap_dataset_root = r"C:\Users\confocal_microscope\Desktop\{Test}_DataSet"
save_dir_root = r"C:\Users\confocal_microscope\Desktop\{Test}_Model_history"

args

In [4]:
dataset_name = r"{20230305_NEW_STRUCT}_Academia_Sinica_i409"
dataset_gen_method = "fish_dataset_horiz_cut_1l2_Mix_AP"
dataset_param_name = "DS_SURF3C_CRPS512_SF14_INT20_DRP100_RS2022"
cuda_idx = 1
label_in_filename = 0
train_ratio = 0.8
rand_seed = 2022
model_name = "vit_b_16"
pretrain_weights = "IMAGENET1K_V1"
epochs = 100
batch_size = 32
lr = 1e-5
use_hsv = False # using 'HSV' when getting images from the 'ImgDataset'
aug_on_fly = True # applying augmentation on the fly
forcing_balance = False
forcing_sample_amount = 2800
enable_earlystop = True
max_no_improved = 10 # EarlyStop

if aug_on_fly and forcing_balance: raise ValueError("'aug_on_fly' and 'forcing_balance' can only set one to True at a time")

debug_mode = False # if True, sample 200 images only

# Create path var
save_dir_model = os.path.join(save_dir_root, model_name)
train_selected_dir = os.path.join(ap_dataset_root, dataset_name, dataset_gen_method, dataset_param_name, "train", "selected")

# Set GPU
device, device_name = set_gpu(cuda_idx)
training_logger.info(f"Using '{device}', device_name = '{device_name}'")

| 2023-04-04 04:40:14,029 | Training | INFO | Using 'cuda', device_name = 'NVIDIA GeForce RTX 2080 Ti'


In [5]:
# Create save model directory
time_stamp = datetime.now().strftime('%Y%m%d_%H_%M_%S')
save_dir = os.path.join(save_dir_model, f"Training_{time_stamp}")
create_new_dir(save_dir)


# Set 'np.random.seed'
np.random.seed(rand_seed)


# Create transform and Set 'rand_seed' if 'aug_on_fly' is True
if aug_on_fly: 
    ia.seed(rand_seed) # To get consistent augmentations ( 'imgaug' package )
    transform = compose_transform()
else: transform = None


# Scan classes to create 'class_mapper'
num2class_list, class2num_dict = get_sortedClassMapper_from_dir(train_selected_dir)
training_logger.info(f"num2class_list = {num2class_list}, class2num_dict = {class2num_dict}")


# Scan tiff
dataset_img_dict = { "all_classes" : [] }
# random sampling from each class with a constant value (forcing balance)
if aug_on_fly: dataset_img_dict['all_classes'] = glob(os.path.normpath(f"{train_selected_dir}/*/*selected*.tiff"))
elif forcing_balance:
    for key, value in class2num_dict.items(): # key: class, value: class_idx
        dataset_img_dict[key] = glob(os.path.normpath(f"{train_selected_dir}/{key}/*.tiff"))
        dataset_img_dict[key] = np.random.choice(dataset_img_dict[key], size=forcing_sample_amount, replace=False)
        dataset_img_dict['all_classes'].extend(dataset_img_dict[key])
else: dataset_img_dict['all_classes'] = glob(os.path.normpath(f"{train_selected_dir}/*/*.tiff"))
training_logger.info(f"total = {len(dataset_img_dict['all_classes'])}")
## debug mode: random select 200 images
if debug_mode:
    dataset_img_dict['all_classes'] = np.random.choice(dataset_img_dict['all_classes'], size=200, replace=False)
    training_logger.info(f"Debug mode, only select first {len(dataset_img_dict['all_classes'])}")


# Split train, valid dataset
train_img_list, valid_img_list = train_test_split(dataset_img_dict['all_classes'], random_state=rand_seed, train_size=train_ratio)
## save 'training_amount'
training_amount = f"{{ dataset_{len(dataset_img_dict['all_classes'])} }}_{{ train_{len(train_img_list)} }}_{{ valid_{len(valid_img_list)} }}"
with open(os.path.normpath(f"{save_dir}/{training_amount}"), mode="w") as f_writer: pass


# Create 'train_set', 'train_dataloader'
training_logger.info(f"train_data ({len(train_img_list)})")
[training_logger.info(f"{i} : img_path = {train_img_list[i]}") for i in range(5)]
train_set = ImgDataset(train_img_list, class_mapper=class2num_dict, label_in_filename=label_in_filename, 
                       use_hsv=use_hsv, transform=transform, logger=training_logger)
train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True, pin_memory=True) # TODO:  Dataloader shuffle consistency
training_logger.info(f"※ : total train batches: {len(train_dataloader)}")


# Create 'valid_set', 'valid_dataloader'
training_logger.info(f"valid_data ({len(valid_img_list)})")
[training_logger.info(f"{i} : img_path = {valid_img_list[i]}") for i in range(5)]
valid_set = ImgDataset(valid_img_list, class_mapper=class2num_dict, label_in_filename=label_in_filename, 
                       use_hsv=use_hsv)
valid_dataloader = DataLoader(valid_set, batch_size=batch_size, shuffle=False, pin_memory=True)
training_logger.info(f"※ : total valid batches: {len(valid_dataloader)}")


# Read test ( debug mode only )
if debug_mode:
    read_test = cv2.imread(train_img_list[-1])
    training_logger.info(f"Read Test: {train_img_list[-1]}")
    cv2.imshow("Read Test", read_test)
    cv2.waitKey(0)


# Create model
training_logger.info(f"load model using 'torch.hub.load()', model_name: '{model_name}', pretrain_weights: '{pretrain_weights}'")
model = torch.hub.load('pytorch/vision', model_name, weights=pretrain_weights)
## modify model structure
model.heads.head = nn.Linear(in_features=768, out_features=len(class2num_dict), bias=True)
model.to(device)
# print(model)


# Initial 'loss function' and 'optimizer'
if aug_on_fly:
    logs_path = os.path.join(ap_dataset_root, dataset_name, dataset_gen_method, 
                             dataset_param_name, r"{Logs}_train_selected_summary.log")
    with open(logs_path, 'r') as f_writer: class_counts: Dict[str, int] = json.load(f_writer)
    loss_fn = nn.CrossEntropyLoss(weight=calculate_class_weight(class_counts)) # apply 'class_weight'
else: 
    loss_fn = nn.CrossEntropyLoss()
loss_fn.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01) # TODO:  use momentum, lr_scheduler
# optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, weight_decay=0.01, momentum=0.9) # TODO:  use momentum, lr_scheduler


# Training
## training variables
training_timer = Timer()
train_logs = []
valid_logs = []
## progress bar
pbar_n_epoch = tqdm(total=epochs, desc=f"Epoch ")
pbar_n_train = tqdm(total=len(train_dataloader), desc="Train ")
pbar_n_valid = tqdm(total=len(valid_dataloader), desc="Valid ")
## best validation condition
best_val_log = { "Best": time_stamp, "epoch": 0 }
best_val_avg_f1 = 0.0
best_val_avg_loss = np.inf
best_model_state_dict = copy.deepcopy(model.state_dict())
best_optimizer_state_dict = copy.deepcopy(optimizer.state_dict())
## early stop
accum_no_improved = 0
## exception
training_state = None
try:
    
    ## TODO:  recover training
    training_timer.start()
    for epoch in range(1, epochs+1):

        
        # Update progress bar description
        pbar_n_epoch.desc = f"Epoch {epoch} "
        pbar_n_epoch.refresh()
        pbar_n_train.n = 0
        pbar_n_train.refresh()
        pbar_n_valid.n = 0
        pbar_n_valid.refresh()
        
        
        
        # Start training
        ## reset variables
        epoch_train_log = { "Train": "", "epoch": epoch }
        pred_list = []
        gt_list = []
        accum_batch_loss = 0.0
        ## set to training mode
        model.train()
        for data in train_dataloader:
            x_train, y_train, fish_ID_list = data
            x_train, y_train = x_train.to(device), y_train.to(device) # move to GPU
            preds = model(x_train)
            loss_value = loss_fn(preds, y_train)
            
            ## update mode_parameters by back_propagation
            loss_value.backward()
            optimizer.step()
            optimizer.zero_grad() # clean gradients after step
            
            ## extend 'pred_list', 'gt_list'
            _, pred_train = torch.max(preds, 1) # get the highest probability class
            pred_list.extend(pred_train.cpu().numpy().tolist()) # conversion flow: Tensor --> ndarray --> list
            gt_list.extend(y_train.cpu().numpy().tolist())
            ## add current batch loss
            accum_batch_loss += loss_value.item() # get value of Tensor
            
            ## update 'pbar_n_train'
            pbar_n_train.update(1)
            pbar_n_train.refresh()
        
        caulculate_metrics(epoch_train_log, (accum_batch_loss/len(train_dataloader)),
                        gt_list, pred_list, class2num_dict)
        # print(json.dumps(epoch_train_log, indent=4))
        train_logs.append(epoch_train_log)
        ## update postfix of 'pbar_n_train'
        pbar_n_train.postfix = f" {'{'} Loss: {epoch_train_log['average_loss']}, Avg_f1: {epoch_train_log['average_f1']} {'}'} "
        pbar_n_train.refresh()
        # End training
        
        
        
        # Start validating
        ## reset variables
        epoch_valid_log = { "Valid": "", "epoch": epoch }
        pred_list = []
        gt_list = []
        accum_batch_loss = 0.0
        ## set to evaluation mode
        model.eval() 
        with torch.no_grad(): 
            for data in valid_dataloader:
                x_valid, y_valid, fish_ID_list = data
                x_valid, y_valid = x_valid.to(device), y_valid.to(device) # move to GPU
                preds = model(x_valid)
                loss_value = loss_fn(preds, y_valid)
                
                ## extend 'pred_list', 'gt_list'
                _, pred_valid = torch.max(preds, 1)
                pred_list.extend(pred_valid.cpu().numpy().tolist())
                gt_list.extend(y_valid.cpu().numpy().tolist())
                ## add current batch loss
                accum_batch_loss += loss_value.item()
                
                ## update 'pbar_n_valid'
                pbar_n_valid.update(1)
                pbar_n_valid.refresh()

        caulculate_metrics(epoch_valid_log, (accum_batch_loss/len(valid_dataloader)),
                        gt_list, pred_list, class2num_dict)
        # print(json.dumps(epoch_valid_log, indent=4))
        valid_logs.append(epoch_valid_log)
        ## update postfix of 'pbar_n_valid'
        pbar_n_valid.postfix = f" {'{'} Loss: {epoch_valid_log['average_loss']}, Avg_f1: {epoch_valid_log['average_f1']} {'}'} "
        pbar_n_valid.refresh()
        # End validating
        
        
        
        # Check best condition, average_f1 = (macro_f1 + micro_f1)/2
        if epoch_valid_log["average_f1"] > best_val_avg_f1:
            best_val_avg_f1 = epoch_valid_log["average_f1"]
            ## update 'best_val_log'
            best_val_log["epoch"] = epoch
            caulculate_metrics(best_val_log, (accum_batch_loss/len(valid_dataloader)),
                            gt_list, pred_list, class2num_dict)
            tqdm.write((f"Epoch: {epoch:0{len(str(epochs))}}, "
                        f"☆★☆ BEST_VALIDATION ☆★☆, "
                        f"best_val_avg_loss = {best_val_log['average_loss']}, "
                        f"best_val_avg_f1 = {best_val_log['average_f1']}"))
            best_model_state_dict = copy.deepcopy(model.state_dict())
            best_optimizer_state_dict = copy.deepcopy(optimizer.state_dict())
        
        # Update figure 
        plot_training_trend_kwargs = {
            "plt"        : plt,
            "save_dir"   : save_dir,
            "loss_key"   : "average_loss",
            "score_key"  : "average_f1",
            "train_logs" : pd.DataFrame(train_logs),
            "valid_logs" : pd.DataFrame(valid_logs),
        }
        plot_training_trend(**plot_training_trend_kwargs)
        
        # Update 'pbar_n_epoch'
        pbar_n_epoch.update(1)
        pbar_n_epoch.refresh()
        
        
        # Check 'EarlyStop'
        if enable_earlystop:
            if epoch_valid_log["average_loss"] < best_val_avg_loss:
                best_val_avg_loss = epoch_valid_log["average_loss"]
                accum_no_improved = 0
            else: 
                accum_no_improved += 1
                if accum_no_improved == max_no_improved: sys.exit()


except KeyboardInterrupt:
    training_state = "KeyboardInterrupt"
    tqdm.write("KeyboardInterrupt")
    
except SystemExit:
    training_state = "EarlyStop"
    tqdm.write("EarlyStop, exit training")
    
except Exception as e:
    training_state = "ExceptionError"
    tqdm.write(traceback.format_exc())
    with open(os.path.normpath(f"{save_dir}/{{Logs}}_ExceptionError.log"), mode="w") as f_writer:
        f_writer.write(traceback.format_exc())

else:
    training_state = "Completed"
    tqdm.write("Training Completed")
    
finally:

    # Close 'progress_bar'
    pbar_n_epoch.close()
    pbar_n_train.close()
    pbar_n_valid.close()
    
    # Save training consume time
    training_timer.stop()
    training_timer.calculate_consume_time()
    training_timer.save_consume_time(save_dir, desc="training time")
    
    
    if valid_logs: # If 'valid_logs' is not empty, then 'train_logs' and 'best_val_log' are also not empty.
        
        # Save logs (convert to Dataframe)
        save_training_logs(save_dir, train_logs, valid_logs, best_val_log)
        
        # Save model
        save_model("best", save_dir, best_model_state_dict, best_optimizer_state_dict, best_val_log)
        save_model("final", save_dir, model.state_dict(), optimizer.state_dict(), {"train": train_logs, "valid": valid_logs})

        # Rename 'save_dir'
        ## new_name_format = {time_stamp}_{state}_{target_epochs_with_ImgLoadOptions}_{test_f1}
        ## state = {EarlyStop, Interrupt, Completed, Tested, etc.}
        rename_training_dir(state=training_state, orig_dir_path=save_dir, 
                            epochs=valid_logs[-1]["epoch"], aug_on_fly=aug_on_fly, use_hsv=use_hsv, time_stamp=time_stamp)
        
    else: # Delete folder if less than one epoch has been completed.
        
        print(f"Less than One epoch has been completed, remove directory '{save_dir}' ")
        shutil.rmtree(save_dir)

| 2023-04-04 04:40:14,188 | Training | INFO | num2class_list = ['L', 'M', 'S'], class2num_dict = {'L': 0, 'M': 1, 'S': 2}
| 2023-04-04 04:40:14,200 | Training | INFO | total = 1980
| 2023-04-04 04:40:14,202 | Training | INFO | train_data (1584)
| 2023-04-04 04:40:14,203 | Training | INFO | 0 : img_path = C:\Users\confocal_microscope\Desktop\{Test}_DataSet\{20230305_NEW_STRUCT}_Academia_Sinica_i409\fish_dataset_horiz_cut_1l2_Mix_AP\DS_SURF3C_CRPS512_SF14_INT20_DRP100_RS2022\train\selected\S\S_fish_25_A_selected_2.tiff
| 2023-04-04 04:40:14,203 | Training | INFO | 1 : img_path = C:\Users\confocal_microscope\Desktop\{Test}_DataSet\{20230305_NEW_STRUCT}_Academia_Sinica_i409\fish_dataset_horiz_cut_1l2_Mix_AP\DS_SURF3C_CRPS512_SF14_INT20_DRP100_RS2022\train\selected\S\S_fish_34_A_selected_0.tiff
| 2023-04-04 04:40:14,204 | Training | INFO | 2 : img_path = C:\Users\confocal_microscope\Desktop\{Test}_DataSet\{20230305_NEW_STRUCT}_Academia_Sinica_i409\fish_dataset_horiz_cut_1l2_Mix_AP\DS_SURF3C

path: 'C:\Users\confocal_microscope\Desktop\{Test}_Model_history\vit_b_16\Training_20230404_04_40_14' is created!



Using cache found in C:\Users\confocal_microscope/.cache\torch\hub\pytorch_vision_main
  warn(f"Failed to load image Python extension: {e}")


Epoch :   0%|          | 0/100 [00:00<?, ?it/s]

Train :   0%|          | 0/50 [00:00<?, ?it/s]

Valid :   0%|          | 0/13 [00:00<?, ?it/s]

Epoch: 001, ☆★☆ BEST_VALIDATION ☆★☆, best_val_avg_loss = 0.58604, best_val_avg_f1 = 0.71891
Epoch: 002, ☆★☆ BEST_VALIDATION ☆★☆, best_val_avg_loss = 0.43933, best_val_avg_f1 = 0.80105
Epoch: 003, ☆★☆ BEST_VALIDATION ☆★☆, best_val_avg_loss = 0.39075, best_val_avg_f1 = 0.82036
Epoch: 005, ☆★☆ BEST_VALIDATION ☆★☆, best_val_avg_loss = 0.37699, best_val_avg_f1 = 0.82878
Epoch: 006, ☆★☆ BEST_VALIDATION ☆★☆, best_val_avg_loss = 0.30444, best_val_avg_f1 = 0.85792
Epoch: 007, ☆★☆ BEST_VALIDATION ☆★☆, best_val_avg_loss = 0.27361, best_val_avg_f1 = 0.88637
Epoch: 009, ☆★☆ BEST_VALIDATION ☆★☆, best_val_avg_loss = 0.25615, best_val_avg_f1 = 0.89764
Epoch: 012, ☆★☆ BEST_VALIDATION ☆★☆, best_val_avg_loss = 0.19849, best_val_avg_f1 = 0.92512
Epoch: 016, ☆★☆ BEST_VALIDATION ☆★☆, best_val_avg_loss = 0.17788, best_val_avg_f1 = 0.92664
Epoch: 017, ☆★☆ BEST_VALIDATION ☆★☆, best_val_avg_loss = 0.16468, best_val_avg_f1 = 0.93157
Epoch: 018, ☆★☆ BEST_VALIDATION ☆★☆, best_val_avg_loss = 0.16061, best_val_avg_f