Skip to content

Commit

Permalink
allow keepdim in aten::sum (pytorch#452)
Browse files Browse the repository at this point in the history
Fixes pytorch#205
support static keepdim in aten::sum
  • Loading branch information
jjsjann123 committed Oct 28, 2020
1 parent ceeef26 commit dbd2a37
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 15 deletions.
14 changes: 8 additions & 6 deletions test/test_jit_cuda_fuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,17 +607,18 @@ def test_binary_ops_permutation(self):
x = [7, 8, 12]
self._permutation_helper(x, b_axis, torch.float32, "cuda", perm0, perm1)

def _reduction_helper(self, sizes, reduction_axis, dtype, device, perm0, perm1):
def _reduction_helper(self, sizes, reduction_axis, dtype, device, perm0, perm1, keepdim=False):
class MyReduction(torch.nn.Module):
__constants__ = ['reduction_axis']
__constants__ = ['reduction_axis', 'keepdim']

def __init__(self):
super(MyReduction, self).__init__()
self.reduction_axis = reduction_axis
self.keepdim = keepdim

def forward(self, x: torch.Tensor, y: torch.Tensor):
o = torch.add(x, y)
o = torch.sum(o, dim=self.reduction_axis)
o = torch.sum(o, dim=self.reduction_axis, keepdim=self.keepdim)
return o

t = MyReduction()
Expand All @@ -643,9 +644,10 @@ def test_reduction(self):
# to single element (codegen limitation at this moment)
for num_reduce_dim in range(1, len(x)):
for axes in itertools.combinations(range(len(x)), num_reduce_dim):
perm0 = range(len(x))
perm1 = range(len(x))
self._reduction_helper(x, axes, torch.float32, "cuda", perm0, perm1)
for keepdim in (True, False):
perm0 = range(len(x))
perm1 = range(len(x))
self._reduction_helper(x, axes, torch.float32, "cuda", perm0, perm1, keepdim)

@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
Expand Down
16 changes: 8 additions & 8 deletions torch/csrc/jit/codegen/cuda/parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -461,16 +461,17 @@ class IrParser {
auto self = value_map[node->input(0)->unique()];
auto dims_list = constant_as<c10::List<int64_t>>(node->input(1));
TORCH_INTERNAL_ASSERT(
dims_list.has_value(), "requires static reduce axes");
auto keepdim = constant_as<bool>(node->input(2));
dims_list.has_value(),
"aten::sum cannot be fused with dynamic axes");
std::vector<int> dims;
for (const auto dim : dims_list->vec()) {
dims.emplace_back(static_cast<int>(dim));
}
auto keepdim = constant_as<bool>(node->input(2));
TORCH_INTERNAL_ASSERT(
keepdim.has_value() && !keepdim.value(),
"Keep dim in reduction is not a const false");
auto out = sum(self->as<TensorView>(), dims);
keepdim.has_value(),
"aten::sum cannot be fused with dynamic keepdim");
auto out = sum(self->as<TensorView>(), dims, keepdim.value());
value_map.emplace(node->output()->unique(), out);
},
[](const Node* node) -> bool {
Expand All @@ -491,9 +492,8 @@ class IrParser {
if (node->inputs()[1]->node()->kind() != prim::Constant) {
return false;
}
// we don't support keepdim yet;
if (node->inputs()[2]->node()->kind() != prim::Constant ||
*constant_as<bool>(node->input(2))) {
// we don't support dynamic keepdim yet;
if (node->inputs()[2]->node()->kind() != prim::Constant) {
return false;
}
return true;
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/shape_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ class NaiveTypePropagator {
const auto dims = constant_as<c10::List<int64_t>>(node->input(1));
const auto keepdim = constant_as<bool>(node->input(2));
TORCH_CHECK(
dims.has_value() && keepdim.has_value() && !keepdim.value(),
dims.has_value() && keepdim.has_value(),
"Shape inference cannot handle options.");
node->output()->setType(
unary_reduce_type(out_type, dims->vec(), keepdim.value()));
Expand Down

0 comments on commit dbd2a37

Please sign in to comment.