# Pytorch ViT

Taken from the huggingface `timm` collection. This is supposedly a perfectly accurate translation of the jax ViT implementation into pytorch.

TODO: Verify the parity of the implementation

## Setup

Once per machine, we need to clone the appropriate repository and install the necessary dependencies.

In this case we simply clone the entire image models repo.

In [35]:
# experiment params
num_train_trials = 5
num_train_warmups = 1
num_jit_trials = 10
num_jit_warmups = 2
num_inference_trials = 100
num_inference_warmups = 10

In [1]:
![ -d pytorch-image-models] || git clone https://github.com/huggingface/pytorch-image-models.git

zsh:[:1: ']' expected
Cloning into 'pytorch-image-models'...
remote: Enumerating objects: 17961, done.[K
remote: Counting objects: 100% (1914/1914), done.[K
remote: Compressing objects: 100% (764/764), done.[K
remote: Total 17961 (delta 1366), reused 1420 (delta 1132), pack-reused 16047 (from 1)[K
Receiving objects: 100% (17961/17961), 26.53 MiB | 5.31 MiB/s, done.
Resolving deltas: 100% (13165/13165), done.


In [6]:
!pip install timm
!pip install torch
!pip install torchvision
!pip install pytest
!pip install matplotlib
!pip install tqdm



## Example Usage

From the timm package, we can load a pretrained ViT model and perform inference using it.

In [30]:
from urllib.request import urlopen
from PIL import Image
import timm
import torch

img = Image.open(urlopen(
    'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
))

model = timm.create_model('vit_base_patch16_clip_384.laion2b_ft_in12k_in1k', pretrained=True)
model = model.eval()

# get model specific transforms (normalization, resize)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

output = model(transforms(img).unsqueeze(0))  # unsqueeze single image into batch of 1

top5_probabilities, top5_class_indices = torch.topk(output.softmax(dim=1) * 100, k=5)

print(top5_probabilities)

tensor([[38.2581, 17.1601,  4.3014,  3.6252,  3.1480]],
       grad_fn=<TopkBackward0>)


## Setup our environment

Since we need to actually time the jitting of the VisionTransformer model, we cannot simply rely on the hugging face implementation and thus need to dig into the repo in order to get the accurate timing information

In [12]:
import sys

# Add the timm repo to the system path
sys.path.append("./pytorch-image-models/timm")

# check that we can actually import the ViT class
from timm.models.vision_transformer import VisionTransformer

## Timing the ViT compilation

We perform the same measurement as we did on the JAX ViT, compiling the model a given number of times and taking the average of the last n runs to give a good result.

In [27]:
import tqdm
import time

vit = VisionTransformer()

# TODO: Get dan's help, something isn't right with this measurement

def time_jit(num_jit_runs: int, num_jit_warmups: int) -> float:
    times = []
    for i in tqdm.trange(1, num_jit_runs + 1):
        start_jit_time = time.time()
        
        jitted = torch.nn.Module.compile(vit)

        end_jit_time = time.time()
        torch.compiler.reset()
        if i >= num_jit_warmups:
            times.append(end_jit_time - start_jit_time)

    return sum(times) / len(times)

average_jit_time = time_jit(num_jit_trials, num_jit_warmups)

print(average_jit_time)

100%|███████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 5754.29it/s]

0.00013751453823513456





## Timing ViT Inference

Same measurement as in the JAX ViT

In [38]:
import os
import PIL

!mkdir -p inference

def filecount(dir: str) -> int:
    file_count = 0
    for entry in os.scandir(dir):
        if entry.is_file():
            file_count += 1
    return file_count

if not (filecount("inference") >= num_inference_trials):
    resolution = 384  # parity with the JAX version tested
    for index in range(1, num_inference_trials + 1):
        output = f"picsum{index}.jpg"
        !wget https://picsum.photos/$resolution -O inference/$output

--2024-09-14 02:12:36--  https://picsum.photos/384
Resolving picsum.photos (picsum.photos)... 172.67.74.163, 104.26.5.30, 104.26.4.30, ...
Connecting to picsum.photos (picsum.photos)|172.67.74.163|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://fastly.picsum.photos/id/1016/384/384.jpg?hmac=EwGRdv9asR_VdNErHYmmGv-Tin7XvxLhtej3E6Ht5x8 [following]
--2024-09-14 02:12:37--  https://fastly.picsum.photos/id/1016/384/384.jpg?hmac=EwGRdv9asR_VdNErHYmmGv-Tin7XvxLhtej3E6Ht5x8
Resolving fastly.picsum.photos (fastly.picsum.photos)... 151.101.193.91, 151.101.65.91, 151.101.1.91, ...
Connecting to fastly.picsum.photos (fastly.picsum.photos)|151.101.193.91|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 28927 (28K) [image/jpeg]
Saving to: ‘inference/picsum1.jpg’


2024-09-14 02:12:38 (4.59 MB/s) - ‘inference/picsum1.jpg’ saved [28927/28927]

--2024-09-14 02:12:38--  https://picsum.photos/384
Resolving picsum.photos (picsum.photos)... 172

In [41]:
track = 0
inference_times = []
for i in tqdm.trange(1, num_inference_trials + 1):
    img = PIL.Image.open(f"inference/picsum{i}.jpg")
    start_inference = time.time()
    
    # note that this is not saturating the GPU, a larger batch size would be better
    output = model(transforms(img).unsqueeze(0))  # unsqueeze single image into batch of 1

    top5_probabilities, top5_class_indices = torch.topk(output.softmax(dim=1) * 100, k=5)

    end_inference = time.time()
    if track >= num_inference_warmups:
        inference_times.append(end_inference - start_inference)
    track += 1

average_inference_time = sum(inference_times) / len(inference_times)
# print(inference_times)
print(f"Average inference time: {average_inference_time}")

100%|███████████████████████████████████████████████████████████████████████████████| 100/100 [00:15<00:00,  6.56it/s]

Average inference time: 0.15167298052046035





# Time the fine tuning

Same measurement as in the JAX system

In [None]:
#TODO