# General introduction

In this lab, you will learn how to use the software stack of MASE. There are in total 7 tasks you would need to finish, and 1 optional task.

# Turning you network to a graph

One specific feature of MASE is its capability to transform DL models to a computation graph using the [torch.fx](<https://pytorch.org/docs/stable/fx.html>) framework.


## Use the Transform functionality without CLI

This tutorial describes how to use the MASE transform functionality for a pre-trained model.

## Import related packages and machop

In [99]:
pwd

'/home/laurie2905/mase/machop'

In [100]:
%cd ../machop/

/home/laurie2905/mase/machop


In [101]:
import sys
import logging
import os
from pathlib import Path
from pprint import pprint as pp

from chop.dataset import MaseDataModule, get_dataset_info
from chop.tools.logger import set_logging_verbosity

from chop.passes.graph import (
    save_node_meta_param_interface_pass,
    report_node_meta_param_analysis_pass,
    profile_statistics_analysis_pass,
    add_common_metadata_analysis_pass,
    init_metadata_analysis_pass,
    add_software_metadata_analysis_pass,
)
from chop.tools.get_input import InputGenerator
from chop.tools.checkpoint_load import load_model
from chop.ir import MaseGraph

from chop.models import get_model_info, get_model

set_logging_verbosity("info")

[32mINFO    [0m [34mSet logging level to info[0m


## Set up the dataset

Here we create a `MaseDataModule` using the `jsc` dataset from lab1. Note the `MaseDataModule` also requires the name of the model you plan to use data module with. In this case it is `jsc-tiny`.

In [102]:
# Why do we set batchsize to 8 here
batch_size = 8
model_name = "jsc-tiny"
dataset_name = "jsc"


data_module = MaseDataModule(
    name=dataset_name,
    batch_size=batch_size,
    model_name=model_name,
    num_workers=0,
)
data_module.prepare_data()
data_module.setup()

In [103]:
pwd

'/home/laurie2905/mase/machop'

## Set up the model 

Here we use the previously trained `jsc-tiny` model in lab 1 as an example.

In [164]:
# Change directory (if necessary)
# %cd ../mase_output/batch_32/jsc-tiny_classification_jsc_2024-01-25/software/training_ckpts

# Assuming get_model_info and get_model functions are defined and data_module is available
model_info = get_model_info(model_name)
model = get_model(
    model_name,
    task="cls",
    dataset_info=data_module.dataset_info,
    pretrained=False)

# Load the model from the checkpoint
model = load_model(load_name="cd ../mase_output/batch_32/jsc-tiny_classification_jsc_2024-01-25/software/training_ckpts/best.ckpt", load_type="pl", model=model)


FileNotFoundError: [Errno 2] No such file or directory: 'cd ../mase_output/batch_32/jsc-tiny_classification_jsc_2024-01-25/software/training_ckpts/best.ckpt'

# Get a dummy data in
With the dataset module and model information, we can grab an input generator.

In [108]:
# Why do we set batchsize to 8 here
batch_size = 8
model_name = "jsc-tiny"
dataset_name = "jsc"


data_module = MaseDataModule(
    name=dataset_name,
    batch_size=batch_size,
    model_name=model_name,
    num_workers=0,
)
data_module.prepare_data()
data_module.setup()

# Change directory (if necessary)
# %cd ../mase_output/batch_32/jsc-tiny_classification_jsc_2024-01-25/software/training_ckpts

# Assuming get_model_info and get_model functions are defined and data_module is available
model_info = get_model_info(model_name)
model = get_model(
    model_name,
    task="cls",
    dataset_info=data_module.dataset_info,
    pretrained=False)

# Load the model from the checkpoint
model = load_model(load_name="../mase_output/batch_32/jsc-tiny_classification_jsc_2024-01-25/software/training_ckpts/best.ckpt", load_type="pl", model=model)

# get the input generator
input_generator = InputGenerator(
    data_module=data_module,
    model_info=model_info,
    task="cls",
    which_dataloader="train",
)

# a demonstration of how to feed an input value to the model
dummy_in = next(iter(input_generator))
_ = model(**dummy_in)


## Generate a MaseGraph
We have two forms of passes: transform passes and analysis passes, both of them would require the model to be transferred into a MaseGraph to allow manipulation.

In [109]:
# generate the mase graph and initialize node metadata
mg = MaseGraph(model=model)

## Running an Analysis pass
Analysis pass DOES NOT change the graph

The following analysis passes are essential to prepare the graph for other passes

In [110]:
mg, _ = init_metadata_analysis_pass(mg, None)
mg, _ = add_common_metadata_analysis_pass(mg, {"dummy_in": dummy_in})
mg, _ = add_software_metadata_analysis_pass(mg, None)

We will first run a simple graph analysis to understand the structure of the model.

In [111]:
# report graph is an analysis pass that shows you the detailed information in the graph
from chop.passes.graph import report_graph_analysis_pass
_ = report_graph_analysis_pass(mg)

graph():
    %x : [num_users=1] = placeholder[target=x]
    %seq_blocks_0 : [num_users=1] = call_module[target=seq_blocks.0](args = (%x,), kwargs = {})
    %seq_blocks_1 : [num_users=1] = call_module[target=seq_blocks.1](args = (%seq_blocks_0,), kwargs = {})
    %seq_blocks_2 : [num_users=1] = call_module[target=seq_blocks.2](args = (%seq_blocks_1,), kwargs = {})
    %seq_blocks_3 : [num_users=1] = call_module[target=seq_blocks.3](args = (%seq_blocks_2,), kwargs = {})
    return seq_blocks_3
Network overview:
{'placeholder': 1, 'get_attr': 0, 'call_function': 0, 'call_method': 0, 'call_module': 4, 'output': 1}
Layer types:
[BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), ReLU(inplace=True), Linear(in_features=16, out_features=5, bias=True), ReLU(inplace=True)]


## Running another Analysis pass: Profile statistics

The pass `profile_statistics_analysis_pass` collects statistics of parameters and activations, and save them to node's metadata.

Here is a list of all the supported statistics. Refer to the `__init__` of statistic classes in `chop.passes.analysis.statistical_profiler.stat` to check the args each stat class takes.

This is a more complex analysis than the previous pass, and thus it would require you to pass in additional arguments for this pass.

### Example: the range of weights & input activations of nodes

Say we want to collect the tensor-wise min-max range of the 1st `torch.nn.Linear` nodes' weights & bias, and the channel-wise 97% quantile min-max of the 1st `torch.nn.Linear` nodes' input activations. We can do the following:

Min Max Val of Weights and Biases: This part refers to analyzing the first linear layer (torch.nn.Linear) in the model and collecting the minimum and maximum values of its weights and biases. "Tensor-wise" here means that you are looking at the entire tensor of weights and biases as a whole. For each tensor (one for weights and one for biases), you identify the minimum and maximum values. This type of analysis is useful for understanding the range of values your weights and biases are taking, which can be important for tasks like model quantization or normalization.

Collect the channel-wise 97% quantile min-max of the 1st torch.nn.Linear nodes' input activations: Analyzing the inputs to the first linear layer of the model. "Channel-wise" means the analysis is done separately for each channel (or feature) of the input tensor. In the context of a torch.nn.Linear layer, each 'channel' can be thought of as a feature in the input vector. "97% quantile min-max" refers to finding the range within which 97% of the data lies for each channel. In other words, for each channel, you identify the values between which 97% of the activation values fall. This is done to capture the typical range of activation values, excluding extreme outliers. Valuable for understanding the distribution of activation values, which can be critical for optimizing and scaling neural networks, especially for tasks like robust quantization.

In [112]:

pass_args = {
    "by": "type",                                                            # collect statistics by node name
    "target_weight_nodes": ["linear"],                                       # collect weight statistics for linear layers
    "target_activation_nodes": ["relu"],                                     # collect activation statistics for relu layers
    "weight_statistics": {
        "variance_precise": {"device": "cpu", "dims": "all"},                # collect precise variance of the weight
    },
    "activation_statistics": {
        "range_quantile": {"device": "cpu", "dims": "all", "quantile": 0.97} # collect 97% quantile of the activation range
    },
    "input_generator": input_generator,                                      # the input generator for feeding data to the model
    "num_samples": 32,                                                       # feed 32 samples to the model
}

We can use the `report_node_meta_param_analysis_pass` to inspect the collected statistics.

In [113]:
mg, _ = profile_statistics_analysis_pass(mg, pass_args)
mg, _ = report_node_meta_param_analysis_pass(mg, {"which": ("software",)})

Profiling weight statistics: 100%|██████████| 6/6 [00:00<00:00, 9365.77it/s]
Profiling act statistics: 100%|██████████| 4/4 [00:00<00:00, 312.82it/s]
[32mINFO    [0m [34mInspecting graph [add_common_meta_param_analysis_pass][0m
[32mINFO    [0m [34m
+--------------+--------------+---------------------+--------------+----------------------------------------------------------------------------------------+
| Node name    | Fx Node op   | Mase type           | Mase op      | Software Param                                                                         |
| x            | placeholder  | placeholder         | placeholder  | {'results': {'data_out_0': {'stat': {}}}}                                              |
+--------------+--------------+---------------------+--------------+----------------------------------------------------------------------------------------+
| seq_blocks_0 | call_module  | module              | batch_norm1d | {'args': {'bias': {'stat': {}},            

## Running a Transform pass: Quantisation

As its name suggests, the transform pass would modify the `MaseGraph`.
Similar to the previous analysis pass example, we would need to first declare the configuration for the pass.

Quantization, in the context of machine learning and neural networks, is a technique used to reduce the precision of the numbers used in a model's calculations. It's like using a simpler, rougher scale to measure something instead of a highly precise one. Quantization typically involves converting floating-point numbers (which can represent a very wide range of values with high precision) into integers (which represent a more limited range of values with less precision). For instance, a model might originally use 32-bit floating-point numbers, but with quantization, it could be using 8-bit integers. 

1. Reduced Model Size: Using lower precision numbers means each number takes up less memory, so the entire model is smaller.

2. Faster Performance: Calculations with integers are generally faster than with floating-point numbers, especially on certain hardware.

3. Lower Power Consumption: This is especially important for running models on mobile devices or other hardware with limited power resources.

In [114]:
pass_args = {
    "by": "type",
    "default": {"config": {"name": None}},
    "linear": {
        "config": {
            "name": "integer",
            # data
            "data_in_width": 8,
            "data_in_frac_width": 4,
            # weight
            "weight_width": 8,
            "weight_frac_width": 4,
            # bias
            "bias_width": 8,
            "bias_frac_width": 4,
        }
    },
}

We can then proceed to apply the transformation, in this case, we kept the original graph on purpose, so that we can print a `diff`.

In [115]:
from chop.passes.graph.transforms import (
    quantize_transform_pass,
    summarize_quantization_analysis_pass,
)
from chop.ir.graph.mase_graph import MaseGraph


ori_mg = MaseGraph(model=model)
ori_mg, _ = init_metadata_analysis_pass(ori_mg, None)
ori_mg, _ = add_common_metadata_analysis_pass(ori_mg, {"dummy_in": dummy_in})

mg, _ = quantize_transform_pass(mg, pass_args)
summarize_quantization_analysis_pass(ori_mg, mg, save_dir="quantize_summary")

[32mINFO    [0m [34mQuantized graph histogram:[0m
[32mINFO    [0m [34m
| Original type   | OP           |   Total |   Changed |   Unchanged |
|-----------------+--------------+---------+-----------+-------------|
| BatchNorm1d     | batch_norm1d |       1 |         0 |           1 |
| Linear          | linear       |       1 |         1 |           0 |
| ReLU            | relu         |       2 |         0 |           2 |
| output          | output       |       1 |         0 |           1 |
| x               | placeholder  |       1 |         0 |           1 |[0m




# Exercises:

We have now seen how to:
1. Set up a dataset
2. Set up a model
3. Generate a `MaseGraph` from the model
4. Run Analysis and Transform passes on the `MaseGraph`

Now consider the following problems:

1. Explain the functionality of `report_graph_analysis_pass` and its printed jargons such as `placeholder`, `get_attr` ... You might find the doc of [torch.fx](https://pytorch.org/docs/stable/fx.html) useful.

It is used to produce a report for the graph analysis of a MaseGraph. It takes a MaseGraph as an input and counts the different node operations and module types then returns a tuple of a MaseGraph and an empty dictionary of types of operations. 

The operations are defined as follows:

1. placeholder: Represents inputs to the model; 'name' assigns input names, 'target' names the argument, 'args' nothing or default parameter of function input, 'kwargs' unused.

2. get_attr: These nodes are used to fetch parameters from your model, such as weights from layers. They locate the parameters within the model’s structure. Fetches a parameter from the module hierarchy; 'name' labels the result, 'target' identifies the parameter's location in the hierarchy, 'args' and 'kwargs' are unused.

3. call_function: These nodes represent the application of standalone functions (like torch.add) on data. They keep track of the function being used and the arguments it takes.Applies a function to values; 'name' labels the result, 'target' is the function, 'args' and 'kwargs' are the function's arguments, following Python's convention

4. call_module: These are used when a specific module (a layer in your neural network) is called. 'name' labels the result, 'target' is the module's location in the hierarchy, 'args' and 'kwargs' are arguments excluding 'self'.

5. call_method: Similar to Call_Function, but these nodes are for methods that belong to an object (like tensor.view()). They record the method being called, including the object it is called on (self) and other arguments. 'name' for labeling, 'target' is the method's name, 'args' and 'kwargs' include all method arguments including 'self'.

6. output: Correspond to the return values of functions or the final output of your model.

The function appends a network overview and layer types information to the buffer.

In [116]:

# report graph is an analysis pass that shows you the detailed information in the graph
from chop.passes.graph import report_graph_analysis_pass
_ = report_graph_analysis_pass(mg)


graph():
    %x : [num_users=1] = placeholder[target=x]
    %seq_blocks_0 : [num_users=1] = call_module[target=seq_blocks.0](args = (%x,), kwargs = {})
    %seq_blocks_1 : [num_users=1] = call_module[target=seq_blocks.1](args = (%seq_blocks_0,), kwargs = {})
    %seq_blocks_2 : [num_users=1] = call_module[target=seq_blocks.2](args = (%seq_blocks_1,), kwargs = {})
    %seq_blocks_3 : [num_users=1] = call_module[target=seq_blocks.3](args = (%seq_blocks_2,), kwargs = {})
    return seq_blocks_3
Network overview:
{'placeholder': 1, 'get_attr': 0, 'call_function': 0, 'call_method': 0, 'call_module': 4, 'output': 1}
Layer types:
[BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), ReLU(inplace=True), LinearInteger(in_features=16, out_features=5, bias=True), ReLU(inplace=True)]


### 2. What are the functionalities of `profile_statistics_analysis_pass` and `report_node_meta_param_analysis_pass` respectively?

### profile_statistics_analysis_pass
Function performs a series of operations on a given graph section to collect profile and computes statistics (See Below) related to the weights and activations of the nodes metadata.

#### Args
Graph Node Identification by Name: Targets nodes in the graph whose names match entries in target_weight_nodes or target_act_nodes for statistical analysis.

Targeting by Type or Attribute: Uses a common characteristic, defined by mase_op, to identify nodes for analysis; applicable for various operation types like convolution, pooling, etc.

target_weight_nodes: Specifies which weight layers' data should be recorded for statistical analysis.

target_act_nodes: Designates activation nodes to record statistics for.

weight_stats: Determines the type of statistics to be collected for weight nodes.

act_stats: Defines dimensions, quantile, and device for activation statistics collection.

#### Statistics
Record: Keeps a record of all samples passed to it. It allows for samples to be moved to a specific device and adds a new dimension before concatenation if required.

VarianceOnline: Calculates the running variance and mean using Welford's online algorithm, which is more memory-efficient as it does not require storing all samples.

VariancePrecise: Computes the variance and mean by concatenating samples and using torch.var and torch.mean. It is more precise but uses more memory, which can be significant for large datasets.

RangeNSigma: Determines the range of samples within n standard deviations (sigma) from the mean. It assumes a normal distribution and can operate in either 'precise' or 'online' mode for variance calculation.

RangeMinMax: Calculates the range of samples based on the minimum and maximum values. It can also take the absolute value of samples before calculation.

RangeQuantile: Computes the range based on quantiles. It can take the absolute values of samples and reduce along specified dimensions.

AbsMean: Implements an online algorithm to compute the mean of the absolute values of the samples.

### report_node_meta_param_analysis_pass

Report Generation: Constructs a table with headers based on selected parameter categories:

Includes basic information like node name, operation type (Fx Node op), and Mase type and Mase op.

"which": Specifies which categories of parameters to include in the report (options: "all", "common", "hardware", "software").

"save_path": Defines a file path where the analysis report will be saved.

## MASE OPs and MASE Types

MASE is designed to be a very high-level intermediate representation (IR), this is very different from the classic [LLVM IR](https://llvm.org/docs/LangRef.html) that you might be familiar with.

The following MASE Types are available:
(Note from Aaron: do we have a page somewhere that have summarized this?)


## A deeper dive into the quantisation transform

3. Explain why only 1 OP is changed after the `quantize_transform_pass` .

In [137]:
pass_args = {
    "by": "type",
    "default": {"config": {"name": None}},
    "linear": {
        "config": {
            "name": "integer",
            # data
            "data_in_width": 8,
            "data_in_frac_width": 4,
            # weight
            "weight_width": 8,
            "weight_frac_width": 4,
            # bias
            "bias_width": 8,
            "bias_frac_width": 4,
        }
    
    },
}

from chop.passes.graph.transforms import (
    quantize_transform_pass,
    summarize_quantization_analysis_pass,
)
from chop.ir.graph.mase_graph import MaseGraph


ori_mg = MaseGraph(model=model)
ori_mg, _ = init_metadata_analysis_pass(ori_mg, None)
ori_mg, _ = add_common_metadata_analysis_pass(ori_mg, {"dummy_in": dummy_in})

mg, _ = quantize_transform_pass(mg, pass_args)
summarize_quantization_analysis_pass(ori_mg, mg, save_dir="quantize_summary")

[32mINFO    [0m [34mQuantized graph histogram:[0m
[32mINFO    [0m [34m
| Original type   | OP           |   Total |   Changed |   Unchanged |
|-----------------+--------------+---------+-----------+-------------|
| BatchNorm1d     | batch_norm1d |       1 |         0 |           1 |
| Linear          | linear       |       1 |         1 |           0 |
| ReLU            | relu         |       2 |         0 |           2 |
| output          | output       |       1 |         0 |           1 |
| x               | placeholder  |       1 |         0 |           1 |[0m


As only one call_module type (Linear) is being specified to be quantised in the transform pass. Furthermore, in the jsc-tiny model there is one linear operator. If relu was chosen to be quantised then there would be 2 changes.

4. Write some code to traverse both `mg` and `ori_mg`, check and comment on the nodes in these two graphs. You might find the source code for the implementation of `summarize_quantization_analysis_pass` useful.

In [162]:
from chop.ir.graph.mase_graph import MaseGraph
from chop.passes.graph.utils import get_mase_op, get_mase_type, get_node_actual_target
import torch

# Iterate over pairs of nodes from the original and modified graphs
for ori_n, n in zip(ori_mg.fx_graph.nodes, mg.fx_graph.nodes):
    # Check if the node's target module has changed type after modification
    if type(get_node_actual_target(n)) != type(get_node_actual_target(ori_n)):
        # Retrieve the original and quantized modules from the nodes
        ori_module = get_node_actual_target(ori_n)
        quant_module = get_node_actual_target(n)

        # Print the difference information
        print(f'Difference found at name: {n.name}, '
              f'MASE type: {get_mase_type(n)}, MASE operation: {get_mase_op(n)}\n'
              f'Original module: {type(ori_module)} --> '
              f'New module: {type(quant_module)}')

        # Print the weights of the original and quantized modules
        print(f'Weight of original module: {ori_module.weight}')
        print(f'Weight of quantized module: {quant_module.weight}')
        # print(f'Bias of original module: {ori_module.bias}')
        # print(f'Bias of quantized module: {quant_module.x_quantizer}')

        # Generate a random input tensor based on the input feature size of the quantized module
        test_input = torch.randn(quant_module.in_features)
        print(f'Random generated test input: {test_input}')
        # Apply the original and quantized modules to the test input and print the outputs
        print(f'Output for original module: {ori_module(test_input)}')
        print(f'Output for quantized module: {quant_module(test_input)}')


Difference found at name: seq_blocks_2, MASE type: module_related_func, MASE operation: linear
Original module: <class 'torch.nn.modules.linear.Linear'> --> New module: <class 'chop.passes.graph.transforms.quantize.quantized_modules.linear.LinearInteger'>
Weight of original module: Parameter containing:
tensor([[-0.0592,  0.0724, -0.2058, -0.1840, -0.1549,  0.0057, -0.0741,  0.1262,
         -0.0881,  0.0222, -0.1073, -0.0547, -0.2582, -0.2130, -0.1583, -0.0296],
        [ 0.8399,  0.8483,  0.3889,  0.2798,  0.7101,  0.4522, -0.1635, -0.2955,
          0.2293, -0.4835,  0.9006, -0.4432, -0.0585,  0.2956, -1.9666, -1.1026],
        [ 0.1946,  0.4658,  1.6506, -3.3438,  1.0279,  0.3980,  0.3584, -0.9193,
          0.0906, -0.2203,  0.2473, -0.1677, -0.4138,  0.5857,  0.0652, -1.9793],
        [ 0.1955,  0.2623,  1.3494, -1.7947,  0.8402, -0.8200,  0.1645, -0.3052,
          0.0743, -0.4425,  0.1486, -0.0167, -0.3560,  0.4405,  0.5642, -1.5844],
        [-0.3326, -0.2061,  1.2023, -0.2903

5. Perform the same quantisation flow to the bigger JSC network that you have trained in lab1. You must be aware that now the `pass_args` for your custom network might be different if you have used more than the `Linear` layer in your network.

6. Write code to show and verify that the weights of these layers are indeed quantised. You might need to go through the source code of the implementation of the quantisation pass and also the implementation of the [Quantized Layers](../../machop/chop/passes/transforms/quantize/quantized_modules/linear.py) .

## The command line interface

The same flow can also be executed on the command line throw the `transform` action.

```bash
# make sure you have the same printout
pwd
# it should show
# your_dir/mase-tools/machop

# enter the following command
./ch transform --config configs/examples/jsc_toy_by_type.toml --task cls --cpu=0

7. Load your own pre-trained JSC network, and perform perform the quantisation using the command line interface.

## \[Optional] Write your own pass

Many examples of existing passes are in the [source code](../..//machop/chop/passes/__init__.py), the [test files](../../machop/test/passes) for these passes also contain useful information on helping you to understand how these passes are used.

Implement a pass to count the number of FLOPs (floating-point operations) and BitOPs (bit-wise operations).