# PyTorch 前端

In [1]:
import numpy as np
import torch
from torch import nn
from torch.ao.quantization import get_default_qat_qconfig_mapping
from torch.ao.quantization.quantize_fx import prepare_qat_fx, convert_fx
from tqdm import tqdm
torch.manual_seed(0)


class Demo(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = nn.Conv2d(16, 64, 3, 1, 1, bias=False, groups=16)
        # self.prelu = nn.PReLU(64)
        self.relu = nn.ReLU()

    def forward(self, x: torch.Tensor):
        x = self.conv(x)
        # x = self.prelu(x)
        x = self.relu(x)
        return x

In [2]:
model = Demo()
shape = 1, 16, 32, 32
example_inputs = [torch.rand(*shape),]
# script_module = torch.jit.trace(model.eval(), example_inputs)
model_qat = torch.fx.symbolic_trace(model)
model_qat = torch.fx.GraphModule(model_qat, model_qat.graph)
qconfig_mapping = get_default_qat_qconfig_mapping("qnnpack")
model_prepared = prepare_qat_fx(model_qat, qconfig_mapping, example_inputs).eval()
model_converted = convert_fx(model_prepared).eval()
script_module = torch.jit.trace(model_converted.eval(), example_inputs).eval()
input_infos = [("data", shape),]
default_dtype = "float32"



In [3]:
from tvm.relay.frontend.pytorch import from_pytorch

input_infos = [("data", shape),]
mod, params = from_pytorch(
    script_module, input_infos,
    custom_convert_map=None,
    default_dtype='float32',
    use_parser_friendly_name=False,
    keep_quantized_weight=False
)

In [4]:
graph = script_module.graph.copy()
graph_inputs = list(graph.inputs())

In [5]:
graph_inputs

[self.1 defined in (%self.1 : __torch__.torch.fx.graph_module.GraphModule, %x : Float(1, 16, 32, 32, strides=[16384, 1024, 32, 1], requires_grad=0, device=cpu) = prim::Param()
 ),
 x defined in (%self.1 : __torch__.torch.fx.graph_module.GraphModule, %x : Float(1, 16, 32, 32, strides=[16384, 1024, 32, 1], requires_grad=0, device=cpu) = prim::Param()
 )]

In [6]:
ROOT = "/media/pc/data/lxw/ai/tvm/xinetzone/tvm-book/doc/tutorials/relay/frontend/draft/resnet18_cifar10_relu_qat"
script_module = torch.jit.load(f"{ROOT}/weight/resnet18_cifar10_relu_qat.h5").eval()

In [7]:
input_infos = [("data", (1, 3, 224, 224)),]
mod, params = from_pytorch(
    script_module, input_infos,
    custom_convert_map=None,
    default_dtype='float32',
    use_parser_friendly_name=False,
    keep_quantized_weight=False
)

KeyError: 'conv1_input_zero_point_0'

In [10]:
import tvm
from tvm.relay.frontend import qnn_torch
from tvm.relay.frontend.pytorch import (
    _run_jit_passes,
    Prelude, PyTorchOpConverter,
    get_all_op_names,
    _get_relay_input_vars,
    _debug_rename,
    convert_params,
    _get_output_name,
    get_attr_chains,
    _getattr_full_name,
    _get_users,
    getattr_attr_name,
    _get_tensor_and_var
)


In [11]:
use_parser_friendly_name = False

In [12]:
mod = tvm.IRModule()
prelude = Prelude(mod)
enable_lower_all_tuples = True

converter = PyTorchOpConverter(prelude, default_dtype, use_parser_friendly_name)
graph = script_module.graph.copy()
graph_inputs = list(graph.inputs())
_run_jit_passes(graph, enable_lower_all_tuples)
op_names = get_all_op_names(graph)
converter.report_missing_conversion(op_names)
is_module = isinstance(script_module, torch.jit.ScriptModule)
params = script_module.state_dict() if is_module else {}
outputs = _get_relay_input_vars(
    graph, input_infos, prelude, default_dtype=default_dtype, is_module=is_module
)
source_map = _debug_rename(graph, use_parser_friendly_name)
param_vars, tensors, packed_param_map, param_debug_name_map = convert_params(
    graph, params, source_map, use_parser_friendly_name
)
tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()}
outputs.update(param_vars)
quantized_ops = set(["aten::quantize_per_tensor", "quantized::linear_dynamic"])
if len(quantized_ops.intersection(set(op_names))) > 0:
    weight_quant_params = qnn_torch.get_weight_quant_params(
        script_module, packed_param_map.values()
    )
    qnn_torch.inline_input_quant_params_for_fx(graph, tensors, param_debug_name_map)

