In [1]:
import torch
import torchvision
from PIL import Image
import numpy as np
import random
from tqdm import tqdm
from datasets import load_dataset
import torch.multiprocessing


In [2]:
torch.multiprocessing.set_sharing_strategy('file_system')


In [3]:
model = torchvision.models.vit_b_16(weights=torchvision.models.ViT_B_16_Weights.IMAGENET1K_V1)
model.eval()
transforms = torchvision.models.ViT_B_16_Weights.IMAGENET1K_V1.transforms

In [4]:
with open('../imagenet1000.txt', 'r') as fopen:
    lines = fopen.readlines()

def process_classes(line: str):
    splitted = line.strip().removeprefix('{').removesuffix(',').split(':')
    return (int(splitted[0]), splitted[1].strip().strip('\''))

orig_classes = dict(map(process_classes, lines))

imagenette_classes = dict(enumerate(['tench', 'English springer', 'cassette player', 'chain saw', 'church', 'French horn', 'garbage truck', 'gas pump', 'golf ball', 'parachute']))

for k, v in imagenette_classes.items():
    for k1, v1 in orig_classes.items():
        if v in v1:
            imagenette_classes[k] = k1

In [5]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, datasource, transforms: callable, ramming: bool = False):
        super().__init__()
        self.transforms = transforms
        self.ramming = ramming
        if ramming:
            ram_data = []
            for i in range(len(datasource)):
                data = datasource[i]
                ram_data.append({'image': data['image'], 'label': data['label']})
            self.datasource = ram_data
        else:
            self.datasource = datasource

    def __len__(self) -> int:
        return len(self.datasource)

    def __getitem__(self, index: int) -> torch.Tensor:
        data = self.datasource[index]
        image, label = data['image'], data['label']
        if image.mode != 'RGB':
            image = Image.fromarray(np.array(image)[..., None].repeat(3, -1))
        return self.transforms(image), imagenette_classes[label]

In [6]:
imagenette_train = load_dataset('frgfm/imagenette', '320px', split='train')
imagenette_valid = load_dataset('frgfm/imagenette', '320px', split='validation')

tiny_imagenet_train = load_dataset('Maysee/tiny-imagenet', split='train')
tiny_imagenet_valid = load_dataset('Maysee/tiny-imagenet', split='valid')

In [7]:
num_workers = 1
batch_size = 1

In [8]:
# trainset = Dataset(datasource=tiny_imagenet_train, transforms=transforms())
tf = transforms()
trainset = Dataset(datasource=imagenette_train, transforms=tf)
validset = Dataset(datasource=imagenette_valid, transforms=tf, ramming=True)
valid_dataloader = torch.utils.data.DataLoader(validset, num_workers=num_workers, batch_size=batch_size, shuffle=False)
# valid_dataloader = torch.utils.data.DataLoader(validset, num_workers=num_workers, batch_size=batch_size)

In [9]:
def nbytes(model: torch.nn.Module):
    n = 0
    for p in model.parameters():
        n += p.nbytes

    return n / 1024 ** 2

In [10]:
nbytes(model)

330.2294006347656

In [11]:
from torch.profiler import profile, record_function, ProfilerActivity
from itertools import product
from torch.ao.quantization import get_default_qconfig_mapping
from torch.quantization.quantize_fx import prepare_fx, convert_fx
import gc
from contextlib import nullcontext
from timeit import timeit
import time
from sklearn.metrics import accuracy_score, top_k_accuracy_score
import datetime

def fix_seed(worker_id=0, seed=0xBADCAFE):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

fix_seed()

torch_generator = torch.Generator()
torch_generator.manual_seed(0xBADCAFFE)

<torch._C.Generator at 0x7f2b90354c50>

In [12]:
examples_inputs = [validset[i][0] for i in range(len(validset) // 100)]

In [13]:
matmul_precision = ['medium', 'high', 'highest']
quantization = ['None', 'x86', 'fbgemm']
mixed_precision = ['']
batch_sizes = [1, 4]
num_workers = 1

In [14]:
def run_epoch(model, valid_dataloader, limit=2**32):
    T = 0.0
    Y = []
    Y_hat = []
    for i, (x, y) in enumerate(valid_dataloader):
        if i >= limit:
            break
        Y.append(y.ravel())
        start = time.time()
        y_hat = model(x)
        end = time.time()
        Y_hat.append(y_hat.argmax(-1))
        T += end - start
    return accuracy_score(np.array(Y).ravel(), np.array(Y_hat).ravel()), T

In [16]:
limit = 16
T = {}
date_time = datetime.datetime.now()
accuracy = {}
with torch.no_grad():
    with open(f'profiling{date_time}.txt', 'w+') as fopen:
        for prec, quant, bs in tqdm(product(matmul_precision, quantization, batch_sizes)):
            valid_dataloader = torch.utils.data.DataLoader(validset, num_workers=num_workers, 
                                                        batch_size=batch_size, shuffle=True, 
                                                        worker_init_fn=fix_seed, generator=torch_generator)
            model = torchvision.models.vit_b_16(weights=torchvision.models.ViT_B_16_Weights.IMAGENET1K_V1).eval()
            torch.set_float32_matmul_precision(prec)
            if quant != 'None':
                torch.backends.quantized.engine = quant
                qconfig_mapping = get_default_qconfig_mapping(quant)
                prepared_model = prepare_fx(model, qconfig_mapping, example_inputs=examples_inputs)
                model = convert_fx(prepared_model)
            key  = '_'.join(map(str, [prec, quant, bs, round(nbytes(model))]))
            with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:
                for i, (x, y) in enumerate(valid_dataloader):
                    with record_function("model_inference"):
                        if i >= limit:
                            break
                        model(x)
            fopen.write(f'{key}\n')
            fopen.write(prof.key_averages().table(sort_by="cpu_time_total"))

            # acc, t = run_epoch(model, valid_dataloader, limit)
            
            # T[key] = np.round(t / (min(limit, len(valid_dataloader)) * bs), 3)
            # accuracy[key] = np.round(acc, 3)
            gc.collect()

0it [00:00, ?it/s]STAGE:2023-10-12 19:12:40 82527:82527 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
STAGE:2023-10-12 19:12:43 82527:82527 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2023-10-12 19:12:43 82527:82527 ActivityProfilerController.cpp:322] Completed Stage: Post Processing
1it [00:06,  6.92s/it]STAGE:2023-10-12 19:12:47 82527:82527 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
STAGE:2023-10-12 19:12:50 82527:82527 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2023-10-12 19:12:50 82527:82527 ActivityProfilerController.cpp:322] Completed Stage: Post Processing
  torch.has_cuda,
  torch.has_cudnn,
  torch.has_mps,
  torch.has_mkldnn,
