Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[converter] support new PyTorch operators #123

Open
8 of 12 tasks
peterjc123 opened this issue Oct 9, 2022 · 9 comments
Open
8 of 12 tasks

[converter] support new PyTorch operators #123

peterjc123 opened this issue Oct 9, 2022 · 9 comments
Labels
enhancement New feature or request

Comments

@peterjc123
Copy link
Collaborator

peterjc123 commented Oct 9, 2022

Below are the PyTorch operators that are yet to be supported.

Unclassfied (New)

N/A

Primitives (Python operators)

  • aten::len

Very easy (Constant generation or aliasing)

  • aten::clamp_min
  • aten::clamp_max
  • aten::expand_as

Easy (Direct mapping)

Medium (Composite of multiple operators)

Hard (No mapping or the mapping is too complex)

@peterjc123 peterjc123 added the enhancement New feature or request label Oct 9, 2022
@deephudka05
Copy link

aten::len ops is not supported when I try to convert my resnet model to tflite

@peterjc123
Copy link
Collaborator Author

peterjc123 commented Nov 14, 2022

aten::len ops is not supported when I try to convert my resnet model to tflite

Could you please tell me what is the corresponding op in PyTorch (e.g. torch.mean)? Looks like you are passing the TorchScript model in scripting mode (via torch.jit.script), which we don't have much support for. The model produced by torch.jit.trace and torch.jit.save is better supported.

Update: aten::len added in a1b3c79

@deephudka05
Copy link

deephudka05 commented Nov 15, 2022

aten::len ops is not supported when I try to convert my resnet model to tflite

Could you please tell me what is the corresponding op in PyTorch (e.g. torch.mean)? Looks like you are passing the TorchScript model in scripting mode (via torch.jit.script), which we don't have much support for. The model produced by torch.jit.trace and torch.jit.save is better supported.

Update: aten::len added in a1b3c79

You are correct I am passing the model in scripting mode. Thanks for the help.

@steven0129
Copy link
Contributor

Any plan to support quantized::instance_norm?

@peterjc123
Copy link
Collaborator Author

Any plan to support quantized::instance_norm?

Would you please give me an example with a TFLite model quantized InstanceNorm layer? Once it is supported in TFLite, we will look into that.

@steven0129
Copy link
Contributor

Any plan to support quantized::instance_norm?

Would you please give me an example with a TFLite model quantized InstanceNorm layer? Once it is supported in TFLite, we will look into that.

@peterjc123 You can use the following code as test case. Many Thanks

import torch
from torch import nn
from tinynn.graph.quantization.quantizer import QATQuantizer
from tinynn.graph.tracer import model_tracer
from tinynn.converter import TFLiteConverter

class ChannelWiseConv(nn.Module):
    def __init__(self, in_channel, out_channel, padding=0, stride=1, dilation=1, bias=True, norm_layer='BN'):
        super().__init__()
        self.dwconv_quant = torch.quantization.QuantStub()
        self.dwconv_dequant = torch.quantization.DeQuantStub()
        self.pwconv_quant = torch.quantization.QuantStub()
        self.pwconv_dequant = torch.quantization.DeQuantStub()
        self.dwconv = nn.Conv2d(
            in_channel,
            in_channel,
            kernel_size=3,
            padding=padding,
            stride=stride,
            dilation=dilation,
            groups=in_channel,
            bias=bias
        )

        self.pwconv = nn.Conv2d(
            in_channel,
            out_channel,
            kernel_size=1,
            bias=bias
        )

        self.norm_layer = norm_layer
        self.bn = nn.BatchNorm2d(out_channel)
        self.instance_norm = nn.InstanceNorm2d(out_channel)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.dwconv_quant(x)
        x = self.dwconv(x)
        x = self.dwconv_dequant(x)
        x = self.pwconv_quant(x)
        x = self.pwconv(x)

        if self.norm_layer == 'BN':
            x = self.bn(x)
        elif self.norm_layer == 'IN':
            x = self.instance_norm(x)

        x = self.relu(x)
        x = self.pwconv_dequant(x)

        return x


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.convs1 = ChannelWiseConv(3, 256, norm_layer='IN')
        self.convs2 = ChannelWiseConv(256, 256, norm_layer='IN')
        self.convs3 = ChannelWiseConv(256, 1, norm_layer='IN')

    def forward(self, x):
        x = self.convs1(x)
        x = self.convs2(x)
        x = self.convs3(x)

        return  x


