# Fast Gradient Sign Attack on some ImageNet samples on a Trained ResNet-18

In [None]:
import torch
import torch.nn as nn
from torchvision import models
from torchvision import transforms

from src.data.ImageNet300 import ImageNet300Dataset
from src.utils.getimagenetclasses import get_classes
from src.attacks.attacks import FastGradientSign

In [None]:
resnet_pretrained = models.resnet18(pretrained = True)
resnet_pretrained.eval()
pass

root_dir = r"C:\Users\willi\Documents\in5400\mand1\prelimcode\students\data\imagenetval300imgs\imagenet300"
xmllabeldir = r"C:\Users\willi\Documents\in5400\mand1\prelimcode\students\data\imagenetval300imgs\val"
synsetfile = r"C:\Users\willi\Documents\in5400\mand1\prelimcode\students\synset_words.txt"

image_transforms =  transforms.Compose([
          transforms.Resize((256,256)),
          transforms.ToTensor(),
          transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])


dataset = ImageNet300Dataset(root_dir, xmllabeldir, synsetfile, 300, image_transforms)
dataloader = torch.utils.data.DataLoader(dataset, batch_size = 8, shuffle = False)

classes = get_classes()

for samp in dataloader:
    break
loss_fn = nn.CrossEntropyLoss()

In [None]:
fgs = FastGradientSign(resnet_pretrained, loss_fn, return_logits=True)

## Attack


In [None]:
_, org_pred, new_pred, outputs, new_outputs = fgs(samp['image'], samp['label'])

## How many miss classified


In [None]:
missclassified = (org_pred != new_pred)
torch.sum(missclassified).item(), missclassified

## Choose a sample

In [None]:
samp_idx = 4

## Logits of most confident class:


In [None]:
nn.functional.softmax(outputs, dim = 1)[samp_idx].max(), nn.functional.softmax(outputs, dim = 1)[samp_idx].argmax().item()

## Logits of most confident class AFTER ATTACK


In [None]:
nn.functional.softmax(new_outputs, dim = 1)[samp_idx].max(), nn.functional.softmax(new_outputs, dim = 1)[samp_idx].argmax().item()

## Logits of the old predicted class AFTER ATTACK

In [None]:
nn.functional.softmax(new_outputs, dim = 1)[samp_idx][org_pred[samp_idx]]

# Dev PGD

In [None]:
class ProjectedGradientDescent(nn.Module):

    def __init__(self, model, loss_fn, iterations = 100, device = None, epsilon = 0.25, return_logits = False, norm = 'inf'):
        super().__init__()
        """
        args:
            norm: "l2" or "inf"
        """

        if not device:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = device
        
        self.model = model
        self.loss_fn = loss_fn
        self.epsilon = epsilon
        self.return_logits = return_logits

        self.norm = norm.lower()

        self.iterations = iterations

        self.fgsm = FastGradientSign(model, loss_fn, device, epsilon, return_logits)

    def random_start(self, ball_center):
        if self.norm == 'l2':
            rand_init = torch.randn_like(ball_center)
            unit_init = F.normalize(rand_init.view(rand_init.size(0), -1)).view(rand_init.size())
            number_elements = torch.numel(ball_center)
            r = (torch.rand(rand_init.size(0)) ** (1.0 / number_elements)) * self.epsilon
            r = r[(...,) + (None,) * (r.dim() - 1)]
            move_away = r * unit_init
            return ball_center + move_away
        elif self.norm == 'inf':
            move_away = torch.rand_like(ball_center) * self.epsilon * 2 - self.epsilon
            return ball_center + move_away

    def forward(self, inputs, target, iterations = None):

        self.model.train(False)
        inputs.requires_grad = True
        inputs, target = inputs.to(self.device), target.to(self.device)

        self.model.zero_grad()
        if inputs.grad is not None:
            inputs.grad.zero_()

        perturbed_images = inputs.clone()

        outputs = self.model(perturbed_images)

        # Original prediction
        with torch.no_grad():
            original_preds = outputs.argmax(1)

        if iterations:
            num_iterations = iterations
        else:
            num_iterations = self.iterations

        for it in range(num_iterations):
            perturbed_image = self.fgsm.single_attack(perturbed_images, target)
            """
            loss = self.loss_fn(outputs, target)
            loss.backward()
            
            input_grad = inputs.grad

            update_grad = input_grad.sign()
            
            perturbed_images = perturbed_images + self.epsilon * update_grad

            perturbed_images

            self.model.zero_grad()
            inputs.grad.zero_()

            outputs = self.model(perturbed_images)
            """


        # outputs = self.model(perturbed_images)
        # New prediction
        with torch.no_grad():
            new_preds = outputs.argmax(1)
        
        if self.return_logits:
            return perturbed_images, original_preds, new_preds, outputs, new_outputs
        
        return perturbed_images, original_preds, new_preds

In [None]:
pgd = ProjectedGradientDescent(resnet_pretrained, loss_fn, return_logits=True)

In [None]:
pgd(samp['image'], samp['label'])