# Use Existing Statistical Profiler without CLI

This tutorial describes how to use the statistical profiler to collect
statistics of parameters and activations of a model.

## Import related packages and machop

In [1]:
import sys
import logging
import os
from pathlib import Path
from pprint import pprint as pp

In [2]:
machop_path = Path(".").resolve().parent.parent.parent.parent.parent /"machop"
assert machop_path.exists(), "Failed to find machop at: {}".format(machop_path)
sys.path.append(str(machop_path))

In [3]:
from chop.models import get_resnet18
from chop.passes.graph.mase_graph import MaseGraph
from chop.tools.get_input import InputGenerator
from chop.passes import (
    add_common_metadata_analysis_pass,
    add_mase_ops_analysis_pass,
    init_metadata_analysis_pass,
    add_software_metadata_analysis_pass,
)
from chop.passes.analysis import (
    report_node_meta_param_analysis_pass,
    profile_statistics_analysis_pass,
)
from chop.passes.transforms.interface import save_node_meta_param_transform_pass
from chop.dataset import MyDataModule, get_dataset_info
from chop.tools.logger import set_logging_verbosity

set_logging_verbosity("info")

## Set up dataset and model

Here we use a trained ResNet18 model and Cifar10 dataset as example. Though the model checkpoint was trained on ImageNet, we just use it to show how to use the statistical profiler.

In [4]:
batch_size = 8
model_name = "resnet18"
dataset_name = "cifar10"

datamodule = MyDataModule(
    model_name=model_name,
    batch_size=batch_size,
    dataset_name=dataset_name,
    workers=os.cpu_count(),
    tokenizer=None,
    max_token_len=None,
)
datamodule.prepare_data()
datamodule.setup()

input_generator = InputGenerator(
    datamodule=datamodule,
    task="cls",
    is_nlp_model=False,
    which_dataloader="train",
)

info = get_dataset_info(dataset_name)
model = get_resnet18(info, pretrained=True)



Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


[32m[2023-07-27 20:35:14][chop.models.vision.resnet.resnet][INFO][0m Pretrained weights loaded into Resnet[0m


## Generate MaseGraph
Statistical profiler is an analysis pass working on a MaseGraph, so we
need to generate the MaseGraph of ResNet18 first.

In [5]:
# Test if the sample from dataloader can be passed to the model
dummy_in = {"x": next(iter(datamodule.train_dataloader()))[0]}
_ = model(**dummy_in)

# generate the mase graph and initialize node metadata
mg = MaseGraph(model=model)
mg = init_metadata_analysis_pass(mg, None)
mg = add_mase_ops_analysis_pass(mg, None)
mg = add_common_metadata_analysis_pass(mg, dummy_in)
mg = add_software_metadata_analysis_pass(mg, None)

## Statistic Class

Here is a list of all the supported statistics. Refer to the `__init__` of statistic classes in `chop.passes.analysis.statistical_profiler.stat` to check the args each stat class takes.

In [6]:
from chop.passes.analysis.statistical_profiler.stat import STAT_NAME_TO_CLS
pp(list(STAT_NAME_TO_CLS.keys()))

['record',
 'variance_online',
 'variance_precise',
 'range_n_sigma',
 'range_min_max',
 'range_quantile']


## Profile statistics

The pass `profile_statistics_analysis_pass` collects statistics of parameters and activations, and save them to node's metadata.

### Example: the range of weights & input activations of a `Conv2d` node

Say we want to collect the tensor-wise min-max range of the 1st `torch.nn.Conv2d` nodes' weights & bias, and the channel-wise 97% quantile min-max of the 1st `torch.nn.Conv2d` nodes' input activations. We can do the following:

> **Note**
> - Tensor-wise min-max of weights: `Conv2d` has a weight tensor of shape `(out_channels, in_channels, kernel_size, kernel_size)`, so we need to reduce all dimensions of the weight tensor.
> - Channel-wise 97% quantile min-max of activations: `Conv2d` has an activation tensor of shape `(batch_size, in_channels, height, width)`, so we need to reduce the `0`-th, `2`-nd, and `3`-rd dimensions of the activation tensor. 97% quantile min-max means we sort the input activation tensors ascendingly, and take the min at 3% of the sorted tensor, and the max at 97% of the sorted tensor.

In [7]:
pass_args = {
    "by": "name", # collect statistics by node name
    "target_weight_nodes": ["conv1"], # the 1st conv2d node name is "conv1"
    "target_activation_nodes": ["conv1"],
    "weight_statistics": {
        # collect the min-max range of the weight tensor
        "range_min_max": {
            "dims": "all", # reduce all dimensions
            "abs": False, # do not take the absolute value before min max reduction
        }
    },
    "activation_statistics": {
        "range_quantile": {
            "dims": [0, 2, 3], # reduced dim = 0, 2, 3. The min-max is a 1D tensor of shape (C_in,)
            "abs": False,
            "quantile": 0.97, # take the 97% quantile
        }
    },
    "input_generator": input_generator, # the input generator for feeding data to the model
    "num_samples": 32, # feed 32 samples to the model
}

We can use the `report_node_meta_param_analysis_pass` to inspect the collected statistics.

Since the input activation of the 1st `Conv2d` node has shape `(1, 3, 32, 32)`, we can see the quantile-based range has 3 elements, corresponding to the 3 channels.

In [8]:
mg = profile_statistics_analysis_pass(mg, pass_args)
mg = report_node_meta_param_analysis_pass(mg, {"which": ("software",)})

Profiling weight statistics: 100%|██████████| 71/71 [00:00<00:00, 11938.09it/s]
Profiling act statistics: 100%|██████████| 4/4 [00:00<00:00, 37.55it/s]
[32m[2023-07-27 20:35:15][chop.passes.analysis.report.report_node][INFO][0m Inspecting graph [add_common_meta_param_analysis_pass][0m
[32m[2023-07-27 20:35:15][chop.passes.analysis.report.report_node][INFO][0m 
+-----------------------+---------------+---------------------+---------------------+--------------------------------------------------------------------------------------+
| Node name             | Fx Node op    | Mase type           | Mase op             | Software Param                                                                       |
| x                     | placeholder   | placeholder         | placeholder         | {'results': {'data_out_0': {'stat': {}}}}                                            |
+-----------------------+---------------+---------------------+---------------------+----------------------------

### A More Complex Example

Here is an example of collecting various statistics of a ResNet18 model by node type.

In [9]:
pass_args = {
    "by": "type",
    "target_weight_nodes": [
        "linear",
        "conv2d",
        "batch_norm2d",
        "adaptive_avg_pool2d",
        "relu",
    ],
    "target_activation_nodes": [
        "linear",
        "conv2d",
        "batch_norm2d",
        "adaptive_avg_pool2d",
        "relu",
    ],
    "weight_statistics": {
        # "record": {"device": "cuda"},
        "variance_online": {"device": "cuda", "dims": "all"},
        "variance_precise": {"device": "cuda", "dims": "all"},
        "range_n_sigma": {
            "device": "cuda",
            "dims": "all",
            "abs": False,
            "var_mode": "precise",
            "num_sigma": 3,
        },
        "range_min_max": {"device": "cuda", "dims": "all", "abs": False},
        "range_quantile": {
            "device": "cuda",
            "dims": "all",
            "abs": False,
            "quantile": 0.97,
        },
    },
    "activation_statistics": {
        # "record": {"device": "cuda"},
        "variance_online": {"device": "cuda", "dims": "all"},
        "variance_precise": {"device": "cuda", "dims": "all"},
        "range_n_sigma": {
            "device": "cuda",
            "dims": "all",
            "abs": False,
            "var_mode": "precise",
            "num_sigma": 3,
        },
        "range_min_max": {"device": "cuda", "dims": "all", "abs": False},
        "range_quantile": {
            "device": "cuda",
            "dims": "all",
            "abs": False,
            "quantile": 0.97,
        },
    },
    "input_generator": input_generator,
    "num_samples": 32,
}


Clear the meta data of the MaseGraph first.

Then we can collect the statistics of the model.

In [10]:
# clear the metadata
mg = init_metadata_analysis_pass(mg, None)
mg = add_mase_ops_analysis_pass(mg, None)
mg = add_common_metadata_analysis_pass(mg, dummy_in)
mg = add_software_metadata_analysis_pass(mg, None)

We can save the collected statistics to a toml file with `save_node_param_transform_pass`.

In [11]:
mg = profile_statistics_analysis_pass(mg, pass_args)
mg = report_node_meta_param_analysis_pass(mg, {"which": ("software",)})
mg = save_node_meta_param_transform_pass(
    mg,
    "./node_software_meta_param_no_CLI.toml",
)

Profiling weight statistics: 100%|██████████| 71/71 [00:02<00:00, 34.78it/s]
Profiling act statistics: 100%|██████████| 4/4 [00:01<00:00,  2.45it/s]
[32m[2023-07-27 20:35:19][chop.passes.analysis.report.report_node][INFO][0m Inspecting graph [add_common_meta_param_analysis_pass][0m
[32m[2023-07-27 20:35:19][chop.passes.analysis.report.report_node][INFO][0m 
+-----------------------+---------------+---------------------+---------------------+-------------------------------------------------------------------------------------------------+
| Node name             | Fx Node op    | Mase type           | Mase op             | Software Param                                                                                  |
| x                     | placeholder   | placeholder         | placeholder         | {'results': {'data_out_0': {'stat': {}}}}                                                       |
+-----------------------+---------------+---------------------+--------------------