# Installing MASE (again)

Run the block below to install MASE in the current Colab runtime

In [None]:
git_token = ""
short_code = "jaredjoss"

# Check the current python version (It should be using Python 3.10) and update pip to the latest version.
!python --version
!python -m pip install --user --upgrade pip

# Clone MASE from your branch (the branch must already exist)
!git clone -b lab1_{short_code} https://{git_token}@github.com/DeepWok/mase.git

# Install requirements
!python -m pip install -r ./mase/machop/requirements.txt

# Change working directory to machop
%cd ./mase/machop/

In [1]:
%cd ../../machop/

/Users/jared/Documents/Personal/ICL/courses/ADLS/mase/machop


In [2]:
!./ch --help

usage: ch [--config PATH] [--task TASK] [--load PATH] [--load-type]
          [--batch-size NUM] [--debug] [--log-level]
          [--report-to {wandb,tensorboard}] [--seed NUM] [--quant-config TOML]
          [--training-optimizer TYPE] [--trainer-precision TYPE]
          [--learning-rate NUM] [--weight-decay NUM] [--max-epochs NUM]
          [--max-steps NUM] [--accumulate-grad-batches NUM]
          [--log-every-n-steps NUM] [--cpu NUM] [--gpu NUM] [--nodes NUM]
          [--accelerator TYPE] [--strategy TYPE] [--auto-requeue]
          [--github-ci] [--disable-dataset-cache] [--target STR]
          [--num-targets NUM] [--pretrained] [--max-token-len NUM]
          [--project-dir DIR] [--project NAME] [-h] [-V] [--info [TYPE]]
          action [model] [dataset]

Chop is a simple utility, part of the MASE tookit, to train, test and
transform (i.e. prune or quantise) a supported model.

