In [None]:
import os
import random
import time
import numpy as np
import torch
import logging
import sys
import warnings
import torch.nn as nn
import torch.optim as optim
from torch.backends import cudnn
from collections import defaultdict

from load_data.load_ste_data import *
from algorithm.GoogleNet import GoogleNet
from utils.AverageMeter import AverageMeter
from utils.helper import *

warnings.filterwarnings("ignore")

cfg = load_config()
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
device = torch.device(cfg.system.device)
run_folder = create_folder(cfg.results.run_folder)
# Print the configuration.
logging.basicConfig(level=logging.INFO, format='%(message)s', handlers=[
    logging.FileHandler(os.path.join(run_folder, f'run.log')), logging.StreamHandler(sys.stdout)])
logging.info("Experiment Configuration:")
logging.info("CUDA_VISIBLE_DEVICES：{}".format(os.getenv('CUDA_VISIBLE_DEVICES')))
logging.info(cfg)
logging.info("run_folder:{}".format(run_folder))
# to be reproducible
if torch.cuda.is_available():
    cudnn.benchmark = True
    if cfg.train.seed is not None:
        np.random.seed(cfg.train.seed)  # Numpy module.
        random.seed(cfg.train.seed)  # Python random module.
        torch.manual_seed(cfg.train.seed)  # Sets the seed for generating random numbers.
        torch.cuda.manual_seed(cfg.train.seed)  # Sets the seed for generating random numbers for the current GPU.
        torch.cuda.manual_seed_all(cfg.train.seed)  # Sets the seed for generating random numbers on all GPUs.
        cudnn.deterministic = True

        warnings.warn('You have choosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')
        logging.info('torch.cuda is available!')

# model
ste_model = GoogleNet().to(device)
ste_model = nn.DataParallel(ste_model.to(device))
# loss function
criterion_mse = nn.MSELoss()
# data loaders
cover_test_loader = get_loader(cfg, 'trigger_cover')
secret_test_loader = get_loader(cfg, 'trigger_secret')

logging.info("cover_test_loader:{}".format(len(cover_test_loader.dataset)))
logging.info("secret_test_loader:{}".format(len(secret_test_loader.dataset)))


def generate():
    epoch_start_time = time.time()
    epoch_duration = 0

    ste_model.eval()
    step = 1

    cover_imgs, ste_imgs, rev_imgs, trigger_labels = torch.Tensor(), torch.Tensor(), torch.Tensor(), torch.Tensor()
    # Save relevant metrics in dictionaries.
    losses_dict = defaultdict(AverageMeter)
    img_quality_dict = defaultdict(AverageMeter)
    secret_imgs, wm_labels = load_secret(cfg)

    for batch_idx, data in enumerate(cover_test_loader):
        
        cover_img = data[0].to(device)
        secret_img = secret_imgs[batch_idx % len(secret_imgs)]
        trigger_label = wm_labels[batch_idx % len(wm_labels)]
        logging.info("trigger_label:{}".format(trigger_label))

        ste_img, ex_secret = ste_model(cover_img, secret_img)

        cover_imgs = torch.cat([cover_imgs, cover_img.cpu()], dim=0)
        ste_imgs = torch.cat([ste_imgs, ste_img.detach().cpu()], dim=0)
        rev_imgs = torch.cat([rev_imgs, ex_secret.detach().cpu()], dim=0)
        trigger_labels = torch.cat([trigger_labels, trigger_label.cpu()], dim=0)

        loss_hid = criterion_mse(cover_img, ste_img)
        loss_rev = criterion_mse(secret_img, ex_secret)
        loss_ste = loss_hid + loss_rev

        epoch_duration += time.time() - epoch_start_time

        temp_losses_dict = {
            'loss_hid': loss_hid.item(),
            'loss_rev': loss_rev.item(),
            'loss_ste': loss_ste.item()
        }
        for tag, metric in temp_losses_dict.items():
            losses_dict[tag].update(metric, cover_img.size(0))

        ste_psnr = cal_psnr(cover_img, ste_img.detach())
        ex_psnr = cal_psnr(secret_img, ex_secret.detach())
        img_quality_dict['ste_psnr'].update(ste_psnr, ste_img.size(0))
        img_quality_dict['ex_psnr'].update(ex_psnr, ex_secret.size(0))

        if step % cfg.train.print_freq == 0 or step == (len(cover_test_loader)):
            logging.info('[{}/{}] loss_ste: {:.4f} (loss_hid: {:.4f} loss_rev: {:.4f}) ste_PSNR: {:.2f} ex_PSNR: {:.2f}'.format(
                step, len(cover_test_loader), losses_dict['loss_ste'].avg, losses_dict['loss_hid'].avg, losses_dict['loss_rev'].avg, ste_psnr, ex_psnr))

        step += 1

    epoch_duration = time.time() - epoch_start_time
    logging.info('Duration {:.2f} sec'.format(epoch_duration))

    write_scalars(1, os.path.join(run_folder, 'generate.csv'), losses_dict, img_quality_dict, epoch_duration)
    save_trigger(run_folder, cover_imgs, ste_imgs, rev_imgs, trigger_labels, cover_test_loader)


def main():
    logging.info("################## Generating... ##################")
    ckp_path = os.path.join(r'')  # Fill in the path of specific version of trained GglNet.
    checkpoint = torch.load(ckp_path)
    ste_model.load_state_dict(checkpoint[''])  # model weights
    logging.info("Have loaded steganography checkpoint from '{}'".format(ckp_path))

    generate()


if __name__ == '__main__':
    main()
