In [1]:
import os

import torch
import numpy as np

from urllib.request import urlretrieve

from PIL import Image
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, SequentialSampler
import pickle

from models.modeling import VisionTransformer, CONFIGS

In [2]:
os.makedirs("attention_data", exist_ok=True)
#if not os.path.isfile("attention_data/ilsvrc2012_wordnet_lemmas.txt"):
#    urlretrieve("https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt", "attention_data/ilsvrc2012_wordnet_lemmas.txt")
#if not os.path.isfile("attention_data/ViT-B_16-224.npz"):
#    urlretrieve("https://storage.googleapis.com/vit_models/imagenet21k+imagenet2012/ViT-B_16-224.npz", "attention_data/ViT-B_16-224.npz")

#imagenet_labels = dict(enumerate(open('attention_data/ilsvrc2012_wordnet_lemmas.txt')))

In [3]:
# Download CIFAR-10 (if it doesn't exist) and create the test loader
batch_size = 4

transform_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

testset = datasets.CIFAR10(root="attention_data", train=False, download=True, transform=transform_test)
test_dataloader = DataLoader(testset, batch_size=batch_size, shuffle=False)

Files already downloaded and verified


In [19]:
# Get labels
with open("attention_data/cifar-10-batches-py/batches.meta", 'rb') as labels_names:
    cifar10_labels = pickle.load(labels_names, encoding='bytes')

print(cifar10_labels)

{b'num_cases_per_batch': 10000, b'label_names': [b'airplane', b'automobile', b'bird', b'cat', b'deer', b'dog', b'frog', b'horse', b'ship', b'truck'], b'num_vis': 3072}


In [5]:
#urlretrieve("https://storage.googleapis.com/vit_models/imagenet21k%2Bimagenet2012/ViT-B_16-224.npz", "attention_data/ViT-B_16.npz")

In [25]:
# Prepare model
if (torch.cuda.is_available()):
    device = "cuda"
else:
    device = "cpu"

checkpoint_path = "output/test_checkpoint.bin"

config = CONFIGS["ViT-B_16"]
model = VisionTransformer(config, num_classes=10, zero_head=False, img_size=224, vis=False).to(device)
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint)
model.eval()

VisionTransformer(
  (transformer): Transformer(
    (embeddings): Embeddings(
      (patch_embeddings): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): Encoder(
      (layer): ModuleList(
        (0-11): 12 x Block(
          (attention_norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (ffn_norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (ffn): Mlp(
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (attn): Attention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (out): Linear(in_features=768, out_features=768, bias=Tru

In [30]:
global_acc = 0

for idx, batch in enumerate(test_dataloader):
    with torch.no_grad():
        images, labels = batch
        images = images.to(device)
        
        logits, _ = model(images)

        for idx, image_logits in enumerate(logits):
            probs = torch.nn.Softmax(dim=-1)(image_logits)
            sorted_probs = torch.argsort(probs, dim=-1, descending=True)
            if cifar10_labels[b'label_names'][sorted_probs[0].item()] == cifar10_labels[b'label_names'][labels[idx].item()]:
                global_acc += 1
                

global_acc /= len(testset)
print(global_acc)

0.9798
