In [None]:
import os
import logging
import sys
import warnings
import time
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import defaultdict, OrderedDict
from torchmetrics import Accuracy, Precision, Recall, F1

from data.load_data import *
# Select the steganography algorithm you want.
from network.End_to_end_Ste import End_to_end_Ste
# from network.GoogleNet import GoogleNet
from models.resnet import resnet18
from utils.AverageMeter import AverageMeter
from utils.helper import *
from utils.pytorchtools import E2E_EarlyStopping

warnings.filterwarnings("ignore")

cfg = load_config()
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
device = torch.device(cfg.system.device)
run_folder = create_folder(cfg.results.run_folder)
# Print the configuration.
logging.basicConfig(level=logging.INFO, format='%(message)s', handlers=[logging.FileHandler(os.path.join(run_folder, f'run.log')), logging.StreamHandler(sys.stdout)])
logging.info("Experiment Configuration:")
logging.info("CUDA_VISIBLE_DEVICES：{}".format(os.getenv('CUDA_VISIBLE_DEVICES')))
logging.info(cfg)
logging.info("run_folder:{}".format(run_folder))
# to be reproducible
if torch.cuda.is_available():
    cudnn.benchmark = True
    if cfg.train.seed is not None:
        np.random.seed(cfg.train.seed)  # Numpy module.
        random.seed(cfg.train.seed)  # Python random module.
        torch.manual_seed(cfg.train.seed)  # Sets the seed for generating random numbers.
        torch.cuda.manual_seed(cfg.train.seed)  # Sets the seed for generating random numbers for the current GPU.
        torch.cuda.manual_seed_all(cfg.train.seed)  # Sets the seed for generating random numbers on all GPUs.
        cudnn.deterministic = True

        warnings.warn('You have choosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

# data loaders
train_loader, val_loader, test_loader = get_loader(cfg, 'train'), get_loader(cfg, 'val'), get_loader(cfg, 'test')
logging.info("train_loader:{} val_loader:{} test_loader:{}\n".format(len(train_loader.dataset), len(val_loader.dataset), len(test_loader.dataset)))


def load_pretrained():

    if cfg.train.pretrained_tech is True:
        hid_ckp_path = ''  # Fill in the checkpoint of steganography(e.g. En2D or GglNet).
        hid_state_dicts = torch.load(hid_ckp_path)
        stegan_state_dict = hid_state_dicts['']  # model weights
        stegan_opt_dict = hid_state_dicts['']  # parameters of the optimizer
        ste_model.load_state_dict(stegan_state_dict)
        optimizerH.load_state_dict(stegan_opt_dict)

        pg_dict = stegan_opt_dict['param_groups'][0]
        logging.info("stegan_opt_dict: lr:{} betas: {} eps: {} weight_decay:{}".format(
            pg_dict['lr'], pg_dict['betas'], pg_dict['eps'], pg_dict['weight_decay']))
        pg = optimizerH.param_groups[0]
        logging.info("optimizerH: lr:{} betas: {} eps: {} weight_decay:{}".format(
            pg['lr'],  pg['betas'], pg['eps'], pg['weight_decay']))
        logging.info("Have loaded pretrained stegan_model.")

    if cfg.train.fine_tuning is True:
        dnn_ckp_path = ''  # Fill in the checkpoint of host DNN.
        dnn_ckp = torch.load(dnn_ckp_path)
        Dnnet.load_state_dict(dnn_ckp[''])  # model weights
        optimizerN.load_state_dict(dnn_ckp[''])  # parameters of the optimizer

        logging.info("host DNN at checkpoint" + '=' * 60)

        pg_dict = dnn_ckp['optimizer_state_dict']['param_groups'][0]
        logging.info("dnn_ckp: lr:{} momentum: {} weight_decay:{}".format(pg_dict['lr'], pg_dict['momentum'], pg_dict['weight_decay']))
        pg = optimizerN.param_groups[0]
        logging.info("optimizerN: lr:{} momentum:{} weight_decay:{}".format(pg['lr'], pg['momentum'], pg['weight_decay']))

        logging.info("Epoch：{} val_loss：{:.4f} acc：{:.4%} prec：{:.4%} recall：{:.4%} f1：{:.4%}".format(
            dnn_ckp['epoch'], dnn_ckp['val_loss'], dnn_ckp['acc'], dnn_ckp['prec'], dnn_ckp['recall'], dnn_ckp['f1']))
        
        logging.info("host DNN at checkpoint" + '=' * 60)
        logging.info("Have loaded pretrained ResNet.\n")


# model
ste_model = End_to_end_Ste()
# ste_model = GoogleNet()
Dnnet = resnet18()
ste_model.to(device)
ste_model = nn.DataParallel(ste_model)
Dnnet = nn.DataParallel(Dnnet.to(device))
# optimization
criterion_mse = nn.MSELoss()
# Select appropriate hyper-parameters for specific steganography algorithm, especially the learning rate,
# which will dramatically affect the performance of image hiding and extracting.
optimizerH = torch.optim.Adam(ste_model.parameters(), lr=1e-3)
schedulerH = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizerH, mode='min', factor=0.5)
criterionN = nn.CrossEntropyLoss()
optimizerN = optim.SGD(Dnnet.parameters(), lr=cfg.train.lr, momentum=cfg.train.momentum, weight_decay=cfg.train.weight_decay)
schedulerN = torch.optim.lr_scheduler.StepLR(optimizerN, step_size=20, gamma=0.1)


