In [None]:
import os
import logging
import sys
import warnings
import time
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import defaultdict, OrderedDict
from torchmetrics import Accuracy, Precision, Recall, F1
from itertools import cycle

from data.load_data import *
from network.resnet import resnet18
from utils.AverageMeter import AverageMeter
from utils.helper import *
from utils.pytorchtools import tuned_EarlyStopping

warnings.filterwarnings("ignore")

cfg = load_config()
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
device = torch.device(cfg.system.device)
run_folder = create_folder(cfg.results.run_folder)
# Print the configuration.
logging.basicConfig(level=logging.INFO, format='%(message)s', handlers=[logging.FileHandler(os.path.join(run_folder, f'run.log')), logging.StreamHandler(sys.stdout)])
logging.info("Experiment Configuration:")
logging.info("CUDA_VISIBLE_DEVICES：{}".format(os.getenv('CUDA_VISIBLE_DEVICES')))
logging.info(cfg)
logging.info("run_folder:{}".format(run_folder))
# to be reproducible
if torch.cuda.is_available():
    cudnn.benchmark = True
    if cfg.train.seed is not None:
        np.random.seed(cfg.train.seed)  # Numpy module.
        random.seed(cfg.train.seed)  # Python random module.
        torch.manual_seed(cfg.train.seed)  # Sets the seed for generating random numbers.
        torch.cuda.manual_seed(cfg.train.seed)  # Sets the seed for generating random numbers for the current GPU.
        torch.cuda.manual_seed_all(cfg.train.seed)  # Sets the seed for generating random numbers on all GPUs.
        cudnn.deterministic = True

        warnings.warn('You have choosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

# data loaders
train_loader, val_loader, test_loader = get_loader(cfg, 'train'), get_loader(cfg, 'val'), get_loader(cfg, 'test')
train_trigger_loader, val_trigger_loader, test_trigger_loader = get_loader(cfg, 'trigger_train'), get_loader(cfg, 'trigger_val'), get_loader(cfg, 'trigger_test'), 
logging.info("train_loader:{} val_loader:{} test_loader:{}\n".format(len(train_loader.dataset), len(val_loader.dataset), len(test_loader.dataset)))


def load_pretrained():

    if cfg.train.fine_tuning is True:
        dnn_ckp_path = ''  # Fill in the checkpoint of host DNN.
        dnn_ckp = torch.load(dnn_ckp_path)
        Dnnet.load_state_dict(dnn_ckp[''])  # model weights
        optimizerN.load_state_dict(dnn_ckp[''])  # parameters of the optimizer

        logging.info("At checkpoint" + '=' * 60)

        pg_dict = dnn_ckp['optimizer_state_dict']['param_groups'][0]
        logging.info("dnn_ckp: lr:{} momentum: {} weight_decay:{}".format(pg_dict['lr'], pg_dict['momentum'], pg_dict['weight_decay']))
        pg = optimizerN.param_groups[0]
        logging.info("optimizerN: lr:{} momentum:{} weight_decay:{}".format(pg['lr'], pg['momentum'], pg['weight_decay']))

        logging.info("Epoch：{} val_loss：{:.4f} acc：{:.4%} prec：{:.4%} recall：{:.4%} f1：{:.4%}".format(
            dnn_ckp['epoch'], dnn_ckp['val_loss'], dnn_ckp['acc'], dnn_ckp['prec'], dnn_ckp['recall'], dnn_ckp['f1']))

        logging.info("At checkpoint" + '=' * 60)
        logging.info("Have loaded pretrained ResNet.\n")


# model
Dnnet = resnet18()
Dnnet = nn.DataParallel(Dnnet.to(device))
# optimization
criterionN = nn.CrossEntropyLoss()
optimizerN = optim.SGD(Dnnet.parameters(), lr=cfg.train.lr, momentum=cfg.train.momentum, weight_decay=cfg.train.weight_decay)
schedulerN = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizerN, mode='min', factor=0.1, patience=20, verbose=True)


