# Description

In this notebook, I load and learn the architecture/weight of ViT model

In [1]:
import os 
import json
import time 
from PIL import Image
import torch
from torchvision import transforms

from pytorch_pretrained_vit.transformer_quantization import CustomLinear
from pytorch_pretrained_vit import ViT_Quantized
from pytorch_pretrained_vit.transformer_quantization import MultiHeadedSelfAttention

In [2]:
# open txt file containing imagenet classes
with open('imagenet_1k_name.txt') as f:
    imagenet_classes = [line.strip() for line in f.readlines()]

In [3]:
PATH_MODEL = "/scratch/tnguyen10/L_16_imagenet1k.pth"
MODEL_NAME = "L_16_imagenet1k"

model = ViT_Quantized(name=MODEL_NAME, weights_path=PATH_MODEL, pretrained=True)
model = model.cuda()
model.eval()

print(f"[INFO] Done loading model.")

  state_dict = torch.load(weights_path)


Loaded pretrained weights.
[INFO] Done loading model.


In [4]:
img = transforms.Compose([
    transforms.Resize((384, 384)),
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5),
])(Image.open("dog.png").convert("RGB"))
img = img.cuda()
img = img.unsqueeze(0)

with torch.no_grad():
    outputs = model(img)
outputs = outputs.squeeze(0)
predicted_class = imagenet_classes[outputs.argmax(0)]
print("Predicted class:", predicted_class)

predicted_prob = torch.nn.functional.softmax(outputs, dim=0)
print("Predicted probability:", predicted_prob.max().item())

Predicted class: golden retriever
Predicted probability: 0.9262540936470032


In [5]:
img = transforms.Compose([
    transforms.Resize((384, 384)),
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5),
])(Image.open("car.png").convert("RGB"))
img = img.unsqueeze(0)
img = img.cuda()

# Measure inference time
# Warm-up
for _ in range(5):
    with torch.no_grad():
        _ = model(img)

n_iter = 50
start_time = time.time()
for _ in range(n_iter):
    with torch.no_grad():
        outputs = model(img)
    outputs = outputs.squeeze(0)
    predicted_class = imagenet_classes[outputs.argmax(0)]
end_time = time.time()
avg_inference_time = (end_time - start_time) / n_iter * 1000  # in milliseconds
print(f"Average inference time over {n_iter} iterations: {avg_inference_time:.2f} ms")
print("Predicted class:", predicted_class)

predicted_prob = torch.nn.functional.softmax(outputs, dim=0)
print("Predicted probability:", predicted_prob.max().item())

Average inference time over 50 iterations: 26.25 ms
Predicted class: sports car
Predicted probability: 0.9769683480262756


# 1. Quantize model

In [6]:
for name, module in model.named_modules():
    if isinstance(module, CustomLinear):
        print(f"\nQuantize weights module: {name} ...")
        module.quantize_weights()
        print(f"Quantizate done : {name}.")


Quantize weights module: transformer.blocks.0.attn.proj_q ...
[INFO] Done quantize linear layer weights - shape torch.Size([1024, 1024])
Quantizate done : transformer.blocks.0.attn.proj_q.

Quantize weights module: transformer.blocks.0.attn.proj_k ...
[INFO] Done quantize linear layer weights - shape torch.Size([1024, 1024])
Quantizate done : transformer.blocks.0.attn.proj_k.

Quantize weights module: transformer.blocks.0.attn.proj_v ...
[INFO] Done quantize linear layer weights - shape torch.Size([1024, 1024])
Quantizate done : transformer.blocks.0.attn.proj_v.

Quantize weights module: transformer.blocks.0.pwff.fc1 ...
[INFO] Done quantize linear layer weights - shape torch.Size([4096, 1024])
Quantizate done : transformer.blocks.0.pwff.fc1.

Quantize weights module: transformer.blocks.0.pwff.fc2 ...
[INFO] Done quantize linear layer weights - shape torch.Size([1024, 4096])
Quantizate done : transformer.blocks.0.pwff.fc2.

Quantize weights module: transformer.blocks.1.attn.proj_q ...

In [7]:
img = transforms.Compose([
    transforms.Resize((384, 384)),
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5),
])(Image.open("dog.png").convert("RGB"))
img = img.unsqueeze(0)
img = img.cuda()

with torch.no_grad():
    outputs = model(img)
outputs = outputs.squeeze(0)
predicted_class = imagenet_classes[outputs.argmax(0)]
print("Predicted class:", predicted_class)

predicted_prob = torch.nn.functional.softmax(outputs, dim=0)
print("Predicted probability:", predicted_prob.max().item())

Predicted class: golden retriever
Predicted probability: 0.9246541857719421


In [8]:
img = transforms.Compose([
    transforms.Resize((384, 384)),
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5),
])(Image.open("car.png").convert("RGB"))
img = img.unsqueeze(0)
img = img.cuda()

# Mesure inference time
# Warm-up
for _ in range(5):
    with torch.no_grad():
        _ = model(img)

n_iter = 50
start_time = time.time()
for _ in range(n_iter):
    with torch.no_grad():
        outputs = model(img)
    outputs = outputs.squeeze(0)
    predicted_class = imagenet_classes[outputs.argmax(0)]
end_time = time.time()
avg_inference_time = (end_time - start_time) / n_iter * 1000  # in milliseconds
print(f"Average inference time over {n_iter} iterations: {avg_inference_time:.2f} ms")
print("Predicted class:", predicted_class)

predicted_prob = torch.nn.functional.softmax(outputs, dim=0)
print("Predicted probability:", predicted_prob.max().item())

Average inference time over 50 iterations: 12.69 ms
Predicted class: sports car
Predicted probability: 0.9715847969055176


# 2. Evaluate Image net

In [9]:
import torch
import torch.nn.functional as F
from datasets import load_dataset
from torchvision import transforms
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"
img_size = 384  # or 384 for your ViT
bs = 32

model = model.to(device).eval()

tfm = transforms.Compose([
    transforms.Lambda(lambda im: im.convert("RGB")),   # <— ensure 3 channels
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406),
                         (0.229, 0.224, 0.225)),
])
def batch_iter(stream, batch_size):
    imgs, labels = [], []
    for row in stream:
        img = row["image"]
        if not isinstance(img, Image.Image):
            img = Image.fromarray(img)
        imgs.append(tfm(img))
        labels.append(row["label"])
        if len(imgs) == batch_size:
            yield torch.stack(imgs), torch.tensor(labels, dtype=torch.long)
            imgs, labels = [], []
    if imgs:  # tail
        yield torch.stack(imgs), torch.tensor(labels, dtype=torch.long)

# Stream the validation split
ds_val = load_dataset("ILSVRC/imagenet-1k", split="validation", streaming=True)  # requires HF login + license
top1_hits = 0.0
top5_hits = 0.0
total = 0

idx = 0
max_idx = 20
with torch.no_grad():
    for idx, (xb, yb) in enumerate(batch_iter(ds_val, bs)):
        if idx > max_idx:
            break
        
        xb, yb = xb.to(device), yb.to(device)
        logits = model(xb)

        # Top-1/Top-5
        maxk = 5
        _, pred = logits.topk(maxk, 1, True, True)  # [B,5]
        correct = pred.eq(yb.view(-1,1).expand_as(pred))
        top1_hits += correct[:, :1].reshape(-1).float().sum().item()
        top5_hits += correct[:, :5].reshape(-1).float().sum().item()
        total += yb.size(0)

print(f"✅ ImageNet-1k (streamed) | Top-1: {100*top1_hits/total:.2f}% | Top-5: {100*top5_hits/total:.2f}%")

Resolving data files:   0%|          | 0/294 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/28 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/294 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/28 [00:00<?, ?it/s]

✅ ImageNet-1k (streamed) | Top-1: 80.51% | Top-5: 95.24%
