In [1]:
import os
import time
from shutil import copyfile
import json
from PIL import Image
import torch
from torchvision import transforms

from pytorch_pretrained_vit import ViT

In [2]:
# open txt file containing imagenet classes
with open('imagenet_1k_name.txt') as f:
    imagenet_classes = [line.strip() for line in f.readlines()]

In [3]:
PATH_MODEL = "/scratch/tnguyen10/L_16_imagenet1k.pth"
MODEL_NAME = "L_16_imagenet1k"

model = ViT(name=MODEL_NAME, weights_path=PATH_MODEL, pretrained=True)
model = model.cuda()
model.eval()

  state_dict = torch.load(weights_path)


Loaded pretrained weights.


ViT(
  (patch_embedding): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
  (positional_embedding): PositionalEmbedding1D()
  (transformer): Transformer(
    (blocks): ModuleList(
      (0-23): 24 x Block(
        (attn): MultiHeadedSelfAttention(
          (proj_q): Linear(in_features=1024, out_features=1024, bias=True)
          (proj_k): Linear(in_features=1024, out_features=1024, bias=True)
          (proj_v): Linear(in_features=1024, out_features=1024, bias=True)
          (drop): Dropout(p=0.1, inplace=False)
        )
        (proj): Linear(in_features=1024, out_features=1024, bias=True)
        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
        (pwff): PositionWiseFeedForward(
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features=4096, out_features=1024, bias=True)
        )
        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
        (drop): Dropout(p=0.1, inplace=False)
      

In [4]:
img = transforms.Compose([
    transforms.Resize((384, 384)),
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5),
])(Image.open("dog.png").convert("RGB"))
img = img.unsqueeze(0)
img = img.cuda()
print(img.shape) # torch.Size([1, 3, 384, 384])

with torch.no_grad():
    outputs = model(img)
outputs = outputs.squeeze(0)
predicted_class = imagenet_classes[outputs.argmax(0)]
print("Predicted class:", predicted_class)

predicted_prob = torch.nn.functional.softmax(outputs, dim=0)
print("Predicted probability:", predicted_prob.max().item())

torch.Size([1, 3, 384, 384])
Predicted class: golden retriever
Predicted probability: 0.9262539744377136


In [5]:
img = transforms.Compose([
    transforms.Resize((384, 384)),
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5),
])(Image.open("car.png").convert("RGB"))
img = img.unsqueeze(0)
img = img.cuda()

# Mesure inference time
# Warm-up
for _ in range(5):
    with torch.no_grad():
        _ = model(img)

n_iter = 50
start_time = time.time()
for _ in range(n_iter):
    with torch.no_grad():
        outputs = model(img)
    outputs = outputs.squeeze(0)
    predicted_class = imagenet_classes[outputs.argmax(0)]
end_time = time.time()
avg_inference_time = (end_time - start_time) / n_iter * 1000  # in milliseconds
print(f"Average inference time over {n_iter} iterations: {avg_inference_time:.2f} ms")
print("Predicted class:", predicted_class)

predicted_prob = torch.nn.functional.softmax(outputs, dim=0)
print("Predicted probability:", predicted_prob.max().item())

Average inference time over 50 iterations: 22.59 ms
Predicted class: sports car
Predicted probability: 0.9769683480262756


# 2. Evaluation on ImageNet

In [6]:
import torch
import torch.nn.functional as F
from datasets import load_dataset
from torchvision import transforms
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"
img_size = 384  # or 384 for your ViT
bs = 32

# Your ViT (1000-class head)
# model = ...
model = model.to(device).eval()

tfm = transforms.Compose([
    transforms.Lambda(lambda im: im.convert("RGB")),   # <— ensure 3 channels
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406),
                         (0.229, 0.224, 0.225)),
])
def batch_iter(stream, batch_size):
    imgs, labels = [], []
    for row in stream:
        img = row["image"]
        if not isinstance(img, Image.Image):
            img = Image.fromarray(img)
        imgs.append(tfm(img))
        labels.append(row["label"])
        if len(imgs) == batch_size:
            yield torch.stack(imgs), torch.tensor(labels, dtype=torch.long)
            imgs, labels = [], []
    if imgs:  # tail
        yield torch.stack(imgs), torch.tensor(labels, dtype=torch.long)

# Stream the validation split
ds_val = load_dataset("ILSVRC/imagenet-1k", split="validation", streaming=True)  # requires HF login + license
top1_hits = 0.0
top5_hits = 0.0
total = 0

idx = 0
max_idx = 20
with torch.no_grad():
    for idx, (xb, yb) in enumerate(batch_iter(ds_val, bs)):
        if idx > max_idx:
            break
        
        xb, yb = xb.to(device), yb.to(device)
        logits = model(xb)

        # Top-1/Top-5
        maxk = 5
        _, pred = logits.topk(maxk, 1, True, True)  # [B,5]
        correct = pred.eq(yb.view(-1,1).expand_as(pred))
        top1_hits += correct[:, :1].reshape(-1).float().sum().item()
        top5_hits += correct[:, :5].reshape(-1).float().sum().item()
        total += yb.size(0)

print(f"✅ ImageNet-1k (streamed) | Top-1: {100*top1_hits/total:.2f}% | Top-5: {100*top5_hits/total:.2f}%")

Resolving data files:   0%|          | 0/294 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/28 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/294 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/28 [00:00<?, ?it/s]

✅ ImageNet-1k (streamed) | Top-1: 80.21% | Top-5: 94.94%