def train(epoch):

    epoch_start_time = time.time()

    logging.info('\nTraining epoch: %d' % epoch)

    Dnnet.train()

    step = 1
    total_duration = 0
    real_preds, real_trues, trigger_preds, trigger_trues = [], [], [], []
    # Save relevant metrics in dictionaries.
    losses_dict = defaultdict(AverageMeter)
    metrics_dict = defaultdict()

    real_acc = Accuracy(num_classes=cfg.dataset.num_classes, average='weighted')
    precision = Precision(num_classes=cfg.dataset.num_classes, average='weighted')
    recall = Recall(num_classes=cfg.dataset.num_classes, average='weighted')
    f1 = F1(num_classes=cfg.dataset.num_classes, average='weighted')
    trigger_acc = Accuracy(num_classes=cfg.dataset.num_classes, average='weighted')

    for batch_idx, data in enumerate(zip(train_loader, cycle(train_trigger_loader))):

        input, label = data[0][0].to(device), data[0][1].to(device)
        trigger, trigger_label = data[1][0].to(device), data[1][1].to(device)
        """############################### Dnnet ###############################"""
        inputs = torch.cat([input, trigger.detach()], dim=0)
        labels = torch.cat([label, trigger_label], dim=0)

        dnn_cat_output = Dnnet(inputs)
        real_output = dnn_cat_output[0:cfg.train.batchsize]
        trigger_output = dnn_cat_output[cfg.train.batchsize:]

        loss_cat_Dnn = criterionN(dnn_cat_output, labels)
        loss_real = criterionN(real_output, label)
        loss_trigger = criterionN(trigger_output, trigger_label)

        optimizerN.zero_grad()
        loss_cat_Dnn.backward()
        optimizerN.step()
        """############################### metrics ###############################"""

        real_pred = dnn_cat_output[0:cfg.train.batchsize].argmax(dim=1)
        trigger_pred = dnn_cat_output[cfg.train.batchsize:].argmax(dim=1)

        real_preds.extend(real_pred.cpu().numpy())
        real_trues.extend(label.cpu().numpy())
        trigger_preds.extend(trigger_pred.cpu().numpy())
        trigger_trues.extend(trigger_label.cpu().numpy())

        real_acc.update(real_pred.cpu(), label.cpu())
        precision.update(real_pred.cpu(), label.cpu())
        recall.update(real_pred.cpu(), label.cpu())
        f1.update(real_pred.cpu(), label.cpu())
        trigger_acc.update(trigger_pred.cpu(), trigger_label.cpu())

        temp_losses_dict = {
            'loss_cat_Dnn': loss_cat_Dnn.item(),
            'loss_real': loss_real.item(),
            'loss_trigger': loss_trigger.item()
        }
        for tag, metric in temp_losses_dict.items():
            if tag == 'loss_cat_Dnn':
                losses_dict[tag].update(metric, inputs.size(0))
            elif tag == 'loss_real':
                losses_dict[tag].update(metric, input.size(0))
            else:
                losses_dict[tag].update(metric, trigger.size(0))

        if step % cfg.train.print_freq == 0 or step == (len(train_loader)):
            logging.info('[{}/{}][{}/{}] '
                'Loss_cat_Dnn: {:.4f} Loss_real：{:.4f} Loss_trigger：{:.4f}'.format(epoch, cfg.train.num_epochs, step, len(train_loader),
                    losses_dict['loss_cat_Dnn'].avg, losses_dict['loss_real'].avg, losses_dict['loss_trigger'].avg))

            logging.info("Real acc: {:.4%} Trigger acc: {:.4%} "
                         "Prec: {:.4%} Recall: {:.4%} F1: {:.4%}".format(
                             real_acc.compute(), trigger_acc.compute(), precision.compute(), recall.compute(), f1.compute()))
            logging.info('-' * 130)

        step += 1

    total_duration = time.time() - epoch_start_time
    logging.info('Epoch {} total duration: {:.2f} sec'.format(epoch, total_duration))
    logging.info('-' * 130)

    metrics_dict['real_acc'] = real_acc.compute()
    metrics_dict['precision'] = precision.compute()
    metrics_dict['recall'] = recall.compute()
    metrics_dict['f1'] = f1.compute()
    metrics_dict['trigger_acc'] = trigger_acc.compute()

    write_scalars(epoch, os.path.join(run_folder, 'train.csv'), losses_dict, metrics_dict, None, total_duration)

    return losses_dict, metrics_dict


