# Fast Gradient Sign Attack on some ImageNet samples on a Trained ResNet-18

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from torchvision import transforms

from src.data.ImageNet300 import ImageNet300Dataset
from src.utils.getimagenetclasses import get_classes
from src.attacks.attacks import FastGradientSign, ProjectedGradientDescent
from src.explainability.GradCam import GradCam
from src.utils.ImageDisplayerGradCam import ImageDisplayerGradCam

In [None]:
resnet_pretrained = models.resnet18(pretrained = True)
resnet_pretrained.eval()
pass

root_dir = r"C:\Users\willi\Documents\in5400\mand1\prelimcode\students\data\imagenetval300imgs\imagenet300"
xmllabeldir = r"C:\Users\willi\Documents\in5400\mand1\prelimcode\students\data\imagenetval300imgs\val"
synsetfile = r"C:\Users\willi\Documents\in5400\mand1\prelimcode\students\synset_words.txt"

image_transforms =  transforms.Compose([
          transforms.Resize((256,256)),
          transforms.ToTensor(),
          transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])


dataset = ImageNet300Dataset(root_dir, xmllabeldir, synsetfile, 300, image_transforms)
dataloader = torch.utils.data.DataLoader(dataset, batch_size = 8, shuffle = False)

classes = get_classes()

for samp in dataloader:
    break
loss_fn = nn.CrossEntropyLoss()

# Fast Gradient Sign Method

In [None]:
fgs = FastGradientSign(resnet_pretrained, loss_fn, return_logits=True)

## Attack


In [None]:
perturbed_images, org_pred, new_pred, outputs, new_outputs = fgs(samp['image'], samp['label'])

## How many miss classified


In [None]:
new_classification = (org_pred != new_pred)
torch.sum(new_classification).item(), new_classification

In [None]:
missclassified = (org_pred != samp['label'])
torch.sum(missclassified).item(), missclassified

## Choose a sample

In [None]:
samp_idx = 2

## Logits of most confident class:


In [None]:
nn.functional.softmax(outputs, dim = 1)[samp_idx].max(), nn.functional.softmax(outputs, dim = 1)[samp_idx].argmax().item()

## Logits of most confident class AFTER ATTACK


In [None]:
nn.functional.softmax(new_outputs, dim = 1)[samp_idx].max(), nn.functional.softmax(new_outputs, dim = 1)[samp_idx].argmax().item()

## Logits of the old predicted class AFTER ATTACK

In [None]:
nn.functional.softmax(new_outputs, dim = 1)[samp_idx][org_pred[samp_idx]]

## Grad-CAM:

In [None]:
target_layer = resnet_pretrained.layer4[-1].conv2
cam = GradCam(resnet_pretrained, target_layer, 10, multi_label = False)

classes = get_classes()

image_dispalyer = ImageDisplayerGradCam(resnet_pretrained, 
        cam, 
        classes,
        reshape = transforms.Resize((256,256)), 
        multi_label = False, 
        image_dir = 'image_net_dir',
        pdf = False)

### Perturbed

In [None]:
perturbed_sample = {'image':perturbed_images[samp_idx], 'label': samp['label'][samp_idx], 'filename':samp['filename'][samp_idx]}

In [None]:
image_dispalyer.display_images(perturbed_sample, display_labels_or_predictions = True)
image_dispalyer.display_images(perturbed_sample, display_labels_or_predictions = False)

### Non-perturbed:

In [None]:
normal_sample = {'image':samp['image'][samp_idx], 'label': samp['label'][samp_idx], 'filename':samp['filename'][samp_idx]}

In [None]:
image_dispalyer.display_images(normal_sample, display_labels_or_predictions = True)
image_dispalyer.display_images(normal_sample, display_labels_or_predictions = False)

# Projected Gradient Descent Attack

In [None]:
pgd = ProjectedGradientDescent(resnet_pretrained, loss_fn, return_logits=True)

In [None]:
perturbed_images, original_preds, new_preds, outputs, new_outputs = pgd(samp['image'], samp['label'])

In [None]:
new_classification = (org_pred != new_pred)
torch.sum(new_classification).item(), new_classification

In [None]:
missclassified = (org_pred != samp['label'])
torch.sum(missclassified).item(), missclassified

In [None]:
samp_idx = 0

## Logits of most confident class:


In [None]:
nn.functional.softmax(outputs, dim = 1)[samp_idx].max(), nn.functional.softmax(outputs, dim = 1)[samp_idx].argmax().item()

## Logits of most confident class AFTER ATTACK


In [None]:
nn.functional.softmax(new_outputs, dim = 1)[samp_idx].max(), nn.functional.softmax(new_outputs, dim = 1)[samp_idx].argmax().item()

## Logits of the old predicted class AFTER ATTACK

In [None]:
nn.functional.softmax(new_outputs, dim = 1)[samp_idx][original_preds[samp_idx]]

## Grad-CAM

In [None]:
target_layer = resnet_pretrained.layer4[-1].conv2
cam = GradCam(resnet_pretrained, target_layer, 10, multi_label = False)

classes = get_classes()

image_dispalyer = ImageDisplayerGradCam(resnet_pretrained, 
        cam, 
        classes,
        reshape = transforms.Resize((256,256)), 
        multi_label = False, 
        image_dir = 'image_net_dir',
        pdf = False)

## Perturbed

In [None]:
perturbed_sample = {'image':perturbed_images[samp_idx], 'label': samp['label'][samp_idx], 'filename':samp['filename'][samp_idx]}

In [None]:
image_dispalyer.display_images(perturbed_sample, display_labels_or_predictions = True)
image_dispalyer.display_images(perturbed_sample, display_labels_or_predictions = False)

## Non-perturbed

In [None]:
normal_sample = {'image':samp['image'][samp_idx], 'label': samp['label'][samp_idx], 'filename':samp['filename'][samp_idx]}

In [None]:
image_dispalyer.display_images(normal_sample, display_labels_or_predictions = True)
image_dispalyer.display_images(normal_sample, display_labels_or_predictions = False)