In [34]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, Subset, TensorDataset, DataLoader
import torchattacks 
from torchvision.models import resnet18, ResNet18_Weights
import torch.backends.cudnn as cudnn
import torchvision.transforms.functional as TF
from tqdm import tqdm
import json
import glob
from torchvision import transforms
import random
from PIL import Image
import pickle
from autoattack import AutoAttack

In [35]:
torch.cuda.empty_cache()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda:0")

In [36]:
class RIVAL10(Dataset):    
    def __init__(self, train=True, return_masks=False, data_root=".."):

        self.train = train
        self.return_masks = return_masks
        self.instance_types = ['ordinary']

        root_data = data_root + "/{}/"
        self.data_root = root_data.format('train' if self.train else 'test')

        root_mask = data_root + "/{}/entire_object_masks/"
        self.mask_root = root_mask.format('train' if self.train else 'test')

        with open(data_root + "/meta/label_mappings.json", 'r') as f:
            self.label_mappings = json.load(f)
        with open(data_root + "/meta/wnid_to_class.json", 'r') as f:
            self.wnid_to_class = json.load(f)

        self.collect_instances()
        self.collect_images()

    def get_rival10_og_class(self, img_url):
        wnid = img_url.replace('\\', '/').split('/')[-1].split('_')[0]
        inet_class_name = self.wnid_to_class[wnid]
        classname, class_label = self.label_mappings[inet_class_name]
        return classname, class_label

    def collect_instances(self):
        self.instances_by_type = dict()
        self.all_instances = []
        for subdir in self.instance_types:
            instances = []
            dir_path = self.data_root + subdir
            for f in tqdm(glob.glob(dir_path + '/*')):
                if '.JPEG' in f and 'merged_mask' not in f:
                    img_url = f
                    label_path = f[:-5] + '_attr_labels.npy'
                    merged_mask_path = f[:-5] + '_merged_mask.JPEG'
                    mask_dict_path = f[:-5] + '_attr_dict.pkl'
                    instances.append((img_url, label_path, merged_mask_path, mask_dict_path))
            self.instances_by_type[subdir] = instances.copy()
            self.all_instances.extend(self.instances_by_type[subdir])

    def transform(self, imgs):
        transformed_imgs = []
        resize = transforms.Resize((224, 224))
        i, j, h, w = transforms.RandomResizedCrop.get_params(imgs[0], scale=(0.8, 1.0), ratio=(0.75, 1.25))
        coin_flip = (random.random() < 0.5)
        for ind, img in enumerate(imgs):
            if self.train:
                img = TF.crop(img, i, j, h, w)

                if coin_flip:
                    img = TF.hflip(img)

            img = TF.to_tensor(resize(img))

            if img.shape[0] == 1:
                img = torch.cat([img, img, img], axis=0)

            transformed_imgs.append(img)

        return transformed_imgs

    def __len__(self):
        return len(self.all_instances)

    def collect_images(self):

        self.all_images = []

        for img_url, label_path, merged_mask_path, mask_dict_path in tqdm(self.all_instances):

            class_name, class_label = self.get_rival10_og_class(img_url)

            img = Image.open(img_url)
            if img.mode == 'L':
                img = img.convert("RGB")

            imgs = [img]

            if self.return_masks:
                merged_mask_img = Image.open(merged_mask_path)
                imgs = [img, merged_mask_img]

            imgs = self.transform(imgs)

            if self.return_masks:
                self.all_images.append([imgs[0], imgs[1], class_label])
            else:
                self.all_images.append([imgs[0], class_label, img_url])

    def __getitem__(self, i):

        return self.all_images[i]

In [37]:
def load_dataset(num_samples=300, batch_size=128, rival_mask=False):
    
    trainset = RIVAL10(train=True, return_masks=rival_mask)
       
    subset_indices = list(range(num_samples))
    rival_subset = Subset(trainset, subset_indices)
    
    trainloader = torch.utils.data.DataLoader(rival_subset, batch_size=batch_size, shuffle=True, num_workers=2)

    num_classes = 10

    return trainloader, num_classes

In [38]:
def load_saved_images(org_path, tl_path):
    true_labels = []
    original_imgs = []
    with open(tl_path, 'rb') as file:
        true_labels = pickle.load(file)
    
    with open(org_path, 'rb') as file:
        original_imgs = pickle.load(file)
    
    return true_labels, original_imgs

In [39]:
def create_autoattack_adversarial_images(model, original_imgs, true_labels, batch_size=128):
    model.train()  

    adversarial_images = []
    

    auto_attack = AutoAttack(model, norm='Linf', eps=2/255, version='standard', device=device)
    
    for batch_idx, images in enumerate(original_imgs):
        labels = true_labels[batch_idx * batch_size:(batch_idx + 1) * batch_size]
        images = images.to(device)
        labels = labels.to(device)

    
        adv_imgs = auto_attack.run_standard_evaluation(images, labels, bs=batch_size)
        adversarial_images.append(adv_imgs.cpu())  

    return adversarial_images

In [40]:
def create_pgd_adversarial_images(model, original_imgs, true_labels, batch_size=128, epsilon=2/255, alpha=2/255, steps=10):
    model.train()
    
    adversarial_images = []
    

    attack = torchattacks.PGD(model, eps=epsilon, alpha=alpha, steps=steps)

    for batch_idx, images in enumerate(original_imgs):
        
        labels = true_labels[batch_idx * batch_size:(batch_idx + 1) * batch_size]
        images = images.to(device)
        labels = labels.to(device)
        

        adv_imgs = attack(images, labels)
        adversarial_images.append(adv_imgs.cpu()) 
    
    return adversarial_images

