In [1]:
import sys
import logging
import os
from pathlib import Path
from pprint import pprint as pp

# figure out the correct path
machop_path = Path(".").resolve().parent.parent /"machop"
assert machop_path.exists(), "Failed to find machop at: {}".format(machop_path)
sys.path.append(str(machop_path))

from chop.dataset import MaseDataModule, get_dataset_info
from chop.tools.logger import get_logger

from chop.passes.graph.analysis import (
    report_node_meta_param_analysis_pass,
    profile_statistics_analysis_pass,
)
from chop.passes.graph import (
    add_common_metadata_analysis_pass,
    init_metadata_analysis_pass,
    add_software_metadata_analysis_pass,
)
from chop.tools.get_input import InputGenerator
from chop.ir.graph.mase_graph import MaseGraph

from chop.models import get_model_info, get_model




logger = get_logger("chop")
logger.setLevel(logging.INFO)

batch_size = 8
model_name = "jsc-tiny"
dataset_name = "jsc"


data_module = MaseDataModule(
    name=dataset_name,
    batch_size=batch_size,
    model_name=model_name,
    num_workers=0,
    # custom_dataset_cache_path="../../chop/dataset"
)
data_module.prepare_data()
data_module.setup()

model_info = get_model_info(model_name)
model = get_model(
    model_name,
    task="cls",
    dataset_info=data_module.dataset_info,
    pretrained=False,
    checkpoint = None)

input_generator = InputGenerator(
    data_module=data_module,
    model_info=model_info,
    task="cls",
    which_dataloader="train",
)

dummy_in = next(iter(input_generator))
_ = model(**dummy_in)

# generate the mase graph and initialize node metadata
mg = MaseGraph(model=model)

  from .autonotebook import tqdm as notebook_tqdm


In [25]:
#Define a Search Space


pass_args = {
"by": "type",
"default": {"config": {"name": None}},
"linear": {
        "config": {
            "name": "integer",
            # data
            "data_in_width": 8,
            "data_in_frac_width": 4,
            # weight
            "weight_width": 8,
            "weight_frac_width": 4,
            # bias
            "bias_width": 8,
            "bias_frac_width": 4,
        }
},}

import copy
# build a search space
data_in_frac_widths = [(16, 8), (8, 6), (8, 4), (4, 2)]
w_in_frac_widths = [(16, 8), (8, 6), (8, 4), (4, 2)]
search_spaces = []
for d_config in data_in_frac_widths:
    for w_config in w_in_frac_widths:
        pass_args['linear']['config']['data_in_width'] = d_config[0]
        pass_args['linear']['config']['data_in_frac_width'] = d_config[1]
        pass_args['linear']['config']['weight_width'] = w_config[0]
        pass_args['linear']['config']['weight_frac_width'] = w_config[1]
        # dict.copy() and dict(dict) only perform shallow copies
        # in fact, only primitive data types in python are doing implicit copy when a = b happens
        search_spaces.append(copy.deepcopy(pass_args))

In [26]:
#Defining a search strategy and a runner
# grid search


import torch
import time
from torchmetrics.classification import MulticlassAccuracy
from chop.passes.graph.transforms import (
    quantize_transform_pass,
    summarize_quantization_analysis_pass,
)
from chop.ir.graph.mase_graph import MaseGraph

mg, _ = init_metadata_analysis_pass(mg, None)
mg, _ = add_common_metadata_analysis_pass(mg, {"dummy_in": dummy_in})
mg, _ = add_software_metadata_analysis_pass(mg, None)

metric = MulticlassAccuracy(num_classes=5)
num_batchs = 5
# This first loop is basically our search strategy,
# in this case, it is a simple brute force search

recorded_accs = []
for i, config in enumerate(search_spaces):
    mg, _ = quantize_transform_pass(mg, config)
    j = 0
    

    # Measure model size after quantization
    temp_model_path = "temp_model.pth"
    torch.save(mg.model.state_dict(), temp_model_path)
    model_size = os.path.getsize(temp_model_path) / (1024 * 1024)  # Size in MB
    os.remove(temp_model_path)  # Clean up the temporary file

    # this is the inner loop, where we also call it as a runner.
    acc_avg, loss_avg = 0, 0
    accs, losses = [], []
    latencies = []
    for inputs in data_module.train_dataloader():
        xs, ys = inputs

        #latency
        start_time = time.time()
        preds = mg.model(xs)
        end_time = time.time()
        latencies.append(end_time - start_time)
        #accuracy; loss
        loss = torch.nn.functional.cross_entropy(preds, ys)
        acc = metric(preds, ys)
        accs.append(acc)
        losses.append(loss)
        if j > num_batchs:
            break
        j += 1

        
    acc_avg = sum(accs) / len(accs)
    loss_avg = sum(losses) / len(losses)
    latency_avg = sum(latencies) / len(latencies)
    recorded_accs.append(acc_avg)
    
    #print(acc_avg, loss_avg)
    #print(loss_avg)
    #print((end_time - start_time))

    print(f"Accuracy: {acc_avg:.4g}, Loss: {loss_avg:.4g}, Latency: {latency_avg} seconds, Model Size: {model_size} MB")



Accuracy: 0.3004, Loss: 1.566, Latency: 0.00023865699768066406 seconds, Model Size: 0.0034532546997070312 MB
Accuracy: 0.2512, Loss: 1.555, Latency: 0.00022520337785993303 seconds, Model Size: 0.0034532546997070312 MB
Accuracy: 0.1976, Loss: 1.626, Latency: 0.00021123886108398438 seconds, Model Size: 0.0034532546997070312 MB
Accuracy: 0.07143, Loss: 1.622, Latency: 0.00021542821611676897 seconds, Model Size: 0.0034532546997070312 MB
Accuracy: 0.2024, Loss: 1.554, Latency: 0.00021682466779436384 seconds, Model Size: 0.0034532546997070312 MB
Accuracy: 0.306, Loss: 1.524, Latency: 0.00022833687918526785 seconds, Model Size: 0.0034532546997070312 MB
Accuracy: 0.1143, Loss: 1.633, Latency: 0.00021420206342424666 seconds, Model Size: 0.0034532546997070312 MB
Accuracy: 0.2333, Loss: 1.551, Latency: 0.00020994458879743303 seconds, Model Size: 0.0034532546997070312 MB
Accuracy: 0.2315, Loss: 1.559, Latency: 0.0002005781446184431 seconds, Model Size: 0.0034532546997070312 MB
Accuracy: 0.2738, Lo