def validation(epoch):
    epoch_start_time = time.time()

    logging.info('#' * 130)
    logging.info('Running validation for epoch {}/{}'.format(epoch, cfg.train.num_epochs))

    Dnnet.eval()

    total_duration = 0
    real_preds, real_trues, trigger_preds, trigger_trues = [], [], [], []
    # Save relevant metrics in dictionaries.
    losses_dict = defaultdict(AverageMeter)
    metrics_dict = defaultdict()

    real_acc = Accuracy(num_classes=cfg.dataset.num_classes, average='weighted')
    precision = Precision(num_classes=cfg.dataset.num_classes, average='weighted')
    recall = Recall(num_classes=cfg.dataset.num_classes, average='weighted')
    f1 = F1(num_classes=cfg.dataset.num_classes, average='weighted')
    trigger_acc = Accuracy(num_classes=cfg.dataset.num_classes, average='weighted')

    with torch.no_grad():
        for batch_idx, data in enumerate(zip(val_loader, cycle(val_trigger_loader))):
            input, label = data[0][0].to(device), data[0][1].to(device)
            trigger, trigger_label = data[1][0].to(device), data[1][1].to(device)
            """############################### Dnnet ###############################"""
            inputs = torch.cat([input, trigger.detach()], dim=0)
            labels = torch.cat([label, trigger_label], dim=0)

            dnn_cat_output = Dnnet(inputs)
            real_output = dnn_cat_output[0:cfg.train.batchsize]
            trigger_output = dnn_cat_output[cfg.train.batchsize:]

            loss_cat_Dnn = criterionN(dnn_cat_output, labels)
            loss_real = criterionN(real_output, label)
            loss_trigger = criterionN(trigger_output, trigger_label)
            """############################### metrics ###############################"""
            real_pred = dnn_cat_output[0:cfg.train.batchsize].argmax(dim=1)
            trigger_pred = dnn_cat_output[cfg.train.batchsize:].argmax(dim=1)

            real_preds.extend(real_pred.cpu().numpy())
            real_trues.extend(label.cpu().numpy())
            trigger_preds.extend(trigger_pred.cpu().numpy())
            trigger_trues.extend(trigger_label.cpu().numpy())

            real_acc.update(real_pred.cpu(), label.cpu())
            precision.update(real_pred.cpu(), label.cpu())
            recall.update(real_pred.cpu(), label.cpu())
            f1.update(real_pred.cpu(), label.cpu())
            trigger_acc.update(trigger_pred.cpu(), trigger_label.cpu())

            temp_losses_dict = {
                'loss_cat_Dnn': loss_cat_Dnn.item(),
                'loss_real': loss_real.item(),
                'loss_trigger': loss_trigger.item()
            }
            for tag, metric in temp_losses_dict.items():
                if tag == 'loss_cat_Dnn':
                    losses_dict[tag].update(metric, inputs.size(0))
                elif tag == 'loss_real':
                    losses_dict[tag].update(metric, input.size(0))
                else:
                    losses_dict[tag].update(metric, trigger.size(0))

    logging.info(
        '[{}/{}] Loss_cat_Dnn: {:.4f}  Loss_real：{:.4f} Loss_trigger：{:.4f}'.format(
            epoch, cfg.train.num_epochs,
            losses_dict['loss_cat_Dnn'].avg, losses_dict['loss_real'].avg, losses_dict['loss_trigger'].avg))
    logging.info("Real acc: {:.4%} Trigger acc: {:.4%} "
                 "Precision: {:.4%} Recall: {:.4%} F1: {:.4%}".format(real_acc.compute(), trigger_acc.compute(),
                                                                                                        precision.compute(), recall.compute(), f1.compute()))

    total_duration = time.time() - epoch_start_time
    logging.info('Epoch {} total duration: {:.2f} sec'.format(epoch, total_duration))
    logging.info('#' * 130)

    metrics_dict['real_acc'] = real_acc.compute()
    metrics_dict['precision'] = precision.compute()
    metrics_dict['recall'] = recall.compute()
    metrics_dict['f1'] = f1.compute()
    metrics_dict['trigger_acc'] = trigger_acc.compute()

    write_scalars(epoch, os.path.join(run_folder, 'val.csv'), losses_dict, metrics_dict, None, total_duration)

    return losses_dict, metrics_dict


