# Use the Transform functionality without CLI

This tutorial describes how to use the MASE transform functionality for a pre-trained model.

## Import related packages and machop

In [1]:
import sys
import logging
import os
from pathlib import Path
from pprint import pprint as pp

In [2]:
# figure out the correct path
machop_path = Path(".").resolve().parent.parent /"machop"
assert machop_path.exists(), "Failed to find machop at: {}".format(machop_path)
sys.path.append(str(machop_path))

In [3]:
from chop.dataset import MaseDataModule, get_dataset_info
from chop.tools.logger import set_logging_verbosity

from chop.passes.graph.interface import save_node_meta_param_interface_pass
from chop.passes.graph.analysis import (
    report_node_meta_param_analysis_pass,
    profile_statistics_analysis_pass,
)
from chop.passes.graph import (
    add_common_metadata_analysis_pass,
    init_metadata_analysis_pass,
    add_software_metadata_analysis_pass,
)
from chop.tools.get_input import InputGenerator
from chop.tools.checkpoint_load import load_model
from chop.ir.graph.mase_graph import MaseGraph

from chop.models import get_model_info, get_model

set_logging_verbosity("info")

  from .autonotebook import tqdm as notebook_tqdm
[32mINFO    [0m [34mSet logging level to info[0m


## Set up the dataset 

Here we use the previously trained jsc dataset in lab 1 as an example, the dataset is configured using the internal `MaseDataModule`.

In [6]:
batch_size = 256
model_name = "vgg7"
dataset_name = "cifar10"


data_module = MaseDataModule(
    name=dataset_name,
    batch_size=batch_size,
    model_name=model_name,
    num_workers=0,
)
data_module.prepare_data()
data_module.setup()


Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


## Set up the model 

Here we use the previously trained jsc-tiny model in lab 1 as an example.

In [8]:
# 📝️ change this CHECKPOINT_PATH to the one you trained in Lab1
CHECKPOINT_PATH = "/home/lch121600/ADLSlab/mase/mase_output/vgg7_classification_cifar10_2024-02-21/test-accu-0.9332.ckpt"
model_info = get_model_info(model_name)
model = get_model(
    model_name,
    task="cls",
    dataset_info=data_module.dataset_info,
    pretrained=False)

model = load_model(load_name=CHECKPOINT_PATH, load_type="pl", model=model)

[32mINFO    [0m [34mLoaded pytorch lightning checkpoint from /home/lch121600/ADLSlab/mase/mase_output/vgg7_classification_cifar10_2024-02-21/test-accu-0.9332.ckpt[0m


# Get a dummy data in
With the dataset module and model information, we can grab an input generator.

In [9]:
# get the input generator
input_generator = InputGenerator(
    data_module=data_module,
    model_info=model_info,
    task="cls",
    which_dataloader="train",
)

# a demonstration of how to feed an input value to the model
dummy_in = next(iter(input_generator))
_ = model(**dummy_in)


## Generate a MaseGraph
We have two forms of passes: transform passes and analysis passes, both of them would require the model to be transferred into a MaseGraph to allow manipulation.

In [10]:
# generate the mase graph and initialize node metadata
mg = MaseGraph(model=model)

print(type(mg))

<class 'chop.ir.graph.mase_graph.MaseGraph'>


## Running an Analysis pass
Analysis pass DOES NOT change the graph

The following analysis passes are essential to prepare the graph for other passes

In [11]:
mg, _ = init_metadata_analysis_pass(mg, None)
mg, _ = add_common_metadata_analysis_pass(mg, {"dummy_in": dummy_in})
mg, _ = add_software_metadata_analysis_pass(mg, None)


We will first run a simple graph analysis to understand the structure of the model.

In [12]:
# report graph is an analysis pass that shows you the detailed information in the graph
from chop.passes import report_graph_analysis_pass
_ = report_graph_analysis_pass(mg)

graph():
    %x : torch.Tensor [num_users=1] = placeholder[target=x]
    %feature_layers_0 : [num_users=1] = call_module[target=feature_layers.0](args = (%x,), kwargs = {})
    %feature_layers_1 : [num_users=1] = call_module[target=feature_layers.1](args = (%feature_layers_0,), kwargs = {})
    %feature_layers_2 : [num_users=1] = call_module[target=feature_layers.2](args = (%feature_layers_1,), kwargs = {})
    %feature_layers_3 : [num_users=1] = call_module[target=feature_layers.3](args = (%feature_layers_2,), kwargs = {})
    %feature_layers_4 : [num_users=1] = call_module[target=feature_layers.4](args = (%feature_layers_3,), kwargs = {})
    %feature_layers_5 : [num_users=1] = call_module[target=feature_layers.5](args = (%feature_layers_4,), kwargs = {})
    %feature_layers_6 : [num_users=1] = call_module[target=feature_layers.6](args = (%feature_layers_5,), kwargs = {})
    %feature_layers_7 : [num_users=1] = call_module[target=feature_layers.7](args = (%feature_layers_6,), kwargs 

## Running another Analysis pass: Profile statistics

The pass `profile_statistics_analysis_pass` collects statistics of parameters and activations, and save them to node's metadata.

Here is a list of all the supported statistics. Refer to the `__init__` of statistic classes in `chop.passes.analysis.statistical_profiler.stat` to check the args each stat class takes.

This is a more complex analysis than the previous pass, and thus it would require you to pass in additional arguments for this pass.

### Example: the range of weights & input activations of nodes

Say we want to collect the tensor-wise min-max range of the `torch.nn.Linear` nodes' weights & bias, and the channel-wise 97% quantile min-max of the `torch.nn.ReLU` node's input activations. We can do the following:

In [13]:
pass_args = {
    "by": "type",                                                            # collect statistics by node name
    "target_weight_nodes": ["linear"],                                       # collect weight statistics for linear layers
    "target_activation_nodes": ["relu"],                                     # collect activation statistics for relu layers
    "weight_statistics": {
        "variance_precise": {"device": "cpu", "dims": "all"},                # collect precise variance of the weight
    },
    "activation_statistics": {
        "range_quantile": {"device": "cpu", "dims": "all", "quantile": 0.97} # collect 97% quantile of the activation range
    },
    "input_generator": input_generator,                                      # the input generator for feeding data to the model
    "num_samples": 32,                                                       # feed 32 samples to the model
}

We can use the `report_node_meta_param_analysis_pass` to inspect the collected statistics.

In [14]:
mg, _ = profile_statistics_analysis_pass(mg, pass_args)
mg, _ = report_node_meta_param_analysis_pass(mg, {"which": ("software",)})

Profiling weight statistics: 100%|██████████| 29/29 [00:00<00:00, 2238.20it/s]
Profiling act statistics: 100%|██████████| 1/1 [00:04<00:00,  4.58s/it]


RuntimeError: quantile() input tensor is too large

## Running a Transform pass: Quantisation

As its name suggests, the transform pass would modify the `MaseGraph`.
Similar to the previous analysis pass example, we would need to first declare the configuration for the pass.

In [17]:
pass_args = {
"by": "type",
"default": {"config": {"name": None}},
"conv2d": {
        "config": {
            "name": "integer",
            # data
            "data_in_width": 8,
            "data_in_frac_width": 4,
            # weight
            "weight_width": 8,
            "weight_frac_width": 4,
            # bias
            "bias_width": 8,
            "bias_frac_width": 4,
        }
},
}

We can then proceed to apply the transformation, in this case, we kept the original graph on purpose, so that we can print a `diff`.

In [18]:
from chop.passes.graph.transforms import (
    quantize_transform_pass,
    summarize_quantization_analysis_pass,
)
from chop.ir.graph.mase_graph import MaseGraph


ori_mg = MaseGraph(model=model)
ori_mg, _ = init_metadata_analysis_pass(ori_mg, None)
ori_mg, _ = add_common_metadata_analysis_pass(ori_mg, {"dummy_in": dummy_in})

mg, _ = quantize_transform_pass(mg, pass_args)
summarize_quantization_analysis_pass(ori_mg, mg, save_dir="quantize_summary")

[32mINFO    [0m [34mQuantized graph histogram:[0m
[32mINFO    [0m [34m
| Original type   | OP           |   Total |   Changed |   Unchanged |
|-----------------+--------------+---------+-----------+-------------|
| BatchNorm2d     | batch_norm2d |       6 |         0 |           6 |
| Conv2d          | conv2d       |       6 |         6 |           0 |
| Linear          | linear       |       3 |         3 |           0 |
| MaxPool2d       | max_pool2d   |       3 |         0 |           3 |
| ReLU            | relu         |       8 |         0 |           8 |
| output          | output       |       1 |         0 |           1 |
| view            | view         |       1 |         0 |           1 |
| x               | placeholder  |       1 |         0 |           1 |[0m


In [39]:
##6.
from chop.passes.graph.analysis.quantization.calculate_avg_bits import calculate_avg_bits_mg_analysis_pass
from chop.ir.graph.mase_graph import MaseGraph

ori_mg = MaseGraph(model=model)
ori_mg, _ = init_metadata_analysis_pass(ori_mg, None)
ori_mg, _ = add_common_metadata_analysis_pass(ori_mg, {"dummy_in": dummy_in})


result_graph, analysis_results = calculate_avg_bits_mg_analysis_pass(ori_mg, {})
print("Average bits for w:", analysis_results["w_avg_bit"])

result_graph, analysis_results = calculate_avg_bits_mg_analysis_pass(mg, pass_args)
print("Average bits for w:", analysis_results["w_avg_bit"])








Average bits for w: 32.0
Average bits for w: 12.0


In [26]:
from chop.passes.graph.analysis.quantization.calculate_avg_bits import calculate_avg_bits_mg_analysis_pass
from chop.ir.graph.mase_graph import MaseGraph


def calculate_bits_mg_analysis_pass(graph, pass_args: dict):

    

    for node in graph.fx_graph.nodes:
        mase_meta = node.meta["mase"].parameters
        mase_op = mase_meta["common"]["mase_op"]
        mase_type = mase_meta["common"]["mase_type"]

    
        if mase_type in ["module", "module_related_func"]:
           if mase_op in ["linear", "conv2d", "conv1d"]:
              w_meta = mase_meta["common"]["args"]["weight"]
              # Display the weight metadata
              print(f"Operation: {mase_op}, Node: {node.name}")
              print(f"  Weight Shape: {w_meta['shape']}")
              print(f"  Precision: {w_meta['precision']} bits")
              print(f"  Weight : {w_meta['value']}")
              print()  # Just for better readability between nodes

print ("original graph")
calculate_bits_mg_analysis_pass(ori_mg, {})
    

print("modified grapgh")
calculate_bits_mg_analysis_pass(mg, pass_args)
   

original graph
Operation: linear, Node: seq_blocks_2
  Weight Shape: [5, 16]
  Precision: [32] bits
  Weight : Parameter containing:
tensor([[-1.3568e-01,  2.9449e-01, -1.0031e-01, -7.3900e-02,  7.2525e-03,
          1.5064e-01,  1.0131e-01,  3.3722e-01,  7.9390e-02,  2.4162e-01,
          1.0258e-01,  1.1960e-01, -6.6988e-02, -3.4505e-03,  3.3916e-02,
          1.9480e-01],
        [ 2.6927e-01, -1.6469e-02, -3.3183e-01, -2.1290e-01, -5.3244e-02,
          1.0379e-01,  9.6389e-04,  2.5864e-01,  1.1306e-02,  7.5775e-02,
          2.7047e-01, -1.0713e-01, -1.1219e-01,  5.2345e-02, -2.6532e-01,
          5.1085e-02],
        [-2.1783e-02,  4.3771e-02, -1.7906e-01, -4.0549e-01, -1.8858e-01,
          9.8225e-02,  1.2975e-01,  5.0967e-02,  2.7775e-02, -1.6300e-01,
          7.3984e-02, -2.4694e-01, -2.3154e-01, -1.8903e-01,  1.5337e-01,
         -1.6191e-02],
        [-1.1567e-01, -7.7866e-02,  7.1458e-02,  1.5914e-01,  3.0221e-02,
         -3.9674e-02,  1.4878e-01, -1.8178e-01,  2.9177e-0

In [44]:
#7.
from chop.passes.graph.analysis.flop_estimator.calculator.calc_modules import calculate_modules
from chop.passes.graph.analysis.quantization.calculate_avg_bits import calculate_avg_bits_mg_analysis_pass
from chop.ir.graph.mase_graph import MaseGraph


def FLOP_count(graph, pass_args: dict):
    flop = 0
    for node in graph.fx_graph.nodes:
        mase_meta = node.meta["mase"].parameters
        mase_op = mase_meta["common"]["mase_op"]
        mase_type = mase_meta["common"]["mase_type"]
        if mase_type in ["module", "module_related_func"]:
             m = mase_meta["common"]["args"]["data_in_0"]["type"]
             if m in ["float"]:
            #    n_modual = mase_meta["common"]["mase_type"]["module"]
               in_data = mase_meta["common"]["args"]["data_in_0"]["value"]
               out_data = mase_meta["common"]["results"]["data_out_0"]["value"]
            
               result = calculate_modules(node.meta["mase"].module, [in_data], [out_data])
               if (result != None):
                 flop += (result["computations"] +result["backward_computations"])

               print(flop)

    return flop



           

f = FLOP_count(ori_mg, pass_args)
print (f)

38400
48000
508800.0
528000.0
2371200.0
2409600.0
4252800.0
4329600.0
4348800.0
4809600.0
4848000.0
4857600.0
4929600.0
4932600.0
4932600.0


1.

A sequential representation of the operations in your MyModule's forward method. It starts with the input placeholder, adds the learnable parameter param, passes the result through a linear layer, and finally applies a clamping operation. Each node in the graph corresponds to an operation or method call in your PyTorch model, providing a clear and detailed overview of the computational steps involved.


2. 

profile_statistics_analysis_pass:

This function is likely used for profiling a neural network model.
It analyzes the model graph (mg in your code) to gather statistics about the model's layers and operations.
These statistics might include information such as the number of parameters in each layer, memory usage, computational cost, and execution time.
The pass_args argument could be for providing additional options or configurations for the profiling.
After execution, this function updates the model graph with profiling information which can be used for optimization, debugging, or understanding the model's performance characteristics.

report_node_meta_param_analysis_pass:

This function appears to generate a report based on the analysis of meta-parameters of nodes in the model graph.
The mg parameter would be the model graph, similar to the previous function.
The {"which": ("software",)} argument suggests that this function is filtering or focusing the analysis on certain types of nodes or aspects of the model, possibly those that are software-related. This might mean analyzing aspects like software layers, configurations, or parameters as opposed to hardware-related aspects.
The purpose of this function could be to provide insights into specific parts of the model, such as how certain parameters are set or how different layers are configured, which can be crucial for fine-tuning the model or understanding its behavior.

3.

pass_arg change the liner layer (converted ), the linear layers should be quantized using integer quantization. An 8-bit width is specified for the data with 4 bits used for the fractional part. This means that the values will be represented using fixed-point notation, allowing for 4 bits of precision after the decimal point.