Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Conv2DTranspose with groups not working correctly #10223

Closed
JCBrouwer opened this issue Feb 11, 2022 · 10 comments · Fixed by #10235
Closed

[Bug] Conv2DTranspose with groups not working correctly #10223

JCBrouwer opened this issue Feb 11, 2022 · 10 comments · Fixed by #10235
Assignees

Comments

@JCBrouwer
Copy link
Contributor

I'm trying to convert a PyTorch model which makes use of torch.nn.functional.conv_transpose2d and am running into issues with my converter to the corresponding tvm.relay.op.nn.conv2d_transpose operation.

I've done a little monkey patching on the PyTorchOpConverter as the operations that torch.nn.functional.conv2d/conv_transpose2d trace to (aten::conv2d and aten::conv_transpose2d) aren't covered by default. I've added functions to convert each one to the PyTorchOpConverter so that I have access to self.infer_shape(weight) in the functions as follows:

Converter implementation
class MyPyTorchOpConverter(PyTorchOpConverter):
    def __init__(self, prelude, default_dtype):
        super().__init__(prelude, default_dtype)
        self.update_convert_map(
            {"aten::conv2d": self.convert_conv2d, "aten::conv_transpose2d": self.convert_conv_transpose2d}
        )

    def convert_conv2d(self, inputs, input_types):
        data = inputs[0]
        weight = inputs[1]
        bias = inputs[2]
        strides = inputs[3]
        padding = inputs[4]
        dilation = inputs[5]
        groups = inputs[6]

        channels, input_channels, kh, kw = self.infer_shape(weight)  # OIHW

        if groups > 1 and input_channels == 1:
            channel_multiplier = channels // groups
            new_weight_shape = (groups, channel_multiplier, kh, kw)
            weight = relay.op.transform.reshape(weight, new_weight_shape)

        res = relay.op.nn.conv2d(
            data, weight, strides=strides, padding=padding, dilation=dilation, groups=groups, channels=channels
        )
        if bias is not None:
            res = relay.op.nn.bias_add(res, bias)

        return res

    def convert_conv_transpose2d(self, inputs, input_types):
        data = inputs[0]
        weight = inputs[1]
        bias = inputs[2]
        strides = inputs[3]
        padding = inputs[4]
        output_padding = inputs[5]
        groups = inputs[6]
        dilation = inputs[7]

        input_channels, channels, kh, kw = list(self.infer_shape(weight))  # IOHW

        if groups > 1 and channels == 1:
            channel_multiplier = channels // groups
            new_weight_shape = (groups, channel_multiplier, kh, kw)
            weight = relay.op.transform.reshape(weight, new_weight_shape)

        res = relay.op.nn.conv2d_transpose(
            data,
            weight,
            strides=strides,
            padding=padding,
            output_padding=output_padding,
            dilation=dilation,
            groups=groups,
        )
        if bias is not None:
            res = relay.op.nn.bias_add(res, bias)

        return res

tvm.relay.frontend.pytorch.PyTorchOpConverter = MyPyTorchOpConverter

The implementations of the convertors are adapted from tvm.relay.frontend.pytorch.PyTorchOpConverter.convolution(inputs, input_types) but updated to support the call signature of torch.nn.functional.conv2d/conv_transpose2d.

The problem I'm seeing is that it seems like tvm.relay.op.nn.conv2d_transpose() doesn't respect the groups argument. When I print the input and outputs of the first 4 conv(_transpose) ops in my network, the PyTorch shapes are the following:

PyTorch shapes
conv2d
input (1, 1536, 4, 4)
weight (1536, 512, 3, 3)
groups 3
out (1, 1536, 4, 4)

conv2d
input (1, 1536, 4, 4)
weight (9, 512, 1, 1)
groups 3
out (1, 9, 4, 4)

conv_transpose2d
input (1, 1536, 4, 4)
weight (1536, 512, 3, 3)
groups 3
out (1, 1536, 9, 9)