def train(epoch, train_cover_imgs, train_cover_img_labels, train_wms, train_wm_labels):

    epoch_start_time = time.time()

    logging.info('\nTraining epoch: %d' % epoch)

    ste_model.train()
    Dnnet.train()

    step = 1
    total_duration = 0
    real_preds, real_trues, trigger_preds, trigger_trues, cover_preds, cover_trues = [], [], [], [], [], []
    triggers = torch.Tensor()
    # Save relevant metrics in dictionaries.
    losses_dict = defaultdict(AverageMeter)
    metrics_dict = defaultdict()

    real_acc = Accuracy(num_classes=cfg.dataset.num_classes, average='weighted')
    precision = Precision(num_classes=cfg.dataset.num_classes, average='weighted')
    recall = Recall(num_classes=cfg.dataset.num_classes, average='weighted')
    f1 = F1(num_classes=cfg.dataset.num_classes, average='weighted')
    trigger_acc = Accuracy(num_classes=cfg.dataset.num_classes, average='weighted')
    cover_acc = Accuracy(num_classes=cfg.dataset.num_classes, average='weighted')

    for batch_idx, (input, label) in enumerate(train_loader):

        input, label = input.to(device), label.to(device)
        cover_img = train_cover_imgs[batch_idx % len(train_cover_imgs)]
        cover_img_label = train_cover_img_labels[batch_idx % len(train_cover_img_labels)]
        secret_img = train_wms[batch_idx % len(train_wms)]
        trigger_label = train_wm_labels[batch_idx % len(train_wm_labels)]
        """############################### ste_model ###############################"""
        trigger, trigger_ext_output = ste_model(cover_img, secret_img)
        # Save all the triggers the last time they are traversed.
        if batch_idx >= len(train_loader.dataset) / cfg.train.batchsize - len(train_wm_labels):
            triggers = torch.cat([triggers, trigger.detach().cpu()], dim=0)
        trigger_dnn_output = Dnnet(trigger.detach())
        # loss for steganography
        loss_hid = criterion_mse(cover_img, trigger)
        loss_rev = criterion_mse(secret_img, trigger_ext_output)
        # loss for host DNN and watermarking
        loss_dnn = criterionN(trigger_dnn_output, trigger_label)
        loss_H = cfg.train.loss_hyper_param[0] * loss_hid + cfg.train.loss_hyper_param[1] * loss_rev + cfg.train.loss_hyper_param[2] * loss_dnn

        optimizerH.zero_grad()
        loss_H.backward()
        optimizerH.step()
        """############################### Dnnet ###############################"""
        # It's extremely crucial to detach gradients here in case triggers are erroneously updated by back-propagation of host DNN.
        inputs = torch.cat([input, trigger.detach()], dim=0)
        labels = torch.cat([label, trigger_label], dim=0)

        dnn_cat_output = Dnnet(inputs)
        real_output = dnn_cat_output[0:cfg.train.batchsize]
        trigger_output = dnn_cat_output[cfg.train.batchsize:]
        cover_output = Dnnet(cover_img)

        loss_cat_Dnn = criterionN(dnn_cat_output, labels)
        loss_real = criterionN(real_output, label)
        loss_trigger = criterionN(trigger_output, trigger_label)

        optimizerN.zero_grad()
        loss_cat_Dnn.backward()
        optimizerN.step()
        """############################### metrics ###############################"""
        real_pred = dnn_cat_output[0:cfg.train.batchsize].argmax(dim=1)
        trigger_pred = dnn_cat_output[cfg.train.batchsize:].argmax(dim=1)
        cover_pred = cover_output.argmax(dim=1)

        real_preds.extend(real_pred.cpu().numpy())
        real_trues.extend(label.cpu().numpy())
        trigger_preds.extend(trigger_pred.cpu().numpy())
        trigger_trues.extend(trigger_label.cpu().numpy())
        cover_preds.extend(cover_pred.cpu().numpy())
        cover_trues.extend(cover_img_label.cpu().numpy())

        real_acc.update(real_pred.cpu(), label.cpu())
        precision.update(real_pred.cpu(), label.cpu())
        recall.update(real_pred.cpu(), label.cpu())
        f1.update(real_pred.cpu(), label.cpu())
        trigger_acc.update(trigger_pred.cpu(), trigger_label.cpu())
        cover_acc.update(cover_pred.cpu(), cover_img_label.cpu())

        temp_losses_dict = {
            'loss_hid': loss_hid.item(),
            'loss_rev': loss_rev.item(),
            'loss_dnn': loss_dnn.item(),
            'loss_H': loss_H.item(),
            'loss_cat_Dnn': loss_cat_Dnn.item(),
            'loss_real': loss_real.item(),
            'loss_trigger': loss_trigger.item()
        }
        for tag, metric in temp_losses_dict.items():
            if tag == 'loss_cat_Dnn':
                losses_dict[tag].update(metric, inputs.size(0))
            elif tag == 'loss_real':
                losses_dict[tag].update(metric, input.size(0))
            else:
                losses_dict[tag].update(metric, trigger.size(0))

        if step % cfg.train.print_freq == 0 or step == (len(train_loader)):
            logging.info('[{}/{}][{}/{}] '
                'Loss_H: {:.4f} (loss_hid: {:.4f} loss_rev: {:.4f} loss_dnn: {:.4f}) '
                'Loss_cat_Dnn: {:.4f} Loss_real：{:.4f} Loss_trigger：{:.4f}'.format(epoch, cfg.train.num_epochs, step, len(train_loader),
                    losses_dict['loss_H'].avg, losses_dict['loss_hid'].avg, losses_dict['loss_rev'].avg, 
                    losses_dict['loss_dnn'].avg, losses_dict['loss_cat_Dnn'].avg, losses_dict['loss_real'].avg, losses_dict['loss_trigger'].avg))

            logging.info("Real acc: {:.4%} Trigger acc: {:.4%} Cover acc: {:.4%} "
                         "Prec: {:.4%} Recall: {:.4%} F1: {:.4%}".format(
                             real_acc.compute(), trigger_acc.compute(), cover_acc.compute(), precision.compute(), recall.compute(), f1.compute()))
            logging.info('-' * 130)

        step += 1

    total_duration = time.time() - epoch_start_time
    logging.info('Epoch {} total duration: {:.2f} sec'.format(epoch, total_duration))
    logging.info('-' * 130)

    metrics_dict['real_acc'] = real_acc.compute()
    metrics_dict['precision'] = precision.compute()
    metrics_dict['recall'] = recall.compute()
    metrics_dict['f1'] = f1.compute()
    metrics_dict['trigger_acc'] = trigger_acc.compute()
    metrics_dict['cover_acc'] = cover_acc.compute()

    write_scalars(epoch, os.path.join(run_folder, 'train.csv'), losses_dict, metrics_dict, None, total_duration)

    save_cat_image(cfg, epoch, run_folder, cover_img, trigger, secret_img, trigger_ext_output, 'train')

    return losses_dict, metrics_dict, triggers, train_wm_labels, trigger_ext_output


def validation(epoch, val_cover_imgs, val_cover_img_labels, val_wms, val_wm_labels):
    epoch_start_time = time.time()
    logging.info('#' * 130)
    logging.info('Running validation for epoch {}/{}'.format(epoch, cfg.train.num_epochs))

    Dnnet.eval()
    ste_model.eval()

    total_duration = 0
    real_preds, real_trues, trigger_preds, trigger_trues, cover_preds, cover_trues = [], [], [], [], [], []
    triggers = torch.Tensor()
    # Save relevant metrics in dictionaries.
    losses_dict = defaultdict(AverageMeter)
    metrics_dict = defaultdict()
    img_quality_dict = defaultdict(AverageMeter)

    real_acc = Accuracy(num_classes=cfg.dataset.num_classes, average='weighted')
    precision = Precision(num_classes=cfg.dataset.num_classes, average='weighted')
    recall = Recall(num_classes=cfg.dataset.num_classes, average='weighted')
    f1 = F1(num_classes=cfg.dataset.num_classes, average='weighted')
    trigger_acc = Accuracy(num_classes=cfg.dataset.num_classes, average='weighted')
    cover_acc = Accuracy(num_classes=cfg.dataset.num_classes, average='weighted')

    with torch.no_grad():
        for batch_idx, (input, label) in enumerate(val_loader):

            input, label = input.to(device), label.to(device)
            cover_img = val_cover_imgs[batch_idx % len(val_cover_imgs)]
            cover_img_label = val_cover_img_labels[batch_idx % len(val_cover_img_labels)]
            secret_img = val_wms[batch_idx % len(val_wms)]
            trigger_label = val_wm_labels[batch_idx % len(val_wm_labels)]
            """############################### ste_model ###############################"""
            trigger, trigger_ext_output = ste_model(cover_img, secret_img)
            if batch_idx >= len(val_loader.dataset) / cfg.train.batchsize - len(val_wm_labels):
                triggers = torch.cat([triggers, trigger.detach().cpu()], dim=0)
            trigger_dnn_output = Dnnet(trigger.detach())
            # loss for steganography
            loss_hid = criterion_mse(cover_img, trigger)
            loss_rev = criterion_mse(secret_img, trigger_ext_output)
            # loss for host DNN and watermarking
            loss_dnn = criterionN(trigger_dnn_output, trigger_label)
            loss_H = cfg.train.loss_hyper_param[0] * loss_hid + cfg.train.loss_hyper_param[1] * loss_rev + cfg.train.loss_hyper_param[2] * loss_dnn
            """############################### Dnnet ###############################"""
            inputs = torch.cat([input, trigger.detach()], dim=0)
            labels = torch.cat([label, trigger_label], dim=0)

            dnn_cat_output = Dnnet(inputs)
            real_output = dnn_cat_output[0:cfg.train.batchsize]
            trigger_output = dnn_cat_output[cfg.train.batchsize:]
            cover_output = Dnnet(cover_img)

            loss_cat_Dnn = criterionN(dnn_cat_output, labels)
            loss_real = criterionN(real_output, label)
            loss_trigger = criterionN(trigger_output, trigger_label)
            """############################### metrics ###############################"""
            real_pred = dnn_cat_output[0:cfg.train.batchsize].argmax(dim=1)
            trigger_pred = dnn_cat_output[cfg.train.batchsize:].argmax(dim=1)
            cover_pred = cover_output.argmax(dim=1)

            real_preds.extend(real_pred.cpu().numpy())
            real_trues.extend(label.cpu().numpy())
            trigger_preds.extend(trigger_pred.cpu().numpy())
            trigger_trues.extend(trigger_label.cpu().numpy())
            cover_preds.extend(cover_pred.cpu().numpy())
            cover_trues.extend(cover_img_label.cpu().numpy())

            real_acc.update(real_pred.cpu(), label.cpu())
            precision.update(real_pred.cpu(), label.cpu())
            recall.update(real_pred.cpu(), label.cpu())
            f1.update(real_pred.cpu(), label.cpu())
            trigger_acc.update(trigger_pred.cpu(), trigger_label.cpu())
            cover_acc.update(cover_pred.cpu(), cover_img_label.cpu())

            temp_losses_dict = {
                'loss_hid': loss_hid.item(),
                'loss_rev': loss_rev.item(),
                'loss_dnn': loss_dnn.item(),
                'loss_H': loss_H.item(),
                'loss_cat_Dnn': loss_cat_Dnn.item(),
                'loss_real': loss_real.item(),
                'loss_trigger': loss_trigger.item()
            }
            for tag, metric in temp_losses_dict.items():
                if tag == 'loss_cat_Dnn':
                    losses_dict[tag].update(metric, inputs.size(0))
                elif tag == 'loss_real':
                    losses_dict[tag].update(metric, input.size(0))
                else:
                    losses_dict[tag].update(metric, trigger.size(0))

            ste_psnr = cal_psnr(cover_img, trigger.detach())
            rev_psnr = cal_psnr(secret_img, trigger_ext_output.detach())
            img_quality_dict['ste_psnr'].update(ste_psnr, trigger.size(0))
            img_quality_dict['rev_psnr'].update(rev_psnr, trigger_ext_output.size(0))

    logging.info(
        '[{}/{}] Loss_H: {:.4f} (loss_hid: {:.4f} loss_rev: {:.4f} loss_dnn: {:.4f}) Loss_cat_Dnn: {:.4f}  Loss_real：{:.4f} Loss_trigger：{:.4f}'.format(
            epoch, cfg.train.num_epochs,
            losses_dict['loss_H'].avg, losses_dict['loss_hid'].avg, losses_dict['loss_rev'].avg, losses_dict['loss_dnn'].avg,  losses_dict['loss_cat_Dnn'].avg,
            losses_dict['loss_real'].avg, losses_dict['loss_trigger'].avg))
    logging.info("Real acc: {:.4%} Trigger acc: {:.4%} Cover acc: {:.4%} "
                 "Precision: {:.4%} Recall: {:.4%} F1: {:.4%} ste_psnr: {:.2f} rev_psnr: {:.2f}".format(real_acc.compute(), trigger_acc.compute(), cover_acc.compute(),
                                                                                                        precision.compute(), recall.compute(), f1.compute(), ste_psnr, rev_psnr))

    total_duration = time.time() - epoch_start_time
    logging.info('Epoch {} total duration: {:.2f} sec'.format(epoch, total_duration))
    logging.info('#' * 130)

    metrics_dict['real_acc'] = real_acc.compute()
    metrics_dict['precision'] = precision.compute()
    metrics_dict['recall'] = recall.compute()
    metrics_dict['f1'] = f1.compute()
    metrics_dict['trigger_acc'] = trigger_acc.compute()
    metrics_dict['cover_acc'] = cover_acc.compute()

    write_scalars(epoch, os.path.join(run_folder, 'val.csv'), losses_dict, metrics_dict, img_quality_dict, total_duration)
    save_cat_image(cfg, epoch, run_folder, cover_img, trigger, secret_img, trigger_ext_output, 'val')

    return losses_dict, metrics_dict, img_quality_dict, triggers, val_wm_labels, trigger_ext_output


def test(test_cover_imgs, test_cover_img_labels, test_wms, test_wm_labels):
    epoch_start_time = time.time()

    Dnnet.eval()
    ste_model.eval()

    test_duration = 0
    real_preds, real_trues, trigger_preds, trigger_trues, cover_preds, cover_trues = [], [], [], [], [], []
    triggers = torch.Tensor()
    # Save relevant metrics in dictionaries.
    losses_dict = defaultdict(AverageMeter)
    metrics_dict = defaultdict()
    img_quality_dict = defaultdict(AverageMeter)

    real_acc = Accuracy(num_classes=cfg.dataset.num_classes, average='weighted')
    precision = Precision(num_classes=cfg.dataset.num_classes, average='weighted')
    recall = Recall(num_classes=cfg.dataset.num_classes, average='weighted')
    f1 = F1(num_classes=cfg.dataset.num_classes, average='weighted')
    trigger_acc = Accuracy(num_classes=cfg.dataset.num_classes, average='weighted')
    cover_acc = Accuracy(num_classes=cfg.dataset.num_classes, average='weighted')

    with torch.no_grad():
        for batch_idx, (input, label) in enumerate(test_loader):
            input, label = input.to(device), label.to(device)
            cover_img = test_cover_imgs[batch_idx % len(test_cover_imgs)]
            cover_img_label = test_cover_img_labels[batch_idx % len(test_cover_img_labels)]
            secret_img = test_wms[batch_idx % len(test_wms)]
            trigger_label = test_wm_labels[batch_idx % len(test_wm_labels)]
            """############################### ste_model ###############################"""
            trigger, trigger_ext_output = ste_model(cover_img, secret_img)
            if batch_idx >= len(val_loader.dataset) / cfg.train.batchsize - len(val_wm_labels):
                triggers = torch.cat([triggers, trigger.detach().cpu()], dim=0)
            trigger_dnn_output = Dnnet(trigger.detach())
            # loss for steganography
            loss_hid = criterion_mse(cover_img, trigger.detach())
            loss_rev = criterion_mse(secret_img, trigger_ext_output.detach())
            # loss for host DNN and watermarking
            loss_dnn = criterionN(trigger_dnn_output, trigger_label)
            loss_H = cfg.train.loss_hyper_param[0] * loss_hid + cfg.train.loss_hyper_param[1] * loss_rev + cfg.train.loss_hyper_param[2] * loss_dnn
            """############################### Dnnet ###############################"""
            inputs = torch.cat([input, trigger.detach()], dim=0)
            labels = torch.cat([label, trigger_label], dim=0)

            dnn_cat_output = Dnnet(inputs)
            real_output = dnn_cat_output[0:cfg.train.batchsize]
            trigger_output = dnn_cat_output[cfg.train.batchsize:]
            cover_output = Dnnet(cover_img)

            loss_cat_Dnn = criterionN(dnn_cat_output, labels)
            loss_real = criterionN(real_output, label)
            loss_trigger = criterionN(trigger_output, trigger_label)
            """############################### metrics ###############################"""
            real_pred = dnn_cat_output[0:cfg.train.batchsize].argmax(dim=1)
            trigger_pred = dnn_cat_output[cfg.train.batchsize:].argmax(dim=1)
            cover_pred = cover_output.argmax(dim=1)

            real_preds.extend(real_pred.cpu().numpy())
            real_trues.extend(label.cpu().numpy())
            trigger_preds.extend(trigger_pred.cpu().numpy())
            trigger_trues.extend(trigger_label.cpu().numpy())
            cover_preds.extend(cover_pred.cpu().numpy())
            cover_trues.extend(cover_img_label.cpu().numpy())

            real_acc.update(real_pred.cpu(), label.cpu())
            precision.update(real_pred.cpu(), label.cpu())
            recall.update(real_pred.cpu(), label.cpu())
            f1.update(real_pred.cpu(), label.cpu())
            trigger_acc.update(trigger_pred.cpu(), trigger_label.cpu())
            cover_acc.update(cover_pred.cpu(), cover_img_label.cpu())

            temp_losses_dict = {
                'loss_hid': loss_hid.item(),
                'loss_rev': loss_rev.item(),
                'loss_dnn': loss_dnn.item(),
                'loss_H': loss_H.item(),
                'loss_cat_Dnn': loss_cat_Dnn.item(),
                'loss_real': loss_real.item(),
                'loss_trigger': loss_trigger.item()
            }
            for tag, metric in temp_losses_dict.items():
                if tag == 'loss_cat_Dnn':
                    losses_dict[tag].update(metric, inputs.size(0))
                elif tag == 'loss_real':
                    losses_dict[tag].update(metric, input.size(0))
                else:
                    losses_dict[tag].update(metric, trigger.size(0))

            ste_psnr = cal_psnr(cover_img, trigger.detach())
            rev_psnr = cal_psnr(secret_img, trigger_dnn_output.detach())
            img_quality_dict['ste_psnr'].update(ste_psnr, trigger.size(0))
            img_quality_dict['rev_psnr'].update(rev_psnr, trigger_dnn_output.size(0))

    logging.info(
        'Loss_H: {:.4f} (loss_hid: {:.4f} loss_rev: {:.4f} dnn: {:.4f}) '
        'Loss_cat_Dnn: {:.4f} Loss_real：{:.4f} Loss_trigger：{:.4f}'.format(
            losses_dict['loss_H'].avg, losses_dict['loss_hid'].avg, losses_dict['loss_rev'].avg, losses_dict['loss_dnn'].avg,
            losses_dict['loss_cat_Dnn'].avg, losses_dict['loss_real'].avg, losses_dict['loss_trigger'].avg))
    logging.info("Real acc: {:.4%} Trigger acc: {:.4%} Cover acc: {:.4%} "
                 "Precision: {:.4%} Recall: {:.4%} F1: {:.4%} ste_psnr: {:.2f} rev_psnr: {:.2f}".format(real_acc.compute(), trigger_acc.compute(), cover_acc.compute(), 
                                                                                   precision.compute(), precision.compute(), recall.compute(), f1.compute(), ste_psnr, rev_psnr))

    test_duration = time.time() - epoch_start_time
    logging.info('test duration {:.2f} sec'.format(test_duration))

    metrics_dict['real_acc'] = real_acc.compute()
    metrics_dict['precision'] = precision.compute()
    metrics_dict['recall'] = recall.compute()
    metrics_dict['f1'] = f1.compute()
    metrics_dict['trigger_acc'] = trigger_acc.compute()
    metrics_dict['cover_acc'] = cover_acc.compute()

    write_scalars(1, os.path.join(run_folder, 'test.csv'), losses_dict, metrics_dict, img_quality_dict, test_duration)
    save_cat_image(cfg, 1, run_folder, cover_img, trigger, secret_img, trigger_ext_output, 'test')
    save_separate_image(1, run_folder, triggers, test_wm_labels, trigger_ext_output, 'test')

    plot_confusion_matrix(1, run_folder, 'test_real', real_preds, real_trues, test_loader)
    plot_confusion_matrix(1, run_folder, 'test_trigger', trigger_preds, trigger_trues, None)


