Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Please support Softmax for QAT #38

Closed
steven0129 opened this issue Feb 24, 2022 · 3 comments
Closed

Please support Softmax for QAT #38

steven0129 opened this issue Feb 24, 2022 · 3 comments
Labels
bug Something isn't working

Comments

@steven0129
Copy link
Contributor

Execute a python file like below:

import torch
from torch import nn
from tinynn.converter import TFLiteConverter
from tinynn.graph.quantization.quantizer import QATQuantizer
from tinynn.graph.tracer import model_tracer
from tinynn.util.train_util import DLContext, get_device, train

class DummyNet(nn.Module):
    def __init__(self, num_classes=4):
        super(DummyNet, self).__init__()
        
        self.input_channel = 1
        self.base_channel = 4

        def conv_bn(inp, oup, stride):
            return nn.Sequential(
                nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
                nn.BatchNorm2d(oup),
                nn.ReLU(inplace=True)
            )

        def conv_dw(inp, oup, stride):
            return nn.Sequential(
                nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
                nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
                nn.ReLU(inplace=True)
            )

        self.model = nn.Sequential(
            conv_bn(self.input_channel, self.base_channel, 2), 
            conv_dw(self.base_channel,  self.base_channel * 2, 1),
            conv_dw(self.base_channel * 2, self.base_channel * 4, 2),
            conv_dw(self.base_channel * 4, self.base_channel * 8, 2),
            conv_dw(self.base_channel * 8, self.base_channel * 16, 2),
            conv_dw(self.base_channel * 16, self.base_channel * 16, 1),
            conv_dw(self.base_channel * 16, self.base_channel * 16, 1),
            conv_dw(self.base_channel * 16, self.base_channel * 32, 2),
            conv_dw(self.base_channel * 32, self.base_channel * 32, 1),
            nn.AvgPool2d(kernel_size=(3, 6)),
            nn.Flatten(),
            nn.Linear(self.base_channel * 32, num_classes)
        )

        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.model(x)
        x = self.softmax(x)

        return x

    def predict(self, x):
        x = self.forward(x)
        x = torch.argmax(x, dim=1)
        
        return x

if __name__ == '__main__':
    with model_tracer():
        model = DummyNet()
        model.eval()

        dummy_input = torch.rand((1, 1, 135, 240))
        quantizer = QATQuantizer(model, dummy_input, work_dir='out')
        qat_model = quantizer.quantize()

        device = get_device()
        qat_model.to(device=device)

        with torch.no_grad():
            qat_model.eval()
            qat_model.cpu()
            qat_model = torch.quantization.convert(qat_model)
            torch.backends.quantized.engine = 'qnnpack'
            converter = TFLiteConverter(qat_model, dummy_input, tflite_path='out/dummy_qat.tflite')
            converter.convert()

And then I got below error:

  File "/root/miniconda3/lib/python3.7/site-packages/torch/jit/_trace.py", line 744, in trace
    _module_class,
  File "/root/miniconda3/lib/python3.7/site-packages/torch/jit/_trace.py", line 959, in trace_module
    argument_names,
  File "/root/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/root/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1039, in _slow_forward
    result = self.forward(*input, **kwargs)  File "out/dummynet_qat.py", line 96, in forward    softmax = self.softmax(model_11)  File "/root/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/root/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1039, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/root/miniconda3/lib/python3.7/site-packages/torch/nn/modules/activation.py", line 1256, in forward
    return F.softmax(input, self.dim, _stacklevel=5)
  File "/root/miniconda3/lib/python3.7/site-packages/torch/nn/functional.py", line 1679, in softmax
    ret = input.softmax(dim)
NotImplementedError: Could not run 'aten::_softmax' with arguments from the 'QuantizedCPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::_softmax' is only available for these backends: [CPU, CUDA, MkldnnCPU, BackendSelect, Named, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, UNKNOWN_TENSOR_TYPE_ID, AutogradMLC, AutogradHPU, AutogradNestedTensor, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, Tracer, Autocast, Batched, VmapMode].

CPU: registered at /opt/conda/conda-bld/pytorch_1623448265233/work/build/aten/src/ATen/RegisterCPU.cpp:16286 [kernel]
CUDA: registered at /opt/conda/conda-bld/pytorch_1623448265233/work/build/aten/src/ATen/RegisterCUDA.cpp:20674 [kernel]
MkldnnCPU: registered at /opt/conda/conda-bld/pytorch_1623448265233/work/build/aten/src/ATen/RegisterMkldnnCPU.cpp:563 [kernel]
BackendSelect: fallthrough registered at /opt/conda/conda-bld/pytorch_1623448265233/work/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Named: registered at /opt/conda/conda-bld/pytorch_1623448265233/work/aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
ADInplaceOrView: fallthrough registered at /opt/conda/conda-bld/pytorch_1623448265233/work/aten/src/ATen/core/VariableFallbackKernel.cpp:60 [backend fallback]
AutogradOther: registered at /opt/conda/conda-bld/pytorch_1623448265233/work/torch/csrc/autograd/generated/VariableType_0.cpp:9848 [autograd kernel]
AutogradCPU: registered at /opt/conda/conda-bld/pytorch_1623448265233/work/torch/csrc/autograd/generated/VariableType_0.cpp:9848 [autograd kernel]
AutogradCUDA: registered at /opt/conda/conda-bld/pytorch_1623448265233/work/torch/csrc/autograd/generated/VariableType_0.cpp:9848 [autograd kernel]
AutogradXLA: registered at /opt/conda/conda-bld/pytorch_1623448265233/work/torch/csrc/autograd/generated/VariableType_0.cpp:9848 [autograd kernel]
UNKNOWN_TENSOR_TYPE_ID: registered at /opt/conda/conda-bld/pytorch_1623448265233/work/torch/csrc/autograd/generated/VariableType_0.cpp:9848 [autograd kernel]
AutogradMLC: registered at /opt/conda/conda-bld/pytorch_1623448265233/work/torch/csrc/autograd/generated/VariableType_0.cpp:9848 [autograd kernel]
AutogradHPU: registered at /opt/conda/conda-bld/pytorch_1623448265233/work/torch/csrc/autograd/generated/VariableType_0.cpp:9848 [autograd kernel]
AutogradNestedTensor: registered at /opt/conda/conda-bld/pytorch_1623448265233/work/torch/csrc/autograd/generated/VariableType_0.cpp:9848 [autograd kernel]
AutogradPrivateUse1: registered at /opt/conda/conda-bld/pytorch_1623448265233/work/torch/csrc/autograd/generated/VariableType_0.cpp:9848 [autograd kernel]
AutogradPrivateUse2: registered at /opt/conda/conda-bld/pytorch_1623448265233/work/torch/csrc/autograd/generated/VariableType_0.cpp:9848 [autograd kernel]
AutogradPrivateUse3: registered at /opt/conda/conda-bld/pytorch_1623448265233/work/torch/csrc/autograd/generated/VariableType_0.cpp:9848 [autograd kernel]
Tracer: registered at /opt/conda/conda-bld/pytorch_1623448265233/work/torch/csrc/autograd/generated/TraceType_0.cpp:9750 [kernel]
Autocast: fallthrough registered at /opt/conda/conda-bld/pytorch_1623448265233/work/aten/src/ATen/autocast_mode.cpp:255 [backend fallback]
Batched: registered at /opt/conda/conda-bld/pytorch_1623448265233/work/aten/src/ATen/BatchingRegistrations.cpp:1019 [backend fallback]
VmapMode: fallthrough registered at /opt/conda/conda-bld/pytorch_1623448265233/work/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]

It looks like Softmax is not implemented for QAT. If quantized Softmax is not supported, floating Softmax is good to me.
Could you implement it?

Many Thanks

@peterjc123
Copy link
Collaborator

@steven0129 Having a look now. please wait for a moment.

@peterjc123 peterjc123 added the bug Something isn't working label Feb 24, 2022
@peterjc123
Copy link
Collaborator

@steven0129 Should be fixed by 911791d. Would you please try again?

@steven0129
Copy link
Contributor Author

@peterjc123 Bug is fixed. This issue can be closed.
Many Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants