In [31]:
import sys
import logging
import os
from pathlib import Path
from pprint import pprint as pp

# # figure out the correct path
machop_path = Path(".").resolve().parent.parent /"machop"
assert machop_path.exists(), "Failed to find machop at: {}".format(machop_path)
sys.path.append(str(machop_path))

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch_tensorrt

from torch.utils.tensorboard import SummaryWriter

import pytorch_quantization
from pytorch_quantization import nn as quant_nn
from pytorch_quantization import quant_modules
from pytorch_quantization.tensor_quant import QuantDescriptor
from pytorch_quantization import calib
from tqdm import tqdm

print(pytorch_quantization.__version__)

from chop.dataset import MaseDataModule, get_dataset_info
from chop.tools.logger import set_logging_verbosity

from chop.tools import get_cf_args, get_dummy_input, load_config
from chop.passes.graph import (
    save_node_meta_param_interface_pass,
    report_node_meta_param_analysis_pass,
    profile_statistics_analysis_pass,
    add_common_metadata_analysis_pass,
    init_metadata_analysis_pass,
    add_software_metadata_analysis_pass,
    quantize_tensorrt_transform_pass,
    test_quantize_tensorrt_transform_pass,
    fake_quantize_transform_pass,
    graph_calibration_pass,
    evaluate_fake_quantize_pass,
    fake_quantize_to_trt_pass
)
from chop.tools.get_input import InputGenerator
from chop.tools.checkpoint_load import load_model
from chop.ir import MaseGraph

from chop.models import get_model_info, get_model, get_tokenizer

set_logging_verbosity("info")


[32mINFO    [0m [34mSet logging level to info[0m
I0318 07:16:54.181084 140265329104704 logger.py:44] Set logging level to info


2.1.3


In [32]:
batch_size = 8
model_name = "vgg7"
dataset_name = "cifar10"

# batch_size = 1
# model_name = "facebook/opt-125m:patched"
# dataset_name = "cola"


data_module = MaseDataModule(
    name=dataset_name,
    batch_size=batch_size,
    model_name=model_name,
    num_workers=0,
)
data_module.prepare_data()
data_module.setup()

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [33]:
# 📝️ change this CHECKPOINT_PATH to the one you trained in Lab1
CHECKPOINT_PATH = "/home/qizhu/Desktop/Work/mase/mase_output/test-accu-0.9332.ckpt"
# CHECKPOINT_PATH = "/home/qizhu/Desktop/Work/mase/mase_output/opt125.ckpt"

model_info = get_model_info(model_name)
# quant_modules.initialize()
model = get_model(
    model_name,
    task="cls",
    dataset_info=data_module.dataset_info,
    pretrained=False)

model = load_model(load_name=CHECKPOINT_PATH, load_type="pl", model=model)

[32mINFO    [0m [34mLoaded pytorch lightning checkpoint from /home/qizhu/Desktop/Work/mase/mase_output/test-accu-0.9332.ckpt[0m
I0318 07:17:02.502797 140265329104704 checkpoint_load.py:85] Loaded pytorch lightning checkpoint from /home/qizhu/Desktop/Work/mase/mase_output/test-accu-0.9332.ckpt


In [34]:
mg = MaseGraph(model=model)
ori_mg = MaseGraph(model=model)

# get the input generator
input_generator = InputGenerator(
    data_module=data_module,
    model_info=model_info,
    task="cls",
    which_dataloader="train",
)

# a demonstration of how to feed an input value to the model
dummy_in = next(iter(input_generator))
# _ = model(**dummy_in)


mg, _ = init_metadata_analysis_pass(mg, None)
mg, _ = add_common_metadata_analysis_pass(mg, {"dummy_in": dummy_in})
mg, _ = add_software_metadata_analysis_pass(mg, None)

tensor([[[[-2.1179, -2.1179, -2.1179,  ..., -0.4397, -0.7993, -0.5938],
          [-2.1179, -2.1179, -2.1179,  ..., -0.6109, -0.7993, -0.7308],
          [-2.1179, -2.1179, -2.1179,  ..., -0.8164, -1.1075, -1.1075],
          ...,
          [-2.1179, -2.1179, -2.1179,  ..., -0.1486, -0.4226, -0.4568],
          [-2.1179, -2.1179, -2.1179,  ..., -0.2856, -0.5596, -0.6452],
          [-2.1179, -2.1179, -2.1179,  ..., -2.1179, -2.1179, -2.1179]],

         [[-2.0357, -2.0357, -2.0357,  ..., -0.3901, -0.8102, -0.6176],
          [-2.0357, -2.0357, -2.0357,  ..., -0.5826, -0.7927, -0.7227],
          [-2.0357, -2.0357, -2.0357,  ..., -0.7402, -1.0203, -1.0378],
          ...,
          [-2.0357, -2.0357, -2.0357,  ...,  0.0826, -0.2150, -0.2325],
          [-2.0357, -2.0357, -2.0357,  ..., -0.0224, -0.3025, -0.3725],
          [-2.0357, -2.0357, -2.0357,  ..., -2.0357, -2.0357, -2.0357]],

         [[-1.8044, -1.8044, -1.8044,  ..., -1.1421, -1.3687, -1.0201],
          [-1.8044, -1.8044, -

In [5]:
pass_args = {
    "by": "name",
    "default": {"config": {"name": None}},
}

mg = fake_quantize_transform_pass(mg, pass_args)
pass_args_eval = {
    "data_module": data_module,
}

mg = evaluate_fake_quantize_pass(mg, pass_args_eval)

Average execute time for one batch: 1.96ms
Total accuracy: 90.20%


In [35]:
pass_args = {
    "by": "name",
    "default": {"config": {"name": None}},
    "feature_layers_0": {
        "config": {
            "name": "int",
            "input": {
                "precesion": 8,
                "calibrator": "max",
                "quantize_axis": None,
            },
            "weight": {
                "calibrator": "histogram",
                "quantize_axis": None,
            },
        }
    },    
    "classifier_0": {
        "config": {
            "name": "fp16",
            "input": {
                "precesion": 8,
                "calibrator": "histogram",
                "quantize_axis": None,
            },
            "weight": {
                "calibrator": "max",
                "quantize_axis": None,
            },
        }
    },
    "classifier_1": {
        "config": {
            "name": "fp16",
            "input": {
                "precesion": 8,
                "calibrator": "histogram",
                "quantize_axis": None,
            },
            "weight": {
                "calibrator": "max",
                "quantize_axis": None,
            },
        }
    },
    "classifier_2": {
        "config": {
            "name": "fp16",
            "input": {
                "precesion": 8,
                "calibrator": "histogram",
                "quantize_axis": None,
            },
            "weight": {
                "calibrator": "max",
                "quantize_axis": None,
            },
        }
    },
    "classifier_3": {
        "config": {
            "name": "fp16",
            "input": {
                "precesion": 8,
                "calibrator": "histogram",
                "quantize_axis": None,
            },
            "weight": {
                "calibrator": "max",
                "quantize_axis": None,
            },
        }
    },
}

mg = fake_quantize_transform_pass(mg, pass_args)


In [36]:
pass_args_calibrate = {
    "calibrator": "percentile",
    "percentiles": [99],
    "data_module": data_module,
    "num_batches": 100,
}

In [12]:
for node in mg.fx_graph.nodes:
    print(node.type)

<class 'torch.Tensor'>
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
<class 'torch.Tensor'>


In [37]:
graph_calibration_pass(mg,  pass_args_calibrate)
pass_args_eval = {
    "data_module": data_module,
}

mg = evaluate_fake_quantize_pass(mg, pass_args_eval)

RuntimeError: mat1 and mat2 must have the same dtype

In [10]:
mg.model

GraphModule(
  (feature_layers): Module(
    (0): QuantConv2d(
      3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
      (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=4.9597 calibrator=MaxCalibrator scale=1.0 quant)
      (_weight_quantizer): TensorQuantizer(8bit fake per-tensor amax=0.2797 calibrator=HistogramCalibrator scale=1.0 quant)
    )
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): BatchNorm2d(256, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_

In [11]:
pass_args = {
    "onnxFile": "onnx_test.onnx",
    "engineFile": "engine_test.plan",
    "dataloader": data_module.test_dataloader,
}
mg = fake_quantize_to_trt_pass(mg, pass_args)

  if min_amax < 0:
  max_bound = torch.tensor((2.0**(num_bits - 1 + int(unsigned))) - 1.0, device=amax.device)
  if min_amax <= epsilon:  # Treat amax smaller than minimum representable of fp16 0
  if min_amax <= epsilon:


verbose: False, log level: Level.ERROR

Succeeded finding ONNX file!
Succeeded parsing .onnx file!
Succeeded building engine!
[ 0]Input -> DataType.FLOAT (-1, 3, 32, 32) (8, 3, 32, 32) input
[ 1]Output-> DataType.FLOAT (-1, 10) (8, 10) output
Succeeded running model in TensorRT!
Average execute time for one batch: 0.09ms
Total accuracy: 86.37%


In [10]:
pass_args_eval = {
    "data_module": data_module,
}

mg = evaluate_fake_quantize_pass(mg, pass_args_eval)

Average execute time for one batch: 1.13ms
Total accuracy: 90.11%


In [11]:
for node in mg.fx_graph.nodes:
    print(node.name)

x
feature_layers_0
feature_layers_1
feature_layers_2
feature_layers_3
feature_layers_4
feature_layers_5
feature_layers_6
feature_layers_7
feature_layers_8
feature_layers_9
feature_layers_10
feature_layers_11
feature_layers_12
feature_layers_13
feature_layers_14
feature_layers_15
feature_layers_16
feature_layers_17
feature_layers_18
feature_layers_19
feature_layers_20
view
classifier_0
classifier_1
classifier_2
classifier_3
last_layer
output


# OPT-125M


In [2]:
wikitext_info = get_dataset_info("wikitext2")
opt = get_model(
    "facebook/opt-125m:patched",
    task="lm",
    dataset_info=wikitext_info,
    pretrained=True,
)
opt_tokenizer = get_tokenizer("facebook/opt-125m:patched")

print(f"prepare data")
# Get data module for dummy inputs
data_module = MaseDataModule(
    name="wikitext2",
    batch_size=2,
    num_workers=os.cpu_count(),
    max_token_len=128,
    tokenizer=opt_tokenizer,
    load_from_cache_file=True,
    model_name="facebook/opt-125m@patched",
)
data_module.prepare_data()
data_module.setup()

  return self.fget.__get__(instance, owner)()


prepare data


In [26]:
model_info = get_model_info("facebook/opt-125m:patched")
cf_args = get_cf_args(model_info=model_info, task="lm", model=opt)

mg = MaseGraph(model=opt, cf_args=cf_args)

# dummy_in = get_dummy_input(model_info, data_module=data_module, task="lm")

input_generator = InputGenerator(
    data_module=data_module,
    model_info=model_info,
    task="lm",
    which_dataloader="train",
)

# a demonstration of how to feed an input value to the model
dummy_in = next(iter(input_generator))
if len(mg.model.additional_inputs) > 0:
    dummy_in = dummy_in | mg.model.additional_inputs

# Generate graph and initialize metadata
# print(f"init metadata")
mg, _ = init_metadata_analysis_pass(mg, pass_args=None)
# mg, _ = add_common_metadata_analysis_pass(mg, pass_args={"dummy_in": dummy_in})
# mg, _ = add_software_metadata_analysis_pass(mg, None)

# for node in mg.fx_graph.nodes:
#     print(node.meta['mase'].module)


In [30]:
# print(dummy_in)
mg.model.graph.print_tabular()
mg.model.recompile()

opcode         name                                           target                                                                      args                                                                                                  kwargs
-------------  ---------------------------------------------  --------------------------------------------------------------------------  ----------------------------------------------------------------------------------------------------  -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
placeholder    input_ids                                      input_ids                                                                   (None,)                                                                                               {}
placeholder    attention_mask                                 at

PythonCode(src='\ntorch.fx._symbolic_trace.wrap("chop_models_patched_opt_patched_utils_opt_patched_opt_patched_shape_assertion_4")\ntorch.fx._symbolic_trace.wrap("chop_models_patched_opt_patched_utils_opt_patched_opt_patched_fn_prepare_decoder_attention_mask")\ntorch.fx._symbolic_trace.wrap("chop_models_patched_opt_patched_utils_opt_patched_opt_patched_fn_calculate_causal_lm_loss")\n\ndef forward(self, input_ids : torch.LongTensor = None, attention_mask : typing_Union[torch.Tensor,NoneType] = None, head_mask_1 = None, past_key_values_1 = None, inputs_embeds_1 = None, labels : typing_Union[torch.LongTensor,NoneType] = None, use_cache_1 = None, output_attentions_1 = None, output_hidden_states_1 = None, return_dict_1 = None) -> typing_Union[typing_Tuple,transformers_modeling_outputs_CausalLMOutputWithPast]:\n    _assert_is_none = torch.fx._symbolic_trace._assert_is_none(head_mask_1, \'head_mask has been specialized to have value None but got another value\');  head_mask_1 = None\n    _ass

In [29]:
mg.model(**dummy_in)

{'loss': 0,
 'logits': tensor([[[ -4.1079,  -4.1031,   5.2360,  ...,  -4.1251,  -4.1582,  -4.2522],
          [ -7.0824,  -7.0917,   2.2125,  ...,  -6.9990,  -6.9708,  -6.8121],
          [ -6.8083,  -6.8242,   2.1220,  ...,  -6.8255,  -6.6913,  -7.0274],
          ...,
          [ -8.3876,  -8.3888,   0.4175,  ...,  -8.3798,  -8.4188,  -8.5476],
          [ -7.1151,  -7.1061,  -1.0842,  ...,  -7.1819,  -7.1834,  -7.0982],
          [ -8.3191,  -8.3353,   0.6130,  ...,  -8.3036,  -8.3008,  -8.3952]],
 
         [[ -6.9732,  -6.9660,   1.8757,  ...,  -6.9888,  -6.9423,  -7.0070],
          [ -9.4373,  -9.4374,   2.1275,  ...,  -9.5177,  -9.3029,  -9.7325],
          [ -9.5896,  -9.5981,  -1.2567,  ...,  -9.6837,  -9.4658,  -9.6430],
          ...,
          [ -9.0750,  -9.0574,   2.4102,  ...,  -9.0348,  -9.0470,  -9.0953],
          [ -9.0326,  -9.0244,   5.6410,  ...,  -9.0118,  -9.0628,  -9.1411],
          [-10.8525, -10.8510,   7.0259,  ..., -10.7477, -10.8022, -10.7488]]],
       

In [7]:
for node in mg.fx_graph.nodes:
    args, kwargs = None, None
    if node.op == "placeholder":
        result = dummy_in[node.name]
        print(result)

tensor(..., device='meta', size=(1, 128), dtype=torch.int64)
tensor(..., device='meta', size=(1, 128), dtype=torch.int64)
None
None
None
tensor(..., device='meta', size=(1, 128), dtype=torch.int64)
False
False
False
True


In [5]:
for node in mg.fx_graph.nodes:
    # if node.op == "call_module":
    print(node.op)

placeholder
placeholder
placeholder
call_function
placeholder
call_function
placeholder
call_function
placeholder
placeholder
call_function
call_function
placeholder
call_function
call_function
placeholder
call_function
call_function
placeholder
call_function
call_function
call_function
call_function
call_method
call_module
call_function
call_function
call_function
call_function
call_function
call_module
call_function
call_module
call_module
call_function
call_function
call_function
call_function
call_function
call_method
call_method
call_module
call_module
call_module
call_module
call_function
call_function
call_function
call_method
call_module
call_module
call_function
call_function
call_function
call_function
call_function
call_method
call_method
call_module
call_module
call_module
call_module
call_function
call_function
call_function
call_method
call_module
call_module
call_function
call_function
call_function
call_function
call_function
call_method
call_method
call_module
call_mod

In [25]:
import time
for i, batch in enumerate(data_module.val_dataloader()):
    # print(i, batch.keys())
    if len(mg.model.additional_inputs) > 0:
        batch = batch | mg.model.additional_inputs
    curr_time = time.time()
    outputs = mg.model(**batch)
    print(f"Batch {i} took {time.time() - curr_time} seconds")
    # print(outputs)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)

In [11]:
for k, v in dummy_in.items():
    print(k, v)

input_ids tensor(..., device='meta', size=(1, 128), dtype=torch.int64)
attention_mask tensor(..., device='meta', size=(1, 128), dtype=torch.int64)
labels tensor(..., device='meta', size=(1, 128), dtype=torch.int64)


In [28]:
dummy_in = {
    k: v.cuda() if isinstance(v, torch.Tensor) else v
    for k, v in dummy_in.items()
}

In [18]:
pass_args = {
    "precision": 'int8',                                                     # collect weight statistics for linear layers
    "nCalibration": 10,                                                # collect activation statistics for relu layers
    "dummy_in": dummy_in,
    "onnxFile": 'model.onnx',
    "cacheFile": 'model.INT8Cache',  
    "engineFile": 'model.plan'
}
engine = quantize_tensorrt_transform_pass(mg, pass_args)

  assert condition, message
  assert (
  if input_shape[-1] > 1:
  assert attn_weights.size() == (
  assert attention_mask.size() == (
  return torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
  assert attn_output.size() == (


verbose: False, log level: Level.ERROR



RuntimeError: output 1 (0
[ CPULongType{} ]) of traced region did not have observable data dependence with trace inputs; this probably indicates your program cannot be understood by the tracer.