def main():
    min_cat_loss = np.Inf
    min_cat_loss_epoch = 0

    # Early stop the training according to real_loss, trigger_loss and cat_loss respectively.
    early_real_loss = E2E_EarlyStopping(patience=cfg.train.es_patience, verbose=True, trace_func=logging.info)
    early_trigger_loss = E2E_EarlyStopping(patience=cfg.train.es_patience, verbose=True, trace_func=logging.info)
    early_cat_loss = E2E_EarlyStopping(patience=cfg.train.es_patience, verbose=True, trace_func=logging.info)
    # whether follow the early stopping mechanism
    real_go_on, trigger_go_on, cat_go_on = True, True, True
    # DO NOT load checkpoints, since E2E-Extraction requires to train from scratch.
    # load_pretrained()

    for epoch in range(cfg.train.start_epoch, cfg.train.num_epochs + 1):
        # To shuffle the datasets, cover images and secret images should be reloaded every epoch.
        train_cover_imgs, train_cover_img_labels, train_wms, train_wm_labels = load_cover_and_wm(cfg, 'cover_train')
        val_cover_imgs, val_cover_img_labels, val_wms, val_wm_labels = load_cover_and_wm(cfg, 'cover_val')

        logging.info("\ntrain_cover_img_labels:{} \ntrain_cover_img_labels[0]:{} \ntrain_wm_labels:{} \ntrain_wm_labels[0]:{}".format(len(train_cover_img_labels), train_cover_img_labels[0], len(train_wm_labels), train_wm_labels[0]))
        logging.info("val_cover_img_labels:{} \nval_cover_img_labels[0]:{} \nval_wm_labels:{} \nval_wm_labels[0]:{}".format(len(val_cover_img_labels), val_cover_img_labels[0], len(val_wm_labels), val_wm_labels[0]))

        train_losses_dict, train_metrics_dict, train_triggers, train_trigger_labels, train_ext = train(epoch, train_cover_imgs, train_cover_img_labels, train_wms, train_wm_labels)

        val_losses_dict, val_metrics_dict, img_quality_dict, val_triggers, val_trigger_labels, val_ext = validation(epoch, val_cover_imgs, val_cover_img_labels, val_wms, val_wm_labels)
        logging.info("Learning rate: resnet:{} stegan_model:{}".format(optimizerN.param_groups[0]['lr'], optimizerH.param_groups[0]['lr']))
        plot_scalars(epoch, run_folder, train_losses_dict, train_metrics_dict, val_losses_dict, val_metrics_dict, img_quality_dict)

        logging.info("train_triggers:{} val_triggers:{}".format(len(train_triggers), len(val_triggers)))

        schedulerH.step(val_losses_dict['loss_H'].avg)
        schedulerN.step()

        # Save training sets and validation sets of triggers according to loss_cat_Dnn.
        if val_losses_dict['loss_cat_Dnn'].avg < min_cat_loss:
            min_cat_loss = val_losses_dict['loss_cat_Dnn'].avg
            min_cat_loss_epoch = epoch

            save_separate_image(epoch, run_folder, train_triggers, train_trigger_labels, train_ext, 'train')
            save_separate_image(epoch, run_folder, val_triggers, val_trigger_labels, val_ext, 'val')

        if real_go_on is True:
            early_real_loss(val_losses_dict['loss_real'].avg, epoch, run_folder,
                            ste_model, Dnnet, optimizerH, optimizerN,
                            val_losses_dict, val_metrics_dict, img_quality_dict, 'loss_real')
            if early_real_loss.early_stop:
                logging.info("The training has gained minimum real loss at {}th epoch".format(epoch))
                real_go_on = False

        if trigger_go_on is True:
            early_trigger_loss(val_losses_dict['loss_trigger'].avg, epoch, run_folder,
                            ste_model, Dnnet, optimizerH, optimizerN,
                            val_losses_dict, val_metrics_dict, img_quality_dict, 'loss_trigger')
            if early_trigger_loss.early_stop:
                logging.info("The training has gained minimum trigger loss at {}th epoch".format(epoch))
                trigger_go_on = False

        if cat_go_on is True:
            early_cat_loss(val_losses_dict['loss_cat_Dnn'].avg, epoch, run_folder,
                            ste_model, Dnnet, optimizerH, optimizerN,
                            val_losses_dict, val_metrics_dict, img_quality_dict, 'loss_cat_Dnn')
            if early_cat_loss.early_stop:
                logging.info("The training has gained minimum cat loss at {}th epoch".format(epoch))
                cat_go_on = False
        # Due to unexpected fluctuations during training, every checkpoint is expected to be saved despite the early stopping mechanism.
        save_all_models(epoch, run_folder, ste_model, Dnnet, optimizerH, optimizerN,
                        val_losses_dict, val_metrics_dict, img_quality_dict, None, None)


    logging.info("################## Finished ##################")

    if cfg.test is True:
        logging.info("################## Testing... ##################")

        E2E_pt_root = ''  # where you store the checkpoints
        E2E_ste_path, E2E_dnn_path = os.path.join(E2E_pt_root, ''), os.path.join(E2E_pt_root, '')  # Fill in paths of pretrained steganography and host model respectively.
        E2E_ste, E2E_dnn = torch.load(E2E_ste_path), torch.load(E2E_dnn_path)
        # model weights
        ste_model.load_state_dict(E2E_ste[''])
        Dnnet.load_state_dict(E2E_dnn[''])
        logging.info("Have loaded ste_model checkpoint from '{}'".format(E2E_pt_root))

        logging.info("E2E-Extraction models at checkpoint" + '=' * 60)
        logging.info("Loss_H: {:.4f} (loss_hid: {:.4f} loss_rev: {:.4f} loss_dnn: {:.4f}) Loss_cat_Dnn: {:.4f} Loss_real：{:.4f} Loss_trigger：{:.4f}".format(
            E2E_ste['loss_H'], E2E_ste['loss_hid'], E2E_ste['loss_rev'], E2E_ste['loss_dnn'], E2E_dnn['loss_cat_Dnn'], E2E_dnn['loss_real'], E2E_dnn['loss_trigger']))
        logging.info("Real acc: {:.4%} Trigger acc: {:.4%} Cover acc: {:.4%} Precision: {:.4%} Recall: {:.4%} F1: {:.4%} ste_psnr: {:.2f} rev_psnr: {:.2f}".format(
            E2E_dnn['real_acc'], E2E_dnn['trigger_acc'], E2E_dnn['cover_acc'], E2E_dnn['precision'], E2E_dnn['recall'], E2E_dnn['f1'], E2E_ste['ste_psnr'], E2E_ste['rev_psnr']))
        logging.info("E2E-Extraction models at checkpoint" + '=' * 60)

        test_cover_imgs, test_cover_img_labels, test_wms, test_wm_labels = load_cover_and_wm(cfg, 'cover_test')
        logging.info("test_cover_img_labels:{} test_cover_img_labels[0]:{} test_wm_labels:{} test_wm_labels[0]:{}".format(len(test_cover_img_labels), test_cover_img_labels[0], len(test_wm_labels), test_wm_labels[0]))
        test(test_cover_imgs, test_cover_img_labels, test_wms, test_wm_labels)


if __name__ == '__main__':
    main()