In [41]:
def create_cw_adversarial_images(model, original_imgs, true_labels, batch_size=128, c=1e-4, kappa=0, lr=0.01, steps=1000):
    model.train()
    
    adversarial_images = []
    

    attack = torchattacks.CW(model, c=c, kappa=kappa, steps=steps, lr=lr)
    
    for batch_idx, images in enumerate(original_imgs):
        labels = true_labels[batch_idx * batch_size:(batch_idx + 1) * batch_size]
        images = images.to(device)
        labels = labels.to(device)
        

        adv_imgs = attack(images, labels)
        adversarial_images.append(adv_imgs.cpu())  
    
    return adversarial_images

In [42]:
def create_fgsm_adversarial_images(model, original_imgs, true_labels, epsilon=0.007, batch_size=128):
    model.train()  

    adversarial_images = []


    attack = torchattacks.FGSM(model, eps=epsilon)

    for batch_idx, images in enumerate(original_imgs):
        labels = true_labels[batch_idx * batch_size:(batch_idx + 1) * batch_size]
        images = images.to(device)
        labels = labels.to(device)

        adv_imgs = attack(images, labels)
        adversarial_images.append(adv_imgs.cpu())  # Store on CPU to save memory

    return adversarial_images

In [43]:
def calculate_asr(model, adversarial_images, true_labels, batch_size=128):
    incorrect = 0
    total = len(true_labels)

    with torch.no_grad():
        for batch_idx, images in enumerate(adversarial_images):
            labels = true_labels[batch_idx * batch_size:(batch_idx + 1) * batch_size]
            images = images.to(device)
            labels = labels.to(device)
            model.eval()
     
            with torch.no_grad():
                outputs = model(images)
                _, predicted = torch.max(outputs, 1)
            
            # print('Predicted:',predicted)
            # print('True Labels:',labels)
            # print('-------------------')
            incorrect += (predicted != labels).sum().item()
        
        asr = incorrect / total
    return asr

In [44]:
def calculate_acc(model,original_images, true_labels, batch_size=128):
    correct = 0
    total = len(true_labels)
    
    with torch.no_grad():
        for batch_idx, images in enumerate(original_images):
            labels = true_labels[batch_idx * batch_size:(batch_idx + 1) * batch_size]
            images = images.to(device)
            labels = labels.to(device)
            model.eval()
     
            with torch.no_grad():
                outputs = model(images)
                _, predicted = torch.max(outputs, 1)
            
            # print('Predicted:',predicted)
            # print('True Labels:',labels)
            # print('-------------------')
            correct += (predicted == labels).sum().item()
        
        acc = correct / total
    return acc

In [45]:
torch.manual_seed(42)
cudnn.deterministic = True
cudnn.benchmark = False

In [46]:
tl_path = '..'
org_path = '..'
true_labels, original_imgs = load_saved_images(org_path, tl_path)

In [47]:
pretrained_path = '..'
model_res18 = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
num_ftrs = model_res18.fc.in_features
model_res18.fc = nn.Linear(num_ftrs, 10)
model_res18.load_state_dict(torch.load(pretrained_path))
model_res18 = model_res18.to(device)

print(f'Accuracy on Clean: {calculate_acc(model_res18, original_imgs,true_labels) * 100}')

Accuracy on Clean: 97.6


In [48]:

attacks = [
        ('PGD', create_pgd_adversarial_images),
        ('FGSM', create_fgsm_adversarial_images),
        ('AutoAttack', create_autoattack_adversarial_images),
        ('CW', create_cw_adversarial_images)
    ]

attack_success_rates = []


for attack_name, attack_func in attacks:
        print(f"Running {attack_name} attack...")
        adversarial_images = attack_func(model_res18, original_imgs, true_labels)
        attack_success_rate = calculate_asr(model_res18, adversarial_images,true_labels)
        attack_success_rates.append((attack_name, attack_success_rate))
        print(f'{attack_name} Attack Success Rate = {attack_success_rate * 100:.2f}%')
        print('--------------------')


for attack_name, asr in attack_success_rates:
        print(f'{attack_name}: Attack Success Rate = {asr * 100:.2f}%')
    

Running PGD attack...


PGD Attack Success Rate = 98.00%
--------------------
Running FGSM attack...
FGSM Attack Success Rate = 75.80%
--------------------
Running AutoAttack attack...
setting parameters for standard version
using standard version including apgd-ce, apgd-t, fab-t, square.
initial accuracy: 95.31%
apgd-ce - 1/1 - 119 out of 122 successfully perturbed
robust accuracy after APGD-CE: 2.34% (total time 35.7 s)
apgd-t - 1/1 - 3 out of 3 successfully perturbed
robust accuracy after APGD-T: 0.00% (total time 37.6 s)
max Linf perturbation: 0.00784, nan in tensor: 0, max: 1.00000, min: 0.00000
robust accuracy: 0.00%
using standard version including apgd-ce, apgd-t, fab-t, square.
initial accuracy: 97.66%
apgd-ce - 1/1 - 112 out of 125 successfully perturbed
robust accuracy after APGD-CE: 10.16% (total time 36.6 s)
apgd-t - 1/1 - 12 out of 13 successfully perturbed
robust accuracy after APGD-T: 0.78% (total time 42.9 s)
fab-t - 1/1 - 1 out of 1 successfully perturbed
robust accuracy after FAB-T: 0.00% (