Skip to content

Commit

Permalink
Revert "[nvfuser_upstream_push] nvfuser code base bump 060822 (pytorc…
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchmergebot committed Jun 10, 2022
1 parent bcd7a20 commit d28e9e1
Show file tree
Hide file tree
Showing 60 changed files with 638 additions and 1,971 deletions.
2 changes: 1 addition & 1 deletion benchmarks/cpp/nvfuser/timm.cpp
Expand Up @@ -692,7 +692,7 @@ static void nhwc_seresnet152d_transpose65(Fusion* fusion, void* null) {
auto t17 = set(t16);
auto t29 = castOp(DataType::Half, t17);
auto t18 = mul(t17, t3);
auto t19 = permute(t18, {0, 2, 3, 1});
auto t19 = transpose(t18, {{0, 0}, {1, 3}, {2, 1}, {3, 2}});
auto t30 = castOp(DataType::Half, t19);

fusion->addOutput(t29);
Expand Down
14 changes: 10 additions & 4 deletions benchmarks/cpp/nvfuser/transpose.cpp
Expand Up @@ -81,8 +81,14 @@ static void setupTranspose(
FusionGuard fg(fusion);
typedef std::pair<int, int> transpose_axes;

auto optionalTranspose = [axes](TensorView* tv, bool is_transpose) {
return (is_transpose) ? transpose(tv, axes.first, axes.second) : tv;
auto getTransposeMap =
[](const transpose_axes& axes) -> std::unordered_map<int, int> {
return {{axes.first, axes.second}, {axes.second, axes.first}};
};

auto optionalTranspose = [&getTransposeMap, axes](
TensorView* tv, bool is_transpose) {
return (is_transpose) ? transpose(tv, getTransposeMap(axes)) : tv;
};

auto input1 = makeContigTensor(num_dims);
Expand Down Expand Up @@ -408,8 +414,8 @@ static void Baseline_Transpose(
auto at_input1 = aten_inputs[0];
auto at_input2 = aten_inputs[1];

auto optionalTransposeAten = [&axes](at::Tensor x, bool is_transpose) {
return (is_transpose) ? at::transpose(x, axes.first, axes.second) : x;
auto optionalTransposeAten = [&axes](at::Tensor at, bool is_transpose) {
return (is_transpose) ? at::transpose(at, axes.first, axes.second) : at;
};

for (auto _ : benchmark_state) {
Expand Down
3 changes: 0 additions & 3 deletions setup.py
Expand Up @@ -1076,9 +1076,6 @@ def print_box(msg):
'include/torch/csrc/jit/testing/*.h',
'include/torch/csrc/jit/tensorexpr/*.h',
'include/torch/csrc/jit/tensorexpr/operators/*.h',
'include/torch/csrc/jit/codegen/cuda/*.h',
'include/torch/csrc/jit/codegen/cuda/ops/*.h',
'include/torch/csrc/jit/codegen/cuda/scheduler/*.h',
'include/torch/csrc/onnx/*.h',
'include/torch/csrc/profiler/*.h',
'include/torch/csrc/utils/*.h',
Expand Down
2 changes: 1 addition & 1 deletion test/test_jit_cuda_fuser.py
Expand Up @@ -2228,7 +2228,7 @@ def t(x, y):
self.assertEqual(o.stride(), jit_o.stride())
except Exception as e:
warnings.warn(
"permutation propagation is broken, proper support should come after nvfuser permutation scheduler update")
"permutation propagatoin is broken, proper support should come after nvfuser permutation scheduler update")
self.assertGraphContains(t_jit.graph_for(x, bias), FUSION_GUARD)

@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
Expand Down

0 comments on commit d28e9e1

Please sign in to comment.