# PyTorch 前端

In [1]:
import numpy as np
import torch
from torch import nn
from torch.ao.quantization import get_default_qat_qconfig_mapping
from torch.ao.quantization.quantize_fx import prepare_qat_fx, convert_fx
from tqdm import tqdm
torch.manual_seed(0)


class Demo(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = nn.Conv2d(16, 64, 3, 1, 1, bias=False, groups=16)
        # self.prelu = nn.PReLU(64)
        self.relu = nn.ReLU()

    def forward(self, x: torch.Tensor):
        x = self.conv(x)
        # x = self.prelu(x)
        x = self.relu(x)
        return x

In [2]:
model = Demo()
shape = 1, 16, 32, 32
example_inputs = [torch.rand(*shape),]
# script_module = torch.jit.trace(model.eval(), example_inputs)
model_qat = torch.fx.symbolic_trace(model)
model_qat = torch.fx.GraphModule(model_qat, model_qat.graph)
qconfig_mapping = get_default_qat_qconfig_mapping("qnnpack")
model_prepared = prepare_qat_fx(model_qat, qconfig_mapping, example_inputs).eval()
model_converted = convert_fx(model_prepared).eval()
script_module = torch.jit.trace(model_converted.eval(), example_inputs).eval()
input_infos = [("data", shape),]
default_dtype = "float32"



In [4]:
import set_env

In [5]:
from tvm.relay.frontend.pytorch import from_pytorch

input_infos = [("data", shape),]
mod, params = from_pytorch(
    script_module, input_infos,
    custom_convert_map=None,
    default_dtype='float32',
    use_parser_friendly_name=False,
    keep_quantized_weight=False
)

In [6]:
print(mod["main"])

fn (%data: Tensor[(1, 16, 32, 32), float32] /* span=aten::quantize_per_tensor_0.data:0:0 */, %conv_weight: Tensor[(64, 1, 3, 3), float32] /* span=quantized::conv2d_relu_0:0:0 */) {
  %0 = qnn.quantize(%data, 1f /* span=aten::quantize_per_tensor_0:0:0 */, 0 /* span=aten::quantize_per_tensor_0:0:0 */, out_dtype="uint8", axis=1) /* span=aten::quantize_per_tensor_0:0:0 */;
  %1 = nn.pad(%0, 0f /* span=quantized::conv2d_relu_0:0:0 */, pad_width=[[0, 0], [0, 0], [1, 1], [1, 1]]) /* span=quantized::conv2d_relu_0:0:0 */;
  %2 = qnn.quantize(%conv_weight, 0.00261353f /* span=quantized::conv2d_relu_0:0:0 */, 0 /* span=quantized::conv2d_relu_0:0:0 */, out_dtype="int8", axis=0) /* span=quantized::conv2d_relu_0:0:0 */;
  %3 = qnn.conv2d(%1, %2, 0 /* span=quantized::conv2d_relu_0:0:0 */, 0 /* span=quantized::conv2d_relu_0:0:0 */, 1f /* span=quantized::conv2d_relu_0:0:0 */, 0.00261353f /* span=quantized::conv2d_relu_0:0:0 */, padding=[0, 0, 0, 0], groups=16, channels=64, kernel_size=[3, 3], out_dtype

In [7]:
graph = script_module.graph.copy()
graph_inputs = list(graph.inputs())

In [8]:
graph_inputs

[self.1 defined in (%self.1 : __torch__.torch.fx.graph_module.GraphModule, %x : Float(1, 16, 32, 32, strides=[16384, 1024, 32, 1], requires_grad=0, device=cpu) = prim::Param()
 ),
 x defined in (%self.1 : __torch__.torch.fx.graph_module.GraphModule, %x : Float(1, 16, 32, 32, strides=[16384, 1024, 32, 1], requires_grad=0, device=cpu) = prim::Param()
 )]