In [1]:
import sys
import logging
import os
from pathlib import Path
from pprint import pprint as pp

In [2]:
os.chdir('/home/honghaoyang/mase/machop')

In [3]:
from chop.dataset import MaseDataModule, get_dataset_info
from chop.tools.logger import set_logging_verbosity

from chop.passes.graph import (
    save_node_meta_param_interface_pass,
    report_node_meta_param_analysis_pass,
    profile_statistics_analysis_pass,
    add_common_metadata_analysis_pass,
    init_metadata_analysis_pass,
    add_software_metadata_analysis_pass,
)
from chop.tools.get_input import InputGenerator
from chop.tools.checkpoint_load import load_model
from chop.ir import MaseGraph

from chop.models import get_model_info, get_model

set_logging_verbosity("info")

  from .autonotebook import tqdm as notebook_tqdm
[32mINFO    [0m [34mSet logging level to info[0m


In [4]:
batch_size = 8
model_name = "jsc-tiny"
dataset_name = "jsc"


data_module = MaseDataModule(
    name=dataset_name,
    batch_size=batch_size,
    model_name=model_name,
    num_workers=0,
)
data_module.prepare_data()
data_module.setup()

In [5]:
CHECKPOINT_PATH = "../mase_output/jsc-tiny_classification_jsc_2024-01-29/software/training_ckpts/best.ckpt"
model_info = get_model_info(model_name)
model = get_model(
    model_name,
    task="cls",
    dataset_info=data_module.dataset_info,
    pretrained=False)

model = load_model(load_name=CHECKPOINT_PATH, load_type="pl", model=model)

[32mINFO    [0m [34mLoaded pytorch lightning checkpoint from ../mase_output/jsc-tiny_classification_jsc_2024-01-29/software/training_ckpts/best.ckpt[0m


In [6]:
# get the input generator
input_generator = InputGenerator(
    data_module=data_module,
    model_info=model_info,
    task="cls",
    which_dataloader="train",
)

# a demonstration of how to feed an input value to the model
dummy_in = next(iter(input_generator))
_ = model(**dummy_in)

In [7]:
# generate the mase graph and initialize node metadata
mg = MaseGraph(model=model)
mg, _ = init_metadata_analysis_pass(mg, None)
mg, _ = add_common_metadata_analysis_pass(mg, {"dummy_in": dummy_in})
mg, _ = add_software_metadata_analysis_pass(mg, None)

pass_args = {
    "by": "type",                                                            # collect statistics by node name
    "target_weight_nodes": ["linear"],                                       # collect weight statistics for linear layers
    "target_activation_nodes": ["relu"],                                     # collect activation statistics for relu layers
    "weight_statistics": {
        "variance_precise": {"device": "cpu", "dims": "all"},                # collect precise variance of the weight
    },
    "activation_statistics": {
        "range_quantile": {"device": "cpu", "dims": "all", "quantile": 0.97} # collect 97% quantile of the activation range
    },
    "input_generator": input_generator,                                      # the input generator for feeding data to the model
    "num_samples": 32,                                                       # feed 32 samples to the model
}

mg, _ = profile_statistics_analysis_pass(mg, pass_args)
mg, _ = report_node_meta_param_analysis_pass(mg, {"which": ("software",)})
mg

Profiling weight statistics: 100%|██████████| 6/6 [00:00<00:00, 11310.48it/s]
Profiling act statistics: 100%|██████████| 4/4 [00:00<00:00, 578.48it/s]
[32mINFO    [0m [34mInspecting graph [add_common_meta_param_analysis_pass][0m
[32mINFO    [0m [34m
+--------------+--------------+---------------------+--------------+-----------------------------------------------------------------------------------------+
| Node name    | Fx Node op   | Mase type           | Mase op      | Software Param                                                                          |
| x            | placeholder  | placeholder         | placeholder  | {'results': {'data_out_0': {'stat': {}}}}                                               |
+--------------+--------------+---------------------+--------------+-----------------------------------------------------------------------------------------+
| seq_blocks_0 | call_module  | module              | batch_norm1d | {'args': {'bias': {'stat': {}},       

<chop.ir.graph.mase_graph.MaseGraph at 0x7f7a87e06200>

In [8]:
from chop.passes.graph.transforms import (
    quantize_transform_pass,
    summarize_quantization_analysis_pass,
)
from chop.ir.graph.mase_graph import MaseGraph

pass_args = {
    "by": "type",
    "default": {"config": {"name": None}},
    "linear": {
        "config": {
            "name": "integer",
            # data
            "data_in_width": 8,
            "data_in_frac_width": 4,
            # weight
            "weight_width": 8,
            "weight_frac_width": 4,
            # bias
            "bias_width": 8,
            "bias_frac_width": 4,
        }
    },
}

ori_mg = MaseGraph(model=model)
ori_mg, _ = init_metadata_analysis_pass(ori_mg, None)
ori_mg, _ = add_common_metadata_analysis_pass(ori_mg, {"dummy_in": dummy_in})

mg, _ = quantize_transform_pass(mg, pass_args)

In [9]:
### traverse mg
for node in mg.fx_graph.nodes:
    ### seq_blocks_2 is with linear op
    if node.name == "seq_blocks_2":
        print(node.meta["mase"].parameters["common"])


{'mase_type': 'module_related_func', 'mase_op': 'linear', 'args': {'data_in_0': {'shape': [8, 16], 'torch_dtype': torch.float32, 'type': 'integer', 'precision': [8, 4], 'value': tensor([[0.3986, 0.3627, 2.6615, 3.5014, 2.4787, 1.8169, 0.0690, 0.0000, 0.0953,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 2.6859, 0.7919],
        [1.4392, 0.4052, 0.0000, 0.0000, 0.0000, 0.0000, 0.2687, 0.2495, 0.2791,
         1.3249, 1.9455, 1.2491, 0.2827, 0.7280, 0.0000, 0.0000],
        [0.0000, 0.6823, 0.0000, 0.0000, 0.0000, 0.0000, 2.0834, 2.2368, 1.9493,
         1.2822, 1.2004, 0.0262, 0.4339, 0.5414, 0.0000, 2.8631],
        [0.3218, 1.8367, 1.1289, 0.6413, 2.0010, 2.9984, 1.5374, 1.8358, 1.4468,
         0.0000, 0.0000, 0.0000, 0.0000, 0.4484, 1.4463, 0.0000],
        [0.0000, 0.8694, 0.0000, 0.0000, 0.0000, 0.0000, 0.6381, 0.5723, 0.6191,
         1.7220, 2.0773, 1.4594, 0.6798, 1.3900, 0.0000, 0.2469],
        [0.0000, 1.7612, 2.5714, 2.4181, 1.1607, 0.7127, 0.0000, 0.0000, 0.0000,
      

In [10]:
### generate original mg and traverse
ori_mg = MaseGraph(model=model)
ori_mg, _ = init_metadata_analysis_pass(ori_mg, None)
ori_mg, _ = add_common_metadata_analysis_pass(ori_mg, {"dummy_in": dummy_in})

### traverse ori_mg
for node in ori_mg.fx_graph.nodes:
    ### print nodes with linear op
    if node.name == "seq_blocks_2":
        ori_mg_weights = node.meta["mase"].parameters["common"]["args"]["weight"]["value"]
        print(node.meta["mase"].parameters["common"])

{'mase_type': 'module_related_func', 'mase_op': 'linear', 'args': {'data_in_0': {'shape': [8, 16], 'torch_dtype': torch.float32, 'type': 'float', 'precision': [32], 'value': tensor([[0.3986, 0.3627, 2.6615, 3.5014, 2.4787, 1.8169, 0.0690, 0.0000, 0.0953,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 2.6859, 0.7919],
        [1.4392, 0.4052, 0.0000, 0.0000, 0.0000, 0.0000, 0.2687, 0.2495, 0.2791,
         1.3249, 1.9455, 1.2491, 0.2827, 0.7280, 0.0000, 0.0000],
        [0.0000, 0.6823, 0.0000, 0.0000, 0.0000, 0.0000, 2.0834, 2.2368, 1.9493,
         1.2822, 1.2004, 0.0262, 0.4339, 0.5414, 0.0000, 2.8631],
        [0.3218, 1.8367, 1.1289, 0.6413, 2.0010, 2.9984, 1.5374, 1.8358, 1.4468,
         0.0000, 0.0000, 0.0000, 0.0000, 0.4484, 1.4463, 0.0000],
        [0.0000, 0.8694, 0.0000, 0.0000, 0.0000, 0.0000, 0.6381, 0.5723, 0.6191,
         1.7220, 2.0773, 1.4594, 0.6798, 1.3900, 0.0000, 0.2469],
        [0.0000, 1.7612, 2.5714, 2.4181, 1.1607, 0.7127, 0.0000, 0.0000, 0.0000,
         0

In [11]:
### optional task
from chop.passes.graph.analysis.hhy_lab_pass import flop_count

mg_flop, total_flop  = flop_count.count_flops(mg)

print(total_flop)

{'total_flops': 680}


In [12]:
mg_bit, total_bit = flop_count.count_bitops(mg)
print(total_bit)

AttributeError: 'Node' object has no attribute 'mase'

In [None]:
ori_mg_flop, ori_total_flop  = flop_count.count_flops(ori_mg)

print(ori_total_flop)

{'total_flops': 1320.0}


In [None]:
ori_mg_bit, ori_total_bit = flop_count.count_bitops(ori_mg)
print(ori_total_bit)

{'data_avg_bit': 32.0, 'w_avg_bit': 32.0, 'data_overall_bit': 4096, 'w_overall_bit': 2560}


In [12]:
### perform quantisation to the test-hhy jsc model (bigger jsc network)

batch_size = 8
model_name = "test-hhy"
dataset_name = "jsc"

data_module = MaseDataModule(
    name=dataset_name,
    batch_size=batch_size,
    model_name=model_name,
    num_workers=0,
)
data_module.prepare_data()
data_module.setup()

CHECKPOINT_PATH = "../mase_output/test-hhy_classification_jsc_2024-01-30/software/training_ckpts/best.ckpt"
model_info = get_model_info(model_name)
model_test = get_model(
    model_name,
    task="cls",
    dataset_info=data_module.dataset_info,
    pretrained=False)

model_test = load_model(load_name=CHECKPOINT_PATH, load_type="pl", model=model_test)

[32mINFO    [0m [34mLoaded pytorch lightning checkpoint from ../mase_output/test-hhy_classification_jsc_2024-01-30/software/training_ckpts/best.ckpt[0m


In [13]:
# get the input generator
input_generator = InputGenerator(
    data_module=data_module,
    model_info=model_info,
    task="cls",
    which_dataloader="train",
)

# a demonstration of how to feed an input value to the model
dummy_in = next(iter(input_generator))
_ = model_test(**dummy_in)

In [14]:
# generate the mase graph and initialize node metadata
mg = MaseGraph(model=model_test)
mg, _ = init_metadata_analysis_pass(mg, None)
mg, _ = add_common_metadata_analysis_pass(mg, {"dummy_in": dummy_in})
mg, _ = add_software_metadata_analysis_pass(mg, None)

pass_args = {
    "by": "type",                                                            # collect statistics by node name
    "target_weight_nodes": ["linear"],                       # collect weight statistics for linear layers
    "target_activation_nodes": ["relu"],                                     # collect activation statistics for relu layers
    "weight_statistics": {
        "variance_precise": {"device": "cpu", "dims": "all"},                # collect precise variance of the weight
    },
    "activation_statistics": {
        "range_quantile": {"device": "cpu", "dims": "all", "quantile": 0.97} # collect 97% quantile of the activation range
    },
    "input_generator": input_generator,                                      # the input generator for feeding data to the model
    "num_samples": 32,                                                       # feed 32 samples to the model
}

mg, _ = profile_statistics_analysis_pass(mg, pass_args)
mg, _ = report_node_meta_param_analysis_pass(mg, {"which": ("software",)})
mg

Profiling weight statistics: 100%|██████████| 16/16 [00:00<00:00, 16476.52it/s]
Profiling act statistics: 100%|██████████| 4/4 [00:00<00:00, 274.91it/s]
[32mINFO    [0m [34mInspecting graph [add_common_meta_param_analysis_pass][0m
[32mINFO    [0m [34m
+---------------+--------------+---------------------+--------------+------------------------------------------------------------------------------------------+
| Node name     | Fx Node op   | Mase type           | Mase op      | Software Param                                                                           |
| x             | placeholder  | placeholder         | placeholder  | {'results': {'data_out_0': {'stat': {}}}}                                                |
+---------------+--------------+---------------------+--------------+------------------------------------------------------------------------------------------+
| seq_blocks_0  | call_module  | module              | batch_norm1d | {'args': {'bias': {'stat': 

<chop.ir.graph.mase_graph.MaseGraph at 0x7f4fad1f7f10>

In [15]:
from chop.passes.graph.transforms import (
    quantize_transform_pass,
    summarize_quantization_analysis_pass,
)

pass_args = {
    "by": "type",
    "default": {"config": {"name": None}},
    "linear": {
        "config": {
            "name": "integer",
            # data
            "data_in_width": 8,
            "data_in_frac_width": 4,
            # weight
            "weight_width": 8,
            "weight_frac_width": 4,
            # bias
            "bias_width": 8,
            "bias_frac_width": 4,
        }
    },
}

ori_mg = MaseGraph(model=model_test)
ori_mg, _ = init_metadata_analysis_pass(ori_mg, None)
ori_mg, _ = add_common_metadata_analysis_pass(ori_mg, {"dummy_in": dummy_in})

mg, _ = quantize_transform_pass(mg, pass_args)

In [43]:
### test quantization

test_x = iter(data_module.val_dataloader())
xs, ys = next(test_x)


# for node in mg.fx_graph.nodes:
#     ### print nodes with linear op
#     print(node.meta["mase"].model(xs))
#     break


for node in mg.fx_graph.nodes:
    ### print nodes with linear op
    # print(node.op)
    print(node.meta["mase"].module)
    # print(node.meta["mase"].parameters["common"])
    # break

None
BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
ReLU(inplace=True)
LinearInteger(in_features=16, out_features=32, bias=True)
BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
ReLU(inplace=True)
LinearInteger(in_features=32, out_features=16, bias=True)
BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
ReLU(inplace=True)
LinearInteger(in_features=16, out_features=8, bias=True)
BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
ReLU(inplace=True)
LinearInteger(in_features=8, out_features=5, bias=True)
BatchNorm1d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
ReLU(inplace=True)
None


In [17]:
for node in ori_mg.fx_graph.nodes:
    ### print nodes with linear op
    print(node.meta["mase"].model(xs))
    break

tensor([[0.1210, 0.0000, 0.0000, 0.0000, 3.6378],
        [0.0000, 0.0000, 2.1936, 1.7970, 0.0664],
        [0.2963, 1.8251, 0.0000, 0.0000, 0.0000],
        [0.9156, 0.1657, 0.0000, 1.4720, 0.8020],
        [3.5517, 1.0728, 0.0000, 0.0000, 0.0000],
        [0.4814, 1.9985, 0.4557, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.4554, 2.5201, 2.1980],
        [0.0000, 0.0000, 3.3599, 1.5023, 0.0000]], grad_fn=<ReluBackward0>)


In [None]:
### Optional task
