In [14]:
import os
from models.modeling_nbitlinear import VisionTransformer, CONFIGS
import torch
import numpy as np
from models.nbitlinear import NBitLinear, quant

from urllib.request import urlretrieve

import PIL
from PIL import Image
from torchvision import transforms, datasets
from torch.utils.data import DataLoader

In [9]:
!nvidia-smi

Tue Jun 11 20:23:45 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.78                 Driver Version: 550.78         CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA RTX A4000               Off |   00000000:41:00.0 Off |                  Off |
| 41%   36C    P2             36W /  140W |     546MiB /  16376MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [10]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [11]:
os.makedirs("attention_data", exist_ok=True)
if not os.path.isfile("attention_data/ilsvrc2012_wordnet_lemmas.txt"):
    urlretrieve("https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt", "attention_data/ilsvrc2012_wordnet_lemmas.txt")
if not os.path.isfile("attention_data/ViT-B_16-224.npz"):
    urlretrieve("https://storage.googleapis.com/vit_models/imagenet21k+imagenet2012/ViT-B_16-224.npz", "attention_data/ViT-B_16-224.npz")

imagenet_labels = dict(enumerate(open('attention_data/ilsvrc2012_wordnet_lemmas.txt')))

In [12]:
pretrained_path = 'attention_data/ViT-B_16-224.npz'

config = CONFIGS["ViT-B_16"]
config['weight_bits'] = 2
config['activation_bits'] = 2

model = VisionTransformer(config, num_classes=1000, zero_head=False, img_size=224, vis=False).to(device)
model.load_from(np.load(pretrained_path))
model.to(device)

VisionTransformer(
  (transformer): Transformer(
    (embeddings): Embeddings(
      (patch_embeddings): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): Encoder(
      (layer): ModuleList(
        (0-11): 12 x Block(
          (attention_norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (ffn_norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (ffn): Mlp(
            (fc1): NBitLinear(in_features=768, out_features=3072, bias=True)
            (fc2): NBitLinear(in_features=3072, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (attn): Attention(
            (query): NBitLinear(in_features=768, out_features=768, bias=True)
            (key): NBitLinear(in_features=768, out_features=768, bias=True)
            (value): NBitLinear(in_features=768, out_features=768, bias=True)
            (out): NBitLinear(in_features=768, ou

In [18]:
transform = transforms.Compose([
    transforms.Resize(size=256, interpolation=PIL.Image.BILINEAR),
    transforms.CenterCrop(size=(224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

batch_size = 64

imagenet = datasets.ImageFolder(root="/fs/vulcan-datasets/imagenet/val", transform=transform)
imagenet_loader = DataLoader(dataset=imagenet, batch_size=batch_size, shuffle=False, num_workers=1)

In [27]:
global_acc = 0
with torch.no_grad():
    for i, (images, labels) in enumerate(imagenet_loader):
        images = images.to(device)
        logits, _ = model(images)

        for idx, image_logits in enumerate(logits):
            probs = torch.nn.Softmax(dim=-1)(image_logits)
            sorted_probs = torch.argsort(probs, dim=-1, descending=True)
                    
            y_hat_index = sorted_probs[0].item()
            y_hat = imagenet_labels[y_hat_index]
                    
            y_index = labels[idx].item()
            y = imagenet_labels[y_index]
                    
            if y_hat == y:
                global_acc += 1

global_acc /= len(imagenet_loader.dataset)
print(f"acc: {global_acc}")

acc: 0.001
