In [None]:
import os
import logging
import sys
import torch.nn as nn
import torch.optim as optim
import torchvision
import warnings

from model.resnet import resnet18
from train import train
from validation import validation
from test import test
from data.data_loader import get_loader
from utils import Ipynb_importer
from utils.helper import *
from utils.pytorchtools import EarlyStopping

warnings.filterwarnings("ignore")

cfg = load_config()
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
device = torch.device(cfg.system.device)
run_folder = create_folder(cfg.results.run_folder)

logging.basicConfig(level=logging.INFO, format='%(message)s',
                    handlers=[logging.FileHandler(os.path.join(run_folder, f'run.log')),
                              logging.StreamHandler(sys.stdout)])
logging.info("Experiment Configuration:")
logging.info("CUDA_VISIBLE_DEVICES：{}".format(os.getenv('CUDA_VISIBLE_DEVICES')))
logging.info(cfg)
logging.info("run_folder:{}".format(run_folder))

if torch.cuda.is_available():
    cudnn.benchmark = False
    if cfg.train.seed is not None:
        np.random.seed(cfg.train.seed)  # Numpy module.
        random.seed(cfg.train.seed)  # Python random module.
        torch.manual_seed(cfg.train.seed)  # Sets the seed for generating random numbers.
        torch.cuda.manual_seed(cfg.train.seed)  # Sets the seed for generating random numbers for the current GPU.
        torch.cuda.manual_seed_all(cfg.train.seed)  # Sets the seed for generating random numbers on all GPUs.
        cudnn.deterministic = True

        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')
        logging.info('torch.cuda is available!')

# data loaders
train_loader, val_loader, test_loader = get_loader(cfg, 'train'), get_loader(cfg, 'val'), get_loader(cfg, 'test')
logging.info("train_loader:{} val_loadder:{} test_loader:{}".format(len(train_loader.dataset), len(val_loader.dataset), len(test_loader.dataset)))
# model training
model = resnet18()
model = nn.DataParallel(model.to(device))
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=cfg.train.lr, momentum=cfg.train.momentum, weight_decay=cfg.train.weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=cfg.train.lr_patience, verbose=True)  # patience为默认值

logging.info("\nModel Structure:")
logging.info(model)


def main():
    best_acc, best_acc_epoch = 0.0, 0
    
    early_stopping = EarlyStopping(patience=cfg.train.es_patience, verbose=True, trace_func=logging.info)

    for epoch in range(cfg.train.start_epoch, cfg.train.num_epochs + 1):
        logging.info('\nStarting epoch {}/{}'.format(epoch, cfg.train.num_epochs))
        logging.info('Batch size = {}\nSteps in epoch = {}'.format(cfg.train.batch_size, len(train_loader)))

        train_metrics = train(epoch, run_folder, cfg, device, train_loader, optimizer, criterion, model)

        val_metrics = validation(epoch, run_folder, cfg, device, val_loader, criterion, model)

        plot_scalars(epoch, run_folder, train_metrics, val_metrics)

        val_loss, acc, prec, recall, f1 = val_metrics['Loss'], val_metrics['Acc'], val_metrics['Prec'], val_metrics['Recall'], val_metrics['F1']

        if acc > best_acc:
            save_checkpoint(epoch, model, optimizer, val_metrics, run_folder)
            best_acc = acc
            best_acc_epoch = epoch

        # adjust the learning rate
        scheduler.step(val_loss)
        logging.info("lr after adjusting:{}".format(optimizer.param_groups[0]['lr']))

        # Early_stopping needs the validation loss to check if it has decresed,
        # and if it has, it will make a checkpoint of the current model.
        early_stopping(val_loss, model, epoch=epoch, optimizer=optimizer, val_metrics=val_metrics)
        if early_stopping.early_stop:
            logging.info("The training should be early stopped now.")
            break

    logging.info("################## Finished ##################")
    logging.info("In epoch {}: best acc: {:.4%}".format(best_acc_epoch, best_acc))

    logging.info("################## Testing... ##################")
    test(run_folder, cfg, device, test_loader, criterion, model)


if __name__ == '__main__':
    main()

