In [1]:
import os

import torch
import numpy as np

from urllib.request import urlretrieve

from PIL import Image
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, SequentialSampler
import pickle

import models.modeling as original
import models.modeling_bitlinear as binary


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
os.makedirs("attention_data", exist_ok=True)
#if not os.path.isfile("attention_data/ilsvrc2012_wordnet_lemmas.txt"):
#    urlretrieve("https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt", "attention_data/ilsvrc2012_wordnet_lemmas.txt")
#if not os.path.isfile("attention_data/ViT-B_16-224.npz"):
#    urlretrieve("https://storage.googleapis.com/vit_models/imagenet21k+imagenet2012/ViT-B_16-224.npz", "attention_data/ViT-B_16-224.npz")

#imagenet_labels = dict(enumerate(open('attention_data/ilsvrc2012_wordnet_lemmas.txt')))

In [8]:
# Download CIFAR-10 (if it doesn't exist) and create the test loader
batch_size = 4

transform_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

testset = datasets.CIFAR10(root="attention_data", train=False, download=True, transform=transform_test)
test_dataloader = DataLoader(testset, batch_size=batch_size, shuffle=False)

Files already downloaded and verified


In [7]:
# Get labels
with open("attention_data/cifar-10-batches-py/batches.meta", 'rb') as labels_names:
    cifar10_labels = pickle.load(labels_names, encoding='bytes')

print(cifar10_labels)

{b'num_cases_per_batch': 10000, b'label_names': [b'airplane', b'automobile', b'bird', b'cat', b'deer', b'dog', b'frog', b'horse', b'ship', b'truck'], b'num_vis': 3072}


In [5]:
# Base model used as starting point
#urlretrieve("https://storage.googleapis.com/vit_models/imagenet21k%2Bimagenet2012/ViT-B_16-224.npz", "attention_data/ViT-B_16.npz")

## Evluating Test Accuracy of Trained Models

+ Model weights and training logs can be found [here](https://drive.google.com/drive/u/2/folders/1o6YwUx-lfn0kfxaYC3NJm3I4hEBYuy2n)

In [9]:
# select device
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [10]:
def eval_model(model):
    model.eval()
    global_acc = 0
    for idx, batch in enumerate(test_dataloader):
        with torch.no_grad():
            images, labels = batch
            images = images.to(device)
            
            logits, _ = model(images)

            for idx, image_logits in enumerate(logits):
                probs = torch.nn.Softmax(dim=-1)(image_logits)
                sorted_probs = torch.argsort(probs, dim=-1, descending=True)
                if cifar10_labels[b'label_names'][sorted_probs[0].item()] == cifar10_labels[b'label_names'][labels[idx].item()]:
                    global_acc += 1
                    

    global_acc /= len(testset)
    print(global_acc)

In [8]:
# Baseline finetune run
checkpoint_path = "output/baseline_checkpoint.bin"
config = original.CONFIGS["ViT-B_16"]
baseline_model = original.VisionTransformer(config, num_classes=10, zero_head=False, img_size=224, vis=False).to(device)
checkpoint = torch.load(checkpoint_path)
baseline_model.load_state_dict(checkpoint)

baseline_model

VisionTransformer(
  (transformer): Transformer(
    (embeddings): Embeddings(
      (patch_embeddings): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): Encoder(
      (layer): ModuleList(
        (0-11): 12 x Block(
          (attention_norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (ffn_norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (ffn): Mlp(
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (attn): Attention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (out): Linear(in_features=768, out_features=768, bias=Tru

In [9]:
# CIFAR-10 Accuracy for baseline model
eval_model(baseline_model)

0.9904


In [11]:
# Binarized Run 10,000 iterations
checkpoint_path = "output/config1_checkpoint.bin"
config = binary.CONFIGS["ViT-B_16"]
bit_model1 = binary.VisionTransformer(config, num_classes=10, zero_head=False, img_size=224, vis=False).to(device)
checkpoint = torch.load(checkpoint_path)
bit_model1.load_state_dict(checkpoint)

print(bit_model1)

VisionTransformer(
  (transformer): Transformer(
    (embeddings): Embeddings(
      (patch_embeddings): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): Encoder(
      (layer): ModuleList(
        (0-11): 12 x Block(
          (attention_norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (ffn_norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (ffn): Mlp(
            (fc1): BitLinear(in_features=768, out_features=3072, bias=True)
            (fc2): BitLinear(in_features=3072, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (attn): Attention(
            (query): BitLinear(in_features=768, out_features=768, bias=True)
            (key): BitLinear(in_features=768, out_features=768, bias=True)
            (value): BitLinear(in_features=768, out_features=768, bias=True)
            (out): BitLinear(in_features=768, out_feat

In [12]:
# CIFAR-10 Accuracy for Binarized Run 10,000 iterations
eval_model(bit_model1)

0.3636


In [11]:
# Binarized Run 50,000 iterations
checkpoint_path = "output/config2_checkpoint.bin"
config = binary.CONFIGS["ViT-B_16"]
bit_model2 = binary.VisionTransformer(config, num_classes=10, zero_head=False, img_size=224, vis=False).to(device)
checkpoint = torch.load(checkpoint_path)
bit_model2.load_state_dict(checkpoint)

# CIFAR-10 Accuracy for Binarized Run 50,000 iterations
eval_model(bit_model2)

0.4362