KeyError: 'conv1_input_zero_point_0'

In [None]:
# state_dict = params
# getattr_nodes = graph.findAllNodes("prim::GetAttr", recurse=True)
# params = {}
# param_tensors = {}
# packed_param_map = {}
# param_debug_name_map = {}
# vars_by_name = {}
# seen = set()
# attr_name_sep = "_" if use_parser_friendly_name else "."

# for node in getattr_nodes:
#     if _get_output_name(node) in seen:
#         continue

#     for getattrs in get_attr_chains(node):
#         seen.update(map(_get_output_name, getattrs))

#         full_attr = _getattr_full_name(getattrs, attr_name_sep)
#         full_attr_node_name = _get_output_name(getattrs[-1])
#         print(full_attr, full_attr_node_name)
#         # set variable name by concatenating first consumer's name with full attribute
#         # e.g. "aten::batch_norm_5.running_mean"
#         var_name = attr_name_sep.join(
#             [source_map[_get_users(getattrs[-1])[0]], full_attr.split(attr_name_sep)[-1]]
#         )

#         if full_attr.endswith("_packed_params"):  # for quantized models
#             packed_param_map[full_attr_node_name] = full_attr
#         elif full_attr in state_dict:
#             if var_name in vars_by_name:
#                 var = vars_by_name[var_name]
#             else:
#                 torch_tensor = state_dict[full_attr]
#                 tensor, var = _get_tensor_and_var(torch_tensor, var_name)
#                 param_tensors[var_name] = tensor
#                 # for quantized parameters to be correctly located
#                 param_debug_name_map[full_attr_node_name] = var_name
#                 vars_by_name[var_name] = var
#             params[full_attr_node_name] = var

In [13]:
def get_full_attr_name(current):
    current_attr = getattr_attr_name(current)
    inputs = list(current.inputs())
    # logging.debug(f"current_attr: {current_attr}")
    if len(inputs) == 1:
        # logging.debug(f"get_full_attr_name(inputs[0].node()): {inputs[0].node()}")
        if inputs[0].node().kind() == "prim::GetAttr":
            return get_full_attr_name(inputs[0].node()) + "." + current_attr
        elif inputs[0].node().kind() == "prim::Param":
            return current_attr + ".1"
    return current_attr

In [14]:
for node in graph.findAllNodes("prim::GetAttr", recurse=True):
    out_name = node.output().debugName()
    if "_scale" in out_name or "_zero_point" in out_name:
        full_attr = param_debug_name_map[get_full_attr_name(node)]
        assert full_attr in params, f"{full_attr} not found in param dict."
        param_np = params[full_attr].asnumpy()
        new_const_node = graph.create("prim::Constant")
        new_const_node.insertBefore(node)
        break

In [15]:
full_attr 

'aten::quantize_per_tensor_0.conv1_input_zero_point_0'

In [None]:
current = node
getattr_attr_name(current)

In [None]:
current = node
current_attr = getattr_attr_name(current)
inputs = list(current.inputs())
input_node = inputs[0].node()

In [None]:
input_node.kind()

In [None]:
for getattrs in get_attr_chains(input_node):
    break

In [None]:
getattrs

In [None]:
get_full_attr_name(node)

In [None]:
node.output().debugName()

In [None]:
full_attr = param_debug_name_map[]

In [None]:
source_map, op_type_dict = {}, {}
prim_with_blocks = ["prim::If", "prim::Loop"]

In [None]:
for node in graph.nodes():
    if node.outputsSize() == 0:
        continue
    if node.kind() in prim_with_blocks:
        for block in node.blocks():
            _traverse_graph(block.nodes())
    _rename_outputs(node, source_map, op_type_dict, use_parser_friendly_name)

In [None]:
node.outputsSize()