main arguments:
  action                action to perform. One of
                        (train|

# General introduction

In this lab, you will learn how to use the software stack of MASE. There are in total 7 tasks you would need to finish, and 1 optional task.

# Turning you network to a graph

One specific feature of MASE is its capability to transform DL models to a computation graph using the [torch.fx](<https://pytorch.org/docs/stable/fx.html>) framework.


## Use the Transform functionality without CLI

This tutorial describes how to use the MASE transform functionality for a pre-trained model.

## Import related packages and machop

In [22]:
import sys
import logging
import os
from pathlib import Path
from pprint import pprint as pp

In [23]:
from chop.dataset import MaseDataModule, get_dataset_info
from chop.tools.logger import set_logging_verbosity

from chop.passes.graph import (
    save_node_meta_param_interface_pass,
    report_node_meta_param_analysis_pass,
    profile_statistics_analysis_pass,
    add_common_metadata_analysis_pass,
    init_metadata_analysis_pass,
    add_software_metadata_analysis_pass,
)
from chop.tools.get_input import InputGenerator
from chop.tools.checkpoint_load import load_model
from chop.ir import MaseGraph

from chop.models import get_model_info, get_model

set_logging_verbosity("debug")

[32mINFO    [0m [34mSet logging level to debug[0m


## Set up the dataset

Here we create a `MaseDataModule` using the `jsc` dataset from lab1. Note the `MaseDataModule` also requires the name of the model you plan to use data module with. In this case it is `jsc-tiny`.

In [24]:
batch_size = 8
model_name = "jsc-tiny"
dataset_name = "jsc"


data_module = MaseDataModule(
    name=dataset_name,
    batch_size=batch_size,
    model_name=model_name,
    num_workers=0,
)
data_module.prepare_data()
data_module.setup()


## Set up the model 

Here we use the previously trained `jsc-tiny` model in lab 1 as an example.

In [25]:
# # If you stored your model checkpoint on Google Drive, remember to mount the drive to the current runtime in order to access it
# from google.colab import drive
# drive.mount('/content/drive')

In [26]:
# 📝️ change this CHECKPOINT_PATH to the one you trained in Lab1
CHECKPOINT_PATH = "/Users/jared/Documents/Personal/ICL/courses/ADLS/mase/mase_output/jsc-tiny_classification_jsc_2024-01-23/software/training_ckpts/best-v5.ckpt"
model_info = get_model_info(model_name)
model = get_model(
    model_name,
    task="cls",
    dataset_info=data_module.dataset_info,
    pretrained=False)

model = load_model(load_name=CHECKPOINT_PATH, load_type="pl", model=model)

[32mINFO    [0m [34mLoaded pytorch lightning checkpoint from /Users/jared/Documents/Personal/ICL/courses/ADLS/mase/mase_output/jsc-tiny_classification_jsc_2024-01-23/software/training_ckpts/best-v5.ckpt[0m


# Get a dummy data in
With the dataset module and model information, we can grab an input generator.

In [27]:
# get the input generator
input_generator = InputGenerator(
    data_module=data_module,
    model_info=model_info,
    task="cls",
    which_dataloader="train",
)

# a demonstration of how to feed an input value to the model
dummy_in = next(iter(input_generator))
_ = model(**dummy_in)


## Generate a MaseGraph
We have two forms of passes: transform passes and analysis passes, both of them would require the model to be transferred into a MaseGraph to allow manipulation.

In [28]:
# generate the mase graph and initialize node metadata
mg = MaseGraph(model=model)

## Running an Analysis pass
Analysis pass DOES NOT change the graph

The following analysis passes are essential to prepare the graph for other passes

In [29]:
mg, _ = init_metadata_analysis_pass(mg, None)
mg, _ = add_common_metadata_analysis_pass(mg, {"dummy_in": dummy_in})
mg, _ = add_software_metadata_analysis_pass(mg, None)

[36mDEBUG   [0m [34mgraph():
    %x : [num_users=1] = placeholder[target=x]
    %seq_blocks_0 : [num_users=1] = call_module[target=seq_blocks.0](args = (%x,), kwargs = {})
    %seq_blocks_1 : [num_users=1] = call_module[target=seq_blocks.1](args = (%seq_blocks_0,), kwargs = {})
    %seq_blocks_2 : [num_users=1] = call_module[target=seq_blocks.2](args = (%seq_blocks_1,), kwargs = {})
    %seq_blocks_3 : [num_users=1] = call_module[target=seq_blocks.3](args = (%seq_blocks_2,), kwargs = {})
    return seq_blocks_3[0m


In [30]:
mg

<chop.ir.graph.mase_graph.MaseGraph at 0x2a70f48e0>

We will first run a simple graph analysis to understand the structure of the model.

In [31]:
# report graph is an analysis pass that shows you the detailed information in the graph
from chop.passes.graph import report_graph_analysis_pass
_ = report_graph_analysis_pass(mg)

graph():
    %x : [num_users=1] = placeholder[target=x]
    %seq_blocks_0 : [num_users=1] = call_module[target=seq_blocks.0](args = (%x,), kwargs = {})
    %seq_blocks_1 : [num_users=1] = call_module[target=seq_blocks.1](args = (%seq_blocks_0,), kwargs = {})
    %seq_blocks_2 : [num_users=1] = call_module[target=seq_blocks.2](args = (%seq_blocks_1,), kwargs = {})
    %seq_blocks_3 : [num_users=1] = call_module[target=seq_blocks.3](args = (%seq_blocks_2,), kwargs = {})
    return seq_blocks_3
Network overview:
{'placeholder': 1, 'get_attr': 0, 'call_function': 0, 'call_method': 0, 'call_module': 4, 'output': 1}
Layer types:
[BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), ReLU(inplace=True), Linear(in_features=16, out_features=5, bias=True), ReLU(inplace=True)]


## Running another Analysis pass: Profile statistics

The pass `profile_statistics_analysis_pass` collects statistics of parameters and activations, and save them to node's metadata.

Here is a list of all the supported statistics. Refer to the `__init__` of statistic classes in `chop.passes.analysis.statistical_profiler.stat` to check the args each stat class takes.

This is a more complex analysis than the previous pass, and thus it would require you to pass in additional arguments for this pass.

### Example: the range of weights & input activations of nodes

Say we want to collect the tensor-wise min-max range of the 1st `torch.nn.Linear` nodes' weights & bias, and the channel-wise 97% quantile min-max of the 1st `torch.nn.Linear` nodes' input activations. We can do the following:

In [32]:
pass_args = {
    "by": "type",                                                            # collect statistics by node name
    "target_weight_nodes": ["linear"],                                       # collect weight statistics for linear layers
    "target_activation_nodes": ["relu"],                                     # collect activation statistics for relu layers
    "weight_statistics": {
        "variance_precise": {"device": "cpu", "dims": "all"},                # collect precise variance of the weight
    },
    "activation_statistics": {
        "range_quantile": {"device": "cpu", "dims": "all", "quantile": 0.97} # collect 97% quantile of the activation range
    },
    "input_generator": input_generator,                                      # the input generator for feeding data to the model
    "num_samples": 32,                                                       # feed 32 samples to the model
}

We can use the `report_node_meta_param_analysis_pass` to inspect the collected statistics.

In [33]:
mg, _ = profile_statistics_analysis_pass(mg, pass_args)

Profiling weight statistics: 100%|██████████| 6/6 [00:00<00:00, 12169.16it/s]
Profiling act statistics: 100%|██████████| 4/4 [00:00<00:00, 705.93it/s]


In [34]:
mg, _ = report_node_meta_param_analysis_pass(mg, {"which": ("software",)})

[32mINFO    [0m [34mInspecting graph [add_common_meta_param_analysis_pass][0m
[32mINFO    [0m [34m
+--------------+--------------+---------------------+--------------+----------------------------------------------------------------------------------------+
| Node name    | Fx Node op   | Mase type           | Mase op      | Software Param                                                                         |
| x            | placeholder  | placeholder         | placeholder  | {'results': {'data_out_0': {'stat': {}}}}                                              |
+--------------+--------------+---------------------+--------------+----------------------------------------------------------------------------------------+
| seq_blocks_0 | call_module  | module              | batch_norm1d | {'args': {'bias': {'stat': {}},                                                        |
|              |              |                     |              |           'data_in_0': {'stat': {}}

## Running a Transform pass: Quantisation

As its name suggests, the transform pass would modify the `MaseGraph`.
Similar to the previous analysis pass example, we would need to first declare the configuration for the pass.

In [35]:
pass_args = {
    "by": "type",
    "default": {"config": {"name": None}},
    "linear": {
        "config": {
            "name": "integer",
            # data
            "data_in_width": 8,
            "data_in_frac_width": 4,
            # weight
            "weight_width": 8,
            "weight_frac_width": 4,
            # bias
            "bias_width": 8,
            "bias_frac_width": 4,
        }
    },
}

We can then proceed to apply the transformation, in this case, we kept the original graph on purpose, so that we can print a `diff`.

In [36]:
from chop.passes.graph.transforms import (
    quantize_transform_pass,
    summarize_quantization_analysis_pass,
)
from chop.ir.graph.mase_graph import MaseGraph


ori_mg = MaseGraph(model=model)
ori_mg, _ = init_metadata_analysis_pass(ori_mg, None)
ori_mg, _ = add_common_metadata_analysis_pass(ori_mg, {"dummy_in": dummy_in})

mg, _ = quantize_transform_pass(mg, pass_args)
summarize_quantization_analysis_pass(ori_mg, mg, save_dir="quantize_summary")

[36mDEBUG   [0m [34mgraph():
    %x : [num_users=1] = placeholder[target=x]
    %seq_blocks_0 : [num_users=1] = call_module[target=seq_blocks.0](args = (%x,), kwargs = {})
    %seq_blocks_1 : [num_users=1] = call_module[target=seq_blocks.1](args = (%seq_blocks_0,), kwargs = {})
    %seq_blocks_2 : [num_users=1] = call_module[target=seq_blocks.2](args = (%seq_blocks_1,), kwargs = {})
    %seq_blocks_3 : [num_users=1] = call_module[target=seq_blocks.3](args = (%seq_blocks_2,), kwargs = {})
    return seq_blocks_3[0m
[36mDEBUG   [0m [34mCompare nodes:[0m
[36mDEBUG   [0m [34m
| Ori name     | New name     | MASE_TYPE           | Mase_OP      | Original type   | Quantized type   | Changed   |
|--------------+--------------+---------------------+--------------+-----------------+------------------+-----------|
| x            | x            | placeholder         | placeholder  | x               | x                | False     |
| seq_blocks_0 | seq_blocks_0 | module              | ba

In [37]:
from chop.passes.graph.transforms.quantize.summary import graph_iterator_compare_nodes
graph_iterator_compare_nodes(ori_mg, mg)

[36mDEBUG   [0m [34mCompare nodes:[0m
[36mDEBUG   [0m [34m
| Ori name     | New name     | MASE_TYPE           | Mase_OP      | Original type   | Quantized type   | Changed   |
|--------------+--------------+---------------------+--------------+-----------------+------------------+-----------|
| x            | x            | placeholder         | placeholder  | x               | x                | False     |
| seq_blocks_0 | seq_blocks_0 | module              | batch_norm1d | BatchNorm1d     | BatchNorm1d      | False     |
| seq_blocks_1 | seq_blocks_1 | module_related_func | relu         | ReLU            | ReLU             | False     |
| seq_blocks_2 | seq_blocks_2 | module_related_func | linear       | Linear          | LinearInteger    | True      |
| seq_blocks_3 | seq_blocks_3 | module_related_func | relu         | ReLU            | ReLU             | False     |
| output       | output       | output              | output       | output          | output           | Fa

Unnamed: 0,Ori name,New name,MASE_TYPE,Mase_OP,Original type,Quantized type,Changed
0,x,x,placeholder,placeholder,x,x,False
1,seq_blocks_0,seq_blocks_0,module,batch_norm1d,BatchNorm1d,BatchNorm1d,False
2,seq_blocks_1,seq_blocks_1,module_related_func,relu,ReLU,ReLU,False
3,seq_blocks_2,seq_blocks_2,module_related_func,linear,Linear,LinearInteger,True
4,seq_blocks_3,seq_blocks_3,module_related_func,relu,ReLU,ReLU,False
5,output,output,output,output,output,output,False


In [61]:
from chop.ir.graph.mase_graph import MaseGraph
from chop.passes.graph.utils import get_mase_op, get_mase_type, get_node_actual_target
import torch

# iterate over original and modified graphs
for ori_n, n in zip(ori_mg.fx_graph.nodes, mg.fx_graph.nodes):
    # check if the original node and the modified node are the same
    # if they arent, then it means that that node has been quantized
    if type(get_node_actual_target(n)) != type(get_node_actual_target(ori_n)):
        # retrieve the original module from the node
        ori_module = get_node_actual_target(ori_n)
        # retrieve the quantized module from the node
        quant_module = get_node_actual_target(n)

        print(type(quant_module))

        print(f'Difference found at name: {n.name}, '
              f'MASE type: {get_mase_type(n)}, MASE operation: {get_mase_op(n)}\n'
              f'Original module: {type(ori_module)} --> '
              f'New module: {type(quant_module)}')
        
        # Get the precision and types of the weights of the nodes from their metadata
        mg_precision = n.meta["mase"].parameters["common"]["args"]["weight"]["precision"]
        ori_mg_precision = ori_n.meta["mase"].parameters["common"]["args"]["weight"]["precision"]

        mg_type = n.meta["mase"].parameters["common"]["args"]["weight"]["type"]
        ori_mg_type = ori_n.meta["mase"].parameters["common"]["args"]["weight"]["type"]

        print(f'Precision of original module: {ori_mg_precision}')
        print(f'Precision of modified module: {mg_precision}')

        # print the weights of the original and quantized modules
        print(f'Weight of original module: {ori_module.weight}')
        quantized_weights = quant_module.w_quantizer(ori_module.weight)
        print(f'Weight of quantized module: {quantized_weights}')

        # generate a test input tensor based on the input feature size of the quantized module
        test_input = torch.randn(quant_module.in_features)
        print(f'Random generated test input: {test_input}')

        # apply the original and quantized modules to the test input and print the outputs
        print(f'Output for original module: {ori_module(test_input)}')
        print(f'Output for quantized module: {quant_module(test_input)}')


<class 'chop.passes.graph.transforms.quantize.quantized_modules.linear.LinearInteger'>
Difference found at name: seq_blocks_2, MASE type: module_related_func, MASE operation: linear
Original module: <class 'torch.nn.modules.linear.Linear'> --> New module: <class 'chop.passes.graph.transforms.quantize.quantized_modules.linear.LinearInteger'>
Precision of original module: [32]
Precision of modified module: [8, 4]
Weight of original module: Parameter containing:
tensor([[-2.0059e-01,  7.0937e-02, -9.0220e-01, -3.3770e-01, -2.1713e-01,
         -4.4070e-01, -1.0396e-02, -2.6671e-02,  1.2821e-01,  5.0434e-01,
          6.0154e-02, -9.6376e-02, -1.2634e-01, -5.2315e-02,  1.1104e-01,
         -1.4309e-02],
        [-5.8988e-02,  7.6669e-02,  1.1914e-01,  1.6312e-01, -3.9841e-01,
          1.9455e-01,  1.9242e-02,  2.9598e-01,  1.4485e-01, -1.7565e-01,
          1.7875e-01, -2.9895e-01,  6.2455e-02, -2.7527e-01, -9.8624e-02,
          6.7748e-02],
        [-6.6721e-02,  5.2548e-02, -4.9106e-01



# Exercises:

We have now seen how to:
1. Set up a dataset
2. Set up a model
3. Generate a `MaseGraph` from the model
4. Run Analysis and Transform passes on the `MaseGraph`

Now consider the following problems:

1. Explain the functionality of `report_graph_analysis_pass` and its printed jargons such as `placeholder`, `get_attr` ... You might find the doc of [torch.fx](https://pytorch.org/docs/stable/fx.html) useful.

2. What are the functionalities of `profile_statistics_analysis_pass` and `report_node_meta_param_analysis_pass` respectively?

## MASE OPs and MASE Types

MASE is designed to be a very high-level intermediate representation (IR), this is very different from the classic [LLVM IR](https://llvm.org/docs/LangRef.html) that you might be familiar with.

The following MASE Types are available:
(Note from Aaron: do we have a page somewhere that have summarized this?)


## A deeper dive into the quantisation transform

3. Explain why only 1 OP is changed after the `quantize_transform_pass` .

4. Write some code to traverse both `mg` and `ori_mg`, check and comment on the nodes in these two graphs. You might find the source code for the implementation of `summarize_quantization_analysis_pass` useful.

5. Perform the same quantisation flow to the bigger JSC network that you have trained in lab1. You must be aware that now the `pass_args` for your custom network might be different if you have used more than the `Linear` layer in your network.

6. Write code to show and verify that the weights of these layers are indeed quantised. You might need to go through the source code of the implementation of the quantisation pass and also the implementation of the [Quantized Layers](../../machop/chop/passes/transforms/quantize/quantized_modules/linear.py) .

## The command line interface

The same flow can also be executed on the command line throw the `transform` action.

```bash
# make sure you have the same printout
pwd
# it should show
# your_dir/mase-tools/machop

# enter the following command
./ch transform --config configs/examples/jsc_toy_by_type.toml --task cls --cpu=0
```
7. Load your own pre-trained JSC network, and perform perform the quantisation using the command line interface.

## \[Optional] Write your own pass

Many examples of existing passes are in the [source code](../..//machop/chop/passes/__init__.py), the [test files](../../machop/test/passes) for these passes also contain useful information on helping you to understand how these passes are used.

Implement a pass to count the number of FLOPs (floating-point operations) and BitOPs (bit-wise operations).

In [None]:
# 5.

In [49]:
batch_size = 8
model_name = "jsc-medium"
dataset_name = "jsc"


data_module = MaseDataModule(
    name=dataset_name,
    batch_size=batch_size,
    model_name=model_name,
    num_workers=0,
)
data_module.prepare_data()
data_module.setup()

In [50]:
# 📝️ change this CHECKPOINT_PATH to the one you trained in Lab1
CHECKPOINT_PATH = "/Users/jared/Documents/Personal/ICL/courses/ADLS/mase/mase_output/jsc-medium_classification_jsc_2024-01-24/software/training_ckpts/best-v3.ckpt"
model_info = get_model_info(model_name)
model = get_model(
    model_name,
    task="cls",
    dataset_info=data_module.dataset_info,
    pretrained=False)

model = load_model(load_name=CHECKPOINT_PATH, load_type="pl", model=model)

[32mINFO    [0m [34mLoaded pytorch lightning checkpoint from /Users/jared/Documents/Personal/ICL/courses/ADLS/mase/mase_output/jsc-medium_classification_jsc_2024-01-24/software/training_ckpts/best-v3.ckpt[0m


In [51]:
# get the input generator
input_generator = InputGenerator(
    data_module=data_module,
    model_info=model_info,
    task="cls",
    which_dataloader="train",
)

In [52]:
dummy_in = next(iter(input_generator))

In [53]:
# generate the mase graph and initialize node metadata
mg = MaseGraph(model=model)

mg, _ = init_metadata_analysis_pass(mg, None)
mg, _ = add_common_metadata_analysis_pass(mg, {"dummy_in": dummy_in})
mg, _ = add_software_metadata_analysis_pass(mg, None)

[36mDEBUG   [0m [34mgraph():
    %x : [num_users=1] = placeholder[target=x]
    %seq_blocks_0 : [num_users=1] = call_module[target=seq_blocks.0](args = (%x,), kwargs = {})
    %seq_blocks_1 : [num_users=1] = call_module[target=seq_blocks.1](args = (%seq_blocks_0,), kwargs = {})
    %seq_blocks_2 : [num_users=1] = call_module[target=seq_blocks.2](args = (%seq_blocks_1,), kwargs = {})
    %seq_blocks_3 : [num_users=1] = call_module[target=seq_blocks.3](args = (%seq_blocks_2,), kwargs = {})
    %seq_blocks_4 : [num_users=1] = call_module[target=seq_blocks.4](args = (%seq_blocks_3,), kwargs = {})
    %seq_blocks_5 : [num_users=1] = call_module[target=seq_blocks.5](args = (%seq_blocks_4,), kwargs = {})
    %seq_blocks_6 : [num_users=1] = call_module[target=seq_blocks.6](args = (%seq_blocks_5,), kwargs = {})
    %seq_blocks_7 : [num_users=1] = call_module[target=seq_blocks.7](args = (%seq_blocks_6,), kwargs = {})
    %seq_blocks_8 : [num_users=1] = call_module[target=seq_blocks.8](args = 

In [54]:
# report graph is an analysis pass that shows you the detailed information in the graph
from chop.passes.graph import report_graph_analysis_pass
_ = report_graph_analysis_pass(mg)

graph():
    %x : [num_users=1] = placeholder[target=x]
    %seq_blocks_0 : [num_users=1] = call_module[target=seq_blocks.0](args = (%x,), kwargs = {})
    %seq_blocks_1 : [num_users=1] = call_module[target=seq_blocks.1](args = (%seq_blocks_0,), kwargs = {})
    %seq_blocks_2 : [num_users=1] = call_module[target=seq_blocks.2](args = (%seq_blocks_1,), kwargs = {})
    %seq_blocks_3 : [num_users=1] = call_module[target=seq_blocks.3](args = (%seq_blocks_2,), kwargs = {})
    %seq_blocks_4 : [num_users=1] = call_module[target=seq_blocks.4](args = (%seq_blocks_3,), kwargs = {})
    %seq_blocks_5 : [num_users=1] = call_module[target=seq_blocks.5](args = (%seq_blocks_4,), kwargs = {})
    %seq_blocks_6 : [num_users=1] = call_module[target=seq_blocks.6](args = (%seq_blocks_5,), kwargs = {})
    %seq_blocks_7 : [num_users=1] = call_module[target=seq_blocks.7](args = (%seq_blocks_6,), kwargs = {})
    %seq_blocks_8 : [num_users=1] = call_module[target=seq_blocks.8](args = (%seq_blocks_7,), kwarg

In [55]:
pass_args = {
    "by": "type",                                                            # collect statistics by node name
    "target_weight_nodes": ["linear"],                                       # collect weight statistics for linear layers
    "target_activation_nodes": ["relu"],                                     # collect activation statistics for relu layers
    "weight_statistics": {
        "variance_precise": {"device": "cpu", "dims": "all"},                # collect precise variance of the weight
    },
    "activation_statistics": {
        "range_quantile": {"device": "cpu", "dims": "all", "quantile": 0.97} # collect 97% quantile of the activation range
    },
    "input_generator": input_generator,                                      # the input generator for feeding data to the model
    "num_samples": 32,                                                       # feed 32 samples to the model
}

In [56]:
mg, _ = profile_statistics_analysis_pass(mg, pass_args)
mg, _ = report_node_meta_param_analysis_pass(mg, {"which": ("software",)})

Profiling weight statistics: 100%|██████████| 16/16 [00:00<00:00, 17957.95it/s]
Profiling act statistics: 100%|██████████| 4/4 [00:00<00:00, 249.64it/s]
[32mINFO    [0m [34mInspecting graph [add_common_meta_param_analysis_pass][0m
[32mINFO    [0m [34m
+---------------+--------------+---------------------+--------------+-----------------------------------------------------------------------------------------+
| Node name     | Fx Node op   | Mase type           | Mase op      | Software Param                                                                          |
| x             | placeholder  | placeholder         | placeholder  | {'results': {'data_out_0': {'stat': {}}}}                                               |
+---------------+--------------+---------------------+--------------+-----------------------------------------------------------------------------------------+
| seq_blocks_0  | call_module  | module              | batch_norm1d | {'args': {'bias': {'stat': {}},

In [57]:
pass_args = {
    "by": "type",
    "default": {"config": {"name": None}},
    "linear": {
        "config": {
            "name": "integer",
            # data
            "data_in_width": 8,
            "data_in_frac_width": 4,
            # weight
            "weight_width": 8,
            "weight_frac_width": 4,
            # bias
            "bias_width": 8,
            "bias_frac_width": 4,
        }
    },
}

In [58]:
from chop.passes.graph.transforms import (
    quantize_transform_pass,
    summarize_quantization_analysis_pass,
)
from chop.ir.graph.mase_graph import MaseGraph

ori_mg = MaseGraph(model=model)
ori_mg, _ = init_metadata_analysis_pass(ori_mg, None)
ori_mg, _ = add_common_metadata_analysis_pass(ori_mg, {"dummy_in": dummy_in})

mg, _ = quantize_transform_pass(mg, pass_args)
summarize_quantization_analysis_pass(ori_mg, mg, save_dir="quantize_summary")

[36mDEBUG   [0m [34mgraph():
    %x : [num_users=1] = placeholder[target=x]
    %seq_blocks_0 : [num_users=1] = call_module[target=seq_blocks.0](args = (%x,), kwargs = {})
    %seq_blocks_1 : [num_users=1] = call_module[target=seq_blocks.1](args = (%seq_blocks_0,), kwargs = {})
    %seq_blocks_2 : [num_users=1] = call_module[target=seq_blocks.2](args = (%seq_blocks_1,), kwargs = {})
    %seq_blocks_3 : [num_users=1] = call_module[target=seq_blocks.3](args = (%seq_blocks_2,), kwargs = {})
    %seq_blocks_4 : [num_users=1] = call_module[target=seq_blocks.4](args = (%seq_blocks_3,), kwargs = {})
    %seq_blocks_5 : [num_users=1] = call_module[target=seq_blocks.5](args = (%seq_blocks_4,), kwargs = {})
    %seq_blocks_6 : [num_users=1] = call_module[target=seq_blocks.6](args = (%seq_blocks_5,), kwargs = {})
    %seq_blocks_7 : [num_users=1] = call_module[target=seq_blocks.7](args = (%seq_blocks_6,), kwargs = {})
    %seq_blocks_8 : [num_users=1] = call_module[target=seq_blocks.8](args = 

# 6.

In [60]:
# iterate over original and modified graphs
for ori_n, n in zip(ori_mg.fx_graph.nodes, mg.fx_graph.nodes):
    # check if the original node and the modified node are the same
    # if they arent, then it means that that node has been quantized
    if type(get_node_actual_target(n)) != type(get_node_actual_target(ori_n)):
        # retrieve the original module from the node
        ori_module = get_node_actual_target(ori_n)
        # retrieve the quantized module from the node
        quant_module = get_node_actual_target(n)

        print(f'Difference found at name: {n.name}, '
              f'MASE type: {get_mase_type(n)}, MASE operation: {get_mase_op(n)}\n'
              f'Original module: {type(ori_module)} --> '
              f'New module: {type(quant_module)}')
        
        # Get the precision and types of the weights of the nodes from their metadata
        mg_typemg_precision = n.meta["mase"].parameters["common"]["args"]["weight"]["precision"]
        ori_mg_precision = ori_n.meta["mase"].parameters["common"]["args"]["weight"]["precision"]

        mg_type = n.meta["mase"].parameters["common"]["args"]["weight"]["type"]
        ori_mg_type = ori_n.meta["mase"].parameters["common"]["args"]["weight"]["type"]

        print(f'Precision of original module: {ori_mg_precision}')
        print(f'Precision of modified module: {mg_precision}')

        print(f'Type of original module: {ori_mg_type}')
        print(f'Tyoe of modified module: {mg_type}')

        # print the weights of the original and quantized modules
        print(f'Weight of original module: {ori_module.weight}')
        quantized_weights = quant_module.w_quantizer(ori_module.weight)
        print(f'Weight of quantized module: {quantized_weights}')

        # generate a test input tensor based on the input feature size of the quantized module
        test_input = torch.randn(quant_module.in_features)

        # apply the original and quantized modules to the test input and print the outputs
        print(f'Output for original module: {ori_module(test_input)}')
        print(f'Output for quantized module: {quant_module(test_input)}')


Difference found at name: seq_blocks_2, MASE type: module_related_func, MASE operation: linear
Original module: <class 'torch.nn.modules.linear.Linear'> --> New module: <class 'chop.passes.graph.transforms.quantize.quantized_modules.linear.LinearInteger'>
Precision of original module: [32]
Precision of modified module: [8, 4]
Type of original module: float
Tyoe of modified module: integer
Weight of original module: Parameter containing:
tensor([[-2.0059e-01,  7.0937e-02, -9.0220e-01, -3.3770e-01, -2.1713e-01,
         -4.4070e-01, -1.0396e-02, -2.6671e-02,  1.2821e-01,  5.0434e-01,
          6.0154e-02, -9.6376e-02, -1.2634e-01, -5.2315e-02,  1.1104e-01,
         -1.4309e-02],
        [-5.8988e-02,  7.6669e-02,  1.1914e-01,  1.6312e-01, -3.9841e-01,
          1.9455e-01,  1.9242e-02,  2.9598e-01,  1.4485e-01, -1.7565e-01,
          1.7875e-01, -2.9895e-01,  6.2455e-02, -2.7527e-01, -9.8624e-02,
          6.7748e-02],
        [-6.6721e-02,  5.2548e-02, -4.9106e-01, -3.0251e-01,  3.0298e