def test():
    epoch_start_time = time.time()

    Dnnet.eval()

    test_duration = 0
    real_preds, real_trues, trigger_preds, trigger_trues = [], [], [], []

    losses_dict = defaultdict(AverageMeter)
    metrics_dict = defaultdict()

    real_acc = Accuracy(num_classes=cfg.dataset.num_classes, average='weighted')
    precision = Precision(num_classes=cfg.dataset.num_classes, average='weighted')
    recall = Recall(num_classes=cfg.dataset.num_classes, average='weighted')
    f1 = F1(num_classes=cfg.dataset.num_classes, average='weighted')
    trigger_acc = Accuracy(num_classes=cfg.dataset.num_classes, average='weighted')

    with torch.no_grad():
        for batch_idx, data in enumerate(zip(test_loader, cycle(test_trigger_loader))):
            input, label = data[0][0].to(device), data[0][1].to(device)
            trigger, trigger_label = data[1][0].to(device), data[1][1].to(device)
            """############################### Dnnet ###############################"""
            
            inputs = torch.cat([input, trigger.detach()], dim=0)
            labels = torch.cat([label, trigger_label], dim=0)
            
            dnn_cat_output = Dnnet(inputs)
            real_output = dnn_cat_output[0:cfg.train.batchsize]
            trigger_output = dnn_cat_output[cfg.train.batchsize:]

            loss_cat_Dnn = criterionN(dnn_cat_output, labels)
            loss_real = criterionN(real_output, label)
            loss_trigger = criterionN(trigger_output, trigger_label)
            """############################### metrics ###############################"""
            
            real_pred = dnn_cat_output[0:cfg.train.batchsize].argmax(dim=1)
            trigger_pred = dnn_cat_output[cfg.train.batchsize:].argmax(dim=1)
            
            real_preds.extend(real_pred.cpu().numpy())
            real_trues.extend(label.cpu().numpy())
            trigger_preds.extend(trigger_pred.cpu().numpy())
            trigger_trues.extend(trigger_label.cpu().numpy())
            
            real_acc.update(real_pred.cpu(), label.cpu())
            precision.update(real_pred.cpu(), label.cpu())
            recall.update(real_pred.cpu(), label.cpu())
            f1.update(real_pred.cpu(), label.cpu())
            trigger_acc.update(trigger_pred.cpu(), trigger_label.cpu())

            temp_losses_dict = {
                'loss_cat_Dnn': loss_cat_Dnn.item(),
                'loss_real': loss_real.item(),
                'loss_trigger': loss_trigger.item()
            }
            for tag, metric in temp_losses_dict.items():
                if tag == 'loss_cat_Dnn':
                    losses_dict[tag].update(metric, inputs.size(0))
                elif tag == 'loss_real':
                    losses_dict[tag].update(metric, input.size(0))
                else:
                    losses_dict[tag].update(metric, trigger.size(0))

    logging.info(
        'Loss_cat_Dnn: {:.4f} Loss_real：{:.4f} Loss_trigger：{:.4f}'.format(
            losses_dict['loss_cat_Dnn'].avg, losses_dict['loss_real'].avg, losses_dict['loss_trigger'].avg))
    logging.info("Real acc: {:.4%} Trigger acc: {:.4%} "
                 "Precision: {:.4%} Recall: {:.4%} F1: {:.4%}".format(real_acc.compute(), trigger_acc.compute(), 
                                                                                   precision.compute(), precision.compute(), recall.compute(), f1.compute()))

    test_duration = time.time() - epoch_start_time
    logging.info('test duration {:.2f} sec'.format(test_duration))

    metrics_dict['real_acc'] = real_acc.compute()
    metrics_dict['precision'] = precision.compute()
    metrics_dict['recall'] = recall.compute()
    metrics_dict['f1'] = f1.compute()
    metrics_dict['trigger_acc'] = trigger_acc.compute()

    write_scalars(1, os.path.join(run_folder, 'test.csv'), losses_dict, metrics_dict, None, test_duration)
    # confusion matrix
    plot_confusion_matrix(1, run_folder, 'test_real', real_preds, real_trues, test_loader)
    plot_confusion_matrix(1, run_folder, 'test_trigger', trigger_preds, trigger_trues, None)


