In [None]:
import os
import logging
import torch
import sys
import time
import torch.nn as nn

from torchmetrics import Accuracy
from utils.AverageMeter import AverageMeter
from utils.helper import *
from model.resnet import resnet18
from data.data_loader import get_loader

cfg = load_config()
device = torch.device(cfg.system.device)
run_folder = create_folder(cfg.results.run_folder)

logging.basicConfig(level=logging.INFO, format='%(message)s', handlers=[logging.FileHandler(os.path.join(run_folder, f'run.log')), logging.StreamHandler(sys.stdout)])
logging.info("Experiment Configuration:")

if torch.cuda.is_available():
    cudnn.benchmark = False
    if cfg.train.seed is not None:
        np.random.seed(cfg.train.seed)  # Numpy module.
        random.seed(cfg.train.seed)  # Python random module.
        torch.manual_seed(cfg.train.seed)  # Sets the seed for generating random numbers.
        torch.cuda.manual_seed(cfg.train.seed)  # Sets the seed for generating random numbers for the current GPU.
        torch.cuda.manual_seed_all(cfg.train.seed)  # Sets the seed for generating random numbers on all GPUs.
        cudnn.deterministic = True


model = resnet18()
model = nn.DataParallel(model.to(device))
criterion = nn.CrossEntropyLoss()

trigger_loader = get_loader(cfg, 'trigger')

logging.info(cfg)
logging.info(run_folder)


def robustness_attack():

    ckp_path = ''  # Fill in the path of the model checkpoint.
    checkpoint = torch.load(ckp_path)
    model.load_state_dict(checkpoint[''])  # model weights
    logging.info("Have loaded test checkpoint from '{}'".format(ckp_path))

    model.eval()
    loss_avg = AverageMeter()

    acc = Accuracy(num_classes=cfg.dataset.num_classes, average='weighted')

    with torch.no_grad():
        for _, (trigger, trigger_label) in enumerate(trigger_loader):

            trigger = trigger.to(device).float()
            trigger_label = trigger_label.to(device)

            output = model(trigger)
            loss = criterion(output, trigger_label)

            trigger_preds_cpu = torch.argmax(output, dim=1).detach().cpu()
            trigger_label_cpu = trigger_label.detach().cpu()

            loss_avg.update(loss.item(), trigger.size(0))
            acc.update(trigger_preds_cpu, trigger_label_cpu)

        logging.info('Loss: {:.4f}\t accuracy_socre: {:.4%}\t'.format(loss_avg.avg, acc.compute()))


if __name__ == '__main__':
    robustness_attack()
