In [None]:
import os
import random
import time
import numpy as np
import torch
import logging
import sys
import warnings
import torch.nn as nn
import torch.optim as optim
from torch.backends import cudnn
from collections import defaultdict

from load_data.load_ste_data import *
from algorithm.GoogleNet import GoogleNet
from utils.AverageMeter import AverageMeter
from utils.helper import *
from utils.pytorchtools import EarlyStopping

warnings.filterwarnings("ignore")

cfg = load_config()
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
device = torch.device(cfg.system.device)
run_folder = create_folder(cfg.results.run_folder)
# Print the configuration.
logging.basicConfig(level=logging.INFO, format='%(message)s', handlers=[
    logging.FileHandler(os.path.join(run_folder, f'run.log')), logging.StreamHandler(sys.stdout)])
logging.info("Experiment Configuration:")
logging.info("CUDA_VISIBLE_DEVICES：{}".format(os.getenv('CUDA_VISIBLE_DEVICES')))
logging.info(cfg)
logging.info("run_folder:{}".format(run_folder))
# to be reproducible
if torch.cuda.is_available():
    cudnn.benchmark = True
    if cfg.train.seed is not None:
        np.random.seed(cfg.train.seed)  # Numpy module.
        random.seed(cfg.train.seed)  # Python random module.
        torch.manual_seed(cfg.train.seed)  # Sets the seed for generating random numbers.
        torch.cuda.manual_seed(cfg.train.seed)  # Sets the seed for generating random numbers for the current GPU.
        torch.cuda.manual_seed_all(cfg.train.seed)  # Sets the seed for generating random numbers on all GPUs.
        cudnn.deterministic = True

        warnings.warn('You have choosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')
        logging.info('torch.cuda is available!')

# model
ste_model = GoogleNet().to(device)
ste_model = nn.DataParallel(ste_model.to(device))
# optimization
criterion_mse = nn.MSELoss()
optimizer = optim.Adam(ste_model.parameters(), lr=cfg.train.lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5)
# data loaders
cover_train_loader, cover_val_loader = get_loader(cfg, 'cover_train'), get_loader(cfg, 'cover_val')
logging.info("cover_train_loader:{} cover_val_loader:{}".format(len(cover_train_loader.dataset), len(cover_val_loader.dataset)))

secret_train_loader, secret_val_loader = get_loader(cfg, 'secret_train'), get_loader(cfg, 'secret_val')
logging.info("secret_train_loader:{} secret_val_loader:{}".format(len(secret_train_loader.dataset), len(secret_val_loader.dataset)))


def train(epoch):
    epoch_start_time = time.time()
    epoch_duration = 0

    logging.info('Training epoch: %d' % epoch)
    ste_model.train()
    step = 1

    # Save relevant metrics in dictionaries.
    losses_dict = defaultdict(AverageMeter)
    img_quality_dict = defaultdict(AverageMeter)  # PSNR

    for batch_idx, data in enumerate(zip(cover_train_loader, secret_train_loader)):
        cover_img, secret_img = data[0][0].to(device), data[1][0].to(device)

        ste_img, ex_secret = ste_model(cover_img, secret_img)

        loss_hid = criterion_mse(cover_img, ste_img)
        loss_rev = criterion_mse(secret_img, ex_secret)
        loss_ste = loss_hid + loss_rev

        optimizer.zero_grad()
        loss_ste.backward()
        optimizer.step()

        epoch_duration += time.time() - epoch_start_time

        temp_losses_dict = {
            'loss_hid': loss_hid.item(),
            'loss_rev': loss_rev.item(),
            'loss_ste': loss_ste.item()
        }
        for tag, metric in temp_losses_dict.items():
            losses_dict[tag].update(metric, cover_img.size(0))

        ste_psnr = cal_psnr(cover_img, ste_img.detach())
        ex_psnr = cal_psnr(secret_img, ex_secret.detach())
        img_quality_dict['ste_psnr'].update(ste_psnr, ste_img.size(0))
        img_quality_dict['ex_psnr'].update(ex_psnr, ex_secret.size(0))

        if step % cfg.train.print_freq == 0 or step == (len(cover_train_loader)):
            logging.info('[{}/{}][{}/{}] loss_ste: {:.4f} (loss_hid: {:.4f} loss_rev: {:.4f}) ste_PSNR: {:.2f} ex_PSNR: {:.2f}'.format(
                    epoch, cfg.train.num_epochs, step, len(cover_train_loader),
                    losses_dict['loss_ste'].avg, losses_dict['loss_hid'].avg, losses_dict['loss_rev'].avg, ste_psnr, ex_psnr))

        step += 1

    epoch_duration = time.time() - epoch_start_time
    logging.info('Training duration of epoch {}: {:.2f} sec'.format(epoch, epoch_duration))

    save_cat_image(cfg, epoch, run_folder, cover_img, ste_img, secret_img, ex_secret, 'train')
    write_scalars(epoch, os.path.join(run_folder, 'train.csv'), losses_dict, img_quality_dict, epoch_duration)

    return losses_dict, img_quality_dict


def validation(epoch):
    epoch_start_time = time.time()
    epoch_duration = 0

    logging.info("========================validation========================")
    ste_model.eval()
    step = 1

    # Save relevant metrics in dictionaries.
    losses_dict = defaultdict(AverageMeter)
    img_quality_dict = defaultdict(AverageMeter)  # PSNR

    for batch_idx, data in enumerate(zip(cover_val_loader, secret_val_loader)):
        cover_img, secret_img = data[0][0].to(device), data[1][0].to(device)

        ste_img, ex_secret = ste_model(cover_img, secret_img)

        loss_hid = criterion_mse(cover_img, ste_img)
        loss_rev = criterion_mse(secret_img, ex_secret)
        loss_ste = loss_hid + loss_rev

        epoch_duration += time.time() - epoch_start_time

        temp_losses_dict = {
            'loss_hid': loss_hid.item(),
            'loss_rev': loss_rev.item(),
            'loss_ste': loss_ste.item()
        }
        for tag, metric in temp_losses_dict.items():
            losses_dict[tag].update(metric, cover_img.size(0))

        ste_psnr = cal_psnr(cover_img, ste_img.detach())
        ex_psnr = cal_psnr(secret_img, ex_secret.detach())
        img_quality_dict['ste_psnr'].update(ste_psnr, ste_img.size(0))
        img_quality_dict['ex_psnr'].update(ex_psnr, ex_secret.size(0))

        if step % cfg.train.print_freq == 0 or step == (len(cover_val_loader)):
            logging.info('[{}/{}][{}/{}] loss_ste: {:.4f} (loss_hid: {:.4f} loss_rev: {:.4f}) ste_PSNR: {:.2f} ex_PSNR: {:.2f}'.format(
                epoch, cfg.train.num_epochs, step, len(cover_val_loader),
                losses_dict['loss_ste'].avg, losses_dict['loss_hid'].avg, losses_dict['loss_rev'].avg, ste_psnr, ex_psnr))

        step += 1

    logging.info("========================validation========================")

    epoch_duration = time.time() - epoch_start_time
    logging.info('Validation duration of epoch{}: {:.2f} sec'.format(epoch, epoch_duration))

    save_cat_image(cfg, epoch, run_folder, cover_img, ste_img, secret_img, ex_secret, 'val')
    write_scalars(epoch, os.path.join(run_folder, 'val.csv'), losses_dict, img_quality_dict, epoch_duration)

    return losses_dict, img_quality_dict


def main():
    # Initialize the early stopping object.
    early_loss = EarlyStopping(patience=cfg.train.es_patience, verbose=True, trace_func=logging.info)
    # The checkpoint is not necessary when you decide to train the algorithm from scratch.
    if cfg.train.has_ckp is True:
        ste_ckp = torch.load("")  # Fill in the path of your prepared checkpoint.
        ste_model.load_state_dict(ste_ckp[''])  # model weight
        optimizer.load_state_dict(ste_ckp[''])  # parameters of the optimizer
        logging.info("Have loaded steganography checkpoint.")

    for epoch in range(cfg.train.start_epoch, cfg.train.num_epochs + 1):

        train_losses_dict, train_img_quality_dict = train(epoch)
        val_losses_dict, val_img_quality_dict = validation(epoch)
        # visualize
        plot_scalars(epoch, run_folder, train_losses_dict, train_img_quality_dict, val_losses_dict, val_img_quality_dict)

        # Save checkpoints when loss decreased.
        early_loss(epoch, run_folder, ste_model, optimizer, val_losses_dict, val_img_quality_dict)
        if early_loss.early_stop:
            logging.info("The training has gained minimum loss at {}th epoch".format(epoch))
            break

    logging.info("################## finished ##################")


if __name__ == '__main__':
    main()