STAGE:2023-10-12 19:12:57 82527:82527 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
STAGE:2023-10-12 19:13:00 82527:82527 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2023-10-12 19:13:00 82527:82527 ActivityProfilerController.cpp:

In [None]:
w = next(model.parameters()).detach().numpy()

In [None]:
model.state_dict().keys()

odict_keys(['class_token', 'conv_proj_input_scale_0', 'conv_proj_input_zero_point_0', '_scale_2', '_zero_point_2', 'encoder_scale_0', 'encoder_zero_point_0', 'encoder_scale_1', 'encoder_zero_point_1', 'encoder_layers_encoder_layer_0_scale_0', 'encoder_layers_encoder_layer_0_zero_point_0', 'encoder_layers_encoder_layer_0_scale_1', 'encoder_layers_encoder_layer_0_zero_point_1', 'encoder_layers_encoder_layer_0_mlp_1_scale_0', 'encoder_layers_encoder_layer_0_mlp_1_zero_point_0', 'encoder_layers_encoder_layer_0_scale_2', 'encoder_layers_encoder_layer_0_zero_point_2', 'encoder_layers_encoder_layer_1_scale_0', 'encoder_layers_encoder_layer_1_zero_point_0', 'encoder_layers_encoder_layer_1_scale_1', 'encoder_layers_encoder_layer_1_zero_point_1', 'encoder_layers_encoder_layer_1_mlp_1_scale_0', 'encoder_layers_encoder_layer_1_mlp_1_zero_point_0', 'encoder_layers_encoder_layer_1_scale_2', 'encoder_layers_encoder_layer_1_zero_point_2', 'encoder_layers_encoder_layer_2_scale_0', 'encoder_layers_encod

In [None]:
model.state_dict()['encoder_layers_encoder_layer_1_scale_2']

tensor(1.)

In [None]:
model.conv_proj.weight()[-1].data.ravel()

tensor([ 0.0020, -0.0047,  0.0108, -0.0094, -0.0087,  0.0128,  0.0182,  0.0296,
         0.0182, -0.0007,  0.0148,  0.0249,  0.0081,  0.0141, -0.0061,  0.0020,
        -0.0067, -0.0040,  0.0168,  0.0094, -0.0067, -0.0067,  0.0020,  0.0222,
         0.0081, -0.0168, -0.0034,  0.0108, -0.0135, -0.0108, -0.0188, -0.0007,
        -0.0040, -0.0081,  0.0309,  0.0323,  0.0168, -0.0128, -0.0161,  0.0155,
        -0.0020, -0.0377, -0.0121,  0.0094, -0.0148, -0.0047, -0.0108,  0.0155,
        -0.0074, -0.0202,  0.0061,  0.0087, -0.0067, -0.0457, -0.0605, -0.0148,
        -0.0242, -0.0531, -0.0128,  0.0128, -0.0135, -0.0040, -0.0094,  0.0175,
         0.0155, -0.0094,  0.0020, -0.0027, -0.0087, -0.0309, -0.0336,  0.0168,
         0.0101, -0.0148,  0.0229,  0.0336, -0.0027,  0.0054, -0.0054,  0.0161,
         0.0242,  0.0034,  0.0114, -0.0040,  0.0074,  0.0175,  0.0269,  0.0437,
         0.0114, -0.0034,  0.0256,  0.0081, -0.0121,  0.0074, -0.0040,  0.0087,
         0.0040, -0.0175, -0.0202, -0.02

In [None]:
for key, t in T.items():
    print(key, 'time:', t, 'acc', accuracy[key])

In [None]:

date_time = datetime.datetime.now()
date_time
with open(f'results{date_time}.txt', 'w+') as fopen:
    for key, t in T.items():
        fopen.write(f"{key}, time:, {t}, acc, {accuracy[key]}\n")