In [1]:
import sys
import logging
import os
from pathlib import Path
from pprint import pprint as pp

In [2]:
%pwd

'/mnt/d/imperial/second_term/adls/rs1923/mase_real'

In [3]:
# figure out the correct path
machop_path = Path(".").resolve().parent.parent /"rs1923/mase_real/machop"
assert machop_path.exists(), "Failed to find machop at: {}".format(machop_path)
sys.path.append(str(machop_path))

In [4]:
from chop.dataset import MaseDataModule, get_dataset_info
from chop.tools.logger import set_logging_verbosity

from chop.passes.graph import (
    save_node_meta_param_interface_pass,
    report_node_meta_param_analysis_pass,
    profile_statistics_analysis_pass,
    add_common_metadata_analysis_pass,
    init_metadata_analysis_pass,
    add_software_metadata_analysis_pass,
)
from chop.tools.get_input import InputGenerator
from chop.tools.checkpoint_load import load_model
from chop.ir import MaseGraph

from chop.models import get_model_info, get_model

set_logging_verbosity("debug") # shift to "debug" mode

  from .autonotebook import tqdm as notebook_tqdm


Total number of JSC_1923 parameters: 3285
Total number of JSC_Tiny parameters: 117


[32mINFO    [0m [34mSet logging level to debug[0m


In [5]:
# set up the dataset

batch_size = 8
model_name = "jsc-tiny"
#model_name = "jsc-rs1923"
dataset_name = "jsc"

data_module = MaseDataModule(
    name=dataset_name,
    batch_size=batch_size,
    model_name=model_name,
    num_workers=0,
)
data_module.prepare_data()
data_module.setup()

In [6]:
# set up the model

model_info = get_model_info(model_name)
model = get_model(
    model_name,
    task="cls",
    dataset_info=data_module.dataset_info,
    pretrained=False,
    checkpoint=None)

# we use the one that we've trained for 10 epochs on colab
CHECKPOINT_PATH = "./mase_output/jsc-tiny_classification_jsc_2024-02-03/software/training_ckpts/best.ckpt"  #for JSC-Tiny
#CHECKPOINT_PATH = "./mase_output/jsc-rs1923_classification_jsc_2024-02-05/software/training_ckpts/best.ckpt"   #for JSC-rs1923

own_model = load_model(load_name=CHECKPOINT_PATH, load_type="pl", model=model)

[32mINFO    [0m [34mLoaded pytorch lightning checkpoint from ./mase_output/jsc-tiny_classification_jsc_2024-02-03/software/training_ckpts/best.ckpt[0m


In [7]:
# generate the mase graph and initialize node metadata
mg = MaseGraph(model = own_model)

In [8]:
# get the input generator
input_generator = InputGenerator(
    data_module=data_module,
    model_info=model_info,
    task="cls",
    which_dataloader="train",
)

# a demonstration of how to feed an input value to the model
dummy_in = next(iter(input_generator))
_ = model(**dummy_in)

In [9]:
dummy_in['x'].shape

torch.Size([8, 1, 16])

In [10]:
mg, _ = init_metadata_analysis_pass(mg, None)
mg, _ = add_common_metadata_analysis_pass(mg, {"dummy_in": dummy_in})
mg, _ = add_software_metadata_analysis_pass(mg, None)

[36mDEBUG   [0m [34mgraph():
    %x : [num_users=1] = placeholder[target=x]
    %conv1 : [num_users=1] = call_module[target=conv1](args = (%x,), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.nn.functional.relu](args = (%conv1,), kwargs = {inplace: False})
    %conv2 : [num_users=1] = call_module[target=conv2](args = (%relu,), kwargs = {})
    %relu_1 : [num_users=2] = call_function[target=torch.nn.functional.relu](args = (%conv2,), kwargs = {inplace: False})
    %block_conv1 : [num_users=1] = call_module[target=block.conv1](args = (%relu_1,), kwargs = {})
    %relu_2 : [num_users=1] = call_function[target=torch.nn.functional.relu](args = (%block_conv1,), kwargs = {inplace: False})
    %block_conv2 : [num_users=1] = call_module[target=block.conv2](args = (%relu_2,), kwargs = {})
    %add : [num_users=1] = call_function[target=operator.add](args = (%block_conv2, %relu_1), kwargs = {})
    %relu_3 : [num_users=1] = call_function[target=torch.nn.functional.relu](ar

In [11]:
# report graph is an analysis pass that shows you the detailed information in the graph
from chop.passes import report_graph_analysis_pass
_ = report_graph_analysis_pass(mg)

graph():
    %x : [num_users=1] = placeholder[target=x]
    %conv1 : [num_users=1] = call_module[target=conv1](args = (%x,), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.nn.functional.relu](args = (%conv1,), kwargs = {inplace: False})
    %conv2 : [num_users=1] = call_module[target=conv2](args = (%relu,), kwargs = {})
    %relu_1 : [num_users=2] = call_function[target=torch.nn.functional.relu](args = (%conv2,), kwargs = {inplace: False})
    %block_conv1 : [num_users=1] = call_module[target=block.conv1](args = (%relu_1,), kwargs = {})
    %relu_2 : [num_users=1] = call_function[target=torch.nn.functional.relu](args = (%block_conv1,), kwargs = {inplace: False})
    %block_conv2 : [num_users=1] = call_module[target=block.conv2](args = (%relu_2,), kwargs = {})
    %add : [num_users=1] = call_function[target=operator.add](args = (%block_conv2, %relu_1), kwargs = {})
    %relu_3 : [num_users=1] = call_function[target=torch.nn.functional.relu](args = (%add,), kwargs = 

In [15]:
# at present, designed only for JSC-Tiny 

pass_args = {
    "by": "type",                                                            # collect statistics by node name
    "target_weight_nodes": ["linear"],                                       # collect weight statistics for linear layers
    "target_activation_nodes": ["relu"],                                     # collect activation statistics for relu layers
    "weight_statistics": {
        "variance_precise": {"device": "cpu", "dims": "all"},                # collect precise variance of the weight
    },
    "activation_statistics": {
        "range_quantile": {"device": "cpu", "dims": "all", "quantile": 0.97} # collect 97% quantile of the activation range
    },
    "input_generator": input_generator,                                      # the input generator for feeding data to the model
    "num_samples": 32,                                                       # feed 32 samples to the model
}

In [16]:
# at present, designed only for JSC-Tiny

mg, _ = profile_statistics_analysis_pass(mg, pass_args)
mg, _ = report_node_meta_param_analysis_pass(mg, {"which": ("software",)})

Profiling weight statistics: 100%|██████████| 6/6 [00:00<00:00, 4629.47it/s]
Profiling act statistics: 100%|██████████| 4/4 [00:00<00:00, 143.95it/s]
[32mINFO    [0m [34mInspecting graph [add_common_meta_param_analysis_pass][0m
[32mINFO    [0m [34m
+--------------+--------------+---------------------+--------------+------------------------------------------------------------------------------------------+
| Node name    | Fx Node op   | Mase type           | Mase op      | Software Param                                                                           |
| x            | placeholder  | placeholder         | placeholder  | {'results': {'data_out_0': {'stat': {}}}}                                                |
+--------------+--------------+---------------------+--------------+------------------------------------------------------------------------------------------+
| seq_blocks_0 | call_module  | module              | batch_norm1d | {'args': {'bias': {'stat': {}},    

In [17]:
# we shift from software to common, which offers more information

mg, _ = report_node_meta_param_analysis_pass(mg, {"which": ("common",)})

[32mINFO    [0m [34mInspecting graph [add_common_meta_param_analysis_pass][0m
[32mINFO    [0m [34m
+--------------+--------------+---------------------+--------------+-----------------------------------------------------------------------------------------------------------------------+
| Node name    | Fx Node op   | Mase type           | Mase op      | Common Param                                                                                                          |
| x            | placeholder  | placeholder         | placeholder  | {'args': {},                                                                                                          |
|              |              |                     |              |  'mase_op': 'placeholder',                                                                                            |
|              |              |                     |              |  'mase_type': 'placeholder',                                         

In [18]:
# transformation pass for JSC-Tiny

pass_args = {
"by": "type",
"default": {"config": {"name": None}},

"linear": {  
        "config": {
            "name": "integer",
            # data
            "data_in_width": 8,
            "data_in_frac_width": 4,
            # weight
            "weight_width": 8,
            "weight_frac_width": 4,
            # bias
            "bias_width": 8,
            "bias_frac_width": 4,
        }
},
}

In [13]:
# transformation pass for JSC-rs1923

pass_args = {
"by": "type",
"default": {"config": {"name": None}},

"conv1d": {
    "config": {
        "name": "integer",
        # data
        "data_in_width": 8,
        "data_in_frac_width": 4,
        # weight
        "weight_width": 8,
        "weight_frac_width": 4,
        # bias
        "bias_width": 8,
        "bias_frac_width": 4,
    }
},

"linear": {  
        "config": {
            "name": "integer",
            # data
            "data_in_width": 8,
            "data_in_frac_width": 4,
            # weight
            "weight_width": 8,
            "weight_frac_width": 4,
            # bias
            "bias_width": 8,
            "bias_frac_width": 4,
        }
},
}

In [14]:
from chop.passes.graph.transforms import (
    quantize_transform_pass,
    summarize_quantization_analysis_pass,
)
from chop.ir.graph.mase_graph import MaseGraph


ori_mg = MaseGraph(model=own_model)
ori_mg, _ = init_metadata_analysis_pass(ori_mg, None)
ori_mg, _ = add_common_metadata_analysis_pass(ori_mg, {"dummy_in": dummy_in})
ori_mg, _ = add_software_metadata_analysis_pass(ori_mg, None)

mg = MaseGraph(model=own_model)
mg, _ = init_metadata_analysis_pass(mg, None)
mg, _ = add_common_metadata_analysis_pass(mg, {"dummy_in": dummy_in})
mg, _ = add_software_metadata_analysis_pass(mg, None)
mg, _ = quantize_transform_pass(mg, pass_args)

summarize_quantization_analysis_pass(ori_mg, mg, save_dir="quantize_summary")

[36mDEBUG   [0m [34mgraph():
    %x : [num_users=1] = placeholder[target=x]
    %conv1 : [num_users=1] = call_module[target=conv1](args = (%x,), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.nn.functional.relu](args = (%conv1,), kwargs = {inplace: False})
    %conv2 : [num_users=1] = call_module[target=conv2](args = (%relu,), kwargs = {})
    %relu_1 : [num_users=2] = call_function[target=torch.nn.functional.relu](args = (%conv2,), kwargs = {inplace: False})
    %block_conv1 : [num_users=1] = call_module[target=block.conv1](args = (%relu_1,), kwargs = {})
    %relu_2 : [num_users=1] = call_function[target=torch.nn.functional.relu](args = (%block_conv1,), kwargs = {inplace: False})
    %block_conv2 : [num_users=1] = call_module[target=block.conv2](args = (%relu_2,), kwargs = {})
    %add : [num_users=1] = call_function[target=operator.add](args = (%block_conv2, %relu_1), kwargs = {})
    %relu_3 : [num_users=1] = call_function[target=torch.nn.functional.relu](ar

In [15]:
# Exercise 1,2,3 have no code, they are analysis and have been on the report

In [21]:
# Exercise 4: traverse both mg and ori_mg

from chop.passes.graph.utils import get_mase_op, get_mase_type, get_node_actual_target

def get_type_str(node):
    if node.op == "call_module":
        return type(get_node_actual_target(node)).__name__
    else:
        return node.target

In [22]:
import logging
import os

import numpy as np
import pandas as pd
from tabulate import tabulate

logger = logging.getLogger(__name__)

headers = [
    "Ori name",
    "New name",
    "Node_OP",
    "MASE_TYPE",
    "Mase_OP",
    "Original type",
    "Quantized type",
    "Changed",
]

rows=[]
for ori_n, n in zip(ori_mg.fx_graph.nodes, mg.fx_graph.nodes):
    rows.append(
            [
                ori_n.name,
                n.name,
                n.op,
                get_mase_type(n),
                get_mase_op(n),
                get_type_str(ori_n),
                get_type_str(n),
                type(get_node_actual_target(n)) != type(get_node_actual_target(ori_n)),
            ]
        )

logger.debug("Compare nodes:")
logger.debug("\n" + tabulate(rows, headers=headers, tablefmt="orgtbl"))

df = pd.DataFrame(rows, columns=headers)


In [23]:
df

Unnamed: 0,Ori name,New name,Node_OP,MASE_TYPE,Mase_OP,Original type,Quantized type,Changed
0,x,x,placeholder,placeholder,placeholder,x,x,False
1,seq_blocks_0,seq_blocks_0,call_module,module,batch_norm1d,BatchNorm1d,BatchNorm1d,False
2,seq_blocks_1,seq_blocks_1,call_module,module_related_func,relu,ReLU,ReLU,False
3,seq_blocks_2,seq_blocks_2,call_module,module_related_func,linear,Linear,LinearInteger,True
4,seq_blocks_3,seq_blocks_3,call_module,module_related_func,relu,ReLU,ReLU,False
5,output,output,output,output,output,output,output,False


In [37]:
# df.to_csv("/mnt/d/imperial/thanks.csv")

In [23]:
# 6: Write code to show and verify that the weights of these layers are indeed quantised.

from machop.chop.passes.graph.utils import get_node_actual_target
from machop.chop.passes.graph.utils import get_mase_op
from machop.chop.passes.graph.utils import get_mase_type
import torch

for ori_n, n in zip(ori_mg.fx_graph.nodes, mg.fx_graph.nodes):
    if isinstance(get_node_actual_target(ori_n), torch.nn.modules.Linear): # Linear
        print(ori_n.meta["mase"].module.weight)
        print(n.meta['mase'].module.w_quantizer(n.meta['mase'].module.weight).detach())

# We could see clearly that these weights have been quantized

Parameter containing:
tensor([[ 0.0752, -0.0539,  0.0252,  ...,  0.0037, -0.1073,  0.0500],
        [-0.0059,  0.0253,  0.0007,  ...,  0.0487,  0.0635,  0.0750],
        [ 0.0975,  0.0499, -0.0471,  ...,  0.0235,  0.0838,  0.1013],
        [ 0.0061,  0.0469, -0.0265,  ...,  0.0446,  0.0610, -0.0798],
        [ 0.0059, -0.0739,  0.0598,  ...,  0.0549, -0.0014, -0.0054]],
       requires_grad=True)
tensor([[ 0.0625, -0.0625,  0.0000,  ...,  0.0000, -0.1250,  0.0625],
        [-0.0000,  0.0000,  0.0000,  ...,  0.0625,  0.0625,  0.0625],
        [ 0.1250,  0.0625, -0.0625,  ...,  0.0000,  0.0625,  0.1250],
        [ 0.0000,  0.0625, -0.0000,  ...,  0.0625,  0.0625, -0.0625],
        [ 0.0000, -0.0625,  0.0625,  ...,  0.0625, -0.0000, -0.0000]])


In [24]:
from machop.chop.passes.graph.utils import get_node_actual_target
from machop.chop.passes.graph.utils import get_mase_op
from machop.chop.passes.graph.utils import get_mase_type
import torch

for ori_n, n in zip(ori_mg.fx_graph.nodes, mg.fx_graph.nodes):
    # As we've seen, the convolution and linear modules have changed
    if isinstance(get_node_actual_target(ori_n), torch.nn.modules.Linear): # Linear
        print(f"There is quantization at {n.name}, mase_op: {get_mase_op(n)}")
        print(f"original module: {type(get_node_actual_target(ori_n))}, new_module: {type(get_node_actual_target(n))}")
        print(f"original weight: {get_node_actual_target(ori_n).weight}")
        print(f"quantized weight: {get_node_actual_target(n).w_quantizer(get_node_actual_target(n).weight)}")
        print(f"original bias: {get_node_actual_target(ori_n).bias}")
        print(f"quantized bias: {get_node_actual_target(n).b_quantizer(get_node_actual_target(n).bias)}")

        # generate a random input to a quantized layer for quantisation verification
        random_input = torch.randn(get_node_actual_target(n).in_features)
        print(f'output for original module: {get_node_actual_target(ori_n)(random_input)}')
        print(f'output for quantized module: {get_node_actual_target(n)(random_input)}')
    
    
    if isinstance(get_node_actual_target(ori_n), torch.nn.modules.conv.Conv1d): # Conv1d
        print(f"There is quantization at {n.name}, mase_op: {get_mase_op(n)}")
        print(f"original module: {type(get_node_actual_target(ori_n))}, new_module: {type(get_node_actual_target(n))}")
        print(f"original weight: {get_node_actual_target(ori_n).weight}")
        print(f"quantized weight: {get_node_actual_target(n).w_quantizer(get_node_actual_target(n).weight)}")
        print(f"original bias: {get_node_actual_target(ori_n).bias}")
        print(f"quantized bias: {get_node_actual_target(n).b_quantizer(get_node_actual_target(n).bias)}")
    

There is quantization at conv1, mase_op: conv1d
original module: <class 'torch.nn.modules.conv.Conv1d'>, new_module: <class 'chop.passes.graph.transforms.quantize.quantized_modules.conv1d.Conv1dInteger'>
original weight: Parameter containing:
tensor([[[ 0.0139,  0.4526, -0.4716]],

        [[-0.5095, -0.3401,  0.0407]],

        [[-0.0469,  0.5830, -0.0303]],

        [[ 0.2598, -0.2470, -0.2621]],

        [[-0.5880, -0.5082, -0.3283]],

        [[-0.0397,  0.3216,  0.4720]],

        [[-0.4444, -0.2566,  0.3409]],

        [[ 0.5061, -0.0790,  0.5585]]], requires_grad=True)
quantized weight: tensor([[[ 0.0000,  0.4375, -0.5000]],

        [[-0.5000, -0.3125,  0.0625]],

        [[-0.0625,  0.5625, -0.0000]],

        [[ 0.2500, -0.2500, -0.2500]],

        [[-0.5625, -0.5000, -0.3125]],

        [[-0.0625,  0.3125,  0.5000]],

        [[-0.4375, -0.2500,  0.3125]],

        [[ 0.5000, -0.0625,  0.5625]]], grad_fn=<IntegerQuantizeBackward>)
original bias: Parameter containing:
tensor(

In [26]:
# 7. Load your own pre-trained JSC network, and perform perform the quantisation using the command line interface.

%cd ./machop
!./ch transform --config configs/examples/jsc_rs1923_by_type.toml --task cls --cpu=0

# By default: quantize is by "type" and search is by "name"

/mnt/d/imperial/second_term/adls/rs1923/mase_real/machop
Total number of JSC_1923 parameters: 3285
Total number of JSC_Tiny parameters: 117
Seed set to 0
+-------------------------+--------------------------+--------------------------+-----------------+--------------------------+
| Name                    |         Default          |       Config. File       | Manual Override |        Effective         |
+-------------------------+--------------------------+--------------------------+-----------------+--------------------------+
| task                    |      [38;5;8mclassification[0m      |           [38;5;8mcls[0m            |       cls       |           cls            |
| load_name               |           [38;5;8mNone[0m           | ../mase_output/jsc-rs192 |                 | ../mase_output/jsc-rs192 |
|                         |                          | 3_classification_jsc_202 |                 | 3_classification_jsc_202 |
|                         |                  

In [14]:
# optional task

%pwd

'/mnt/d/imperial/second_term/adls/rs1923/mase_real/machop'

In [15]:
#%cd ./machop
from chop.passes.graph.analysis import add_flops_bitops_analysis_pass
mg = MaseGraph(model=own_model)
mg, _ = init_metadata_analysis_pass(mg, None)
mg, total_flops, total_bitops, _ = add_flops_bitops_analysis_pass(mg)

In [20]:
total_bitops

5099520