In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import os
import models.modeling as full_precision
from models.modeling_nbitlinear_dynamic import VisionTransformer, CONFIGS
import torch
import numpy as np
import pandas as pd
import seaborn as sns
# from models.nbitlinear import NBitLinear, quant

from urllib.request import urlretrieve

import PIL
from PIL import Image
from torchvision import transforms, datasets
from torch.utils.data import DataLoader

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
!nvidia-smi

Fri Jun 14 00:16:53 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.78                 Driver Version: 550.78         CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA RTX A6000               Off |   00000000:C1:00.0 Off |                  Off |
| 30%   25C    P8             17W /  300W |       1MiB /  49140MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [5]:
os.makedirs("attention_data", exist_ok=True)
if not os.path.isfile("attention_data/ilsvrc2012_wordnet_lemmas.txt"):
    urlretrieve("https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt", "attention_data/ilsvrc2012_wordnet_lemmas.txt")
    
imagenet_labels = dict(enumerate(open('attention_data/ilsvrc2012_wordnet_lemmas.txt')))

In [6]:
# # NOTE: run to download ViT pretrained-checkpoint 
# if not os.path.isfile("attention_data/ViT-B_16-224.npz"):
#     urlretrieve("https://storage.googleapis.com/vit_models/imagenet21k+imagenet2012/ViT-B_16-224.npz", "attention_data/ViT-B_16-224.npz")

# pretrained_path = 'attention_data/ViT-B_16-224.npz'

In [7]:
# NOTE: or set path to checkpoint appropriately
pretrained_path = '/fs/nexus-scratch/vla/ViT_pretrained_checkpoints/ViT-B_16-224.npz'

In [8]:
# transform = transforms.Compose([
#     transforms.Resize(size=256, interpolation=PIL.Image.BILINEAR),
#     transforms.CenterCrop(size=(224, 224)),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# ])

# batch_size = 128

# imagenet = datasets.ImageFolder(root="/fs/vulcan-datasets/imagenet/val", transform=transform)
# imagenet_loader = DataLoader(dataset=imagenet, batch_size=batch_size, shuffle=False, num_workers=1)


In [9]:
import torchvision.models.vision_transformer as vit

transform = vit.ViT_B_16_Weights.IMAGENET1K_V1.transforms()
transform

ImageClassification(
    crop_size=[224]
    resize_size=[256]
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    interpolation=InterpolationMode.BILINEAR
)

In [10]:
batch_size = 512

imagenet = datasets.ImageFolder(root="/fs/vulcan-datasets/imagenet/val", transform=transform)
imagenet_loader = DataLoader(dataset=imagenet, batch_size=batch_size, shuffle=False, num_workers=1)

In [11]:
def eval_model(model):
    
    model.eval()
    
    global_acc = 0
    with torch.no_grad():
        for i, (images, labels) in enumerate(imagenet_loader):
            images = images.to(device)
            logits, _ = model(images)
            # logits = model(images)
            

            for idx, image_logits in enumerate(logits):
                probs = torch.nn.Softmax(dim=-1)(image_logits)
                sorted_probs = torch.argsort(probs, dim=-1, descending=True)
                        
                y_hat_index = sorted_probs[0].item()
                y_hat = imagenet_labels[y_hat_index]
                
                y_index = labels[idx].item()
                y = imagenet_labels[y_index]
                        
                if y_hat == y:
                    global_acc += 1
                
    global_acc /= len(imagenet_loader.dataset)
    print(f"acc: {global_acc}")
    torch.cuda.empty_cache()
    
    return global_acc

In [12]:
gather = []

## Full-Precision Baseline

In [13]:
config = full_precision.CONFIGS["ViT-B_16"]

baseline_model = full_precision.VisionTransformer(config, num_classes=1000, zero_head=False, img_size=224, vis=False).to(device)
baseline_model.load_from(np.load(pretrained_path))
baseline_model.to(device)

baseline_model

VisionTransformer(
  (transformer): Transformer(
    (embeddings): Embeddings(
      (patch_embeddings): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): Encoder(
      (layer): ModuleList(
        (0-11): 12 x Block(
          (attention_norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (ffn_norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (ffn): Mlp(
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (attn): Attention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (out): Linear(in_features=768, out_features=768, bias=Tru

In [45]:
# NOTE: vit-b_16-224 run
acc = eval_model(baseline_model)
gather.append(('Baseline', acc))
gather

acc: 0.75664


[('Baseline', 0.75664)]

In [14]:
pretrained_path = '/fs/nexus-scratch/vla/ViT_pretrained_checkpoints/ViT-B_16.npz'
config = full_precision.CONFIGS["ViT-B_16"]

baseline_model = full_precision.VisionTransformer(config, num_classes=1000, zero_head=False, img_size=224, vis=False).to(device)
baseline_model.load_from(np.load(pretrained_path))
baseline_model.to(device)

# NOTE: vit-b_16 run
acc = eval_model(baseline_model)
gather.append(('Baseline-vit-b_16', acc))
gather

load_pretrained: grid-size from 24 to 14
acc: 0.69462


[('Baseline-vit-b_16', 0.69462)]

In [None]:
from torchvision.models import vit_b_16, ViT_B_16_Weights
pytorch_vit = vit_b_16(ViT_B_16_Weights)
pytorch_vit.to(device)
pytorch_vit

In [34]:
# NOTE: pytorch vit run
acc = eval_model(pytorch_vit)
gather.append(('pytorch-vit', acc))
gather

acc: 0.81068


[('pytorch-vit', 0.81068)]

## Quantized Models (PTQ)
- NOTE: using minmax quantization with custom nn.Linear replacement [NBitLinearDynamic](./models/nbitlineardynamic.py)

In [13]:
pretrained_path = '/fs/nexus-scratch/vla/ViT_pretrained_checkpoints/ViT-B_16-224.npz'

def load_quant_model(weight_bits=8, activation_bits=8):
    config = CONFIGS["ViT-B_16"]
    config['weight_bits'] = weight_bits
    config['activation_bits'] = activation_bits

    model = VisionTransformer(config, num_classes=1000, zero_head=False, img_size=224, vis=False).to(device)
    model.load_from(np.load(pretrained_path))
    model.to(device)
    
    return model

In [31]:
model = load_quant_model(16,16)
model

VisionTransformer(
  (transformer): Transformer(
    (embeddings): Embeddings(
      (patch_embeddings): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): Encoder(
      (layer): ModuleList(
        (0-11): 12 x Block(
          (attention_norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (ffn_norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (ffn): Mlp(
            (fc1): NBitLinearDynamic(in_features=768, out_features=3072, bias=True)
            (fc2): NBitLinearDynamic(in_features=3072, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (attn): Attention(
            (query): NBitLinearDynamic(in_features=768, out_features=768, bias=True)
            (key): NBitLinearDynamic(in_features=768, out_features=768, bias=True)
            (value): NBitLinearDynamic(in_features=768, out_features=768, bias=True)
            (o

In [32]:
# 16-bit run
acc = eval_model(model)
gather.append(('16-bit', acc))

acc: 0.55744


In [29]:
# 8-bit run
model = load_quant_model(8,8)
acc = eval_model(model)
gather.append(('8-bit', acc))

acc: 0.55758


In [20]:
# 6-bit run
model = load_quant_model(6,6)
acc = eval_model(model) 
gather.append(('6-bit', acc))

acc: 0.51352


In [21]:
# 4-bit run
model = load_quant_model(4,4)
acc = eval_model(model)
gather.append(('4-bit', acc))

acc: 0.0715


In [22]:
# 2-bit run
model = load_quant_model(2,2)
acc = eval_model(model)
gather.append(('2-bit', acc))

acc: 0.00108


In [None]:
# TODO: plots

## Pytorch PTQ

In [69]:
baseline_head = baseline_model.head
baseline_head

Linear(in_features=768, out_features=1000, bias=True)

In [11]:
example_inputs = (next(iter(imagenet_loader))[0])

In [12]:
from torch.quantization.quantize_fx import prepare_fx, convert_fx
from torch.ao.quantization import default_dynamic_qconfig, float_qparams_weight_only_qconfig, QConfigMapping
import torch.nn as nn

# Full docs for supported qconfig for floating point modules/ops can be found in `quantization docs <https://pytorch.org/docs/stable/quantization.html#module-torch.quantization>`_
# Full docs for `QConfigMapping <https://pytorch.org/docs/stable/generated/torch.ao.quantization.qconfig_mapping.QConfigMapping.html#torch.ao.quantization.qconfig_mapping.QConfigMapping>`_
qconfig_mapping = (QConfigMapping()
    # .set_object_type(nn.Embedding, float_qparams_weight_only_qconfig)
    # .set_object_type(nn.LSTM, default_dynamic_qconfig)
    .set_object_type(nn.Linear, default_dynamic_qconfig)
)

prepared_model = prepare_fx(baseline_model, qconfig_mapping, example_inputs)
print("prepared model:", prepared_model)
quantized_model = convert_fx(baseline_model)
print("quantized model", quantized_model)


TraceError: Proxy object cannot be iterated. This can be attempted when the Proxy is used in a loop or as a *args or **kwargs function argument. See the torch.fx docs on pytorch.org for a more detailed explanation of what types of control flow can be traced, and check out the Proxy docstring for help troubleshooting Proxy iteration errors

In [None]:
torch.quantization.quantize

In [73]:
def print_size_of_model(model, label=""):
    torch.save(model.state_dict(), "temp.p")
    size=os.path.getsize("temp.p")
    print("model: ",label,' \t','Size (MB):', size/1e6)
    os.remove('temp.p')
    return size

In [75]:
_ = print_size_of_model(baseline_model)

model:    	 Size (MB): 346.335422


In [77]:
_ = print_size_of_model(quantized_model)

model:    	 Size (MB): 91.586734


In [78]:
# NPTE: not supported on GPU...
acc = eval_model(quantized_model)

NotImplementedError: Could not run 'quantized::linear_dynamic' with arguments from the 'CUDA' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'quantized::linear_dynamic' is only available for these backends: [CPU, Meta, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradMPS, AutogradXPU, AutogradHPU, AutogradLazy, AutogradMeta, Tracer, AutocastCPU, AutocastCUDA, FuncTorchBatched, BatchedNestedTensor, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher].

CPU: registered at ../aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp:662 [kernel]
Meta: registered at ../aten/src/ATen/core/MetaFallbackKernel.cpp:23 [backend fallback]
BackendSelect: fallthrough registered at ../aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Python: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:154 [backend fallback]
FuncTorchDynamicLayerBackMode: registered at ../aten/src/ATen/functorch/DynamicLayer.cpp:498 [backend fallback]
Functionalize: registered at ../aten/src/ATen/FunctionalizeFallbackKernel.cpp:324 [backend fallback]
Named: registered at ../aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
Conjugate: registered at ../aten/src/ATen/ConjugateFallback.cpp:17 [backend fallback]
Negative: registered at ../aten/src/ATen/native/NegateFallback.cpp:19 [backend fallback]
ZeroTensor: registered at ../aten/src/ATen/ZeroTensorFallback.cpp:86 [backend fallback]
ADInplaceOrView: fallthrough registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:86 [backend fallback]
AutogradOther: registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:53 [backend fallback]
AutogradCPU: registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:57 [backend fallback]
AutogradCUDA: registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:65 [backend fallback]
AutogradXLA: registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:69 [backend fallback]
AutogradMPS: registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:77 [backend fallback]
AutogradXPU: registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:61 [backend fallback]
AutogradHPU: registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:90 [backend fallback]
AutogradLazy: registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:73 [backend fallback]
AutogradMeta: registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:81 [backend fallback]
Tracer: registered at ../torch/csrc/autograd/TraceTypeManual.cpp:297 [backend fallback]
AutocastCPU: fallthrough registered at ../aten/src/ATen/autocast_mode.cpp:378 [backend fallback]
AutocastCUDA: fallthrough registered at ../aten/src/ATen/autocast_mode.cpp:244 [backend fallback]
FuncTorchBatched: registered at ../aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:720 [backend fallback]
BatchedNestedTensor: registered at ../aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:746 [backend fallback]
FuncTorchVmapMode: fallthrough registered at ../aten/src/ATen/functorch/VmapModeRegistrations.cpp:28 [backend fallback]
Batched: registered at ../aten/src/ATen/LegacyBatchingRegistrations.cpp:1075 [backend fallback]
VmapMode: fallthrough registered at ../aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
FuncTorchGradWrapper: registered at ../aten/src/ATen/functorch/TensorWrapper.cpp:203 [backend fallback]
PythonTLSSnapshot: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:162 [backend fallback]
FuncTorchDynamicLayerFrontMode: registered at ../aten/src/ATen/functorch/DynamicLayer.cpp:494 [backend fallback]
PreDispatch: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:166 [backend fallback]
PythonDispatcher: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:158 [backend fallback]


In [17]:
from torch.quantization.observer import MinMaxObserver, MovingAverageMinMaxObserver, HistogramObserver
C, L = 3, 4
normal = torch.distributions.normal.Normal(0,1)
inputs = [normal.sample((C, L)), normal.sample((C, L))]
print(inputs)

# >>>>>
# [tensor([[-0.0590,  1.1674,  0.7119, -1.1270],
#          [-1.3974,  0.5077, -0.5601,  0.0683],
#          [-0.0929,  0.9473,  0.7159, -0.4574]]]),

# tensor([[-0.0236, -0.7599,  1.0290,  0.8914],
#          [-1.1727, -1.2556, -0.2271,  0.9568],
#          [-0.2500,  1.4579,  1.4707,  0.4043]])]

observers = [MinMaxObserver(), MovingAverageMinMaxObserver(), HistogramObserver()]
for obs in observers:
  for x in inputs: obs(x) 
  print(obs.__class__.__name__, obs.calculate_qparams())

[tensor([[ 0.3652,  0.8875, -0.8823,  0.5704],
        [-1.1190,  0.2417,  0.9542, -0.5831],
        [ 0.1154, -1.1137,  2.1861,  0.7512]]), tensor([[ 0.1743,  0.5928,  0.5895,  1.5303],
        [ 0.0349,  0.0213,  0.8716, -0.6756],
        [ 0.6009, -0.0383,  0.3956, -0.3443]])]
MinMaxObserver (tensor([0.0130]), tensor([86], dtype=torch.int32))
MovingAverageMinMaxObserver (tensor([0.0129]), tensor([86], dtype=torch.int32))
HistogramObserver (tensor([0.0130]), tensor([86], dtype=torch.int32))


In [22]:
from models.nbitlineardynamic import quant
quant(x, 8, observers[0])

tensor([[ 0.1685,  0.5962,  0.5833,  1.5294],
        [ 0.0389,  0.0259,  0.8684, -0.6740],
        [ 0.5962, -0.0389,  0.4018, -0.3500]])