Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Frontend][QNN] fix access param_debug_name_map to node output name in fx-quantized graph node replacement #16217

Merged

Conversation

PineApple777
Copy link
Contributor

@PineApple777 PineApple777 commented Dec 8, 2023

This PR solves the problem of not finding the key of param_debug_name_map with "name", an attribute of prim::GetAttr, when inline_input_quant_params_for_fx replacing prim::GetAttr with prim::Constant. This problem occurs when the model is quantized by each sub-models with quantize_fx, and the zero point and scale variable names are different from the name, the attribute of prim::GetAttr. To solve this, when modifying all prim::GetAttr to prim::Constant, modify to access param_debug_name_map by using the node.output.debugName() . The reason is that nodes.output().debugName() as shown pattern like .<number> at the end of their names, which distinguishes same name and that can be found in param_debug_name_map

  • simple example
import torch
import tvm.relay as relay
from torch.ao.quantization import quantize_fx, get_default_qconfig_mapping, get_default_qconfig

class SimpleExample(torch.nn.Module):
    def __init__(self, in_feature, out_feature):
        super(SimpleExample, self).__init__()
        self.simple_dense_1 = torch.nn.Linear(in_feature, out_feature)
        self.simple_dense_2 = torch.nn.Linear(out_feature, out_feature)
        
    def forward(self, x):
        x = self.simple_dense_1(x)
        x = self.simple_dense_2(x)
        return x
    
class SimpleModel(torch.nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.layer = torch.nn.ModuleList()
        self.layer.append(SimpleExample(128, 128))
        self.layer.append(SimpleExample(128, 128))
        self.layer.append(SimpleExample(128, 128))

    def forward(self, x):
        x = self.layer[0](x)
        x = self.layer[1](x)
        x = self.layer[2](x)
        return x
    
model = SimpleModel()
random_sample = torch.randn([128, 128])

default_qconfig_mapping = get_default_qconfig_mapping()
qconfig = get_default_qconfig()
default_qconfig_mapping.set_global(qconfig)

# quantize per layer
for child_names, child_mods in model.named_children():
    for child_name, child_mod in child_mods.named_children():
        mod_int8 = quantize_fx.prepare_fx(child_mod, qconfig_mapping=default_qconfig_mapping, example_inputs=random_sample)
        qmod = quantize_fx.convert_fx(mod_int8, qconfig_mapping=default_qconfig_mapping)
        model.layer[int(child_name)] = qmod

# simple random data calibrate
for i in range(100):
    calib_sample = torch.randn([128, 128])
    model(calib_sample)

scripted_model = torch.jit.trace(model, example_inputs=random_sample)

input_infos = [
    ("x", ([128, 128], "float32"))
]
tvm_model = relay.frontend.from_pytorch(scripted_model, input_infos=input_infos, keep_quantized_weight=True)

When quantization is performed for each sub-model, different zero point and scale values ​​are calibrated for each layer.

  • simple example error log
Traceback (most recent call last):
  File "example.py", line 67, in <module>
    tvm_model = relay.frontend.from_pytorch(scripted_model, input_infos=input_infos, keep_quantized_weight=True)
  File "/home/sunwook/tvm/python/tvm/relay/frontend/pytorch.py", line 5385, in from_pytorch
    qnn_torch.inline_input_quant_params_for_fx(graph, tensors, param_debug_name_map)
  File "/home/sunwook/tvm/python/tvm/relay/frontend/qnn_torch.py", line 555, in inline_input_quant_params_for_fx
    full_attr = param_debug_name_map[get_full_attr_name(node)]
KeyError: 'layer.0.simple_dense_1_input_zero_point_0'

we can find simple_dense_1_input_zero_point_0.7 in param_debug_name_map, but original code recursively find all attribute name in children model, for example, layer.0.simple_dense_1_input_zero_point_0.

  • converted torchscript
...
// simple_dense_1_input_zero_point_0.7 != layer.0.simple_dense_1_input_zero_point_0
%simple_dense_1_input_zero_point_0.7 : Tensor = prim::GetAttr[name="simple_dense_1_input_zero_point_0"](%_0)
%simple_dense_1_input_scale_0.7 : Tensor = prim::GetAttr[name="simple_dense_1_input_scale_0"](%_0)
%18 : QUInt8(128, 128, strides=[128, 1], requires_grad=0, device=cpu) = aten::quantize_per_tensor(%x.1, %simple_dense_1_input_scale_0.7, %simple_dense_1_input_zero_point_0.7, %13)
...
// simple_dense_1_input_zero_point_0.9 != layer.1.simple_dense_1_input_zero_point_0
%simple_dense_1_input_zero_point_0.9 : Tensor = prim::GetAttr[name="simple_dense_1_input_zero_point_0"](%_1)
%simple_dense_1_input_scale_0.9 : Tensor = prim::GetAttr[name="simple_dense_1_input_scale_0"](%_1)
%33 : QUInt8(128, 128, strides=[128, 1], requires_grad=0, device=cpu) = aten::quantize_per_tensor(%x.3, %simple_dense_1_input_scale_0.9, %simple_dense_1_input_zero_point_0.9, %28)
...
  • keys in param_debug_name_map
dict_keys(['simple_dense_1_input_zero_point_0', 'simple_dense_1_input_scale_0', 'simple_dense_1_input_zero_point_0.9', 'simple_dense_1_input_scale_0.9', 'simple_dense_1_input_zero_point_0.7', 'simple_dense_1_input_scale_0.7'])

@PineApple777 PineApple777 changed the title [Relay][Frontend][QNN] fix getting full node attribute name in fx-based quantized graphs [Relay][Frontend][QNN] Fix getting full node attribute name in fx-based quantized graphs Dec 8, 2023
@PineApple777 PineApple777 changed the title [Relay][Frontend][QNN] Fix getting full node attribute name in fx-based quantized graphs [Relay][Frontend][QNN] fix getting full node attribute name in quantized fx-graph Dec 23, 2023
@PineApple777 PineApple777 changed the title [Relay][Frontend][QNN] fix getting full node attribute name in quantized fx-graph [Relay][Frontend][QNN] fix access param_debug_name_map to node output name in fx-quantized graph node replacement Dec 24, 2023
@PineApple777 PineApple777 marked this pull request as ready for review December 24, 2023 12:02
@PineApple777
Copy link
Contributor Author

This PR quite simple bugfix about the mismatch of accessing parameter keys. could you please review this pr @masahi

@masahi masahi merged commit 506eff2 into apache:main Dec 27, 2023
25 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants