In [1]:
import torch
import torchvision
from PIL import Image
import numpy as np
import random
from tqdm import tqdm
from datasets import load_dataset
import torch.multiprocessing


In [2]:
torch.multiprocessing.set_sharing_strategy('file_system')


In [3]:
model = torchvision.models.vit_b_16(weights=torchvision.models.ViT_B_16_Weights.IMAGENET1K_V1)
model.eval()
transforms = torchvision.models.ViT_B_16_Weights.IMAGENET1K_V1.transforms

In [4]:
with open('../imagenet1000.txt', 'r') as fopen:
    lines = fopen.readlines()

def process_classes(line: str):
    splitted = line.strip().removeprefix('{').removesuffix(',').split(':')
    return (int(splitted[0]), splitted[1].strip().strip('\''))

orig_classes = dict(map(process_classes, lines))

imagenette_classes = dict(enumerate(['tench', 'English springer', 'cassette player', 'chain saw', 'church', 'French horn', 'garbage truck', 'gas pump', 'golf ball', 'parachute']))

for k, v in imagenette_classes.items():
    for k1, v1 in orig_classes.items():
        if v in v1:
            imagenette_classes[k] = k1

In [5]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, datasource, transforms: callable, ramming: bool = False):
        super().__init__()
        self.transforms = transforms
        self.ramming = ramming
        if ramming:
            ram_data = []
            for i in range(len(datasource)):
                data = datasource[i]
                ram_data.append({'image': data['image'], 'label': data['label']})
            self.datasource = ram_data
        else:
            self.datasource = datasource

    def __len__(self) -> int:
        return len(self.datasource)

    def __getitem__(self, index: int) -> torch.Tensor:
        data = self.datasource[index]
        image, label = data['image'], data['label']
        if image.mode != 'RGB':
            image = Image.fromarray(np.array(image)[..., None].repeat(3, -1))
        return self.transforms(image), imagenette_classes[label]

In [6]:
imagenette_train = load_dataset('frgfm/imagenette', '320px', split='train')
imagenette_valid = load_dataset('frgfm/imagenette', '320px', split='validation')

tiny_imagenet_train = load_dataset('Maysee/tiny-imagenet', split='train')
tiny_imagenet_valid = load_dataset('Maysee/tiny-imagenet', split='valid')

In [7]:
num_workers = 1
batch_size = 1

In [8]:
# trainset = Dataset(datasource=tiny_imagenet_train, transforms=transforms())
tf = transforms()
trainset = Dataset(datasource=imagenette_train, transforms=tf)
validset = Dataset(datasource=imagenette_valid, transforms=tf, ramming=True)
valid_dataloader = torch.utils.data.DataLoader(validset, num_workers=num_workers, batch_size=batch_size, shuffle=False)
# valid_dataloader = torch.utils.data.DataLoader(validset, num_workers=num_workers, batch_size=batch_size)

In [9]:
def nbytes(model: torch.nn.Module):
    n = 0
    for p in model.parameters():
        n += p.nbytes

    return n / 1024 ** 2

In [10]:
nbytes(model)

330.2294006347656

with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:
    with record_function("model_inference"):
        model(inputs)

In [11]:
from torch.profiler import profile, record_function, ProfilerActivity
from itertools import product
from torch.ao.quantization import get_default_qconfig_mapping
from torch.quantization.quantize_fx import prepare_fx, convert_fx
import gc
from contextlib import nullcontext
from timeit import timeit
import time
from sklearn.metrics import accuracy_score, top_k_accuracy_score

def fix_seed(worker_id=0, seed=0xBADCAFE):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

fix_seed()

torch_generator = torch.Generator()
torch_generator.manual_seed(0xBADCAFFE)

<torch._C.Generator at 0x7efcf8b941b0>

In [12]:
examples_inputs = [trainset[i][0] for i in map(int, np.unique(np.random.randint(0, len(trainset), size=128)))]

In [13]:
matmul_precision = ['medium', 'high', 'highest']
quantization = ['None', 'x86', 'fbgemm']
mixed_precision = ['']
batch_sizes = [1, 4]
num_workers = 1

In [14]:
def run_epoch(model, valid_dataloader, limit=2**32):
    T = 0.0
    Y = []
    Y_hat = []
    for i, (x, y) in enumerate(valid_dataloader):
        if i >= limit:
            break
        Y.append(y.ravel())
        start = time.time()
        y_hat = model(x)
        end = time.time()
        Y_hat.append(y_hat.argmax(-1))
        T += end - start
    return accuracy_score(np.array(Y).ravel(), np.array(Y_hat).ravel()), T

