In [None]:
!pip install timm

In [None]:
import os
import json
from tqdm.notebook import tnrange, tqdm
from IPython.display import clear_output
import matplotlib.pyplot as plt
import numpy as np
import PIL
from PIL import Image
import cv2


import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
import torch.nn as nn

from timm import create_model

device = torch.device("cpu") if not torch.cuda.is_available() else torch.device("cuda:0")
print("Using device", device)

In [3]:
model_name = "vit_base_patch16_224"
# create a ViT model : https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
model = create_model(model_name, pretrained=True).to(device)

In [4]:
# Define transforms for test
IMG_SIZE = (224, 224)
NORMALIZE_MEAN = (0.5, 0.5, 0.5)
NORMALIZE_STD = (0.5, 0.5, 0.5)
transforms = [
              T.Resize(IMG_SIZE),
              T.ToTensor(),
              T.Normalize(NORMALIZE_MEAN, NORMALIZE_STD)
              ]

transforms = T.Compose(transforms)

In [42]:
%%capture
# ImageNet Labels
!wget https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt
imagenet_labels = dict(enumerate(open('ilsvrc2012_wordnet_lemmas.txt')))
short_labels = np.empty([1000], dtype=object)
for i in range(len(imagenet_labels.values())):
    print(imagenet_labels[i])
    left_text = imagenet_labels[i].partition(",")[0]
    no_newline_label = left_text.partition("\n")[0]
    short_labels[i]=no_newline_label

# Demo Image
!wget https://github.com/hirotomusiker/schwert_colab_data_storage/blob/master/images/vit_demo/santorini.png?raw=true -O santorini.png
!wget https://upload.wikimedia.org/wikipedia/commons/thumb/8/85/Senior_Airman_Benton_Pohlman_fires_an_M4_carbine_rifle_%2834135723246%29.jpg/1200px-Senior_Airman_Benton_Pohlman_fires_an_M4_carbine_rifle_%2834135723246%29.jpg -O assault_rifle.jpg
img = PIL.Image.open('santorini.png')
img_tensor = transforms(img).unsqueeze(0).to(device)


# #quick test
# output = model(img_tensor)
# print(imagenet_labels[int(torch.argmax(output))])

In [44]:
TENSOR_MEANS, TENSOR_STD = torch.FloatTensor(NORMALIZE_MEAN)[:,None,None], torch.FloatTensor(NORMALIZE_STD)[:,None,None]

def patch_forward(patch):
    # Map patch values from [-infty,infty] to min and max
    patch = (torch.tanh(patch) + 1 - 2 * TENSOR_MEANS) / (2 * TENSOR_STD)
    return patch

def place_patch(img, patch, coordinates=None):
    for i in range(img.shape[0]):
        temp = img[i]

        if coordinates:
            h_offset = coordinates[0]
            w_offset = coordinates[1]
        else:
            h_offset = np.random.randint(0,img.shape[2]-patch.shape[1]-1)
            w_offset = np.random.randint(0,img.shape[3]-patch.shape[2]-1)
            
        temp[:, h_offset:h_offset+patch.shape[1],w_offset:w_offset+patch.shape[2]] = patch_forward(patch)
        img[i] = temp
    return img


In [45]:
def patch_attack(input_image_path, input_torch_transforms, model, target_class, patch_size, num_epochs, learning_rate, pretrained_patch_path):
    
    # Create parameter and optimizer
    patch_size = (patch_size, patch_size)
    if pretrained_patch_path is None:
        patch = nn.Parameter(torch.zeros(3, patch_size[0], patch_size[1]), requires_grad=True)
    else:
        patch_load = torch.load(pretrained_patch_path)
        patch = nn.Parameter(torch.tensor(patch_load), requires_grad=True)

    optimizer = torch.optim.SGD([patch], lr=learning_rate, momentum=0.8)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[300, 400, 450, 500, 575], gamma=0.4)
    loss_module = nn.CrossEntropyLoss()
    losses=[]

    input_pil_image = Image.open(input_image_path)
    input_tensor = input_torch_transforms(input_pil_image)
    input_tensor = torch.transpose(input_tensor, 2, 1)

    # print(output[0][target_class])

    # # Training loop
    for epoch in tnrange(num_epochs):

        batch =  (input_tensor.unsqueeze(0)).repeat_interleave(8, dim = 0)
        patched_batch = place_patch(batch, patch)
        pred = model(patched_batch.to(device))
        labels = torch.zeros(patched_batch.shape[0], device=pred.device, dtype=torch.long).fill_(target_class)
        loss = loss_module(pred, labels)
        optimizer.zero_grad()
        loss.mean().backward()
        optimizer.step()
        losses.append(loss.mean().cpu().detach().numpy())
        plt.plot(losses)
        plt.xlabel("epochs")
        plt.ylabel("loss")
        plt.show()
        clear_output(wait=True)

    patch_temp=patch.detach().cpu().permute(2, 1, 0)
    patch_temp = (torch.tanh(patch_temp) + 1) / 2
    patch_temp = np.clip(patch_temp, a_min=0.0, a_max=1.0)
    plt.imshow(patch_temp)
    plt.grid(False)
    plt.show()
    print("op/ on target class <beta>", pred[0][target_class])
    return(patch)