def main():
    # Early stop the training according to real_loss, trigger_loss and cat_loss respectively.
    early_real_loss = tuned_EarlyStopping(patience=cfg.train.es_patience, verbose=True, trace_func=logging.info)
    early_trigger_loss = tuned_EarlyStopping(patience=cfg.train.es_patience, verbose=True, trace_func=logging.info)
    early_cat_loss = tuned_EarlyStopping(patience=cfg.train.es_patience, verbose=True, trace_func=logging.info)
    # whether follow the early stopping mechanism
    real_go_on, trigger_go_on, cat_go_on = True, True, True

    load_pretrained()

    for epoch in range(cfg.train.start_epoch, cfg.train.num_epochs + 1):

        train_losses_dict, train_metrics_dict = train(epoch)
        val_losses_dict, val_metrics_dict = validation(epoch)

        logging.info("Learning rate of ResNet:{}".format(optimizerN.param_groups[0]['lr']))

        plot_scalars(epoch, run_folder, train_losses_dict, train_metrics_dict, val_losses_dict, val_metrics_dict, None)

        schedulerN.step(val_losses_dict['loss_cat_Dnn'].avg)

        if real_go_on is True:
            early_real_loss(val_losses_dict['loss_real'].avg, epoch, run_folder, Dnnet, optimizerN, val_losses_dict, val_metrics_dict, 'loss_real')
            if early_real_loss.early_stop:
                logging.info("The training has gained minimum real loss at {}th epoch".format(epoch))
                real_go_on = False

        if trigger_go_on is True:
            early_trigger_loss(val_losses_dict['loss_trigger'].avg, epoch, run_folder, Dnnet, optimizerN, val_losses_dict, val_metrics_dict, 'loss_trigger')
            if early_trigger_loss.early_stop:
                logging.info("The training has gained minimum trigger loss at {}th epoch".format(epoch))
                trigger_go_on = False

        if cat_go_on is True:
            early_cat_loss(val_losses_dict['loss_cat_Dnn'].avg, epoch, run_folder, Dnnet, optimizerN, val_losses_dict, val_metrics_dict, 'loss_cat_Dnn')
            if early_cat_loss.early_stop:
                logging.info("The training has gained minimum cat_loss at {}th epoch".format(epoch))
                cat_go_on = False
        # Due to unexpected fluctuations during training, every checkpoint is expected to be saved despite the early stopping mechanism.
        save_tuned_host(epoch, run_folder, Dnnet, optimizerN, val_losses_dict, val_metrics_dict)


    logging.info("################## Finished ##################")

    if cfg.test is True:
        logging.info("################## Testing... ##################")
        tuned_pt_root = ''  # where you store the checkpoints
        tuned_ste_path, tuned_dnn_path = os.path.join(tuned_pt_root, ''), os.path.join(tuned_pt_root, '')  # Fill in paths of trained steganography and fine-tuned host model respectively.
        tuned_ste, tuned_dnn = torch.load(tuned_ste_path), torch.load(tuned_dnn_path)
        # model weights
        ste_model.load_state_dict(tuned_ste[''])
        Dnnet.load_state_dict(tuned_dnn[''])
        logging.info("Have loaded model checkpoints from '{}'".format(tuned_pt_root))

        logging.info("Fine-tuned model at checkpoint" + '=' * 60)
        logging.info("Loss_H: {:.4f} (loss_hid: {:.4f} loss_rev: {:.4f} loss_dnn: {:.4f}) Loss_cat_Dnn: {:.4f} Loss_real：{:.4f} Loss_trigger：{:.4f}".format(
            tuned_ste['loss_H'], tuned_ste['loss_hid'], tuned_ste['loss_rev'], tuned_ste['loss_dnn'], tuned_dnn['loss_cat_Dnn'], tuned_dnn['loss_real'], tuned_dnn['loss_trigger']))
        logging.info("Real acc: {:.4%} Trigger acc: {:.4%} Cover acc: {:.4%} Precision: {:.4%} Recall: {:.4%} F1: {:.4%} ste_psnr: {:.2f} rev_psnr: {:.2f}".format(
            tuned_dnn['real_acc'], tuned_dnn['trigger_acc'], tuned_dnn['cover_acc'], tuned_dnn['precision'], tuned_dnn['recall'], tuned_dnn['f1'], tuned_ste['ste_psnr'], tuned_ste['rev_psnr']))
        logging.info("Fine-tuned model at checkpoint" + '=' * 60)

        test_cover_imgs, test_cover_img_labels, test_wms, test_wm_labels = load_cover_and_wm(cfg, 'cover_test')
        logging.info("test_cover_img_labels:{} test_cover_img_labels[0]:{} test_wm_labels:{} test_wm_labels[0]:{}".format(len(test_cover_img_labels), test_cover_img_labels[0], len(test_wm_labels), test_wm_labels[0]))
        test(test_cover_imgs, test_cover_img_labels, test_wms, test_wm_labels)


if __name__ == '__main__':
    main()
