Skip to content

Commit

Permalink
[Relay] [PyTorch] Add aten::broadcast_tensors (#11863)
Browse files Browse the repository at this point in the history
* add aten::broadcast_tensors

* add entry

* fix test
  • Loading branch information
Yuanjing Shi committed Jun 25, 2022
1 parent 600a201 commit 98bf40f
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
8 changes: 8 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1971,6 +1971,13 @@ def expand_as(self, inputs, input_types):
target = _op.cast(target, t0)
return _op.broadcast_to_like(inputs[0], target)

def broadcast_tensors(self, inputs, input_types):
tensor_list = inputs[0]
import torch

res_shape = list(torch.broadcast_shapes(*[self.infer_shape(t) for t in tensor_list]))
return [_op.broadcast_to(tensor, res_shape) for tensor in tensor_list]

def Bool(self, inputs, input_types):
assert len(inputs) == 1
return inputs[0]
Expand Down Expand Up @@ -3189,6 +3196,7 @@ def create_convert_map(self):
"aten::upsample_trilinear3d": self.make_upsample3d("linear"),
"aten::upsample_nearest3d": self.make_upsample3d("nearest_neighbor"),
"aten::expand_as": self.expand_as,
"aten::broadcast_tensors": self.broadcast_tensors,
"aten::lt": self.make_elemwise("less"),
"aten::gt": self.make_elemwise("greater"),
"aten::le": self.make_elemwise("less_equal"),
Expand Down
22 changes: 22 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1791,6 +1791,28 @@ def forward(self, *args):
verify_model(Expand2().float().eval(), input_data=input_data)


@tvm.testing.uses_gpu
def test_forward_broadcast_tensors():
torch.set_grad_enabled(False)

class BroadCastTensors1(Module):
def forward(self, x, y):
return torch.broadcast_tensors(x, y)

x = torch.arange(3).view(1, 1, 3)
y = torch.arange(2).view(1, 2, 1)
verify_model(BroadCastTensors1().float().eval(), input_data=[x, y])

class BroadCastTensors2(Module):
def forward(self, x, y, z):
return torch.broadcast_tensors(x, y, z)

x = torch.arange(3).view(1, 1, 3)
y = torch.arange(2).view(1, 2, 1)
z = torch.arange(4).view(4, 1, 1)
verify_model(BroadCastTensors2().float().eval(), input_data=[x, y, z])


@tvm.testing.uses_gpu
def test_forward_pow():
torch.set_grad_enabled(False)
Expand Down

0 comments on commit 98bf40f

Please sign in to comment.