In [None]:
import imageio
import torch
import sys
import numpy as np
import torchvision.transforms as transforms

from glob import glob
from algorithm.YeNet import YeNet
from load_data.pair_loader import *
from utils.helper import *

cfg = load_config()
device = torch.device(cfg.device)
run_folder = create_folder(cfg.run_folder)

# to be reproducible
if torch.cuda.is_available():
    cudnn.benchmark = True
    if cfg.seed is not None:
        np.random.seed(cfg.seed)  # Numpy module.
        random.seed(cfg.seed)  # Python random module.
        torch.manual_seed(cfg.seed)  # Sets the seed for generating random numbers.
        torch.cuda.manual_seed(cfg.seed)  # Sets the seed for generating random numbers for the current GPU.
        torch.cuda.manual_seed_all(cfg.seed)  # Sets the seed for generating random numbers on all GPUs.
        cudnn.deterministic = True

# Print the configuration.
logging.basicConfig(level=logging.INFO, format='%(message)s', handlers=[logging.FileHandler(os.path.join(run_folder, f'run.log')), logging.StreamHandler(sys.stdout)])
logging.info("Experiment Configuration:")
logging.info("CUDA_VISIBLE_DEVICES：{}".format(os.getenv('CUDA_VISIBLE_DEVICES')))
logging.info(cfg)
logging.info("run_folder:{}".format(run_folder))

TEST_BATCH_SIZE = cfg.batch_size
COVER_PATH = cfg.cover_path
STEGO_PATH = cfg.stego_path
CHKPT = cfg.ckp

model = YeNet().to(device)

ckpt = torch.load(CHKPT)
model.load_state_dict(ckpt[''])  # model weights

test_accuracy = []

test_data = Dataset_Load_trigger(cfg.stego_path, cfg.test_size,
                                      transform=transforms.Compose([
                                          ToPILImage(),
                                          ToGrayImage(),
                                          # Resize(cfg.resize, cfg.resize),
                                          ToTensor(cfg.resize)]))

test_loader = DataLoader(test_data, batch_size=cfg.test_batch_size, shuffle=True)

for i, test_batch in enumerate(test_loader):

    trigger = test_batch['stego'].to(device, dtype=torch.float)
    trigger_label = test_batch['label'][0].to(device, dtype=torch.long)

    outputs = model(trigger)
    prediction = outputs.data.max(1)[1]

    accuracy = prediction.eq(trigger_label.data).sum()*100.0/(trigger_label.size()[0])
    test_accuracy.append(accuracy.item())

logging.info("detection accuracy = %.2f"%(sum(test_accuracy)/len(test_accuracy)))