In [1]:
import torch
import torchvision
from PIL import Image
import numpy as np
import random
from tqdm import tqdm
from datasets import load_dataset
import torch.multiprocessing


In [2]:
torch.multiprocessing.set_sharing_strategy('file_system')


In [3]:
model = torchvision.models.vit_b_16(weights=torchvision.models.ViT_B_16_Weights.IMAGENET1K_V1)
model.eval()
transforms = torchvision.models.ViT_B_16_Weights.IMAGENET1K_V1.transforms

In [4]:
with open('../imagenet1000.txt', 'r') as fopen:
    lines = fopen.readlines()

def process_classes(line: str):
    splitted = line.strip().removeprefix('{').removesuffix(',').split(':')
    return (int(splitted[0]), splitted[1].strip().strip('\''))

orig_classes = dict(map(process_classes, lines))

imagenette_classes = dict(enumerate(['tench', 'English springer', 'cassette player', 'chain saw', 'church', 'French horn', 'garbage truck', 'gas pump', 'golf ball', 'parachute']))

for k, v in imagenette_classes.items():
    for k1, v1 in orig_classes.items():
        if v in v1:
            imagenette_classes[k] = k1

In [5]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, datasource, transforms: callable):
        super().__init__()
        self.transforms = transforms
        self.datasource = datasource

    def __len__(self) -> int:
        return len(self.datasource)

    def __getitem__(self, index: int) -> torch.Tensor:
        data = self.datasource[index]
        image, label = data['image'], data['label']
        if image.mode != 'RGB':
            image = Image.fromarray(np.array(image)[..., None].repeat(3, -1))
        return self.transforms(image), imagenette_classes[label]

In [6]:
imagenette_train = load_dataset('frgfm/imagenette', '320px', split='train')
imagenette_valid = load_dataset('frgfm/imagenette', '320px', split='validation')

tiny_imagenet_train = load_dataset('Maysee/tiny-imagenet', split='train')
tiny_imagenet_valid = load_dataset('Maysee/tiny-imagenet', split='valid')

In [7]:
num_workers = 1
batch_size = 1

In [8]:
# trainset = Dataset(datasource=tiny_imagenet_train, transforms=transforms())
tf = transforms()
trainset = Dataset(datasource=imagenette_train, transforms=tf)
validset = Dataset(datasource=imagenette_valid, transforms=tf)
valid_dataloader = torch.utils.data.DataLoader(validset, num_workers=num_workers, batch_size=batch_size, shuffle=False)
# valid_dataloader = torch.utils.data.DataLoader(validset, num_workers=num_workers, batch_size=batch_size)

In [9]:
def nbytes(model: torch.nn.Module):
    n = 0
    for p in model.parameters():
        n += p.nbytes

    return n / 1024 ** 2

In [10]:
import torch
print(torch.__config__.parallel_info())

ATen/Parallel:
	at::get_num_threads() : 4
	at::get_num_interop_threads() : 4
OpenMP 201511 (a.k.a. OpenMP 4.5)
	omp_get_max_threads() : 4
Intel(R) oneAPI Math Kernel Library Version 2022.2-Product Build 20220804 for Intel(R) 64 architecture applications
	mkl_get_max_threads() : 4
Intel(R) MKL-DNN v3.1.1 (Git Hash 64f6bcbcbab628e96f33a62c3e975f8535a7bde4)
std::thread::hardware_concurrency() : 8
Environment variables:
	OMP_NUM_THREADS : [not set]
	MKL_NUM_THREADS : [not set]
ATen parallel backend: OpenMP



In [11]:
nbytes(model)

330.2294006347656

In [12]:
from torch.profiler import profile, record_function, ProfilerActivity
from itertools import product
from torch.ao.quantization import get_default_qconfig_mapping
from torch.quantization.quantize_fx import prepare_fx, convert_fx
import gc
from contextlib import nullcontext
from timeit import timeit
import time
from sklearn.metrics import accuracy_score, top_k_accuracy_score
import datetime

def fix_seed(worker_id=0, seed=0xBADCAFE):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

fix_seed()

torch_generator = torch.Generator()
torch_generator.manual_seed(0xBADCAFFE)

<torch._C.Generator at 0x7fda24163650>

In [23]:
quantization = ['x86', 'fbgemm']
mixed_precision = ['']
batch_sizes = [1, 4]
num_workers = 1

In [24]:
def run_epoch(model, valid_dataloader, limit=2**32):
    T = 0.0
    Y = []
    Y_hat = []
    for i, (x, y) in enumerate(valid_dataloader):
        if i >= limit:
            break
        Y.append(y.ravel())
        start = time.time()
        y_hat = model(x)
        end = time.time()
        Y_hat.append(y_hat.argmax(-1))
        T += end - start
    return accuracy_score(np.array(Y).ravel(), np.array(Y_hat).ravel()), T

In [25]:
def calibrate(model, dataloader):
    with torch.no_grad():
        for i, (x, y) in enumerate(dataloader):
            if i > 64:
                break
            model(x)

In [26]:
limit = 16
T = {}
date_time = datetime.datetime.now()
accuracy = {}
with torch.inference_mode():
    with open(f'profiling{date_time}.txt', 'w+') as fopen:
        for quant, bs in tqdm(product(quantization, batch_sizes)):
            valid_dataloader = torch.utils.data.DataLoader(validset, num_workers=num_workers, 
                                                        batch_size=batch_size, shuffle=True, 
                                                        worker_init_fn=fix_seed, generator=torch_generator)
            model = torchvision.models.vit_b_16(weights=torchvision.models.ViT_B_16_Weights.IMAGENET1K_V1).eval()
            if quant != 'None':
                torch.backends.quantized.engine = quant
                qconfig_mapping = get_default_qconfig_mapping(quant)
                prepared_model = prepare_fx(model, qconfig_mapping, example_inputs=next(iter(valid_dataloader))[0])
                calibrate(prepared_model, valid_dataloader)
                model = convert_fx(prepared_model)
            key  = '_'.join(map(str, [quant, bs, round(nbytes(model))]))

            with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:
                for i, (x, y) in enumerate(valid_dataloader):
                    with record_function("model_inference"):
                        if i >= limit:
                            break
                        model(x)
            assert False
            fopen.write(f'{key}\n')
            fopen.write(prof.key_averages().table(sort_by="cpu_time_total"))

            # acc, t = run_epoch(model, valid_dataloader, limit)
            
            # T[key] = np.round(t / (min(limit, len(valid_dataloader)) * bs), 3)
            # accuracy[key] = np.round(acc, 3)
            gc.collect()

  torch.has_cuda,
  torch.has_cudnn,
  torch.has_mps,
  torch.has_mkldnn,
STAGE:2023-10-18 02:48:21 189450:189450 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
STAGE:2023-10-18 02:48:24 189450:189450 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2023-10-18 02:48:24 189450:189450 ActivityProfilerController.cpp:322] Completed Stage: Post Processing
0it [00:35, ?it/s]


AssertionError: 

In [28]:
model

GraphModule(
  (conv_proj): QuantizedConv2d(3, 768, kernel_size=(16, 16), stride=(16, 16), scale=0.052607402205467224, zero_point=62)
  (encoder): Module(
    (dropout): QuantizedDropout(p=0.0, inplace=False)
    (layers): Module(
      (encoder_layer_0): Module(
        (ln_1): QuantizedLayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): QuantizedDropout(p=0.0, inplace=False)
        (ln_2): QuantizedLayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): Module(
          (0): QuantizedLinear(in_features=768, out_features=3072, scale=0.07862671464681625, zero_point=70, qscheme=torch.per_channel_affine)
          (1): GELU(approximate='none')
          (2): QuantizedDropout(p=0.0, inplace=False)
          (3): QuantizedLinear(in_features=3072, out_features=768, scale=0.037298478186130524, zero_poin

In [None]:
for key, t in T.items():
    print(key, 'time:', t, 'acc', accuracy[key])

In [None]:

date_time = datetime.datetime.now()
date_time
with open(f'results{date_time}.txt', 'w+') as fopen:
    for key, t in T.items():
        fopen.write(f"{key}, time:, {t}, acc, {accuracy[key]}\n")