In [1]:
import sys
import logging
import os
from pathlib import Path
from pprint import pprint as pp

In [2]:
import os
os.chdir('/home/honghaoyang/mase_hhy/machop')

In [3]:
from chop.dataset import MaseDataModule, get_dataset_info
from chop.tools.logger import set_logging_verbosity

from chop.passes.graph import (
    save_node_meta_param_interface_pass,
    report_node_meta_param_analysis_pass,
    profile_statistics_analysis_pass,
    add_common_metadata_analysis_pass,
    init_metadata_analysis_pass,
    add_software_metadata_analysis_pass,
)
from chop.tools.get_input import InputGenerator
from chop.tools.checkpoint_load import load_model
from chop.ir import MaseGraph

from chop.models import get_model_info, get_model

set_logging_verbosity("info")

  from .autonotebook import tqdm as notebook_tqdm
2024-01-30 22:24:54.325610: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
[32mINFO    [0m [34mSet logging level to info[0m


In [4]:
batch_size = 8
model_name = "jsc-tiny"
dataset_name = "jsc"


data_module = MaseDataModule(
    name=dataset_name,
    batch_size=batch_size,
    model_name=model_name,
    num_workers=0,
)
data_module.prepare_data()
data_module.setup()

In [5]:
CHECKPOINT_PATH = "../mase_output/jsc-tiny_classification_jsc_2024-01-29/software/training_ckpts/best.ckpt"
model_info = get_model_info(model_name)
model = get_model(
    model_name,
    task="cls",
    dataset_info=data_module.dataset_info,
    pretrained=False)

model = load_model(load_name=CHECKPOINT_PATH, load_type="pl", model=model)

[32mINFO    [0m [34mLoaded pytorch lightning checkpoint from ../mase_output/jsc-tiny_classification_jsc_2024-01-29/software/training_ckpts/best.ckpt[0m


In [6]:
# get the input generator
input_generator = InputGenerator(
    data_module=data_module,
    model_info=model_info,
    task="cls",
    which_dataloader="train",
)

# a demonstration of how to feed an input value to the model
dummy_in = next(iter(input_generator))
_ = model(**dummy_in)

In [7]:
# generate the mase graph and initialize node metadata
mg = MaseGraph(model=model)
mg, _ = init_metadata_analysis_pass(mg, None)
mg, _ = add_common_metadata_analysis_pass(mg, {"dummy_in": dummy_in})
mg, _ = add_software_metadata_analysis_pass(mg, None)

pass_args = {
    "by": "type",                                                            # collect statistics by node name
    "target_weight_nodes": ["linear"],                                       # collect weight statistics for linear layers
    "target_activation_nodes": ["relu"],                                     # collect activation statistics for relu layers
    "weight_statistics": {
        "variance_precise": {"device": "cpu", "dims": "all"},                # collect precise variance of the weight
    },
    "activation_statistics": {
        "range_quantile": {"device": "cpu", "dims": "all", "quantile": 0.97} # collect 97% quantile of the activation range
    },
    "input_generator": input_generator,                                      # the input generator for feeding data to the model
    "num_samples": 32,                                                       # feed 32 samples to the model
}

mg, _ = profile_statistics_analysis_pass(mg, pass_args)
mg, _ = report_node_meta_param_analysis_pass(mg, {"which": ("software",)})
mg

Profiling weight statistics: 100%|██████████| 6/6 [00:00<00:00, 8450.58it/s]
Profiling act statistics: 100%|██████████| 4/4 [00:00<00:00, 660.75it/s]
[32mINFO    [0m [34mInspecting graph [add_common_meta_param_analysis_pass][0m
[32mINFO    [0m [34m
+--------------+--------------+---------------------+--------------+-----------------------------------------------------------------------------------------+
| Node name    | Fx Node op   | Mase type           | Mase op      | Software Param                                                                          |
| x            | placeholder  | placeholder         | placeholder  | {'results': {'data_out_0': {'stat': {}}}}                                               |
+--------------+--------------+---------------------+--------------+-----------------------------------------------------------------------------------------+
| seq_blocks_0 | call_module  | module              | batch_norm1d | {'args': {'bias': {'stat': {}},        

<chop.ir.graph.mase_graph.MaseGraph at 0x7f1c59b03640>

In [8]:
from chop.passes.graph.transforms import (
    quantize_transform_pass,
    summarize_quantization_analysis_pass,
)
from chop.ir.graph.mase_graph import MaseGraph

pass_args = {
    "by": "type",
    "default": {"config": {"name": None}},
    "linear": {
        "config": {
            "name": "integer",
            # data
            "data_in_width": 8,
            "data_in_frac_width": 4,
            # weight
            "weight_width": 8,
            "weight_frac_width": 4,
            # bias
            "bias_width": 8,
            "bias_frac_width": 4,
        }
    },
}

ori_mg = MaseGraph(model=model)
ori_mg, _ = init_metadata_analysis_pass(ori_mg, None)
ori_mg, _ = add_common_metadata_analysis_pass(ori_mg, {"dummy_in": dummy_in})

mg, _ = quantize_transform_pass(mg, pass_args)

In [23]:
### traverse mg
for node in mg.fx_graph.nodes:
    ### print nodes with linear op
    if node.name == "seq_blocks_2":
        print(node.meta["mase"].parameters["common"])

{'mase_type': 'module_related_func', 'mase_op': 'linear', 'args': {'data_in_0': {'shape': [8, 16], 'torch_dtype': torch.float32, 'type': 'integer', 'precision': [8, 4], 'value': tensor([[2.9449, 0.3909, 0.0518, 0.0000, 0.2384, 0.0000, 2.0124, 0.0000, 1.8840,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7085, 0.0000],
        [0.2402, 1.3345, 0.9587, 0.4929, 0.3547, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.1844, 0.0000],
        [0.0000, 1.8296, 2.7684, 3.4358, 2.1908, 1.8690, 0.3045, 0.1490, 0.3121,
         0.4894, 0.7254, 0.0000, 0.0000, 0.0729, 2.4858, 2.2498],
        [0.0000, 1.6255, 0.0000, 0.0000, 0.0000, 0.0000, 1.1004, 1.5476, 1.0446,
         1.4316, 1.6664, 0.5808, 0.5228, 0.7410, 0.0000, 0.0000],
        [0.0000, 1.1303, 1.5073, 1.0451, 0.3286, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.1829, 1.8319],
        [0.0000, 0.4226, 2.2974, 2.2002, 2.4752, 3.1566, 1.5582, 2.2048, 1.4660,
      

In [24]:
### generate original mg and traverse
ori_mg = MaseGraph(model=model)
ori_mg, _ = init_metadata_analysis_pass(ori_mg, None)
ori_mg, _ = add_common_metadata_analysis_pass(ori_mg, {"dummy_in": dummy_in})

### traverse ori_mg
for node in ori_mg.fx_graph.nodes:
    ### print nodes with linear op
    if node.name == "seq_blocks_2":
        print(node.meta["mase"].parameters["common"])

{'mase_type': 'module_related_func', 'mase_op': 'linear', 'args': {'data_in_0': {'shape': [8, 16], 'torch_dtype': torch.float32, 'type': 'float', 'precision': [32], 'value': tensor([[2.9449, 0.3909, 0.0518, 0.0000, 0.2384, 0.0000, 2.0124, 0.0000, 1.8840,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7085, 0.0000],
        [0.2402, 1.3345, 0.9587, 0.4929, 0.3547, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.1844, 0.0000],
        [0.0000, 1.8296, 2.7684, 3.4358, 2.1908, 1.8690, 0.3045, 0.1490, 0.3121,
         0.4894, 0.7254, 0.0000, 0.0000, 0.0729, 2.4858, 2.2498],
        [0.0000, 1.6255, 0.0000, 0.0000, 0.0000, 0.0000, 1.1004, 1.5476, 1.0446,
         1.4316, 1.6664, 0.5808, 0.5228, 0.7410, 0.0000, 0.0000],
        [0.0000, 1.1303, 1.5073, 1.0451, 0.3286, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.1829, 1.8319],
        [0.0000, 0.4226, 2.2974, 2.2002, 2.4752, 3.1566, 1.5582, 2.2048, 1.4660,
         1

In [25]:
### perform quantisation to the test-hhy jsc model (bigger jsc network)

batch_size = 8
model_name = "test-hhy"
dataset_name = "jsc"

data_module = MaseDataModule(
    name=dataset_name,
    batch_size=batch_size,
    model_name=model_name,
    num_workers=0,
)
data_module.prepare_data()
data_module.setup()

CHECKPOINT_PATH = "../mase_output/test-hhy_classification_jsc_2024-01-30/software/training_ckpts/best.ckpt"
model_info = get_model_info(model_name)
model_test = get_model(
    model_name,
    task="cls",
    dataset_info=data_module.dataset_info,
    pretrained=False)

model_test = load_model(load_name=CHECKPOINT_PATH, load_type="pl", model=model_test)

[32mINFO    [0m [34mLoaded pytorch lightning checkpoint from ../mase_output/test-hhy_classification_jsc_2024-01-30/software/training_ckpts/best.ckpt[0m


In [26]:
# get the input generator
input_generator = InputGenerator(
    data_module=data_module,
    model_info=model_info,
    task="cls",
    which_dataloader="train",
)

# a demonstration of how to feed an input value to the model
dummy_in = next(iter(input_generator))
_ = model_test(**dummy_in)

In [39]:
# generate the mase graph and initialize node metadata
mg = MaseGraph(model=model_test)
mg, _ = init_metadata_analysis_pass(mg, None)
mg, _ = add_common_metadata_analysis_pass(mg, {"dummy_in": dummy_in})
mg, _ = add_software_metadata_analysis_pass(mg, None)

pass_args = {
    "by": "type",                                                            # collect statistics by node name
    "target_weight_nodes": ["linear"],                       # collect weight statistics for linear layers
    "target_activation_nodes": ["relu"],                                     # collect activation statistics for relu layers
    "weight_statistics": {
        "variance_precise": {"device": "cpu", "dims": "all"},                # collect precise variance of the weight
    },
    "activation_statistics": {
        "range_quantile": {"device": "cpu", "dims": "all", "quantile": 0.97} # collect 97% quantile of the activation range
    },
    "input_generator": input_generator,                                      # the input generator for feeding data to the model
    "num_samples": 32,                                                       # feed 32 samples to the model
}

mg, _ = profile_statistics_analysis_pass(mg, pass_args)
mg, _ = report_node_meta_param_analysis_pass(mg, {"which": ("software",)})
mg

Profiling weight statistics: 100%|██████████| 16/16 [00:00<00:00, 30671.33it/s]
Profiling act statistics: 100%|██████████| 4/4 [00:00<00:00, 135.37it/s]
[32mINFO    [0m [34mInspecting graph [add_common_meta_param_analysis_pass][0m
[32mINFO    [0m [34m
+---------------+--------------+---------------------+--------------+--------------------------------------------------------------------------------------+
| Node name     | Fx Node op   | Mase type           | Mase op      | Software Param                                                                       |
| x             | placeholder  | placeholder         | placeholder  | {'results': {'data_out_0': {'stat': {}}}}                                            |
+---------------+--------------+---------------------+--------------+--------------------------------------------------------------------------------------+
| seq_blocks_0  | call_module  | module              | batch_norm1d | {'args': {'bias': {'stat': {}},            

<chop.ir.graph.mase_graph.MaseGraph at 0x7f1c50380160>

In [40]:
pass_args = {
    "by": "type",
    "default": {"config": {"name": None}},
    "linear": {
        "config": {
            "name": "integer",
            # data
            "data_in_width": 8,
            "data_in_frac_width": 4,
            # weight
            "weight_width": 8,
            "weight_frac_width": 4,
            # bias
            "bias_width": 8,
            "bias_frac_width": 4,
        }
    },
}

ori_mg = MaseGraph(model=model_test)
ori_mg, _ = init_metadata_analysis_pass(ori_mg, None)
ori_mg, _ = add_common_metadata_analysis_pass(ori_mg, {"dummy_in": dummy_in})

mg, _ = quantize_transform_pass(mg, pass_args)

In [41]:
### traverse mg
for node in mg.fx_graph.nodes:
    ### print nodes with linear op
    print(node.meta["mase"].parameters["common"])

{'mase_type': 'placeholder', 'mase_op': 'placeholder', 'args': {}, 'results': {'data_out_0': {'type': 'float', 'precision': [32], 'shape': [8, 16], 'torch_dtype': torch.float32, 'value': tensor([[ 1.2441e+00, -1.1295e+00, -1.2436e+00, -8.4918e-01, -1.1469e+00,
         -6.2866e-01, -4.7122e-01, -5.7831e-01, -4.7122e-01, -3.0468e-01,
         -2.5359e-01, -2.2359e-01, -4.7123e-01, -4.9353e-01, -1.3168e+00,
         -9.3992e-01],
        [ 1.0926e+00, -1.0636e+00, -4.8878e-01, -4.7233e-01, -5.5337e-02,
         -7.1181e-02,  1.2746e+00,  1.2222e+00,  1.2746e+00, -9.4655e-01,
         -1.3010e+00, -1.0520e+00, -6.1425e-01, -5.6063e-01,  1.2413e-02,
          2.1765e-01],
        [-6.6305e-01,  6.6685e-01,  2.1318e+00,  2.1168e+00,  1.7596e+00,
          1.7749e+00, -3.0629e-01, -1.1556e-03, -3.0629e-01,  7.6382e-01,
          1.0673e+00,  1.6796e+00,  8.4089e-01,  1.4240e+00,  1.7242e+00,
          3.6234e-01],
        [-5.7262e-02,  2.5935e-01,  5.9235e-02, -1.1852e-01, -9.7287e-02,
    