In [None]:
import os
import random
import time
import numpy as np
import torch
import logging
import sys
import warnings
import torch.nn as nn
import torch.optim as optim
from torch.backends import cudnn
from collections import defaultdict

from load_data.trigger_loader import get_loader
from algorithm.GoogleNet_decoder import GoogleNet_decoder
from utils.AverageMeter import AverageMeter
from utils.helper import *

warnings.filterwarnings("ignore")

cfg = load_config()
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
device = torch.device(cfg.system.device)
run_folder = create_folder(cfg.results.run_folder)
# Print the configuration.
logging.basicConfig(level=logging.INFO, format='%(message)s', handlers=[
    logging.FileHandler(os.path.join(run_folder, f'run.log')), logging.StreamHandler(sys.stdout)])
logging.info("Experiment Configuration:")
logging.info(cfg)
logging.info("run_folder:{}".format(run_folder))
# to be reproducible
if torch.cuda.is_available():
    cudnn.benchmark = True
    if cfg.train.seed is not None:
        np.random.seed(cfg.train.seed)  # Numpy module.
        random.seed(cfg.train.seed)  # Python random module.
        torch.manual_seed(cfg.train.seed)  # Sets the seed for generating random numbers.
        torch.cuda.manual_seed(cfg.train.seed)  # Sets the seed for generating random numbers for the current GPU.
        torch.cuda.manual_seed_all(cfg.train.seed)  # Sets the seed for generating random numbers on all GPUs.
        cudnn.deterministic = True

# Only decoder is needed.
revealNet = nn.DataParallel(revealNet.to(device))
criterion_mse = nn.MSELoss()

ste_ckp_path = ''  # Fill in the path of encoder-decoder checkpoint.
ste_ckp = torch.load(ste_ckp_path)
decoder_state_dict = {k: v for k, v in ste_ckp['model_state_dict'].items() if 's3' in k}  # pick out the weights of decoder
for k, v in decoder_state_dict.items():
    logging.info(k)
revealNet.load_state_dict(decoder_state_dict)
logging.info("Have loaded decoder checkpoint from {}".format(ste_ckp_path))


trigger_loader = get_loader(cfg, 'trigger')
wm_loader = get_loader(cfg, 'watermark')

all_wm, all_wm_label = {}, {}
for wm, wm_label in wm_loader:
    wm, wm_label = wm.to(device), wm_label.to(device)
    all_wm[wm_label.item()] = wm
    all_wm_label[wm_label.item()] = wm_label
logging.info("all_wm_label:{}".format(all_wm_label))


def spoofing_attack():
    epoch_start_time = time.time()
    epoch_duration = 0

    revealNet.eval()
    step = 1

    # Save relevant metrics in dictionaries.
    losses_dict = defaultdict(AverageMeter)
    img_quality_dict = defaultdict(AverageMeter)  # PSNR

    for batch_idx, (trigger, trigger_label) in enumerate(trigger_loader):

        trigger = trigger.to(device)
        trigger_label = trigger_label.to(device)
        logging.info("trigger_label:{}".format(trigger_label))
        secret_img = torch.LongTensor().to(device)
        # Store a batch of secret images by specific order(the secret images hidden in triggers are predefined).
        for label in trigger_label:
            if label.item() == 7:
                secret_img = torch.cat((secret_img, all_wm[7]))
            elif label.item() == 56:
                secret_img = torch.cat((secret_img, all_wm[56]))
            elif label.item() == 83:
                secret_img = torch.cat((secret_img, all_wm[83]))
            elif label.item() == 98:
                secret_img = torch.cat((secret_img, all_wm[98]))

        ex_secret = revealNet(trigger)

        loss_rev = criterion_mse(secret_img, ex_secret)

        epoch_duration += time.time() - epoch_start_time

        temp_losses_dict = {
            'loss_rev': loss_rev.item(),
        }
        for tag, metric in temp_losses_dict.items():
            losses_dict[tag].update(metric, trigger.size(0))

        ex_psnr = cal_psnr(secret_img, ex_secret.detach())
        img_quality_dict['ex_psnr'].update(ex_psnr, ex_secret.size(0))
        
        # Save a batch of extracted images and corresponding original ones.
        save_rev(cfg, run_folder, secret_img, ex_secret)

        if step % cfg.train.print_freq == 0 or step == (len(trigger_loader)):
            logging.info('[{}/{}]loss_rev: {:.4f} ex_PSNR: {:.2f}'.format(step, len(trigger_loader), losses_dict['loss_rev'].avg, img_quality_dict['ex_psnr'].avg))

        step += 1

    epoch_duration = time.time() - epoch_start_time
    logging.info('Duration {:.2f} sec'.format(epoch_duration))

    write_scalars(1, os.path.join(run_folder, 'spoofing_attack.csv'), losses_dict, img_quality_dict, epoch_duration)


if __name__ == '__main__':
    spoofing_attack()