## Import

In [1]:
"""Train
"""
from datetime import datetime
from time import time
import numpy as np
import shutil, random, os, sys, torch
from glob import glob
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

# prj_dir = os.path.dirname(os.path.abspath(__file__)) # for script
prj_dir = os.path.dirname(os.path.abspath("")) # for jupyter
sys.path.append(prj_dir)

from modules.utils import load_yaml, get_logger
from modules.metrics import get_metric_function
from modules.earlystoppers import EarlyStopper
from modules.losses import get_loss_function
from modules.optimizers import get_optimizer
from modules.schedulers import get_scheduler
from modules.scalers import get_image_scaler
from modules.datasets import SegDataset
from modules.recorders import Recorder
from modules.trainer import Trainer
from models.utils import get_model

In [2]:
prj_dir = os.path.dirname(os.path.abspath("baseline")) # for jupyter

In [3]:
prj_dir

'c:\\Dev\\2022\\maicon\\baseline'

In [4]:
yaml = 'train.yaml'

## Set configs

In [5]:
# Load config
config_path = os.path.join(prj_dir, 'config', yaml)
config = load_yaml(config_path)

In [6]:

# # Set train serial: ex) 20211004
# train_serial = datetime.now().strftime("%Y%m%d_%H%M%S")
# train_serial = 'debug' if config['debug'] else train_serial

# # Set random seed, deterministic
# torch.cuda.manual_seed(config['seed'])
# torch.manual_seed(config['seed'])
# np.random.seed(config['seed'])
# random.seed(config['seed'])
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False

# # Set device(GPU/CPU)
# os.environ['CUDA_VISIBLE_DEVICES'] = str(config['gpu_num'])
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# # Create train result directory and set logger
# train_result_dir = os.path.join(prj_dir, 'results', 'train', train_serial)
# os.makedirs(train_result_dir, exist_ok=True)

# # Set logger
# logging_level = 'debug' if config['verbose'] else 'info'
# logger = get_logger(name='train',
#                     file_path=os.path.join(train_result_dir, 'train.log'),
#                     level=logging_level)



## Dataset

In [19]:
# Set data directory
train_dirs = os.path.join(prj_dir, 'data', 'train')

# Load data and create dataset for train 
# Load image scaler
train_img_paths = glob(os.path.join(train_dirs, 'x', '*.png'))
train_img_paths, val_img_paths = train_test_split(train_img_paths, test_size=config['val_size'], random_state=config['seed'], shuffle=True)

train_dataset = SegDataset(paths=train_img_paths,
                        input_size=[config['input_width'], config['input_height']],
                        scaler=get_image_scaler(config['scaler']),
                        logger=None)
val_dataset = SegDataset(paths=val_img_paths,
                        input_size=[config['input_width'], config['input_height']],
                        scaler=get_image_scaler(config['scaler']),
                        logger=None)
# Create data loader
train_dataloader = DataLoader(dataset=train_dataset,
                            batch_size=config['batch_size'],
                            num_workers=config['num_workers'], 
                            shuffle=config['shuffle'],
                            drop_last=config['drop_last'])
                            
val_dataloader = DataLoader(dataset=val_dataset,
                            batch_size=config['batch_size'],
                            num_workers=config['num_workers'], 
                            shuffle=False,
                            drop_last=config['drop_last'])

# logger.info(f"Load dataset, train: {len(train_dataset)}, val: {len(val_dataset)}")

In [22]:
for batch_id, (x,y,filename) in enumerate(train_dataloader):
    print(x.shape)

torch.Size([8, 3, 336, 768])
torch.Size([8, 3, 336, 768])
torch.Size([8, 3, 336, 768])
torch.Size([8, 3, 336, 768])
torch.Size([8, 3, 336, 768])
torch.Size([8, 3, 336, 768])
torch.Size([8, 3, 336, 768])
torch.Size([8, 3, 336, 768])
torch.Size([8, 3, 336, 768])
torch.Size([8, 3, 336, 768])
torch.Size([8, 3, 336, 768])
torch.Size([8, 3, 336, 768])
torch.Size([8, 3, 336, 768])
torch.Size([8, 3, 336, 768])
torch.Size([8, 3, 336, 768])
torch.Size([8, 3, 336, 768])
torch.Size([8, 3, 336, 768])
torch.Size([8, 3, 336, 768])
torch.Size([8, 3, 336, 768])
torch.Size([8, 3, 336, 768])
torch.Size([8, 3, 336, 768])
torch.Size([8, 3, 336, 768])
torch.Size([8, 3, 336, 768])
torch.Size([8, 3, 336, 768])
torch.Size([8, 3, 336, 768])
torch.Size([8, 3, 336, 768])
torch.Size([8, 3, 336, 768])
torch.Size([8, 3, 336, 768])
torch.Size([8, 3, 336, 768])
torch.Size([8, 3, 336, 768])
torch.Size([8, 3, 336, 768])
torch.Size([8, 3, 336, 768])
torch.Size([8, 3, 336, 768])
torch.Size([8, 3, 336, 768])
torch.Size([8,

KeyboardInterrupt: 

## Model, Opimizer, Scheduler, Loss and etc

In [8]:
# # Load model
# model = get_model(model_str=config['architecture'])
# model = model(classes=config['n_classes'],
#             encoder_name=config['encoder'],
#             encoder_weights=config['encoder_weight'],
#             activation=config['activation']).to(device)
# logger.info(f"Load model architecture: {config['architecture']}")

# # Set optimizer
# optimizer = get_optimizer(optimizer_str=config['optimizer']['name'])
# optimizer = optimizer(model.parameters(), **config['optimizer']['args'])

# # Set Scheduler
# scheduler = get_scheduler(scheduler_str=config['scheduler']['name'])
# scheduler = scheduler(optimizer=optimizer, **config['scheduler']['args'])

# # Set loss function
# loss_func = get_loss_function(loss_function_str=config['loss']['name'])
# loss_func = loss_func(**config['loss']['args'])

# # Set metric
# metric_funcs = {metric_name:get_metric_function(metric_name) for metric_name in config['metrics']}
# logger.info(f"Load optimizer:{config['optimizer']['name']}, scheduler: {config['scheduler']['name']}, loss: {config['loss']['name']}, metric: {config['metrics']}")

# # Set trainer
# trainer = Trainer(model=model,
#                 optimizer=optimizer,
#                 scheduler=scheduler,
#                 loss_func=loss_func,
#                 metric_funcs=metric_funcs,
#                 device=device,
#                 logger=logger)
# logger.info(f"Load trainer")

# # Set early stopper
# early_stopper = EarlyStopper(patience=config['earlystopping_patience'],
#                             logger=logger)
# # Set recorder
# recorder = Recorder(record_dir=train_result_dir,
#                     model=model,
#                     optimizer=optimizer,
#                     scheduler=scheduler,
#                     logger=logger)
# logger.info("Load early stopper, recorder")

# # Recorder - save train config
# shutil.copy(config_path, os.path.join(recorder.record_dir, yaml))


## Train

In [9]:
# config_path = os.path.join(prj_dir, 'config', "train"+".yaml")
# config = load_yaml(config_path)
# model = get_model(model_str=config['architecture'])
# model = model(classes=config['n_classes'],
#             encoder_name=config['encoder'],
#             encoder_weights=config['encoder_weight'],
#             activation=config['activation'])
# check_point_path = os.path.join(prj_dir, 'results', 'train','20221113_010847', 'model.pt')
# check_point = torch.load(check_point_path)
# model.load_state_dict(check_point['model'])

In [10]:
from collections import OrderedDict
# architectures = ['DeepLabV3Plus']
# for architecture in architectures:
#----
# Load config
# config_path = os.path.join(prj_dir, 'config', "train"+".yaml")
# config = load_yaml(config_path)

# Set train serial: ex) 20211004
train_serial = datetime.now().strftime("%Y%m%d_%H%M%S")
train_serial = 'debug' if config['debug'] else train_serial

# Set random seed, deterministic
torch.cuda.manual_seed(config['seed'])
torch.manual_seed(config['seed'])
np.random.seed(config['seed'])
random.seed(config['seed'])
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Set device(GPU/CPU)
# os.environ['CUDA_VISIBLE_DEVICES'] = str(config['gpu_num'])
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Create train result directory and set logger
# train_result_dir = os.path.join(prj_dir, 'results', 'train', architecture, train_serial)
train_result_dir = os.path.join(prj_dir, 'results', 'train', train_serial)
os.makedirs(train_result_dir, exist_ok=True)

# Set logger
logging_level = 'debug' if config['verbose'] else 'info'
logger = get_logger(name='train',
                    file_path=os.path.join(train_result_dir, 'train.log'),
                    level=logging_level)

#----
# Load model
# model = get_model(model_str=config['architecture'])
# model = model(classes=config['n_classes'],
#             encoder_name=config['encoder'],
#             encoder_weights=config['encoder_weight'],
#             activation=config['activation'])
# check_point_path = os.path.join(prj_dir, 'results', 'train', 'DeepLabV3Plus','20221112_222652', 'model.pt')
# check_point = torch.load(check_point_path)
# model.load_state_dict(check_point['model'])
# config_path = os.path.join(prj_dir, 'config', "train"+".yaml")
# config = load_yaml(config_path)
model = get_model(model_str=config['architecture'])
model = model(classes=config['n_classes'],
            encoder_name=config['encoder'],
            encoder_weights=config['encoder_weight'],
            activation=config['activation'])
# check_point_path = os.path.join(prj_dir, 'results', 'train', 'DeepLabV3Plus','20221112_222652', 'model.pt')
check_point_path = os.path.join(prj_dir, 'results', 'train', '20221116_173442', 'model.pt')
check_point = torch.load(check_point_path)

model_dict = check_point['model']
keys = model_dict.keys()
values = model_dict.values()
new_keys = []

for key in keys:
  new_key = key[7:]    # remove the 'module.'
  new_keys.append(new_key)
new_dict = OrderedDict(list(zip(new_keys, values)))
model.load_state_dict(new_dict)
# model.load_state_dict(check_point['model'])
logger.info(f"Load model weight, {check_point_path}")

NGPU = torch.cuda.device_count()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if NGPU > 1:
    print("Multi")
    model = torch.nn.DataParallel(model, device_ids=list(range(NGPU)))
torch.multiprocessing.set_start_method('spawn')
model.to(device)
logger.info(f"Load model architecture: {config['architecture']}")

# Set optimizer
optimizer = get_optimizer(optimizer_str=config['optimizer']['name'])
optimizer = optimizer(model.parameters(), **config['optimizer']['args'])

# Set Scheduler
scheduler = get_scheduler(scheduler_str=config['scheduler']['name'])
scheduler = scheduler(optimizer=optimizer, **config['scheduler']['args'])

# Set loss function
loss_func = get_loss_function(loss_function_str=config['loss']['name'])
loss_func = loss_func(**config['loss']['args'])

# Set metric
metric_funcs = {metric_name:get_metric_function(metric_name) for metric_name in config['metrics']}
logger.info(f"Load optimizer:{config['optimizer']['name']}, scheduler: {config['scheduler']['name']}, loss: {config['loss']['name']}, metric: {config['metrics']}")

# Set trainer
trainer = Trainer(model=model,
                optimizer=optimizer,
                scheduler=scheduler,
                loss_func=loss_func,
                metric_funcs=metric_funcs,
                device=device,
                logger=logger)
logger.info(f"Load trainer")

# Set early stopper
early_stopper = EarlyStopper(patience=config['earlystopping_patience'],
                            logger=logger)
# Set recorder
recorder = Recorder(record_dir=train_result_dir,
                    model=model,
                    optimizer=optimizer,
                    scheduler=scheduler,
                    logger=logger)
logger.info("Load early stopper, recorder")

# Recorder - save train config
shutil.copy(config_path, os.path.join(recorder.record_dir, yaml))

#----
# Train
print("START TRAINING")
logger.info("START TRAINING")
for epoch_id in range(config['n_epochs']):
    
    # Initiate result row
    row = dict()
    row['epoch_id'] = epoch_id
    row['train_serial'] = train_serial
    row['lr'] = trainer.scheduler.get_last_lr()

    # Train
    print(f"Epoch {epoch_id}/{config['n_epochs']} Train..")
    logger.info(f"Epoch {epoch_id}/{config['n_epochs']} Train..")
    tic = time()
    trainer.train(dataloader=train_dataloader, epoch_index=epoch_id)
    toc = time()
    # Write tarin result to result row
    row['train_loss'] = trainer.loss  # Loss
    for metric_name, metric_score in trainer.scores.items():
        row[f'train_{metric_name}'] = metric_score

    row['train_elapsed_time'] = round(toc-tic, 1)
    # Clear
    trainer.clear_history()

    # Validation
    print(f"Epoch {epoch_id}/{config['n_epochs']} Validation..")
    logger.info(f"Epoch {epoch_id}/{config['n_epochs']} Validation..")
    tic = time()
    trainer.validate(dataloader=val_dataloader, epoch_index=epoch_id)
    toc = time()
    row['val_loss'] = trainer.loss
    # row[f"val_{config['metric']}"] = trainer.score
    for metric_name, metric_score in trainer.scores.items():
        row[f'val_{metric_name}'] = metric_score
    row['val_elapsed_time'] = round(toc-tic, 1)
    trainer.clear_history()

    # Performance record - row
    recorder.add_row(row)
    
    # Performance record - plot
    recorder.save_plot(config['plot'])

    # Check early stopping
    early_stopper.check_early_stopping(row[config['earlystopping_target']])
    if early_stopper.patience_counter == 0:
        recorder.save_weight(epoch=epoch_id)
        
    if early_stopper.stop:
        print(f"Epoch {epoch_id}/{config['n_epochs']}, Stopped counter {early_stopper.patience_counter}/{config['earlystopping_patience']}")
        logger.info(f"Epoch {epoch_id}/{config['n_epochs']}, Stopped counter {early_stopper.patience_counter}/{config['earlystopping_patience']}")
        break

print("END TRAINING")
logger.info("END TRAINING")

Multi
START TRAINING
Epoch 0/100 Train..


100%|██████████| 1350/1350 [3:48:45<00:00, 10.17s/it] 


Epoch 0/100 Validation..


100%|██████████| 150/150 [24:43<00:00,  9.89s/it]


Epoch 1/100 Train..


100%|██████████| 1350/1350 [3:49:59<00:00, 10.22s/it] 


Epoch 1/100 Validation..


100%|██████████| 150/150 [24:56<00:00,  9.98s/it]


Epoch 2/100 Train..


100%|██████████| 1350/1350 [3:52:20<00:00, 10.33s/it]  


Epoch 2/100 Validation..


100%|██████████| 150/150 [24:42<00:00,  9.88s/it]


Epoch 3/100 Train..


100%|██████████| 1350/1350 [3:50:09<00:00, 10.23s/it] 


Epoch 3/100 Validation..


100%|██████████| 150/150 [25:02<00:00, 10.02s/it]


Epoch 4/100 Train..


100%|██████████| 1350/1350 [3:49:37<00:00, 10.21s/it] 


Epoch 4/100 Validation..


100%|██████████| 150/150 [24:59<00:00, 10.00s/it]


Epoch 5/100 Train..


100%|██████████| 1350/1350 [3:49:42<00:00, 10.21s/it] 


Epoch 5/100 Validation..


100%|██████████| 150/150 [24:48<00:00,  9.93s/it]


Epoch 6/100 Train..


100%|██████████| 1350/1350 [3:49:54<00:00, 10.22s/it]  


Epoch 6/100 Validation..


100%|██████████| 150/150 [24:54<00:00,  9.97s/it]


Epoch 7/100 Train..


100%|██████████| 1350/1350 [3:49:58<00:00, 10.22s/it]  


Epoch 7/100 Validation..


100%|██████████| 150/150 [24:52<00:00,  9.95s/it]


Epoch 8/100 Train..


100%|██████████| 1350/1350 [3:50:53<00:00, 10.26s/it] 


Epoch 8/100 Validation..


100%|██████████| 150/150 [25:02<00:00, 10.02s/it]


Epoch 9/100 Train..


100%|██████████| 1350/1350 [3:49:19<00:00, 10.19s/it]  


Epoch 9/100 Validation..


100%|██████████| 150/150 [25:01<00:00, 10.01s/it]


Epoch 10/100 Train..


100%|██████████| 1350/1350 [3:48:57<00:00, 10.18s/it] 


Epoch 10/100 Validation..


100%|██████████| 150/150 [25:13<00:00, 10.09s/it]


Epoch 11/100 Train..


100%|██████████| 1350/1350 [3:49:45<00:00, 10.21s/it] 


Epoch 11/100 Validation..


100%|██████████| 150/150 [25:09<00:00, 10.06s/it]


Epoch 12/100 Train..


100%|██████████| 1350/1350 [3:49:34<00:00, 10.20s/it] 


Epoch 12/100 Validation..


100%|██████████| 150/150 [25:09<00:00, 10.07s/it]


Epoch 13/100 Train..


100%|██████████| 1350/1350 [3:50:06<00:00, 10.23s/it] 


Epoch 13/100 Validation..


100%|██████████| 150/150 [25:00<00:00, 10.00s/it]


Epoch 14/100 Train..


100%|██████████| 1350/1350 [3:49:55<00:00, 10.22s/it] 


Epoch 14/100 Validation..


100%|██████████| 150/150 [24:42<00:00,  9.88s/it]


Epoch 15/100 Train..


100%|██████████| 1350/1350 [3:49:48<00:00, 10.21s/it] 


Epoch 15/100 Validation..


100%|██████████| 150/150 [24:52<00:00,  9.95s/it]


Epoch 16/100 Train..


100%|██████████| 1350/1350 [3:51:17<00:00, 10.28s/it] 


Epoch 16/100 Validation..


100%|██████████| 150/150 [25:09<00:00, 10.07s/it]


Epoch 17/100 Train..


100%|██████████| 1350/1350 [3:49:31<00:00, 10.20s/it] 


Epoch 17/100 Validation..


100%|██████████| 150/150 [25:15<00:00, 10.10s/it]


Epoch 18/100 Train..


100%|██████████| 1350/1350 [3:49:47<00:00, 10.21s/it] 


Epoch 18/100 Validation..


100%|██████████| 150/150 [25:10<00:00, 10.07s/it]


Epoch 19/100 Train..


100%|██████████| 1350/1350 [3:50:43<00:00, 10.25s/it] 


Epoch 19/100 Validation..


100%|██████████| 150/150 [25:21<00:00, 10.14s/it]


Epoch 20/100 Train..


100%|██████████| 1350/1350 [3:49:59<00:00, 10.22s/it] 


Epoch 20/100 Validation..


100%|██████████| 150/150 [24:58<00:00,  9.99s/it]


Epoch 21/100 Train..


100%|██████████| 1350/1350 [3:49:11<00:00, 10.19s/it] 


Epoch 21/100 Validation..


100%|██████████| 150/150 [24:57<00:00,  9.99s/it]


Epoch 22/100 Train..


  4%|▍         | 59/1350 [10:21<3:46:33, 10.53s/it]


KeyboardInterrupt: 

## Prediction

In [None]:
# from datetime import datetime
# from tqdm import tqdm
# import numpy as np
# import random, os, sys, torch, cv2, warnings
# from glob import glob
# from torch.utils.data import DataLoader
#  # Create train result directory and set logger
# pred_result_dir = os.path.join(prj_dir, 'results', 'pred', 'pred_serial')
# pred_result_dir_mask = os.path.join(prj_dir, 'results', 'pred', 'pred_serial', 'mask')
# os.makedirs(pred_result_dir, exist_ok=True)
# os.makedirs(pred_result_dir_mask, exist_ok=True)

# # Set logger
# logging_level = 'debug' if config['verbose'] else 'info'
# logger = get_logger(name='train',
#                     file_path=os.path.join(pred_result_dir, 'pred.log'),
#                     level=logging_level)

# # Set data directory
# test_dirs = os.path.join(prj_dir, 'data', 'test')
# test_img_paths = glob(os.path.join(test_dirs, 'x', '*.png'))

# #! Load data & create dataset for train 
# test_dataset = SegDataset(paths=test_img_paths,
#                         input_size=[config['input_width'], config['input_height']],
#                         scaler=get_image_scaler(config['scaler']),
#                         mode='test',
#                         logger=logger)

# # Create data loader
# test_dataloader = DataLoader(dataset=test_dataset,
#                             batch_size=config['batch_size'],
#                             num_workers=config['num_workers'],
#                             shuffle=False,
#                             drop_last=False)

# # Predict
# logger.info(f"START PREDICTION")

# model.eval()

# with torch.no_grad():

#     for batch_id, (x, orig_size, filename) in enumerate(tqdm(test_dataloader)):
        
#         x = x.to(device, dtype=torch.float)
#         y_pred = model(x)
#         y_pred_argmax = y_pred.argmax(1).cpu().numpy().astype(np.uint8)
#         orig_size = [(orig_size[0].tolist()[i], orig_size[1].tolist()[i]) for i in range(len(orig_size[0]))]
#         # Save predict result
#         for filename_, orig_size_, y_pred_ in zip(filename, orig_size, y_pred_argmax):
#             resized_img = cv2.resize(y_pred_, [orig_size_[1], orig_size_[0]], interpolation=cv2.INTER_NEAREST)
#             cv2.imwrite(os.path.join(pred_result_dir_mask, filename_), resized_img)
# logger.info(f"END PREDICTION")

100%|██████████| 74/74 [05:50<00:00,  4.74s/it]
