In [1]:

from typing import Dict, List
import torch
import csv
import argparse

from perceptual_advex.utilities import add_dataset_model_arguments, \
    get_dataset_model
from perceptual_advex.attacks import *
from robustbench import load_model


In [5]:
exp = lambda i:f"/root/hhtpro/123/perceptual-advex/data/exp/exp{i}/exp{i}.ckpt.pth"
parser = argparse.ArgumentParser(
    description='Adversarial training evaluation')
args = parser.parse_args([])
args.expnum = 5
args.arch = 'resnet50'
args.parallel = 1
args.dataset = 'cifar'
args.dataset_path = '/root/hhtpro/123/CIFAR10'
args.batch_size = 50
args.num_batches = 10
args.output = f'/root/hhtpro/123/perceptual-advex/data/exp/exp{args.expnum}/evaluation.csv'

if args.expnum in [1, 2, 3, 4]:
    args.checkpoint = exp(args.expnum)
    dataset, model = get_dataset_model(args)
    _, val_loader = dataset.make_loaders(1, args.batch_size, only_val=True)
else:
    args.checkpoint = exp(1)
    dataset, model = get_dataset_model(args)
    _, val_loader = dataset.make_loaders(1, args.batch_size, only_val=True)
    if args.expnum == 5:
        print("here")
        model = load_model(model_name="Rebuffi2021Fixing_70_16_cutmix_extra", 
        dataset='cifar10', threat_model="Linf", model_dir = '/root/hhtpro/123/models')
    elif args.expnum == 6:
        model = load_model(model_name="Rebuffi2021Fixing_70_16_cutmix_extra", 
        dataset='cifar10', threat_model="L2", model_dir = '/root/hhtpro/123/models')
        
model.cuda().eval()

=> loading checkpoint '/root/hhtpro/123/perceptual-advex/data/exp/exp1/exp1.ckpt.pth'
==> Preparing dataset cifar..
Files already downloaded and verified


DMWideResNet(
  (init_conv): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (layer): Sequential(
    (0): _BlockGroup(
      (block): Sequential(
        (0): _Block(
          (batchnorm_0): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu_0): Swish()
          (conv_0): Conv2d(16, 256, kernel_size=(3, 3), stride=(1, 1), bias=False)
          (batchnorm_1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu_1): Swish()
          (conv_1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (shortcut): Conv2d(16, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (1): _Block(
          (batchnorm_0): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu_0): Swish()
          (conv_0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), bias=False)
          (batchn

In [11]:
args.attacks = [
    "NoAttack()",
    "LinfAttack(model, dataset_name='cifar', num_iterations=100)",
    "L2Attack(model, dataset_name='cifar', num_iterations=100)",
    "JPEGLinfAttack(model, dataset_name='cifar', num_iterations=100)",
    "FogAttack(model, dataset_name='cifar', num_iterations=100)",
    "StAdvAttack(model, num_iterations=100)",
    "ReColorAdvAttack(model, num_iterations=100)",
    "LagrangePerceptualAttack(model, num_iterations=40, lpips_model='alexnet')",
    # "PerceptualPGDAttack(model, num_iterations=40, lpips_model='alexnet')"
]
attack_names: List[str] = args.attacks
attacks = [eval(attack_name) for attack_name in attack_names]

# Parallelize
if torch.cuda.is_available():
    device_ids = list(range(args.parallel))
    model = nn.DataParallel(model, device_ids)
    attacks = [nn.DataParallel(attack, device_ids) for attack in attacks]

batches_correct: Dict[str, List[torch.Tensor]] = \
    {attack_name: [] for attack_name in attack_names}

for batch_index, (inputs, labels) in enumerate(val_loader):
    print(f'BATCH {batch_index:05d}')

    if (
        args.num_batches is not None and
        batch_index >= args.num_batches
    ):
        break

    if torch.cuda.is_available():
        inputs = inputs.cuda()
        labels = labels.cuda()

    for attack_name, attack in zip(attack_names, attacks):
        adv_inputs = attack(inputs, labels)
        with torch.no_grad():
            adv_logits = model(adv_inputs)
        batch_correct = (adv_logits.argmax(1) == labels).detach()

        batch_accuracy = batch_correct.float().mean().item()
        print(f'ATTACK {attack_name}',
                f'accuracy = {batch_accuracy * 100:.1f}',
                sep='\t')
        batches_correct[attack_name].append(batch_correct)




BATCH 00000
ATTACK NoAttack()	accuracy = 96.0
ATTACK LinfAttack(model, dataset_name='cifar', num_iterations=100)	accuracy = 66.0
ATTACK L2Attack(model, dataset_name='cifar', num_iterations=100)	accuracy = 36.0
ATTACK JPEGLinfAttack(model, dataset_name='cifar', num_iterations=100)	accuracy = 2.0
ATTACK FogAttack(model, dataset_name='cifar', num_iterations=100)	accuracy = 6.0




ATTACK StAdvAttack(model, num_iterations=100)	accuracy = 4.0
ATTACK ReColorAdvAttack(model, num_iterations=100)	accuracy = 78.0


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


ATTACK LagrangePerceptualAttack(model, num_iterations=40, lpips_model='alexnet')	accuracy = 0.0
BATCH 00001
ATTACK NoAttack()	accuracy = 96.0
ATTACK LinfAttack(model, dataset_name='cifar', num_iterations=100)	accuracy = 62.0
ATTACK L2Attack(model, dataset_name='cifar', num_iterations=100)	accuracy = 30.0
ATTACK JPEGLinfAttack(model, dataset_name='cifar', num_iterations=100)	accuracy = 4.0
ATTACK FogAttack(model, dataset_name='cifar', num_iterations=100)	accuracy = 14.0
ATTACK StAdvAttack(model, num_iterations=100)	accuracy = 4.0
ATTACK ReColorAdvAttack(model, num_iterations=100)	accuracy = 78.0
ATTACK LagrangePerceptualAttack(model, num_iterations=40, lpips_model='alexnet')	accuracy = 0.0
BATCH 00002
ATTACK NoAttack()	accuracy = 98.0
ATTACK LinfAttack(model, dataset_name='cifar', num_iterations=100)	accuracy = 60.0
ATTACK L2Attack(model, dataset_name='cifar', num_iterations=100)	accuracy = 26.0
ATTACK JPEGLinfAttack(model, dataset_name='cifar', num_iterations=100)	accuracy = 2.0
ATTACK

In [None]:

print('OVERALL')
accuracies = []
attacks_correct: Dict[str, torch.Tensor] = {}
for attack_name in attack_names:
    attacks_correct[attack_name] = torch.cat(batches_correct[attack_name])
    accuracy = attacks_correct[attack_name].float().mean().item()
    print(f'ATTACK {attack_name}',
            f'accuracy = {accuracy * 100:.1f}',
            sep='\t')
    accuracies.append(accuracy)

with open(args.output, 'w') as out_file:
    out_csv = csv.writer(out_file)
    out_csv.writerow(attack_names)
    if args.per_example:
        for example_correct in zip(*[
            attacks_correct[attack_name] for attack_name in attack_names
        ]):
            out_csv.writerow(
                [int(attack_correct.item()) for attack_correct
                    in example_correct])
    out_csv.writerow(accuracies)