conv2d
data (1, 1536, 11, 11)
weight (1536, 1, 4, 4)
groups 1536
out (1, 1536, 8, 8)

While the TVM shapes are:

TVM shapes
conv2d
input [1, 1536, 4, 4]
weight [1536, 512, 3, 3]
groups 3
out (1, 1536, 4, 4)

conv2d
input [1, 1536, 4, 4]
weight [9, 512, 1, 1]
groups 3
out (1, 9, 4, 4)

conv2d_transpose
input [1, 1536, 4, 4]
weight [1536, 512, 3, 3]
groups 3
out (1, 512, 9, 9)

conv2d
input [1, 512, 11, 11]
weight [512, 1, 4, 4]
groups 1536
TVMError

Notice that the output shape of tvm.relay.op.nn.conv2d_transpose() does not have the correct number of channels (output is as if groups = 1). This leads to the error in the next conv2d operation:

Error traceback
Traceback (most recent call last):
  File "/home/hans/code/stylegan3/func.py", line 217, in <module>
    Gtvm, tvm_params = relay.frontend.pytorch.from_pytorch(
  File "/home/hans/code/tvm/python/tvm/relay/frontend/pytorch.py", line 4010, in from_pytorch
    outputs = converter.convert_operators(_get_operator_nodes(graph.nodes()), outputs, ret_name)
  File "/home/hans/code/tvm/python/tvm/relay/frontend/pytorch.py", line 3385, in convert_operators
    relay_out = relay_op(
  File "/home/hans/code/stylegan3/func.py", line 101, in convert_conv2d
    print("out", self.infer_shape(res))
  File "/home/hans/code/tvm/python/tvm/relay/frontend/pytorch.py", line 204, in infer_shape
    typ = self.infer_type(inputs, mod=mod)
  File "/home/hans/code/tvm/python/tvm/relay/frontend/pytorch.py", line 162, in infer_type
    new_mod = transform.InferType()(new_mod)
  File "/home/hans/code/tvm/python/tvm/ir/transform.py", line 161, in __call__
    return _ffi_transform_api.RunPass(self, mod)
  File "/home/hans/code/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  7: TVMFuncCall
  6: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  5: tvm::transform::Pass::operator()(tvm::IRModule) const
  4: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  3: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  2: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1}>(tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  1: tvm::relay::TypeInferencer::Infer(tvm::GlobalVar, tvm::relay::Function)
  0: tvm::relay::TypeSolver::Solve() [clone .cold]
  9: TVMFuncCall
  8: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  7: tvm::transform::Pass::operator()(tvm::IRModule) const
  6: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  5: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  4: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1}>(tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  3: tvm::relay::TypeInferencer::Infer(tvm::GlobalVar, tvm::relay::Function)
  2: tvm::relay::TypeSolver::Solve()
  1: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<bool (tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>::AssignTypedLambda<bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>(bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&))::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  0: tvm::relay::ReshapeRel(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)
  File "/home/hans/code/tvm/src/relay/analysis/type_solver.cc", line 624
TVMError: 
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
  Check failed: (false) is false: [14:35:14] /home/hans/code/tvm/src/relay/op/tensor/transform.cc:787: 
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------

  Check failed: oshape_sum == data_shape_sum (24576 vs. 8192) : Input tensor shape(1536,1,4,4) and reshaped shape(512,1,4,4) are not compatible!

As a workaround, I've rewritten my conv_transpose2d converter to manually split the data and weights into groups, perform each transposed conv, and then concatenate them back. This converter does seem to give the correct output shape although I haven't yet tested the outputs for correctness, I might have just gotten lucky with the shapes.

Workaround converter implementation (manual grouping)
    def convert_conv_transpose2d_workaround(self, inputs, input_types):
        data = inputs[0]
        weight = inputs[1]
        bias = inputs[2]
        strides = inputs[3]
        padding = inputs[4]
        output_padding = inputs[5]
        groups = inputs[6]
        dilation = inputs[7]

        input_channels, channels, kh, kw = list(self.infer_shape(weight))

        if groups > 1 and channels == 1:
            channel_multiplier = channels // groups
            new_weight_shape = (groups, channel_multiplier, kh, kw)
            weight = relay.op.transform.reshape(weight, new_weight_shape)

        datas = relay.op.split(data, groups, axis=1)
        weights = relay.op.split(weight, groups, axis=0)

        rs = []
        for d, w in zip(datas, weights):
            r = relay.op.nn.conv2d_transpose(
                d, w, strides=strides, padding=padding, output_padding=output_padding, dilation=dilation, groups=1
            )
            if bias is not None:
                r = relay.op.nn.bias_add(r, bias)
            rs.append(r)
        res = relay.op.concatenate(rs, axis=1)

        return res

Expected behavior

The groups argument of tvm.relay.op.nn.conv2d_transpose should work correctly like tvm.relay.op.nn.conv2d does.

Actual behavior

The transposed convolution seems to only be applied to a single group?

Environment

Ubuntu 20.04
PyTorch 1.12.0.dev20220210
TVM 0.9.dev525+g8aeb72265 (compiled from main a couple hours ago)
CUDA 11.4

Steps to reproduce

from copy import deepcopy

import torch
import tvm.relay
from torch.nn.functional import conv_transpose2d

_original_get_constant = deepcopy(tvm.relay.frontend.pytorch._get_constant)


def _my_get_constant(node):
    """Monkey patch in support for prim::Constant lists, I guess torch.jit.optimize_for_inference introduces these?"""
    if node.output().type().kind() == "ListType":
        print("WARNING: Encountered ListType in _get_constant, doing weird eval stuff to get the list value:", end=" ")
        lst = eval(node.__repr__().split("value=")[1].replace("]()", ""))
        print(lst)
        return lst
    else:
        return _original_get_constant(node)


tvm.relay.frontend.pytorch._get_constant = _my_get_constant


def convert_conv_transpose2d(inputs, input_types):
    data = inputs[0]
    weight = inputs[1]
    bias = inputs[2]
    strides = inputs[3]
    padding = inputs[4]
    output_padding = inputs[5]
    groups = inputs[6]
    dilation = inputs[7]

    res = tvm.relay.op.nn.conv2d_transpose(
        data,
        weight,
        strides=strides,
        padding=padding,
        output_padding=output_padding,
        dilation=dilation,
        groups=groups,
    )
    if bias is not None:
        res = tvm.relay.op.nn.bias_add(res, bias)

    return res


class ModulatedConvTranspose2D(torch.nn.Module):
    def forward(self, x, w, s):
        B, C, H, W = x.shape
        I, O, KH, KW = w.shape

        # weight is different for each input in batch (this is why we want grouped conv transpose)
        w = w.unsqueeze(0) * s.reshape(B, 1, 1, 1, 1)
        w = w.reshape(B * I, O, KH, KW)

        x = x.reshape(1, B * C, H, W)

        x = conv_transpose2d(x, w, stride=(2, 2), padding=(1, 1), output_padding=(1, 1), groups=B)

        # Check failed: oshape_sum == data_shape_sum (524288 vs. 131072) : Input tensor shape(4,256,16,32) and reshaped shape(1,256,16,32) are not compatible!
        x = x.reshape(B, O, H * 2, W * 2)

        return x


with torch.inference_mode():
    b, c, h, w, k = 4, 512, 8, 16, 3
    inputs = torch.rand(b, c, h, w)
    weights = torch.rand(c, c // 2, k, k)
    styles = torch.rand(b)

    torch_mod = torch.jit.optimize_for_inference(
        torch.jit.trace(ModulatedConvTranspose2D().eval(), (inputs, weights, styles))
    )

    outputs_torch = torch_mod(inputs, weights, styles)
    print("Torch output shape", outputs_torch.shape)  # torch.Size([4, 256, 16, 32])

    tvm_mod, params = tvm.relay.frontend.pytorch.from_pytorch(
        torch_mod,
        [("inputs", inputs.shape), ("weights", weights.shape), ("styles", styles.shape)],
        {"aten::conv_transpose2d": convert_conv_transpose2d},
    )
@masahi
Copy link
Member

masahi commented Feb 11, 2022

Thanks, yes conv2d_transpose with group was only recently fixed and supported in #9465. I think we haven't updated our PyTorch frontend to benefit from this change. I'll take a look.

@masahi masahi self-assigned this Feb 11, 2022
@masahi
Copy link
Member

masahi commented Feb 13, 2022

Fixed in #10235. You shouldn't be using torch.jit.optimize_for_inference, it does no good for us and it even introduces aten::conv2d etc that we don't recognize. Of course we already support PT conv2d op, but we expect them to be represented as aten::_convolution, which is the case after you run torch.jit.trace.

@JCBrouwer
Copy link
Contributor Author

Great, thanks for the quick response!

I can confirm that it's converting from PyTorch correctly for my more complex model on your branch.

Now I'm getting an AssertionError on relay.build() though:

AssertionError: only support groups == 1 when targetting cuda/gpu

Are there plans to support grouped transposed convolutions on GPU?

Is there a better workaround than splitting the groups manually?

@masahi
Copy link
Member

masahi commented Feb 14, 2022

Yeah, cuda backend doesnt support groups. We should fix that but I am not looking to do it. A PR welcome.

You can try the cpu backend to verify the result. If you are ok with using cudnn, I can quickly enable groups support for cudnn conv transpose 2d.

@JCBrouwer
Copy link
Contributor Author

Yes using CUDNN is fine with me, that would be great :)

@masahi
Copy link
Member

masahi commented Feb 15, 2022

Are you sure you can run your PT model on cuda via cudnn (use nvprof)? I'm getting CUDNN_STATUS_BAD_PARAM when trying to run your model with cudnn. As always, cuDNN error message is not helpful in telling why my params are bad.

A WIP branch https://github.com/apache/tvm/compare/main...masahi:conv2d-transpose-group-cudnn?expand=1 if you want to hack on it.

@JCBrouwer
Copy link
Contributor Author

JCBrouwer commented Feb 26, 2022

Hello @masahi , sorry for the slow response, I somehow missed the notification on this one, thanks for enabling the op!

I took a look at running on your branch and was also getting BAD_PARAMs on both the above test case and my full model. After a bit of mucking around I noticed this change is incorrect. The argument is the conv_mode which should be left as 1 (according to the main branch).

Changing that back I'm able to run both the test case and my larger model with grouped conv2d_transpose ops on the CUDNN backend 🎉

Sadly I'm still just a few FPS shy of my performance target so I'll have to keep on digging for speedups.

RE: support for groups in the regular cuda backend. Do you have a general idea of what kind of changes are necessary for that? I'm no expert, but I might be able to figure it out if it's just adapting similar code from grouped conv2d to work for grouped conv2d_transpose.

For good measure: the updated test code which now works with the one line change to your branch
import torch
import tvm.relay
from torch.nn.functional import conv_transpose2d
from tvm import relay
from tvm.contrib import graph_executor


class ModulatedConvTranspose2D(torch.nn.Module):
    def forward(self, x, w, s):
        B, C, H, W = x.shape
        I, O, KH, KW = w.shape

        # weight is different for each input in batch (this is why we want grouped conv transpose)
        w = w.unsqueeze(0) * s.reshape(B, 1, 1, 1, 1)
        w = w.reshape(B * I, O, KH, KW)

        x = x.reshape(1, B * C, H, W)
        x = conv_transpose2d(x, w, stride=(2, 2), padding=(1, 1), output_padding=(1, 1), groups=B)
        x = x.reshape(B, O, H * 2, W * 2)

        return x


with torch.inference_mode():
    device = "cuda"
    target = "cuda -libs=cudnn"
    dtype = torch.float16
    tvm_dtype = dtype.__repr__().split(".")[-1]

    b, c, h, w, k = 4, 512, 8, 16, 3
    inputs = torch.rand((b, c, h, w), dtype=dtype, device=device)
    weights = torch.rand((c, c // 2, k, k), dtype=dtype, device=device)
    styles = torch.rand((b), dtype=dtype, device=device)

    torch_mod = torch.jit.trace(ModulatedConvTranspose2D().eval().to(device), (inputs, weights, styles))

    outputs_torch = torch_mod(inputs, weights, styles)
    print("Torch output shape", tuple(outputs_torch.shape))  # (4, 256, 16, 32)

    tvm_mod, tvm_params = relay.frontend.pytorch.from_pytorch(
        torch_mod,
        [
            ("inputs", (tuple(inputs.shape), tvm_dtype)),
            ("weights", (tuple(weights.shape), tvm_dtype)),
            ("styles", (tuple(styles.shape), tvm_dtype)),
        ],
    )

    with tvm.transform.PassContext(opt_level=10):
        lib = relay.build(tvm_mod, target=target, params=tvm_params)
    m = graph_executor.GraphModule(lib["default"](tvm.cuda()))

    m.run(
        inputs=tvm.nd.array(inputs.cpu(), device=tvm.cuda()),
        weights=tvm.nd.array(weights.cpu(), device=tvm.cuda()),
        styles=tvm.nd.array(styles.cpu(), device=tvm.cuda()),
    )
    print("TVM output shape  ", m.get_output(0).numpy().shape)  # (4, 256, 16, 32)

@masahi
Copy link
Member

masahi commented Feb 26, 2022

oops good find! Can you send a PR? I can quickly merge it (if I do it you need to wait until next week).

Sadly I'm still just a few FPS shy of my performance target so I'll have to keep on digging for speedups.

How TVM + cuDNN compares to PT? Since you are running on fp16, I'd hope that we can use tensorcore. But I've never seen grouped convolution running on tensorcore. Also cutlass is generally faster than cuDNN but it doesn't support grouped or depth wise afaik.

RE: support for groups in the regular cuda backend. Do you have a general idea of what kind of changes are necessary for that

Yes, you can try adding group argument to

def conv2d_transpose_nchw(cfg, data, kernel, stride, padding, out_dtype, output_padding):
. I think it shouldn't be too hard (also see #9465) .

and update python/relay/op/strategy/x86.py similarly to how #9465 did it for x86.

You may try our auto-scheduler to see if it can beat cuDNN.

@JCBrouwer
Copy link
Contributor Author

How TVM + cuDNN compares to PT? Since you are running on fp16, I'd hope that we can use tensorcore. But I've never seen grouped convolution running on tensorcore. Also cutlass is generally faster than cuDNN but it doesn't support grouped or depth wise afaik.

I'm not quite sure what the best way to benchmark/profile things is. I've been trying to use ncu, but it's very slow for the PyTorch models.

At the moment the PyTorch models (vanilla, traced, optimize_for_inference) reach about 9-11 fps and the TVM + CUDNN is about 15-19 fps. I'm hoping to get into the 25-30 fps range.

I'm trying to get more of the computation to be done in fp16, but I've ran into this issue #10397 .

You may try our auto-scheduler to see if it can beat cuDNN.

So far the autotvm tuner hasn't been successful for me. It took a couple days to tune all of the ops with the default settings from the tutorial and it actually ended up slower than the untuned TVM + CUDNN version.

I haven't looked too deep into the auto-scheduler yet because I couldn't find a good tutorial of applying it to a large model (I think the only tutorial is for single ops?).

I figured it would also be less effective due to using CUDNN, which might reduce the flexibility of the scheduler, although I'm not sure if that's actually the case.

@masahi
Copy link
Member

masahi commented Feb 27, 2022

https://github.com/apache/tvm/tree/main/gallery/how_to/tune_with_autoscheduler has e2e examples of using the auto scheduler. But yeah, I don't expect it to beat cuDNN, unless cuDNN implementation of dgrad with group is really poor.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants