In [1]:
import torch
import torch.nn.functional as F
from torchvision.models import convnext_tiny, ConvNeXt_Tiny_Weights
from torchvision import transforms
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from PIL import Image
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [11]:
!export CUDA_VISIBLE_DEVICES=1

In [12]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

device(type='cuda')

In [16]:
# Load model
model = convnext_tiny(weights=ConvNeXt_Tiny_Weights.DEFAULT).eval()
model = model.to(device)
# load data. data downlaoded from https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000 
path_to_imagenet = "./imagenet-mini-val"
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
])
dataset = datasets.ImageFolder(path_to_imagenet, transform=transform)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False) # batch_size changed to 1 to test pgd


In [14]:
# test classifier 

all_preds = []
all_labels = []

with torch.no_grad():
    for _, batch in enumerate(dataloader):
        inputs, labels = batch
        inputs, labels = inputs.to(device), labels.to(device)
        
        logits = model(inputs)
        _, preds = torch.max(logits, 1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        
accuracy = accuracy_score(all_labels, all_preds)
print(accuracy)

0.7705837369360183


In [26]:
epsilon = 8/255
alpha = 2/255
max_steps = 40

all_preds = []
all_labels = []
all_targets = []
all_step = []
all_adverse_preds = []


for _, batch in enumerate(dataloader):
    inputs, labels = batch
    inputs, labels = inputs.to(device), labels.to(device)
    
    logits = model(inputs)
    preds = torch.argmax(logits)
    all_preds.append(preds.cpu().numpy())
    all_labels.append(labels.cpu().numpy()[0])

    target_tensor = torch.randint(0, 1000, labels.shape, device=device)
    all_targets.append(target_tensor.cpu().numpy()[0])
    # print(labels.cpu().numpy()[0], preds.cpu().numpy(), target_tensor.cpu().numpy()[0])
    
    adverse_image = inputs.clone().detach()
    found=False
    for step in range(max_steps):
        adverse_image.requires_grad = True

        adverse_logits = model(adverse_image)
        adverse_class = torch.argmax(adverse_logits).item()
        
        if adverse_class == target_tensor.cpu().numpy()[0]:
            all_step.append(step)
            all_adverse_preds.append(adverse_class)
            found = True
            # print(step, adverse_class, target_tensor.cpu().numpy()[0])
            break

        adverse_loss = -F.cross_entropy(adverse_logits, target_tensor)
        grad = torch.autograd.grad(adverse_loss, adverse_image, retain_graph=False, create_graph=False)[0]

        adverse_image = adverse_image + alpha * grad.sign()
        adverse_image = torch.max(torch.min(adverse_image, inputs + epsilon), inputs - epsilon)
        adverse_image = torch.clamp(adverse_image, 0, 1).detach()

    if not found:
        all_step.append(max_steps)
        all_adverse_preds.append(adverse_class)

print(accuracy_score(all_labels, all_preds))
print(accuracy_score(all_targets, all_adverse_preds))

0.7705837369360183
1.0


In [27]:
# mean steps 
import numpy as np
from collections import Counter

counts = Counter(all_step)
print(counts)
all_step = np.asarray(all_step)
np.mean(all_step), np.std(all_step), np.max(all_step), np.min(all_step)


Counter({2: 1744, 3: 1398, 4: 398, 1: 285, 5: 76, 6: 16, 0: 4, 7: 1, 8: 1})


(2.561814937547795, 0.8822837660324482, 8, 0)

In [None]:
# epsilon = 8/255
# alpha = 2/255
# Counter({2: 1744, 3: 1398, 4: 398, 1: 285, 5: 76, 6: 16, 0: 4, 7: 1, 8: 1})
# (2.561814937547795, 0.8822837660324482, 8, 0)

# epsilon = 4/255
# alpha = 1/255
# Counter({3: 1633, 2: 1325, 4: 608, 1: 171, 5: 142, 6: 22, 7: 12, 0: 4, 8: 4, 10: 1, 40: 1})
# (2.844761661993372, 1.1328366391138112, 40, 0)

# epsilon = 1/255
# alpha = 0.5/255
# Counter({3: 927, 4: 749, 2: 651, 5: 463, 6: 282, 7: 158, 40: 146, 8: 118, 1: 97, 9: 69, 10: 60, 11: 40, 14: 21, 12: 20, 17: 15, 13: 15, 16: 12, 15: 12, 18: 10, 23: 7, 19: 6, 22: 5, 30: 4, 21: 4, 28: 4, 29: 4, 24: 4, 34: 3, 20: 3, 35: 2, 25: 2, 32: 2, 0: 2, 26: 2, 33: 1, 38: 1, 31: 1, 27: 1})
# (6.043843996941116, 7.6253140214820725, 40, 0)