In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import os
import models.modeling as full_precision
from models.modeling_nbitlinear import VisionTransformer, CONFIGS
import torch
import numpy as np
import pandas as pd
import seaborn as sns
# from models.nbitlinear import NBitLinear, quant

from urllib.request import urlretrieve

import PIL
from PIL import Image
from torchvision import transforms, datasets
from torch.utils.data import DataLoader

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
!nvidia-smi

Wed Jun 12 13:09:50 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.78                 Driver Version: 550.78         CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA RTX A6000               Off |   00000000:01:00.0 Off |                  Off |
| 30%   28C    P8             22W /  300W |       1MiB /  49140MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [5]:
os.makedirs("attention_data", exist_ok=True)
if not os.path.isfile("attention_data/ilsvrc2012_wordnet_lemmas.txt"):
    urlretrieve("https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt", "attention_data/ilsvrc2012_wordnet_lemmas.txt")
    
imagenet_labels = dict(enumerate(open('attention_data/ilsvrc2012_wordnet_lemmas.txt')))

In [None]:
# # NOTE: run to download ViT pretrained-checkpoint 
# if not os.path.isfile("attention_data/ViT-B_16-224.npz"):
#     urlretrieve("https://storage.googleapis.com/vit_models/imagenet21k+imagenet2012/ViT-B_16-224.npz", "attention_data/ViT-B_16-224.npz")

# pretrained_path = 'attention_data/ViT-B_16-224.npz'

In [6]:
# NOTE: or set path to checkpoint appropriately
pretrained_path = '/fs/nexus-scratch/vla/ViT_pretrained_checkpoints/ViT-B_16-224.npz'

In [7]:
# transform = transforms.Compose([
#     transforms.Resize(size=256, interpolation=PIL.Image.BILINEAR),
#     transforms.CenterCrop(size=(224, 224)),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# ])

# batch_size = 128

# imagenet = datasets.ImageFolder(root="/fs/vulcan-datasets/imagenet/val", transform=transform)
# imagenet_loader = DataLoader(dataset=imagenet, batch_size=batch_size, shuffle=False, num_workers=1)


In [45]:
import torchvision.models.vision_transformer as vit

transform = vit.ViT_B_16_Weights.IMAGENET1K_V1.transforms()
transform

ImageClassification(
    crop_size=[224]
    resize_size=[256]
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    interpolation=InterpolationMode.BILINEAR
)

In [46]:
batch_size = 128

imagenet = datasets.ImageFolder(root="/fs/vulcan-datasets/imagenet/val", transform=transform)
imagenet_loader = DataLoader(dataset=imagenet, batch_size=batch_size, shuffle=False, num_workers=1)

In [48]:
def eval_model(model):
    global_acc = 0
    with torch.no_grad():
        for i, (images, labels) in enumerate(imagenet_loader):
            images = images.to(device)
            logits, _ = model(images)

            for idx, image_logits in enumerate(logits):
                probs = torch.nn.Softmax(dim=-1)(image_logits)
                sorted_probs = torch.argsort(probs, dim=-1, descending=True)
                        
                y_hat_index = sorted_probs[0].item()
                y_hat = imagenet_labels[y_hat_index]
                        
                y_index = labels[idx].item()
                y = imagenet_labels[y_index]
                        
                if y_hat == y:
                    global_acc += 1

    global_acc /= len(imagenet_loader.dataset)
    print(f"acc: {global_acc}")
    
    return global_acc

In [50]:
gather = []

## Full-Precision Baseline

In [51]:
config = full_precision.CONFIGS["ViT-B_16"]

baseline_model = full_precision.VisionTransformer(config, num_classes=1000, zero_head=False, img_size=224, vis=False).to(device)
baseline_model.load_from(np.load(pretrained_path))
baseline_model.to(device)

baseline_model

VisionTransformer(
  (transformer): Transformer(
    (embeddings): Embeddings(
      (patch_embeddings): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): Encoder(
      (layer): ModuleList(
        (0-11): 12 x Block(
          (attention_norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (ffn_norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (ffn): Mlp(
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (attn): Attention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (out): Linear(in_features=768, out_features=768, bias=Tru

In [52]:
acc = eval_model(baseline_model)
print(f'Top-1 acc:{acc}')
gather.append(('Baseline', acc))
gather

acc: 0.70924
Top-1 acc:0.70924


[('Baseline', 0.70924)]

## Absmax Quantized Models (PTQ)

In [5]:
def load_absmax_model(weight_bits=8, activation_bits=8):
    config = CONFIGS["ViT-B_16"]
    config['weight_bits'] = weight_bits
    config['activation_bits'] = activation_bits

    model = VisionTransformer(config, num_classes=1000, zero_head=False, img_size=224, vis=False).to(device)
    model.load_from(np.load(pretrained_path))
    model.to(device)
    
    return model

VisionTransformer(
  (transformer): Transformer(
    (embeddings): Embeddings(
      (patch_embeddings): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): Encoder(
      (layer): ModuleList(
        (0-11): 12 x Block(
          (attention_norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (ffn_norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (ffn): Mlp(
            (fc1): NBitLinear(in_features=768, out_features=3072, bias=True)
            (fc2): NBitLinear(in_features=3072, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (attn): Attention(
            (query): NBitLinear(in_features=768, out_features=768, bias=True)
            (key): NBitLinear(in_features=768, out_features=768, bias=True)
            (value): NBitLinear(in_features=768, out_features=768, bias=True)
            (out): NBitLinear(in_features=768, ou

In [None]:
# 8-bit run
model = load_absmax_model(8,8)
acc = eval_model(model)
print(f'Top-1 acc:{acc}')
gather.append(('8-bit', acc))

In [None]:
# 6-bit run
model = load_absmax_model(6,6)
acc = eval_model(model)
print(f'Top-1 acc:{acc}')
gather.append(('6-bit', acc))

In [None]:
# 4-bit run
model = load_absmax_model(4,4)
acc = eval_model(model)
print(f'Top-1 acc:{acc}')
gather.append(('4-bit', acc))

In [None]:
# 2-bit run
model = load_absmax_model(2,2)
acc = eval_model(model)
print(f'Top-1 acc:{acc}')
gather.append(('2-bit', acc))