In [15]:
limit = 16
T = {}
accuracy = {}
with torch.no_grad():
    for prec, quant, bs in tqdm(product(matmul_precision, quantization, batch_sizes)):
        valid_dataloader = torch.utils.data.DataLoader(validset, num_workers=num_workers, 
                                                       batch_size=batch_size, shuffle=True, 
                                                       worker_init_fn=fix_seed, generator=torch_generator)
        model = torchvision.models.vit_b_16(weights=torchvision.models.ViT_B_16_Weights.IMAGENET1K_V1).eval()
        torch.set_float32_matmul_precision(prec)
        if quant != 'None':
            torch.backends.quantized.engine = quant
            qconfig_mapping = get_default_qconfig_mapping(quant)
            prepared_model = prepare_fx(model, qconfig_mapping, example_inputs=examples_inputs)
            model = convert_fx(prepared_model)
        key  = '_'.join(map(str, [prec, quant, bs, round(nbytes(model))]))
        acc, t = run_epoch(model, valid_dataloader, limit)
        
        T[key] = np.round(t / (min(limit, len(valid_dataloader)) * bs), 3)
        accuracy[key] = np.round(acc, 3)
        gc.collect()

0it [00:00, ?it/s]

tensor([574])
tensor([569])
tensor([497])
tensor([482])
tensor([491])
tensor([0])
tensor([0])
tensor([571])


1it [00:04,  4.45s/it]

tensor([566])
tensor([574])
tensor([482])
tensor([701])
tensor([571])
tensor([0])
tensor([491])
tensor([482])


  torch.has_cuda,
  torch.has_cudnn,
  torch.has_mps,
  torch.has_mkldnn,


tensor([566])
tensor([497])
tensor([569])
tensor([571])
tensor([574])
tensor([566])
tensor([701])
tensor([217])




tensor([217])
tensor([482])
tensor([701])
tensor([566])
tensor([566])
tensor([574])
tensor([0])
tensor([0])




tensor([566])
tensor([574])
tensor([569])
tensor([574])
tensor([571])
tensor([569])
tensor([497])
tensor([566])




tensor([574])
tensor([571])
tensor([491])
tensor([491])
tensor([0])
tensor([217])
tensor([491])
tensor([701])


6it [00:33,  6.07s/it]

tensor([491])
tensor([571])
tensor([569])
tensor([566])
tensor([482])
tensor([217])
tensor([217])
tensor([566])


7it [00:37,  5.22s/it]

tensor([701])
tensor([217])
tensor([571])
tensor([491])
tensor([217])
tensor([497])
tensor([497])
tensor([497])




tensor([0])
tensor([497])
tensor([569])
tensor([566])
tensor([566])
tensor([566])
tensor([0])
tensor([482])




tensor([0])
tensor([574])
tensor([569])
tensor([497])
tensor([217])
tensor([566])
tensor([482])
tensor([569])




tensor([574])
tensor([569])
tensor([482])
tensor([217])
tensor([571])
tensor([497])
tensor([482])
tensor([701])




tensor([217])
tensor([491])
tensor([569])
tensor([701])
tensor([491])
tensor([0])
tensor([566])
tensor([497])


12it [01:05,  5.71s/it]

tensor([217])
tensor([497])
tensor([217])
tensor([701])
tensor([566])
tensor([571])
tensor([491])
tensor([701])


13it [01:09,  5.19s/it]

tensor([497])
tensor([491])
tensor([574])
tensor([569])
tensor([217])
tensor([497])
tensor([217])
tensor([574])


14it [01:16,  5.46s/it]


KeyboardInterrupt: 

In [None]:
for key, t in T.items():
    print(key, 'time:', t, 'acc', accuracy[key])

medium_None_1_330 time: 0.334 acc 0.875
medium_None_4_330 time: 0.087 acc 1.0
medium_x86_1_109 time: 0.35 acc 0.25
medium_x86_4_109 time: 0.063 acc 0.0
medium_fbgemm_1_109 time: 0.244 acc 0.0
medium_fbgemm_4_109 time: 0.066 acc 0.0
high_None_1_330 time: 0.275 acc 1.0
high_None_4_330 time: 0.098 acc 1.0
high_x86_1_109 time: 0.27 acc 0.0
high_x86_4_109 time: 0.069 acc 0.0
high_fbgemm_1_109 time: 0.218 acc 0.0
high_fbgemm_4_109 time: 0.066 acc 0.125
highest_None_1_330 time: 0.289 acc 1.0
highest_None_4_330 time: 0.066 acc 1.0
highest_x86_1_109 time: 0.207 acc 0.125
highest_x86_4_109 time: 0.054 acc 0.0
highest_fbgemm_1_109 time: 0.236 acc 0.375
highest_fbgemm_4_109 time: 0.059 acc 0.125


In [None]:
import datetime
date_time = datetime.datetime.now()
date_time
with open(f'results{date_time}.txt', 'w+') as fopen:
    for key, t in T.items():
        fopen.write(f"{key}, time:, {t}, acc, {accuracy[key]}\n")