In [None]:
import os 
from tqdm import tqdm
from PIL import Image
from transformers import ViTImageProcessor, ViTForImageClassification

from torch.utils.data import DataLoader
from custom_dataset import ImageNet1K
from dataset.classes import IMAGENET2012_CLASSES

In [None]:
processor = ViTImageProcessor.from_pretrained('./vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('./vit-base-patch16-224')

model = model.to('cuda')
model.eval()
print('Model loaded!')

In [None]:
# calculate param
param = sum(p.numel() for p in model.parameters())
print(f'Param: {param}')

In [4]:
val_dataset = ImageNet1K(image_path='./dataset/val_data/', labels=IMAGENET2012_CLASSES, transform=processor)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

In [None]:
for batch in val_loader:
    print(batch[0]['pixel_values'].shape)
    print(batch[1])
    break 

In [None]:
model(batch[0]['pixel_values'].squeeze(1).to('cuda')).logits

In [None]:
# test the accuracy of vision transformer
accurate = 0
count = 0

for i, batch in enumerate(val_loader):
    image = batch[0]['pixel_values'].squeeze(1).to('cuda')
    label = batch[1].to('cuda')
    pred = model(image).logits.argmax(dim=1)
    
    accurate += (pred == label).sum()
    count += image.shape[0]
    
    if i and i % 20 == 0:
        print(f'step {i}/ {len(val_loader)}, accuracy: {accurate/count}')