Skip to content

Commit

Permalink
Fixes to make simplest conv working (hidet-org#22)
Browse files Browse the repository at this point in the history
Simple model with one conv2d failed. 
- fix signature for conv* ops to corresponds torch.nn.functional]
- add missed padding normalization

After that the model works
  • Loading branch information
vadiklyutiy authored and hjjq committed Feb 27, 2024
1 parent 06cf139 commit e3417e2
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 7 deletions.
26 changes: 20 additions & 6 deletions python/hidet/graph/frontend/torch/register_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@


@register_function(torch.nn.functional.conv1d)
def conv1d(x: Tensor, weight: Tensor, bias: Optional[Tensor], stride, padding, dilation, groups):
def conv1d(x: Tensor, weight: Tensor, bias: Optional[Tensor] = None, stride=1, padding=0, dilation=1, groups=1):
x = ops.conv_pad(x, padding)
y = ops.conv1d(x, weight, stride=stride, dilations=dilation, groups=groups)
if bias is not None:
Expand All @@ -40,7 +40,14 @@ def conv1d(x: Tensor, weight: Tensor, bias: Optional[Tensor], stride, padding, d

@register_function(torch.nn.functional.conv_transpose1d)
def conv1d_transpose(
x: Tensor, weight: Tensor, bias: Optional[Tensor], stride, padding, output_padding, groups, dilation
x: Tensor,
weight: Tensor,
bias: Optional[Tensor] = None,
stride=1,
padding=0,
output_padding=0,
groups=1,
dilation=1,
):
if dilation != 1 and not same_list(dilation, [1]):
raise NotImplementedError("dilation != 1")
Expand All @@ -51,7 +58,7 @@ def conv1d_transpose(


@register_function(torch.nn.functional.conv2d)
def conv2d(x: Tensor, weight: Tensor, bias: Optional[Tensor], stride, padding, dilation, groups):
def conv2d(x: Tensor, weight: Tensor, bias: Optional[Tensor] = None, stride=1, padding=0, dilation=1, groups=1):
y = ops.conv2d(x, weight, stride, dilation, groups, padding=padding)
if bias is not None:
y = y + ops.unsqueeze(bias, [0, 2, 3])
Expand All @@ -60,7 +67,7 @@ def conv2d(x: Tensor, weight: Tensor, bias: Optional[Tensor], stride, padding, d

@register_function(torch.nn.functional.conv_transpose2d)
def conv2d_transpose(
x: Tensor, weight: Tensor, bias: Optional[Tensor], stride, padding, output_padding, groups, dilation
x: Tensor, weight: Tensor, bias: Optional[Tensor], stride=1, padding=0, output_padding=0, groups=1, dilation=1
):
if dilation != 1 and not same_list(dilation, [1, 1]):
raise NotImplementedError("dilation != 1")
Expand All @@ -71,7 +78,7 @@ def conv2d_transpose(


@register_function(torch.nn.functional.conv3d)
def conv3d(x: Tensor, weight: Tensor, bias: Optional[Tensor], stride, padding, dilation, groups):
def conv3d(x: Tensor, weight: Tensor, bias: Optional[Tensor] = None, stride=1, padding=0, dilation=1, groups=1):
x = ops.conv_pad(x, padding)
y = ops.conv3d(x, weight, stride, dilation, groups)
if bias is not None:
Expand All @@ -81,7 +88,14 @@ def conv3d(x: Tensor, weight: Tensor, bias: Optional[Tensor], stride, padding, d

@register_function(torch.nn.functional.conv_transpose3d)
def conv3d_transpose(
x: Tensor, weight: Tensor, bias: Optional[Tensor], stride, padding, output_padding, groups, dilation
x: Tensor,
weight: Tensor,
bias: Optional[Tensor] = None,
stride=1,
padding=0,
output_padding=0,
groups=1,
dilation=1,
):
if dilation != 1 and not same_list(dilation, [1, 1, 1]):
raise NotImplementedError("dilation != 1")
Expand Down
10 changes: 9 additions & 1 deletion python/hidet/graph/ops/conv2d/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,14 @@
from typing import List, Union, Sequence
from hidet import ir
from hidet.graph.ops.utils import Task, Operator, Tensor, TensorNode
from hidet.graph.ops.utils import compute, input_like, normalize_stride, normalize_dilations, reduce
from hidet.graph.ops.utils import (
compute,
input_like,
normalize_stride,
normalize_dilations,
normalize_conv_padding,
reduce,
)
from hidet.utils.py import cdiv


Expand Down Expand Up @@ -147,6 +154,7 @@ def __init__(
):
stride = normalize_stride(stride)
dilations = normalize_dilations(dilations)
padding = normalize_conv_padding(padding, 2)
super().__init__(
inputs=[x, w],
attributes={'padding': padding, 'stride': stride, 'groups': groups, 'dilations': dilations},
Expand Down
9 changes: 9 additions & 0 deletions python/hidet/graph/ops/utils/tensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,15 @@ def normalize_padding(padding: Union[Int, Sequence[Int]], dim=2) -> List[Int]:
)


def normalize_conv_padding(padding: Union[Int, Sequence[Int]], dim) -> List[Int]:
if isinstance(padding, int):
return [padding for _ in range(dim)]
elif isinstance(padding, (list, tuple)):
assert len(padding) == dim
return padding
raise ValueError('Incorrect conv padding: {}; dim is {}'.format(padding, dim))


def normalize_dim(dim: Optional[Union[Int, Sequence[Int]]], rank: int) -> Union[Int, List[Int]]:
"""
normalize a dim from [-rank, rank] or None to [0, rank].
Expand Down

0 comments on commit e3417e2

Please sign in to comment.