Skip to content

Commit

Permalink
Cut dependencies and clean up Arm backend unit tester (pytorch#2231)
Browse files Browse the repository at this point in the history
Summary:
bypass-github-pytorch-ci-checks
bypass-github-export-checks

Pull Request resolved: pytorch#2231

Reviewed By: mergennachin

Differential Revision: D54640970

Pulled By: digantdesai

fbshipit-source-id: 5bab38b60cff1ceb74d1a0b06694e240af1ba9d1
  • Loading branch information
freddan80 authored and facebook-github-bot committed Mar 11, 2024
1 parent 47d2737 commit 70c5be3
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 142 deletions.
4 changes: 1 addition & 3 deletions backends/arm/test/ops/test_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _test_add_tosa_BI_pipeline(
.to_executorch()
)
if TOSA_REF_MODEL_INSTALLED:
tester.run_method().compare_outputs()
tester.run_method().compare_outputs(qtol=1)
else:
logger.warning(
"TOSA ref model tool not installed, skip numerical correctness tests"
Expand Down Expand Up @@ -118,8 +118,6 @@ def test_add_tosa_MI(self):
test_data = (torch.randn(4, 4, 4),)
self._test_add_tosa_MI_pipeline(self.Add(), test_data)

# TODO: Will this type of parametrization be supported? pytest seem
# have issue with it.
@parameterized.expand(
[
(torch.ones(5),), # test_data
Expand Down
152 changes: 111 additions & 41 deletions backends/arm/test/tester/arm_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.

from enum import Enum
from typing import Optional, Tuple
from typing import List, Optional, Tuple, Union

import torch
from executorch.backends.arm.arm_backend import (
Expand All @@ -15,6 +15,7 @@
from executorch.backends.arm.arm_partitioner import ArmPartitioner

from executorch.backends.arm.test.tosautil.tosa_test_utils import (
QuantizationParams,
TosaProfile,
TosaTestUtils,
)
Expand All @@ -32,6 +33,7 @@
get_symmetric_quantization_config,
XNNPACKQuantizer,
)
from torch.export import ExportedProgram


class ArmBackendSelector(Enum):
Expand Down Expand Up @@ -61,6 +63,7 @@ def __init__(
TosaProfile.BI or TosaProfile.MI
"""
self.tosa_test_util = None
self.is_quantized = profile == TosaProfile.BI
if backend == ArmBackendSelector.TOSA:
self.tosa_test_util = TosaTestUtils(profile=profile)
# The spec below tiggers arm_backend.py to output two files:
Expand Down Expand Up @@ -119,54 +122,121 @@ def run_method(
), "self.tosa_test_util is not initialized, cannot use run_method()"
inputs_to_run = inputs or self.inputs

# TODO: we can't possible need to use all these stages??
export_stage = self.stages[
self.stage_name(Export)
] # this is what XNNpack use to get quant params
toedge_stage = self.stages[
self.stage_name(ToEdge)
] # this is what get_input_quantization_params use to get quant params
partition_stage = self.stages[
self.stage_name(Partition)
] # this is what tosa_ref_dump_inputs use....

# TODO: I'd prefer to use this TOSA buffer instead of output.tosa,
# generated by arm_backend.py. The issue is that we're still depending
# on desc.json, which is created from TosaSerializer class, not from
# the serialized TOSA buffer. Leave this here for review purposes.
# ts_serialized = self._get_serialized_tosa_buffer( # unused
# partition_stage.artifact
# )

# This is where the torch reference output is calculated and set
# TODO: This sets self.quantization_scale, which is duplicates
# self.tosa_test_util.quantization.output.scales (?). Fixme.
(
self.reference_output,
self.quantization_scale,
) = self._calculate_reference_output(export_stage.artifact, inputs_to_run)

# Convert the torch inputs to something TOSA ref model can use
tensor_names_and_inputs_np = self.tosa_test_util.convert_inputs_to_tosa(
partition_stage.artifact, toedge_stage.artifact, inputs_to_run
export_stage = self.stages[self.stage_name(Export)]

(input_names, qp_input) = self._get_input_params(export_stage.artifact)
(output_name, qp_output) = self._get_output_param(export_stage.artifact)

# Calculate the reference output using the original module or the quant
# module. self.quantization_scale is used by compare_outputs() to
# calculate the tolerance
self.quantization_scale = None if qp_output is None else qp_output.scale
if self.is_quantized:
module_for_ref = self.stages[self.stage_name(Quantize)].artifact
else:
module_for_ref = self.original_module
self.reference_output = self._calculate_reference_output(
module_for_ref, inputs_to_run
)

# Run the TOSA ref model to get the output tensor, which will be
# compared to the torch output in compare_outputs()
self.stage_output = self.tosa_test_util.run_tosa_ref_model(
tensor_names_and_inputs_np
params_input=(input_names, qp_input),
param_output=(output_name, qp_output),
inputs=inputs_to_run,
)

return self

def _get_serialized_tosa_buffer(self, partition_stage: Partition) -> bytes:
def _get_input_params(
self, program: ExportedProgram
) -> Tuple[str, Union[List[QuantizationParams], List[None]]]:
"""
This is just a prototype...
Todo:
* The "_0" indicates that there are many lowered modules. Loop it!
* There's probably a better way to get this buffer. An API? Yes,
it seems the serialize stage does this for you...
Get name and optionally quantization parameters for the inputs to this
model.
Args:
program (ExportedProgram): The program to get input parameters from
Returns:
Tuple[str, Optional[QuantizationParams]]: A tuple containing the
input node names and their quantization parameters.
"""
input_names = []
# E.g. bias and weights are 'placeholders' as well. This is used to
# get only the use inputs.
usr_inputs = program.graph_signature.user_inputs
for node in program.graph.nodes:
if node.op == "placeholder" and node.name in usr_inputs:
input_names.append(node.name)
continue

if self.is_quantized:
quant_params = []
for node in program.graph.nodes:
if (
node.target
== torch.ops.quantized_decomposed.quantize_per_tensor.default
and node.args[0].name in input_names
):
qp = QuantizationParams(
node_name=node.args[0].name, scale=node.args[1], zp=node.args[2]
)
quant_params.append(qp)
if len(quant_params) == len(
input_names
): # break early if we have all the inputs quantized parameters
break
assert len(quant_params) != 0, "Quantization paramerters not found"
return (input_names, quant_params)
else:
return (input_names, len(input_names) * [None]) # return a list of None's

def _get_output_param(
self, program: ExportedProgram
) -> Tuple[str, Union[QuantizationParams, None]]:
"""
return partition_stage._edge_programs[
"forward"
]._graph_module.lowered_module_0.processed_bytes
Get name and optionally quantization parameters for the inputs to this
model.
Args:
program (ExportedProgram): The program to get output parameters from.
Returns:
Tuple[str, Optional[QuantizationParams]]: A tuple containing the
output node name and its quantization parameters.
"""
output_node = None
for node in program.graph.nodes:
if node.op == "output":
output_node = node
break

if self.is_quantized:
quant_params = None
for node in program.graph.nodes:
if (
node.target
== torch.ops.quantized_decomposed.dequantize_per_tensor.default
and node == output_node.args[0][0]
):
quant_params = QuantizationParams(
node_name=node.args[0].name, scale=node.args[1], zp=node.args[2]
)
break # break early, there's only one output node
assert quant_params is not None, "Quantization paramerters not found"
return (output_node.name, quant_params)
else:
return (output_node.name, None)

@staticmethod
def _calculate_reference_output(
module: Union[torch.fx.GraphModule, torch.nn.Module], inputs
) -> torch.Tensor:
"""
Note: I'd prefer to use the base class method here, but since it use the
exported program, I can't. The partitioner stage clears the state_dict
of the exported program, which causes an issue when evaluating the
module.
"""

return module.forward(*inputs)

0 comments on commit 70c5be3

Please sign in to comment.