if __name__ == '__main__':
    model = Model()
    quantizer = QATQuantizer(
        model, 
        torch.randn(1, 3, 160, 256),
        work_dir='out',
        config={'rewrite_graph': False}
    )

    quantizer.quantize()

    model.eval()
    model.cpu()
    model = quantizer.convert(model)
    torch.backends.quantized.engine = 'qnnpack'
    converter = TFLiteConverter(model, torch.randn(1, 3, 160, 256), tflite_path='instance_norm.tflite')
    converter.convert()

@peterjc123
Copy link
Collaborator Author

peterjc123 commented Mar 31, 2023

Any plan to support quantized::instance_norm?

Would you please give me an example with a TFLite model quantized InstanceNorm layer? Once it is supported in TFLite, we will look into that.

@peterjc123 You can use the following code as test case. Many Thanks

import torch
from torch import nn
from tinynn.graph.quantization.quantizer import QATQuantizer
from tinynn.graph.tracer import model_tracer
from tinynn.converter import TFLiteConverter

class ChannelWiseConv(nn.Module):
    def __init__(self, in_channel, out_channel, padding=0, stride=1, dilation=1, bias=True, norm_layer='BN'):
        super().__init__()
        self.dwconv_quant = torch.quantization.QuantStub()
        self.dwconv_dequant = torch.quantization.DeQuantStub()
        self.pwconv_quant = torch.quantization.QuantStub()
        self.pwconv_dequant = torch.quantization.DeQuantStub()
        self.dwconv = nn.Conv2d(
            in_channel,
            in_channel,
            kernel_size=3,
            padding=padding,
            stride=stride,
            dilation=dilation,
            groups=in_channel,
            bias=bias
        )

        self.pwconv = nn.Conv2d(
            in_channel,
            out_channel,
            kernel_size=1,
            bias=bias
        )

        self.norm_layer = norm_layer
        self.bn = nn.BatchNorm2d(out_channel)
        self.instance_norm = nn.InstanceNorm2d(out_channel)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.dwconv_quant(x)
        x = self.dwconv(x)
        x = self.dwconv_dequant(x)
        x = self.pwconv_quant(x)
        x = self.pwconv(x)

        if self.norm_layer == 'BN':
            x = self.bn(x)
        elif self.norm_layer == 'IN':
            x = self.instance_norm(x)

        x = self.relu(x)
        x = self.pwconv_dequant(x)

        return x


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.convs1 = ChannelWiseConv(3, 256, norm_layer='IN')
        self.convs2 = ChannelWiseConv(256, 256, norm_layer='IN')
        self.convs3 = ChannelWiseConv(256, 1, norm_layer='IN')

    def forward(self, x):
        x = self.convs1(x)
        x = self.convs2(x)
        x = self.convs3(x)

        return  x


if __name__ == '__main__':
    model = Model()
    quantizer = QATQuantizer(
        model, 
        torch.randn(1, 3, 160, 256),
        work_dir='out',
        config={'rewrite_graph': False}
    )

    quantizer.quantize()

    model.eval()
    model.cpu()
    model = quantizer.convert(model)
    torch.backends.quantized.engine = 'qnnpack'
    converter = TFLiteConverter(model, torch.randn(1, 3, 160, 256), tflite_path='instance_norm.tflite')
    converter.convert()

I mean if there isn't quantized InstanceNorm support for TFLite, then there is not much we can help here. Your code only generates a quantized PyTorch model with InstanceNorm, which is not thing I actually need. In short, we don't currently support that because we don't know how to achieve that. We may only action when the solution is clear (which operators to use and how the graph is organized). The only thing we can do now is to wrap it with Quantize and Dequantize nodes during translation, which I guess it not what you want right?

@mjamroz
Copy link

mjamroz commented Oct 26, 2023

Is there any chance to implement torch aten::scaled_dot_product_attention?
https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html says it could be done as

# Efficient implementation equivalent to the following:
scale_factor = 1 / math.sqrt(Q.size(-1)) if scale is None else scale
attn_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) if is_causal else attn_mask
attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf')) if attn_mask.dtype==torch.bool else attn_mask
attn_weight = torch.softmax((Q @ K.transpose(-2, -1) * scale_factor) + attn_mask, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p)
return attn_weight @ V

@peterjc123
Copy link
Collaborator Author

Is there any chance to implement torch aten::scaled_dot_product_attention? https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html says it could be done as

# Efficient implementation equivalent to the following:
scale_factor = 1 / math.sqrt(Q.size(-1)) if scale is None else scale
attn_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) if is_causal else attn_mask
attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf')) if attn_mask.dtype==torch.bool else attn_mask
attn_weight = torch.softmax((Q @ K.transpose(-2, -1) * scale_factor) + attn_mask, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p)
return attn_weight @ V

Please create a new issue for that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants