In [1]:
from modules.utils import load_yaml, save_yaml, get_logger, str2bool
from modules.earlystoppers import EarlyStopper
from modules.recorders import Recorder
from modules.datasets import *
from modules.trainer import Trainer
from modules.optimizers import get_optimizer
from modules.schedulers import PolyLR
from models.utils import get_model, EMA
import torch

from datetime import datetime, timezone, timedelta
import numpy as np
import random
import os
import copy
import pandas as pd
import argparse

import wandb
import warnings
warnings.filterwarnings('ignore')

# parser = argparse.ArgumentParser(description='Semi-supervised Segmentation for AICompetition')
# parser.add_argument('--is_trained', help='boolean flag', default=False, type=str2bool)
# parser.add_argument('--is_colab', help='boolean flag', default=False, type=str2bool)
# args = parser.parse_args()

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Root directory
__file__ = './'
PROJECT_DIR = os.path.dirname(__file__)
GDRIVE_DIR = '/content/drive/MyDrive'

# Load config
config_path = os.path.join(PROJECT_DIR, 'config', 'train_config.yml')
config = load_yaml(config_path)

# is_trained = args.is_trained
# if is_trained == True:
is_trained = False
if is_trained == True:
      pre_config = config
      is_trained = config['TRAINER']['is_trained']
      train_serial = config['TRAINER']['train_serial']
      config = load_yaml(os.path.join(PROJECT_DIR, 'results', 'train', train_serial, 'train_config.yml'))
#elif args.is_trained == False:
elif is_trained == False:
    # Train Serial
    is_trained = False
    kst = timezone(timedelta(hours=9))
    train_serial = datetime.now(tz=kst).strftime("%Y%m%d_%H%M%S")
print(train_serial)


# Recorder directory
#if args.is_colab == True:
is_colab = False
if is_colab == True:
    RECORDER_DIR = os.path.join(GDRIVE_DIR, 'results', 'train', train_serial)
else:
    RECORDER_DIR = os.path.join(PROJECT_DIR, 'results', 'train', train_serial)
os.makedirs(RECORDER_DIR, exist_ok=True)

# Data directory
DATA_DIR = os.path.join(PROJECT_DIR, 'data', config['DIRECTORY']['dataset'])

# Seed
torch.manual_seed(config['TRAINER']['seed'])
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(config['TRAINER']['seed'])
random.seed(config['TRAINER']['seed'])

# GPU
os.environ['CUDA_VISIBLE_DEVICES'] = str(config['TRAINER']['gpu'])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


20220617_112400


In [4]:
"""
00. Set Logger
"""
logger = get_logger(name='train', dir_=RECORDER_DIR, stream=False)
logger.info(f"Set Logger {RECORDER_DIR}")


# Load data

In [None]:
"""
01. Load data
"""
# Dataset
data_loader = BuildDataLoader(num_labels=config['MODEL']['num_labels'], dataset_path=config['DIRECTORY']['dataset'],
                                batch_size=config['DATALOADER']['batch_size'])
train_l_loader, train_u_loader, valid_l_loader, _ = data_loader.build(supervised=False)
print(len(train_l_loader))
print(len(train_u_loader))
logger.info(f"Load data, train (labeled):{len(train_l_loader)} train (unlabeled):{len(train_u_loader)} val:{len(valid_l_loader)}")


In [None]:

"""
02. Set model
"""
# Load model
model = get_model(model_name=config['TRAINER']['model'],num_classes=config['MODEL']['num_labels'],
                    output_dim=config['MODEL']['output_dim']).to(device)
ema = EMA(model, 0.99)  # Mean teacher model

"""
03. Set trainer
"""
# Optimizer
optimizer = get_optimizer(optimizer_name=config['TRAINER']['optimizer'])
optimizer = optimizer(params=model.parameters(),lr=config['TRAINER']['learning_rate'])
scheduler = PolyLR(optimizer, config['TRAINER']['n_epochs'], power=0.9)

# Early stoppper
early_stopper = EarlyStopper(patience=config['TRAINER']['early_stopping_patience'],
                            mode=config['TRAINER']['early_stopping_mode'],
                            logger=logger)

# Trainer
trainer = Trainer(model=model,
                    ema=ema,
                    data_loader=data_loader,
                    optimizer=optimizer,
                    device=device,
                    logger=logger,
                    config=config['TRAINER'],
                    interval=config['LOGGER']['logging_interval'])

"""
Logger
"""
# Recorder
recorder = Recorder(record_dir=RECORDER_DIR,
                    model=model,
                    optimizer=optimizer,
                    scheduler=scheduler,
                    logger=logger)

# !Wandb
if config['LOGGER']['wandb'] == True: ## 사용시 본인 wandb 계정 입력
    wandb_project_serial = 'v3p_adamw_classmix'
    wandb_username = 'a22106'
    wandb.init(project=wandb_project_serial, dir=RECORDER_DIR, entity=wandb_username)
    wandb.run.name = train_serial
    wandb.config.update(config)
    wandb.watch(model)

# Save train config
save_yaml(os.path.join(RECORDER_DIR, 'train_config.yml'), config)

"""
04. TRAIN
"""
# Train
# set epoch
n_epochs = config['TRAINER']['n_epochs']
if is_trained:
    pre_record_csv = pd.read_csv(os.path.join(PROJECT_DIR, 'results', 'train', train_serial, 'record.csv'))
    pre_epoch = list(pre_record_csv['epoch_index'])[-1] +1
else:
    pre_epoch = 0
