Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nn.Conv2d and F.conv2d with groups == input_channels (DepthWise) generates PartitionedCall in tensorflow frozen_graph #30

Closed
pedrofrodenas opened this issue Mar 26, 2024 · 1 comment

Comments

@pedrofrodenas
Copy link

When I define a define a model with DepthWise convolutions (groups == input_channels) the model is converted sucessfully but the tensorflow frozen_graph of this model cannot be converted to tensorflow.js. The problem is that keras.layers.Conv2D generates a PartitionedCall in the frozen_graph that cannot be converted to tensorflow.js.

I provide the python code to reproduce the problem:

import torch.nn as nn
import torch

from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2


import tensorflow as tf
import nobuco
from nobuco.commons import ChannelOrder, ChannelOrderingStrategy

class ExampleModel(nn.Module):
    def __init__(self, 
                **kwargs):
        
        super(ExampleModel, self).__init__()
        self.layer1 = nn.Conv2d(16, 16, (3,3), (1,1), (0,0), (1,1), 16)
        self.layer2 = nn.ReLU()


    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        return x


model = ExampleModel()

# Put model in inference mode
model.eval()

x = torch.randn(1, 16, 113, 113, requires_grad=False)

keras_model = nobuco.pytorch_to_keras(
    model,
    args=[x], kwargs=None)

# Assuming 'model' is your Keras model
full_model = tf.function(lambda x: keras_model(x))
full_model = full_model.get_concrete_function(
    tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))

# Convert Keras model to frozen ConcreteFunction
frozen_func = convert_variables_to_constants_v2(full_model)
frozen_func.graph.as_graph_def()

# Print the input and output tensors
print("Frozen model inputs: ", frozen_func.inputs)
print("Frozen model outputs: ", frozen_func.outputs)

# Save frozen graph to disk
tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
                logdir='.',
                name='ExampleModel.pb',
                as_text=False)

Inspecting the ExampleModel.pb with Netron this is what happens:

Screenshot from 2024-03-26 14-17-34

In order to fix this error, I made a custom nn.Conv2d converter:

@nobuco.converter(nn.Conv2d)
def converter_Conv2d(self, input: Tensor):
    weight = self.weight
    bias = self.bias
    groups = self.groups
    padding = self.padding
    stride = self.stride
    dilation = self.dilation
    
    

    out_filters, in_filters, kh, kw = weight.shape

    weights = weight.cpu().detach().numpy()

    if groups == 1:
        weights = tf.transpose(weights, (2, 3, 1, 0))
    else:
        weights = tf.transpose(weights, (2, 3, 0, 1))

    if bias is not None:
        biases = bias.cpu().detach().numpy()
        params = [weights, biases]
        use_bias = True
    else:
        params = [weights]
        use_bias = False

    if isinstance(dilation, numbers.Number):
        dilation = (dilation, dilation)

    if isinstance(padding, numbers.Number):
        padding = (padding, padding)

    pad_str = 'valid'
    pad_layer = None

    if padding == 'same':
        pad_str = 'same'
    elif padding != (0, 0):
        pad_layer = keras.layers.ZeroPadding2D(padding)

    if groups == 1:
        conv = keras.layers.Conv2D(filters=out_filters,
                                kernel_size=(kh, kw),
                                strides=stride,
                                padding=pad_str,
                                dilation_rate=dilation,
                                groups=groups,
                                use_bias=use_bias,
                                weights=params
                                )
    else:
        conv = keras.layers.DepthwiseConv2D(
                    kernel_size=(kh, kw),
                    strides=stride,
                    padding=pad_str,
                    use_bias=use_bias,
                    activation=None,
                    depth_multiplier=1,
                    weights=params,
                    dilation_rate=dilation,
                )

    def func(input):
        if pad_layer is not None:
            input = pad_layer(input)
        output = conv(input)
        return output
    return func

But I think that probably is better to fix this in the source code.

@AlexanderLutsenko
Copy link
Owner

Hey, thanks for bringing this up! I fixed what I could in v0.12.2, but there are still problems with TFJS. It only works when groups == 1 or groups == in_channels. As a last resort, you can always express grouped convolution as normal one, missing out on efficiency. In fact, that's how ConvTranspose1d/ConvTranspose2d are converted, Tensorflow just completely botched it: tensorflow/tensorflow#45216.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants