# Certification of Vision Transformers

In this notebook we will go over how to use the PyTorchSmoothedViT tool and be able to certify vision transformers against patch attacks!

### Overview

This was introduced in Certified Patch Robustness via Smoothed Vision Transformers (https://arxiv.org/abs/2110.07719). The core technique is one of *image ablations*, where the image is blanked out except for certain regions. By ablating the input in different ways every time we can obtain many predicitons for a single input. Now, as we are ablating large parts of the image the attacker's patch attack is also getting removed in many predictions. Based on factors like the size of the adversarial patch and the size of the retained part of the image the attacker will only be able to influence a limited number of predictions. In fact, if the attacker has a m x m patch attack and the retained part of the image is a column of width s then the maximum number of predictions that could be affected are: 

$
insert equation here
$

Based on this relationship we can derive a simple but effective criterion that if we are making many predictions for an image and the highest predicted class $c_t$ has been predicted $k$ times and the second most predicted class $c_{t-1}$ has been predicted $k_{t-1}$ times then we have a certified prediction if: 


$insert here$

Intuitivly we are saying that even if $k$ predictions were adversarially influenced and those predictions were to change, then the model will *still* have predicted class $c_t$.

### What's special about Vision Transformers?

The formulation above is very generic and it can be applied to any nerual network model, in fact the original paper which proposed it () considered the case with convolutional nerual networks. 

However, Vision Transformers (or ViTs) are well siuted to this task of predicting with vision ablations for two key reasons: 

+ ViTs first tokenize the input into a series of image regions which get embedded and then processed through the neural network. Thus, by considering the input as a set of tokens we can drop tokens which correspond to fully masked (i.e ablated)regions significantly saving on the compute costs. 

+ Secondly, the ViT's self attention layer enables sharing of information globally at every layer. In contrast convolutional neural networks build up the receptive field over a series of layers. Hence, ViTs can be more effective at classifying an image based on its small unablated regions.

Let's see how to use these tools!

In [None]:
# The core tool is PyTorchSmoothedViT which can be imported as follows:
from art.estimators.certification.smoothed_vision_transformers import PyTorchSmoothedViT

In [None]:
# There are a few ways we can interface with it. 
# The most direct way to get setup is by specifying the name of a supported transformer.
# Behind the scenes we are using the timm library (link: ) so any ViT supported by that libary will work.

art_model = PyTorchSmoothedViT(model='vit_small_patch16_224', # Name of the model acitecture to load
                               loss=torch.nn.CrossEntropyLoss(), # loss function to use
                               optimizer=torch.optim.SGD, # the optimizer to use: note! this is not initialised here we just supply the class!
                               optimizer_params={"lr": 0.01}, # the parameters to use
                               input_shape=(3, 32, 32), # the input shape of the data: Note! ...
                               nb_classes=10,
                               ablation_size=4,
                               load_pretrained=True)

Creating a PyTorchSmoothedViT instance with the above code follows many of the general ART patterns with two caveats: 
+ The optimizer would (normally) be supplied initialised into the estimator along with a pytorch model. However, here we have not yet created the model, we are just supplying the model architecture name. Hence, here we pass the class into PyTorchSmoothedViT with the keyword arguments in optimizer_params which you would normally use to initialise it.
+ The input shape...