In [1]:
import torch
import torchvision
from PIL import Image
import numpy as np
import random
from tqdm import tqdm
from datasets import load_dataset
import torch.multiprocessing
import os
from torch import nn

print('DEFAULT:', torch.get_num_threads(), torch.get_num_interop_threads())
# os.environ['OMP_NUM_THREADS'] = '1'
# os.environ['MKL_NUM_THREADS'] = '1'
# torch.set_num_threads(1), torch.set_num_interop_threads(1)
print(torch.get_num_threads(), torch.get_num_interop_threads())
print(torch.__config__.parallel_info())

torch.multiprocessing.set_sharing_strategy('file_system')

DEFAULT: 4 4
4 4
ATen/Parallel:
	at::get_num_threads() : 4
	at::get_num_interop_threads() : 4
OpenMP 201511 (a.k.a. OpenMP 4.5)
	omp_get_max_threads() : 4
Intel(R) oneAPI Math Kernel Library Version 2022.2-Product Build 20220804 for Intel(R) 64 architecture applications
	mkl_get_max_threads() : 4
Intel(R) MKL-DNN v3.1.1 (Git Hash 64f6bcbcbab628e96f33a62c3e975f8535a7bde4)
std::thread::hardware_concurrency() : 8
Environment variables:
	OMP_NUM_THREADS : [not set]
	MKL_NUM_THREADS : [not set]
ATen parallel backend: OpenMP



In [2]:
model = torchvision.models.vit_b_16(weights=torchvision.models.ViT_B_16_Weights.IMAGENET1K_V1)
model.eval()
transforms = torchvision.models.ViT_B_16_Weights.IMAGENET1K_V1.transforms

In [3]:
with open('../imagenet1000.txt', 'r') as fopen:
    lines = fopen.readlines()

def process_classes(line: str):
    splitted = line.strip().removeprefix('{').removesuffix(',').split(':')
    return (int(splitted[0]), splitted[1].strip().strip('\''))

orig_classes = dict(map(process_classes, lines))

imagenette_classes = dict(enumerate(['tench', 'English springer', 'cassette player', 'chain saw', 'church', 'French horn', 'garbage truck', 'gas pump', 'golf ball', 'parachute']))

for k, v in imagenette_classes.items():
    for k1, v1 in orig_classes.items():
        if v in v1:
            imagenette_classes[k] = k1

In [4]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, datasource, transforms: callable):
        super().__init__()
        self.transforms = transforms
        self.datasource = datasource

    def __len__(self) -> int:
        return len(self.datasource)

    def __getitem__(self, index: int) -> torch.Tensor:
        data = self.datasource[index]
        image, label = data['image'], data['label']
        if image.mode != 'RGB':
            image = Image.fromarray(np.array(image)[..., None].repeat(3, -1))
        return self.transforms(image), imagenette_classes[label]

In [5]:
imagenette_train = load_dataset('frgfm/imagenette', '320px', split='train')
imagenette_valid = load_dataset('frgfm/imagenette', '320px', split='validation')

In [6]:
num_workers = 4
batch_size = 1

In [7]:
tf = transforms()
trainset = Dataset(datasource=imagenette_train, transforms=tf)
validset = Dataset(datasource=imagenette_valid, transforms=tf)
valid_dataloader = torch.utils.data.DataLoader(validset, num_workers=num_workers, batch_size=batch_size, shuffle=False)

In [8]:
def nbytes(model: torch.nn.Module):
    n = 0
    for p in model.parameters():
        n += p.nbytes
    return n / 1024 ** 2

In [9]:
from torch.profiler import profile, record_function, ProfilerActivity
from itertools import product
import torch.quantization
from torch.quantization.observer import MinMaxObserver, MovingAverageMinMaxObserver, HistogramObserver
from torch.ao.quantization import get_default_qconfig_mapping, get_default_qconfig
from torch.quantization.quantize_fx import prepare_fx, convert_fx
from torch.ao.quantization import QConfigMapping
import gc
from contextlib import nullcontext
from timeit import timeit
import time
from sklearn.metrics import accuracy_score, top_k_accuracy_score
import datetime
# import torch.quantization._numeric_suite as ns
import torch.quantization._numeric_suite_fx as ns

def fix_seed(worker_id=0, seed=0xBADCAFE):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

fix_seed()

torch_generator = torch.Generator()
torch_generator.manual_seed(0xBADCAFFE)

<torch._C.Generator at 0x7f74452cd8f0>

In [10]:
from tqdm.notebook import tqdm
from copy import deepcopy

In [11]:
for p in model.parameters():
    p.requires_grad = False

In [12]:
copy_model = deepcopy(model)
module_a = deepcopy(model)
module_a.heads.head = nn.Identity()
module_b = model.heads.head
# model = LoggerModule(module_a, module_b)

In [13]:
# from sklearn.metrics import accuracy_score
# gt = []
# pred = []
# # embeddings = []
# Y = []
# with torch.inference_mode():
#     for x, y in tqdm(valid_dataloader):
#         emb = module_a.forward(x)
#         # embeddings.append(emb)
#         y_hat = module_b(emb)
#         # Y.append(y_hat)
#         gt.append(y)
#         pred.append(y_hat.argmax(-1))
#     gt = torch.cat(gt).ravel().numpy()
#     pred = torch.cat(pred).ravel().numpy()

In [14]:
# gc.collect()
# accuracy_score(gt, pred)

In [16]:
qconfig = get_default_qconfig("x86")
qconfig_mapping = QConfigMapping().set_global(qconfig)
with torch.inference_mode():
    quantized_module_a = torch.quantization.quantize_dynamic(module_a, {nn.Linear}, dtype=torch.qint8, mapping=qconfig)
    quantized_module_b = torch.quantization.quantize_dynamic(module_b, {nn.Linear}, dtype=torch.qint8, mapping=qconfig)

In [17]:
quantized_module_a

VisionTransformer(
  (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_1): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_a

In [18]:
nbytes(quantized_module_a)

327.2958984375

In [19]:
from sklearn.metrics import accuracy_score
gt = []
pred = []
# embeddings = []
Y = []
with torch.inference_mode():
    for x, y in tqdm(valid_dataloader):
        emb = quantized_module_a.forward(x)
        # embeddings.append(emb)
        y_hat = quantized_module_b(emb)
        # Y.append(y_hat)
        gt.append(y)
        pred.append(y_hat.argmax(-1))
    gt = torch.cat(gt).ravel().numpy()
    pred = torch.cat(pred).ravel().numpy()

  0%|          | 0/3925 [00:00<?, ?it/s]

In [20]:
nbytes(model), nbytes(quantized_module_a)

(330.2294006347656, 327.2958984375)

In [21]:
accuracy_score(gt, pred)

0.914140127388535

In [None]:
gc.collect()

18

In [None]:
from torch.profiler import profile, record_function, ProfilerActivity

In [None]:
module_b

Linear(in_features=768, out_features=1000, bias=True)

In [None]:
quantized_module_a.encoder.layers[0].mlp[0]

DynamicQuantizedLinear(in_features=768, out_features=3072, dtype=torch.qint8, qscheme=torch.per_tensor_affine)

In [None]:
x = torch.rand(1, 768)
with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:
    with record_function("model_inference"):
        module_a.encoder.layers[0].mlp[0](x)
print(prof.key_averages().table(sort_by="cpu_time_total"))

----------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                  Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
----------------------  ------------  ------------  ------------  ------------  ------------  ------------  
       model_inference         1.98%     161.000us       100.00%       8.117ms       8.117ms             1  
          aten::linear         0.21%      17.000us        98.02%       7.956ms       7.956ms             1  
               aten::t        91.41%       7.420ms        91.63%       7.438ms       7.438ms             1  
           aten::addmm         5.93%     481.000us         6.17%     501.000us     501.000us             1  
       aten::transpose         0.16%      13.000us         0.22%      18.000us      18.000us             1  
           aten::copy_         0.16%      13.000us         0.16%      13.000us      13.000us             1  
          aten::exp

STAGE:2023-12-10 04:28:01 6533:6533 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
STAGE:2023-12-10 04:28:01 6533:6533 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2023-12-10 04:28:01 6533:6533 ActivityProfilerController.cpp:322] Completed Stage: Post Processing


In [None]:
x = torch.rand(1, 768)
with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:
    with record_function("model_inference"):
        quantized_module_a.encoder.layers[0].mlp[0](x)
print(prof.key_averages().table(sort_by="cpu_time_total"))

-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                         Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
              model_inference         2.98%     298.000us       100.00%      10.000ms      10.000ms             1  
    quantized::linear_dynamic        92.37%       9.237ms        97.01%       9.701ms       9.701ms             1  
             aten::empty_like         4.45%     445.000us         4.51%     451.000us     451.000us             1  
                  aten::empty         0.19%      19.000us         0.19%      19.000us       9.500us             2  
                     aten::to         0.01%       1.000us         0.01%       1.000us       1.000us             1  
-----------------------------  ------------  ------------  ------------ 

STAGE:2023-12-10 04:26:55 6533:6533 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
STAGE:2023-12-10 04:26:55 6533:6533 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2023-12-10 04:26:55 6533:6533 ActivityProfilerController.cpp:322] Completed Stage: Post Processing