In [None]:
plt.rcParams["figure.figsize"] = (5,5)
final_patch_tensor = patch_attack(input_image_path='assault_rifle.jpg',
                                  input_torch_transforms = transforms,
                                  model = model, 
                                  target_class = 605,
                                  patch_size = 40,
                                  num_epochs = 5,
                                  learning_rate=2e-1,
                                  pretrained_patch_path= "patch_class605_size40.pt")

In [None]:
torch.save(final_patch_tensor, "patch_class605_size40.pt")

In [47]:
def NormalizeData(data):
    return (data - np.min(data)) / (np.max(data) - np.min(data))

def attack_image(input_tensor, patch_tensor, coordinates=None):
    batch =  (input_tensor.unsqueeze(0)).repeat_interleave(1, dim = 0)
    patched_batch = place_patch(batch, patch_tensor, coordinates)
    return patched_batch[0]

In [48]:
def top_5_predictions(input_image_path, resize_input_image, patch_tensor, patch_coordinates, model, label_names, plot_save_path):

    input_pil_image = Image.open(input_image_path)
    if resize_input_image:
        input_pil_image = input_pil_image.resize((224,224), Image.CUBIC)
    tensor_transform = T.ToTensor()
    input_tensor = tensor_transform(input_pil_image)
    input_tensor = torch.transpose(input_tensor, 2, 1)

    input_show = np.empty_like(input_tensor.permute(2,1,0))
    input_show[:] = input_tensor.permute(2,1,0)

    fig, ax = plt.subplots(nrows=2, ncols=2)
    ax[0][0].imshow(input_show)
    
    ax[0][0].set_xlabel('Input', size=13)
    ax[0][0].axis('on')

    attacked = attack_image(input_tensor, patch_tensor, patch_coordinates)

    ax[1][0].imshow(attacked.permute(2,1,0))
    ax[1][0].set_xlabel('Adversarial Input'+" [patch (x="+str(patch_coordinates[0])+" y="+str(patch_coordinates[1])+")]", size=13)
    ax[1][0].axis('on')

    input_1 = input_tensor.unsqueeze(0).type(torch.cuda.FloatTensor).to(device)
    input_2 = attacked.unsqueeze(0).type(torch.cuda.FloatTensor).to(device)

    with torch.no_grad():
        results_1 = torch.softmax(model(input_1), dim=-1).cpu()[0]
        results_2 = torch.softmax(model(input_2), dim=-1).cpu()[0]
        values_1, indices_1 = results_1.topk(5)
        values_2, indices_2 = results_2.topk(5)

    argmax_1, argmax_2 = torch.argmax(results_1), torch.argmax(results_2)
    names_1, names_2 = label_names[indices_1.cpu().numpy()], label_names[indices_2.cpu().numpy()]

    ax[0][1].barh([i for i in range(5)], torch.flip(values_1*100, [0]), color=['navy', 'navy', 'navy','navy', 'green'])
    ax[0][1].set_xlim([0, 100])
    ax[0][1].set_yticks([i for i in range(5)])
    ax[0][1].set_yticklabels(names_1[::-1], rotation=0, size=13)
    # ax[0][1].set_ylabel('Classifier Output')
    ax[0][1].set_xlabel('Confidence', size=13)
    ax[0][1].grid()
    
    ax[1][1].barh([i for i in range(5)], torch.flip(values_2*100, [0]), color=['navy', 'navy', 'navy','navy', 'red'])
    ax[1][1].set_xlim([0, 100])
    ax[1][1].set_yticks([i for i in range(5)])
    ax[1][1].set_yticklabels(names_2[::-1], rotation=0, size=13)
    # ax[1][1].set_ylabel('Classifier Output')
    ax[1][1].set_xlabel('Confidence', size=13)
    ax[1][1].grid()

    if plot_save_path:
        plt.savefig(plot_save_path, bbox_inches='tight')


In [None]:
plt.rcParams["figure.figsize"] = (19,10)
top_5_predictions(input_image_path = 'assault_rifle.jpg', 
                  resize_input_image = True, 
                  patch_tensor = final_patch_tensor.detach(), 
                  patch_coordinates = (184-32, 184-32), # set None for random coordinates
                  model = model, 
                  label_names = short_labels,
                  plot_save_path="foo2.png")

In [50]:
!mkdir test11

In [52]:
%%capture
n=0
for y in range(16, 224-32, 16):
    clear_output(wait=True)
    for x in range(16, 224-32, 16):
        top_5_predictions(input_image_path = 'assault_rifle.jpg', 
                  resize_input_image = True, 
                  patch_tensor = final_patch_tensor.detach(), 
                  patch_coordinates = (x, y), # set None for random coordinates
                  model = model, 
                  label_names = short_labels,
                  plot_save_path="/content/test11/"+str(n)+".png"
        )
        n+=1

In [None]:
!zip -r /content/animation9.zip /content/test11