diff --git a/benchmarks/cpp/nvfuser/timm.cpp b/benchmarks/cpp/nvfuser/timm.cpp index 013b609be602..e7e9d22e8c95 100644 --- a/benchmarks/cpp/nvfuser/timm.cpp +++ b/benchmarks/cpp/nvfuser/timm.cpp @@ -692,7 +692,7 @@ static void nhwc_seresnet152d_transpose65(Fusion* fusion, void* null) { auto t17 = set(t16); auto t29 = castOp(DataType::Half, t17); auto t18 = mul(t17, t3); - auto t19 = permute(t18, {0, 2, 3, 1}); + auto t19 = transpose(t18, {{0, 0}, {1, 3}, {2, 1}, {3, 2}}); auto t30 = castOp(DataType::Half, t19); fusion->addOutput(t29); diff --git a/benchmarks/cpp/nvfuser/transpose.cpp b/benchmarks/cpp/nvfuser/transpose.cpp index 27218b7e1c47..39ee0452c160 100644 --- a/benchmarks/cpp/nvfuser/transpose.cpp +++ b/benchmarks/cpp/nvfuser/transpose.cpp @@ -81,8 +81,14 @@ static void setupTranspose( FusionGuard fg(fusion); typedef std::pair transpose_axes; - auto optionalTranspose = [axes](TensorView* tv, bool is_transpose) { - return (is_transpose) ? transpose(tv, axes.first, axes.second) : tv; + auto getTransposeMap = + [](const transpose_axes& axes) -> std::unordered_map { + return {{axes.first, axes.second}, {axes.second, axes.first}}; + }; + + auto optionalTranspose = [&getTransposeMap, axes]( + TensorView* tv, bool is_transpose) { + return (is_transpose) ? transpose(tv, getTransposeMap(axes)) : tv; }; auto input1 = makeContigTensor(num_dims); @@ -408,8 +414,8 @@ static void Baseline_Transpose( auto at_input1 = aten_inputs[0]; auto at_input2 = aten_inputs[1]; - auto optionalTransposeAten = [&axes](at::Tensor x, bool is_transpose) { - return (is_transpose) ? at::transpose(x, axes.first, axes.second) : x; + auto optionalTransposeAten = [&axes](at::Tensor at, bool is_transpose) { + return (is_transpose) ? at::transpose(at, axes.first, axes.second) : at; }; for (auto _ : benchmark_state) { diff --git a/setup.py b/setup.py index a96a46539c8b..4223cc92051c 100644 --- a/setup.py +++ b/setup.py @@ -1076,9 +1076,6 @@ def print_box(msg): 'include/torch/csrc/jit/testing/*.h', 'include/torch/csrc/jit/tensorexpr/*.h', 'include/torch/csrc/jit/tensorexpr/operators/*.h', - 'include/torch/csrc/jit/codegen/cuda/*.h', - 'include/torch/csrc/jit/codegen/cuda/ops/*.h', - 'include/torch/csrc/jit/codegen/cuda/scheduler/*.h', 'include/torch/csrc/onnx/*.h', 'include/torch/csrc/profiler/*.h', 'include/torch/csrc/utils/*.h', diff --git a/test/test_jit_cuda_fuser.py b/test/test_jit_cuda_fuser.py index 12783e8acd42..05c3520da0e3 100644 --- a/test/test_jit_cuda_fuser.py +++ b/test/test_jit_cuda_fuser.py @@ -2228,7 +2228,7 @@ def t(x, y): self.assertEqual(o.stride(), jit_o.stride()) except Exception as e: warnings.warn( - "permutation propagation is broken, proper support should come after nvfuser permutation scheduler update") + "permutation propagatoin is broken, proper support should come after nvfuser permutation scheduler update") self.assertGraphContains(t_jit.graph_for(x, bias), FUSION_GUARD) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") diff --git a/torch/csrc/jit/codegen/cuda/arith.cpp b/torch/csrc/jit/codegen/cuda/arith.cpp index 3bab78b566d4..d8a9fc9751b9 100644 --- a/torch/csrc/jit/codegen/cuda/arith.cpp +++ b/torch/csrc/jit/codegen/cuda/arith.cpp @@ -19,48 +19,6 @@ namespace cuda { namespace { -Val* simplifiedInt(Val* val) { - TORCH_INTERNAL_ASSERT( - val->isConstInt(), "Expecting Const Int's only in this routine."); - if (val->as()->value().has_value()) { - return val; - } - return IrBuilder::create(val->evaluateInt()); -} - -// If one size is nullptr, return the other. If both symbolic just return v1. If -// one's concrete, prefer that one (simplified). If both concrete make sure -// they're the same size. -Val* promoteSize(Val* v1, Val* v2) { - if (v1 == nullptr) { - TORCH_INTERNAL_ASSERT( - v2 == nullptr || v2->isAnInt(), - "Expecting Int's only in this routine."); - return v2; - } - if (v2 == nullptr) { - return v1; - } - TORCH_INTERNAL_ASSERT( - v1->isAnInt() && v2->isAnInt(), "Expecting Int's only in this routine."); - - if (!v1->isConstInt() && !v2->isConstInt()) { - return v1; - } else if (v1->isConstInt() && v2->isConstInt()) { - TORCH_INTERNAL_ASSERT( - v1->evaluateInt() == v2->evaluateInt(), - "Expected sizes to match but found ", - v1->evaluateInt(), - " and ", - v2->evaluateInt(), - "."); - return simplifiedInt(v1); - } else if (v1->isConstInt()) { - return simplifiedInt(v1); - } - return simplifiedInt(v2); -} - // Will return a new value of type val with the DataType dtype. Val* newScalar(ValType vtype, DataType dtype) { switch (vtype) { @@ -98,11 +56,10 @@ Val* newScalar(ValType vtype, DataType dtype) { TensorView* newOutputTV(const std::vector& vals, DataType dtype) { std::vector tvs; - for (auto val : vals) { - if (val->getValType() == ValType::TensorView) { + for (auto val : vals) + if (val->getValType() == ValType::TensorView) tvs.push_back(val->as()); - } - } + TORCH_CHECK( !tvs.empty(), "Tried to create new output TensorView but received empty list."); @@ -119,72 +76,63 @@ TensorView* newOutputTV(const std::vector& vals, DataType dtype) { std::vector start_offsets(out_domain.size(), 0); std::vector stop_offsets(out_domain.size(), 0); std::vector extent_vals(out_domain.size(), nullptr); - std::vector expanded_extent_vals(out_domain.size(), nullptr); - std::vector> iter_types( - out_domain.size(), c10::nullopt); + std::vector iter_types(out_domain.size(), IterType::Iteration); for (auto tv : tvs) { auto dom = TensorDomain::noReductions(tv->getMaybeRFactorDomain()); TORCH_INTERNAL_ASSERT( dom.size() == out_domain.size(), - "Invalid tensor view found while producing an output, it has ", + "Invalid tensor view found while producing and output, it has ", dom.size(), " dimensions but expected ", out_domain.size()); for (const auto i : c10::irange(dom.size())) { if (dom[i]->isBroadcast()) { - if (dom[i]->hasExpandedExtent()) { - expanded_extent_vals[i] = - promoteSize(expanded_extent_vals[i], dom[i]->expandedExtent()); - } continue; } - extent_vals[i] = promoteSize(extent_vals[i], dom[i]->extent()); - if (iter_types[i].has_value()) { - // TODO: Enable, see conv tests and gather promotion/gather broadcast - // behavior. - // - // TORCH_INTERNAL_ASSERT( - // iter_types[i].value() == dom[i]->getIterType(), - // "Invalid iter type promotion in newOutputTv for expression."); - } else { + if (extent_vals[i] == nullptr) { + extent_vals[i] = dom[i]->extent(); iter_types[i] = dom[i]->getIterType(); } - auto start_offset = dom[i]->start()->as(); auto stop_offset = dom[i]->stopOffset()->as(); // Currently, start is always constant TORCH_INTERNAL_ASSERT( - start_offset->isConstInt(), - "Invalid IterDomain start: ", - start_offset); + start_offset->isConst(), "Invalid IterDomain start: ", start_offset); TORCH_INTERNAL_ASSERT( - stop_offset->isConstInt(), + stop_offset->isConst(), "Invalid IterDomain stop offset: ", stop_offset); start_offsets[i] = - std::max(start_offsets[i], start_offset->evaluateInt()); - stop_offsets[i] = std::max(stop_offsets[i], stop_offset->evaluateInt()); + std::max(start_offsets[i], start_offset->value().value()); + stop_offsets[i] = std::max(stop_offsets[i], stop_offset->value().value()); } } for (const auto dim_i : c10::irange(out_domain.size())) { if (extent_vals[dim_i] != nullptr) { - TORCH_INTERNAL_ASSERT( - iter_types[dim_i].has_value(), - "Could not deduce iter type for new tensor view."); - out_domain[dim_i] = - IterDomainBuilder( - IrBuilder::create(start_offsets[dim_i]), extent_vals[dim_i]) - .stop_offset(IrBuilder::create(stop_offsets[dim_i])) - .iter_type(iter_types[dim_i].value()) - .build(); + out_domain[dim_i] = IrBuilder::create( + IrBuilder::create(start_offsets[dim_i]), + extent_vals[dim_i], + IrBuilder::create(stop_offsets[dim_i]), + ParallelType::Serial, + iter_types[dim_i]); } else { - out_domain[dim_i] = IterDomainBuilder( - FusionGuard::getCurFusion()->zeroVal(), - FusionGuard::getCurFusion()->oneVal()) - .expanded_extent(expanded_extent_vals[dim_i]) - .iter_type(IterType::Broadcast) - .build(); + IterType itype = IterType::BroadcastWithoutStride; + for (const auto tv : tvs) { + auto dim = + TensorDomain::noReductions(tv->getMaybeRFactorDomain())[dim_i]; + // If there's an unresolved bcast dim and it came from a strided dim, + // assume output of it should be strided too + if (dim->getIterType() == IterType::BroadcastWithStride) { + itype = IterType::BroadcastWithStride; + break; + } + } + out_domain[dim_i] = IrBuilder::create( + FusionGuard::getCurFusion()->zeroVal(), + FusionGuard::getCurFusion()->oneVal(), + ParallelType::Serial, + itype); } } @@ -891,11 +839,12 @@ static TensorView* newForReduction( " of tensor ", tv); - new_domain.push_back( - IterDomainBuilder(id) - .resetSchedulingParams() - .iter_type(isReduction ? IterType::Reduction : id->getIterType()) - .build()); + new_domain.push_back(IrBuilder::create( + id->start(), + id->extent(), + id->stopOffset(), + ParallelType::Serial, + isReduction ? IterType::Reduction : id->getIterType())); } TensorDomain* td = IrBuilder::create( @@ -1056,14 +1005,18 @@ TensorView* broadcast( size_t iinp = 0, ibdim = 0; while (ibdim < is_broadcast_dim.size()) { if (is_broadcast_dim[ibdim]) { - out_domain.push_back(IterDomainBuilder( - FusionGuard::getCurFusion()->zeroVal(), - FusionGuard::getCurFusion()->oneVal()) - .iter_type(IterType::Broadcast) - .build()); + out_domain.push_back(IrBuilder::create( + FusionGuard::getCurFusion()->zeroVal(), + FusionGuard::getCurFusion()->oneVal(), + ParallelType::Serial, + IterType::BroadcastWithoutStride)); } else { - out_domain.push_back( - IterDomainBuilder(inp_domain[iinp]).resetSchedulingParams().build()); + out_domain.push_back(IrBuilder::create( + inp_domain[iinp]->start(), + inp_domain[iinp]->extent(), + inp_domain[iinp]->stopOffset(), + inp_domain[iinp]->getParallelType(), + inp_domain[iinp]->getIterType())); iinp++; } ibdim++; @@ -1077,134 +1030,6 @@ TensorView* broadcast( return out_tensor; } -TensorView* expand(TensorView* inp, const std::vector& expanded_sizes) { - auto inp_domain = TensorDomain::noReductions(inp->getMaybeRFactorDomain()); - - TORCH_CHECK( - expanded_sizes.size() == inp_domain.size(), - "Invalid expand, number of sizes provided is expected to be ", - inp_domain.size(), - " but received ", - expanded_sizes.size()); - - std::vector maybe_expanded_sizes; - maybe_expanded_sizes.resize(inp_domain.size(), nullptr); - - // Did a dimension actually get expanded - bool expanded = false; - - std::vector out_domain; - for (auto i : c10::irange(inp_domain.size())) { - auto inp_id = inp_domain[i]; - auto out_id_builder = IterDomainBuilder(inp_id); - maybe_expanded_sizes[i] = inp_domain[i]->extent(); - - auto expanded_size_int = expanded_sizes[i]->getInt(); - - // If the expanded size is -1, let the input extent be propagated - // as is - if (expanded_size_int == -1) { - // This is just done for clarity. It isn't necessary as it's - // already done when constructing out_id_builder. - out_id_builder.extent(inp_id->extent()); - } else if (inp_id->isBroadcast()) { - // When input id is a broadcast, expand the extent to the given - // size, which can be concrete or symbolic. - expanded = true; - out_id_builder.expanded_extent(expanded_sizes[i]); - maybe_expanded_sizes[i] = expanded_sizes[i]; - } else if (!inp_id->extent()->isConstInt()) { - // Input id is non-broadcast and its extent is symbolic. Promote - // the extent to the given expanded size. - // Note that expansion to 1 just means its extent becomes 1 and - // does not mean the ID becomes a broadcast. - out_id_builder.extent(expanded_sizes[i]); - } else { - // Input id is non-broadcast and its extent is concrete. Nothing - // to expand, but the input and expanded sizes should match if - // the expanded size is also concrete. - auto inp_id_size_int = inp_id->extent()->getInt(); - if (expanded_size_int.has_value()) { - TORCH_CHECK( - inp_id_size_int == expanded_size_int, - "Invalid expand size, ", - expanded_sizes[i]->toString(), - ", for ", - inp_id->toString()); - } - } - out_domain.push_back(out_id_builder.build()); - } - - TensorView* out_tensor = IrBuilder::create( - IrBuilder::create( - out_domain, std::vector(out_domain.size(), true)), - inp->getDataType().value()); - if (!expanded) { - IrBuilder::create(UnaryOpType::Set, out_tensor, inp); - } else { - IrBuilder::create(out_tensor, inp, maybe_expanded_sizes); - } - return out_tensor; -} - -TensorView* expand_as(TensorView* inp, TensorView* other) { - auto inp_domain = TensorDomain::noReductions(inp->getMaybeRFactorDomain()); - auto other_domain = - TensorDomain::noReductions(other->getMaybeRFactorDomain()); - - TORCH_CHECK( - inp_domain.size() == other_domain.size(), - "Invalid expand_as, dimensions of inp don't match dimensions of other, expected other to be ", - inp_domain.size(), - " but received ", - other_domain.size()); - - std::vector out_domain; - std::vector maybe_expanded_sizes; - bool expanded = false; - for (auto i : c10::irange(inp_domain.size())) { - auto inp_id = inp_domain[i]; - auto other_id = other_domain[i]; - - auto out_id_builder = IterDomainBuilder(inp_id); - Val* maybe_expanded_size = inp_id->extent(); - - if (!inp_id->isBroadcast()) { - TORCH_INTERNAL_ASSERT( - !other_id->isBroadcast(), - "Cannot expand as a tensor if other has broadcast dimensions that don't map to broadcast dimensions in the input."); - if (!inp_id->isConstInt() && other_id->isConstInt()) { - out_id_builder.extent( - promoteSize(inp_id->extent(), other_id->extent())); - } - } else { - if (!other_id->isBroadcast()) { - expanded = true; - out_id_builder.expanded_extent(other_id->extent()); - maybe_expanded_size = other_id->extent(); - } else if (other_id->isBroadcast() && other_id->hasExpandedExtent()) { - expanded = true; - out_id_builder.expanded_extent(other_id->expandedExtent()); - maybe_expanded_size = other_id->expandedExtent(); - } - } - out_domain.push_back(out_id_builder.build()); - maybe_expanded_sizes.push_back(maybe_expanded_size); - } - - TensorView* out_tensor = IrBuilder::create( - IrBuilder::create( - out_domain, std::vector(out_domain.size(), true)), - inp->getDataType().value()); - if (!expanded) { - IrBuilder::create(UnaryOpType::Set, out_tensor, inp); - } else { - IrBuilder::create(out_tensor, inp, maybe_expanded_sizes); - } - return out_tensor; -} - WelfordResult Welford( TensorView* tv, const std::vector& axes, @@ -1299,6 +1124,27 @@ WelfordResult WelfordResult::rFactor(const std::vector& axes) { return WelfordResult{rf_tvs.at(0), rf_tvs.at(1), rf_tvs.at(2)}; } +TensorView* transpose( + TensorView* inp, + const std::unordered_map& old2new) { + auto inp_domain = TensorDomain::noReductions(inp->getMaybeRFactorDomain()); + std::vector out_domain(inp_domain.size()); + + auto new2old = ir_utils::normalizeOld2New(old2new, inp_domain.size()); + + for (const auto i : c10::irange(out_domain.size())) { + auto in_id = inp_domain[new2old[i]]; + out_domain[i] = in_id->cloneWithoutRFactor(); + } + + TensorView* out_tensor = IrBuilder::create( + IrBuilder::create( + out_domain, std::vector(out_domain.size(), true)), + inp->getDataType().value()); + IrBuilder::create(out_tensor, inp, new2old); + return out_tensor; +} + // COMPOUND OPERATIONS // add_alpha @@ -1736,12 +1582,12 @@ TensorView* shift( "."); } - out_dom.push_back( - IterDomainBuilder( - IrBuilder::create(out_start_offset), inp_axis->extent()) - .stop_offset(IrBuilder::create(out_stop_offset)) - .iter_type(inp_axis->getIterType()) - .build()); + out_dom.push_back(IrBuilder::create( + IrBuilder::create(out_start_offset), + inp_axis->extent(), + IrBuilder::create(out_stop_offset), + ParallelType::Serial, + inp_axis->getIterType())); } out = IrBuilder::create( @@ -1858,11 +1704,11 @@ TensorView* gather( const auto pad_right = pad_width[i][1]; // This may be over-conservative TORCH_INTERNAL_ASSERT(inp_axis->start()->isZeroInt()); + const auto inp_stop_offset = inp_axis->stopOffset()->getInt(); TORCH_INTERNAL_ASSERT( - inp_axis->stopOffset()->isConstInt(), + inp_stop_offset.has_value(), "Dynamic stop offset not supported: ", inp_axis); - const auto inp_stop_offset = inp_axis->stopOffset()->evaluateInt(); const auto extent_adjustment = window_dim - 1 - pad_left - pad_right; TORCH_CHECK( extent_adjustment >= 0, @@ -1873,19 +1719,19 @@ TensorView* gather( pad_left, ". Padding right: ", pad_right); - const auto out_stop_offset = inp_stop_offset + extent_adjustment; - out_root_domains.push_back( - IterDomainBuilder( - FusionGuard::getCurFusion()->zeroVal(), inp_axis->extent()) - .stop_offset(IrBuilder::create(out_stop_offset)) - .iter_type(inp_axis->getIterType()) - .build()); + const auto out_stop_offset = inp_stop_offset.value() + extent_adjustment; + out_root_domains.push_back(IrBuilder::create( + FusionGuard::getCurFusion()->zeroVal(), + inp_axis->extent(), + IrBuilder::create(out_stop_offset), + ParallelType::Serial, + inp_axis->getIterType())); // create a new axis for the gathered domain - out_gather_dom.push_back(IterDomainBuilder( - FusionGuard::getCurFusion()->zeroVal(), - IrBuilder::create(window_dim)) - .iter_type(IterType::Gather) - .build()); + out_gather_dom.push_back(IrBuilder::create( + FusionGuard::getCurFusion()->zeroVal(), + IrBuilder::create(window_dim), + ParallelType::Serial, + IterType::Gather)); } out_root_domains.insert( @@ -1926,11 +1772,13 @@ TORCH_CUDA_CU_API TensorView* viewAsScalar(TensorView* inp) { out_domain.push_back(d->cloneWithoutRFactor()); } - IterDomain* id = IterDomainBuilder( - inp_domain[0]->container()->zeroVal(), - IrBuilder::create(vec_size)) - .iter_type(IterType::VectorComponent) - .build(); + IterDomain* id = IrBuilder::create( + inp_domain[0]->container(), + inp_domain[0]->container()->zeroVal(), + IrBuilder::create(vec_size), + ParallelType::Serial, + IterType::VectorComponent, + false); out_domain.push_back(id); auto out = IrBuilder::create( @@ -1997,11 +1845,12 @@ static TensorView* newForMma( "and", tv_b); - new_domain.push_back( - IterDomainBuilder(id->start(), id->extent()) - .stop_offset(id->stopOffset()) - .iter_type(isReduction ? IterType::Reduction : id->getIterType()) - .build()); + new_domain.push_back(IrBuilder::create( + id->start(), + id->extent(), + id->stopOffset(), + ParallelType::Serial, + isReduction ? IterType::Reduction : id->getIterType())); } TensorDomain* td = IrBuilder::create( diff --git a/torch/csrc/jit/codegen/cuda/arith.h b/torch/csrc/jit/codegen/cuda/arith.h index 3b20777f5e04..53efba8f7301 100644 --- a/torch/csrc/jit/codegen/cuda/arith.h +++ b/torch/csrc/jit/codegen/cuda/arith.h @@ -248,7 +248,7 @@ TORCH_CUDA_CU_API TensorView* isposinf(TensorView*); TORCH_CUDA_CU_API Val* isreal(Val*); TORCH_CUDA_CU_API TensorView* isreal(TensorView*); -// Broadcasts inp based on bool vector. Size of broadcast bool vector should be +// Broadcasts v1 based on bool vector. Size of broadcast bool vector should be // the number of dims desired in the broadcasted tensor. This vector should be // true if output dim should be a broadcasted dim, and false if it is not a // broadcasted dim. Number of false entires must match the number of input dims. @@ -256,21 +256,17 @@ TORCH_CUDA_CU_API TensorView* broadcast( TensorView* inp, const std::vector& is_broadcast_dim); -// Expands input based on provided sizes. expand_sizes should be the same size -// as the input's root domain (really rfactor), and should be -1 for any -// dimension that should remain a symbolic size. For dimensions that remain -// broadcast after the expand should be set to 1, any dimension being expanded -// must be marked as a braodcast in the input and will be expanded to the -// provided constant size. Any dimension that's symbolic in the input but -// specified as a non -1 value will be set to that constant value. -TORCH_CUDA_CU_API TensorView* expand( +//! Transpose a tensor as specified by axis mappings. +//! +//! The transposition mapping is specified with a list of pairs from +//! old to new positions. Positions are relative to the noReduction +//! domain. +//! +//! \param inp Tensor to transpose +//! \param old2new Pairs of mapping from old to new positions. +TORCH_CUDA_CU_API TensorView* transpose( TensorView* inp, - const std::vector& expanded_sizes); - -// Expands input based on other. For dimensions in inp that are broadcast with a -// matching entry in other that's either a broadcast with expanded extent or a -// non broadcasted iter domain, inp will be expanded to other's size. -TORCH_CUDA_CU_API TensorView* expand_as(TensorView* inp, TensorView* other); + const std::unordered_map& old2new); // BINARY OPERATIONS // add diff --git a/torch/csrc/jit/codegen/cuda/codegen.cpp b/torch/csrc/jit/codegen/cuda/codegen.cpp index f0102f992abb..9aee0cc4c526 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.cpp +++ b/torch/csrc/jit/codegen/cuda/codegen.cpp @@ -207,7 +207,10 @@ class CudaKernelGenerator : private OptOutConstDispatch { const auto nDims = std::count_if( maybe_rfactor_domain.begin(), maybe_rfactor_domain.end(), - [](const IterDomain* id) { return !id->isReduction(); }); + [](const IterDomain* id) { + return !id->isReduction() && + id->getIterType() != IterType::BroadcastWithoutStride; + }); code_ << ", Tensor<" << tv->dtype() << ", " << nDims << "> " << varName(tv); } @@ -1587,9 +1590,6 @@ class CudaKernelGenerator : private OptOutConstDispatch { func_args.arg(read_pred); } - func_args.arg(genInline(grouped_grop->entrance_index())); - func_args.arg(genInline(grouped_grop->entrances())); - addProfileArguments(func_args, grouped_grop); indent() << "reduction::gridReduceGroup<" << template_args << ">(\n"; diff --git a/torch/csrc/jit/codegen/cuda/dispatch.cpp b/torch/csrc/jit/codegen/cuda/dispatch.cpp index eb18da1f909c..4557e8efb8de 100644 --- a/torch/csrc/jit/codegen/cuda/dispatch.cpp +++ b/torch/csrc/jit/codegen/cuda/dispatch.cpp @@ -129,9 +129,6 @@ void Expr::dispatch(T handler, Expr* expr) { case ExprType::TransposeOp: ptr(handler)->handle(expr->as()); return; - case ExprType::ExpandOp: - ptr(handler)->handle(expr->as()); - return; case ExprType::ShiftOp: ptr(handler)->handle(expr->as()); return; @@ -288,9 +285,6 @@ void Expr::constDispatch(T handler, const Expr* expr) { case ExprType::TransposeOp: ptr(handler)->handle(expr->as()); return; - case ExprType::ExpandOp: - ptr(handler)->handle(expr->as()); - return; case ExprType::ShiftOp: ptr(handler)->handle(expr->as()); return; @@ -455,9 +449,6 @@ void Expr::mutatorDispatch(T mutator, Expr* expr) { case ExprType::TransposeOp: ptr(mutator)->mutate(expr->as()); return; - case ExprType::ExpandOp: - ptr(mutator)->mutate(expr->as()); - return; case ExprType::ShiftOp: ptr(mutator)->mutate(expr->as()); return; @@ -687,9 +678,6 @@ void OptOutConstDispatch::handle(const Merge* stmt) { void OptOutConstDispatch::handle(const TransposeOp* stmt) { unhandled(stmt); } -void OptOutConstDispatch::handle(const ExpandOp* stmt) { - unhandled(stmt); -} void OptOutConstDispatch::handle(const ShiftOp* stmt) { unhandled(stmt); } @@ -816,9 +804,6 @@ void OptOutDispatch::handle(Merge* stmt) { void OptOutDispatch::handle(TransposeOp* stmt) { unhandled(stmt); } -void OptOutDispatch::handle(ExpandOp* stmt) { - unhandled(stmt); -} void OptOutDispatch::handle(ShiftOp* stmt) { unhandled(stmt); } diff --git a/torch/csrc/jit/codegen/cuda/dispatch.h b/torch/csrc/jit/codegen/cuda/dispatch.h index d83f1b28bd10..a27cbed211f9 100644 --- a/torch/csrc/jit/codegen/cuda/dispatch.h +++ b/torch/csrc/jit/codegen/cuda/dispatch.h @@ -78,7 +78,6 @@ class LoadStoreOp; class MmaOp; class BroadcastOp; class TransposeOp; -class ExpandOp; class ShiftOp; class GatherOp; class ViewAsScalar; @@ -146,7 +145,6 @@ class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase { virtual void handle(const Split* stmt); virtual void handle(const Merge* stmt); virtual void handle(const TransposeOp* stmt); - virtual void handle(const ExpandOp* stmt); virtual void handle(const ShiftOp* stmt); virtual void handle(const GatherOp* stmt); virtual void handle(const ViewAsScalar* stmt); @@ -204,7 +202,6 @@ class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase { virtual void handle(Split* stmt); virtual void handle(Merge* stmt); virtual void handle(TransposeOp* stmt); - virtual void handle(ExpandOp* stmt); virtual void handle(ShiftOp* stmt); virtual void handle(GatherOp* stmt); virtual void handle(ViewAsScalar* stmt); @@ -303,7 +300,6 @@ class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase { virtual void mutate(Split*); virtual void mutate(Merge*); virtual void mutate(TransposeOp*); - virtual void mutate(ExpandOp*); virtual void mutate(ShiftOp*); virtual void mutate(GatherOp*); virtual void mutate(ViewAsScalar*); diff --git a/torch/csrc/jit/codegen/cuda/examples/sinh_extension/README.md b/torch/csrc/jit/codegen/cuda/examples/sinh_extension/README.md deleted file mode 100644 index 14752a41a81b..000000000000 --- a/torch/csrc/jit/codegen/cuda/examples/sinh_extension/README.md +++ /dev/null @@ -1,11 +0,0 @@ -# Build - -``` -python setup.py install -``` - -# Test - -``` -python test.py -``` diff --git a/torch/csrc/jit/codegen/cuda/examples/sinh_extension/main.cpp b/torch/csrc/jit/codegen/cuda/examples/sinh_extension/main.cpp deleted file mode 100644 index 85f393581814..000000000000 --- a/torch/csrc/jit/codegen/cuda/examples/sinh_extension/main.cpp +++ /dev/null @@ -1,44 +0,0 @@ -#include -#include -#include -#include - -#include - -using namespace torch::jit::fuser::cuda; - -at::Tensor sinh_nvfuser(const at::Tensor& input) { - Fusion fusion; - FusionGuard fg(&fusion); - - int dim = input.dim(); - auto dtype = input.scalar_type(); - auto x = - TensorViewBuilder().ndims(dim).dtype(aten_to_data_type(dtype)).build(); - fusion.addInput(x); - - // Using equation sinh(x) = [ exp(x) - exp(-1) ] / 2 - auto output = div(sub(exp(x), exp(neg(x))), IrBuilder::create(2.0)); - fusion.addOutput(output); - - std::cout << "Create fusion:" << std::endl; - fusion.print(); - - auto lparams = schedulePointwise(&fusion, {input}); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input}, lparams); - auto outputs = fe.runFusion({input}, lparams); - - return outputs[0]; -} - -TORCH_LIBRARY(myop, m) { - m.def("sinh_nvfuser", sinh_nvfuser); -} - -TORCH_LIBRARY_IMPL(myop, CUDA, m) { - m.impl("sinh_nvfuser", sinh_nvfuser); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} diff --git a/torch/csrc/jit/codegen/cuda/examples/sinh_extension/setup.py b/torch/csrc/jit/codegen/cuda/examples/sinh_extension/setup.py deleted file mode 100644 index b7369aab8e9c..000000000000 --- a/torch/csrc/jit/codegen/cuda/examples/sinh_extension/setup.py +++ /dev/null @@ -1,14 +0,0 @@ -from setuptools import setup -from torch.utils.cpp_extension import BuildExtension, CUDAExtension - -setup( - name='nvfuser_extension', - ext_modules=[ - CUDAExtension( - name='nvfuser_extension', - pkg='nvfuser_extension', - sources=['main.cpp']) - ], - cmdclass={ - 'build_ext': BuildExtension - }) diff --git a/torch/csrc/jit/codegen/cuda/examples/sinh_extension/test.py b/torch/csrc/jit/codegen/cuda/examples/sinh_extension/test.py deleted file mode 100644 index 125db59b89ce..000000000000 --- a/torch/csrc/jit/codegen/cuda/examples/sinh_extension/test.py +++ /dev/null @@ -1,12 +0,0 @@ -import torch -import nvfuser_extension # noqa: F401 - -t = torch.randn((5, 5), device='cuda') -expected = torch.sinh(t) -output = torch.ops.myop.sinh_nvfuser(t) - -print("Expected:", expected) -print("Output:", output) - -assert torch.allclose(output, expected) -print("They match!") diff --git a/torch/csrc/jit/codegen/cuda/examples/sinh_libtorch/CMakeLists.txt b/torch/csrc/jit/codegen/cuda/examples/sinh_libtorch/CMakeLists.txt deleted file mode 100644 index 5ccb49816b3b..000000000000 --- a/torch/csrc/jit/codegen/cuda/examples/sinh_libtorch/CMakeLists.txt +++ /dev/null @@ -1,8 +0,0 @@ -cmake_minimum_required(VERSION 3.10 FATAL_ERROR) -project(sinh_example LANGUAGES CXX) -set(CMAKE_CXX_STANDARD 14) - -find_package(Torch REQUIRED) - -add_executable(sinh_example main.cpp) -target_link_libraries(sinh_example ${TORCH_LIBRARIES}) diff --git a/torch/csrc/jit/codegen/cuda/examples/sinh_libtorch/README.md b/torch/csrc/jit/codegen/cuda/examples/sinh_libtorch/README.md deleted file mode 100644 index 4345e88b3196..000000000000 --- a/torch/csrc/jit/codegen/cuda/examples/sinh_libtorch/README.md +++ /dev/null @@ -1,14 +0,0 @@ -# Build - -``` -mkdir build -cd build -cmake -DCMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')" .. -make -j -``` - -# Test - -``` -./sinh_example -``` diff --git a/torch/csrc/jit/codegen/cuda/examples/sinh_libtorch/main.cpp b/torch/csrc/jit/codegen/cuda/examples/sinh_libtorch/main.cpp deleted file mode 100644 index 3a28ce3a3f84..000000000000 --- a/torch/csrc/jit/codegen/cuda/examples/sinh_libtorch/main.cpp +++ /dev/null @@ -1,42 +0,0 @@ -#include -#include -#include -#include - -using namespace torch::jit::fuser::cuda; - -at::Tensor sinh_nvfuser(const at::Tensor& input) { - Fusion fusion; - FusionGuard fg(&fusion); - - int dim = input.dim(); - auto dtype = input.scalar_type(); - auto x = - TensorViewBuilder().ndims(dim).dtype(aten_to_data_type(dtype)).build(); - fusion.addInput(x); - - // Using equation sinh(x) = [ exp(x) - exp(-1) ] / 2 - auto output = div(sub(exp(x), exp(neg(x))), IrBuilder::create(2.0)); - fusion.addOutput(output); - - std::cout << "Create fusion:" << std::endl; - fusion.print(); - - auto lparams = schedulePointwise(&fusion, {input}); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input}, lparams); - auto outputs = fe.runFusion({input}, lparams); - - return outputs[0]; -} - -int main() { - auto t = at::randn({5, 5}, at::kCUDA); - auto expected = at::sinh(t); - auto output = sinh_nvfuser(t); - std::cout << "Expected:" << std::endl << expected << std::endl; - std::cout << "Output:" << std::endl << output << std::endl; - TORCH_CHECK(at::allclose(expected, output)); - std::cout << "They match!" << std::endl; -} diff --git a/torch/csrc/jit/codegen/cuda/executor.cpp b/torch/csrc/jit/codegen/cuda/executor.cpp index 2f589213c051..3fa017bb2c53 100644 --- a/torch/csrc/jit/codegen/cuda/executor.cpp +++ b/torch/csrc/jit/codegen/cuda/executor.cpp @@ -274,20 +274,13 @@ at::Tensor inferAndAlloc( const TensorView* tv, const std::vector& sizes, kir::ExpressionEvaluator& expr_eval, - // Map from dim -> expanded size of TV if any expanded broadcast dimensions - // exist - std::unordered_map expanded_map, const CompileOptions& options, bool zero_init = false) { FUSER_PERF_SCOPE("inferAndAlloc"); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - // Going to infer all the sizes of the TensorView std::vector inferred_sizes; - // Expanded sizes is at maximum the same size of inferred_sizes, as you could - // have a fully broadcasted tensor that's being expanded - std::vector expanded_sizes; - bool expanded_dim = false; + for (const auto size : sizes) { const auto inferred_val = expr_eval.evaluate(size); TORCH_INTERNAL_ASSERT( @@ -299,29 +292,6 @@ at::Tensor inferAndAlloc( ") for the buffer ", tv->toString()); inferred_sizes.push_back(inferred_val.value()); - if (expanded_map.count(expanded_sizes.size())) { - auto expanded_size = expanded_map.at(expanded_sizes.size()); - const auto inferred_expanded_size = expr_eval.evaluate(expanded_size); - TORCH_INTERNAL_ASSERT( - inferred_expanded_size.has_value(), - "Could not launch kernel as program could not infer the expanded extent ", - expanded_size->toString(), - "(", - expanded_size->name(), - ") for the buffer ", - tv->toString()); - if (inferred_val.value() != 1) { - TORCH_INTERNAL_ASSERT( - inferred_val.value() == inferred_expanded_size.value(), - "Attempted an expand on a non-broadcasted dimension,", - " but the expand doesn't match the dimensions size."); - } else { - expanded_dim = true; - } - expanded_sizes.push_back(inferred_expanded_size.value()); - } else { - expanded_sizes.push_back(inferred_val.value()); - } } const auto at_type = data_type_to_aten(tv->dtype()); @@ -330,21 +300,13 @@ at::Tensor inferAndAlloc( const auto tensor_options = at::TensorOptions().dtype(at_type).device(options.device); c10::IntArrayRef isizes(inferred_sizes); - auto zeros = at::zeros(isizes, tensor_options); - if (expanded_dim) { - return zeros.expand(expanded_sizes); - } - return zeros; + return at::zeros(isizes, tensor_options); } else { c10::IntArrayRef isizes(inferred_sizes); // Non Variable type guard for empty_cuda call at::AutoDispatchBelowADInplaceOrView non_variable_type_mode; - auto empty = at::native::empty_cuda( + return at::native::empty_cuda( isizes, at_type, c10::nullopt, options.device, c10::nullopt); - if (expanded_dim) { - return empty.expand(expanded_sizes); - } - return empty; } } @@ -359,18 +321,16 @@ at::Tensor inferAndAllocOutput( : domain->getRootDomain(); std::vector sizes; - std::unordered_map expand_map; for (const auto id : maybe_rfactor_domain) { - if (id->isReduction() || id->isStride()) { + if (id->isReduction() || id->isStride() || + id->getIterType() == IterType::BroadcastWithoutStride) { continue; } sizes.push_back(id->extent()); - if (id->isBroadcast() && id->hasExpandedExtent()) { - expand_map[sizes.size() - 1] = id->expandedExtent(); - } } - return inferAndAlloc(tv, sizes, expr_eval, expand_map, options, zero_init); + + return inferAndAlloc(tv, sizes, expr_eval, options, zero_init); } } // namespace @@ -636,11 +596,11 @@ FusionExecutor::GlobalBuffers FusionExecutor::allocGlobalVals( } if (alloc->zeroInit()) { global_buffers.buffers.push_back( - inferAndAlloc(tv, alloc->shape(), expr_eval, {}, options_, true)); + inferAndAlloc(tv, alloc->shape(), expr_eval, options_, true)); global_buffers.zero_init.push_back(true); } else { global_buffers.buffers.push_back( - inferAndAlloc(tv, alloc->shape(), expr_eval, {}, options_, false)); + inferAndAlloc(tv, alloc->shape(), expr_eval, options_, false)); global_buffers.zero_init.push_back(false); } // Remember the tensor buffer used for storing kernel profile @@ -662,7 +622,7 @@ std::vector FusionExecutor::allocOutputs( // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::vector outputs; for (const auto out_i : c10::irange(kernel->outputs().size())) { - // If the output is just trivially the input, just "copy" it over. + // Dummy output. if (kernel->outputs()[out_i]->isFusionInput()) { for (auto inp_i : c10::irange(kernel->inputs().size())) { if (kernel->inputs()[inp_i] == kernel->outputs()[out_i]) { @@ -681,14 +641,13 @@ std::vector FusionExecutor::allocOutputs( kernel->outputs()[out_i]->isA(), "Cannot allocate outputs that are not tensors."); auto output = kernel->outputs()[out_i]->as(); - if (alias_indices.count(out_i) != 0) { - // aliasing to inputs, no need to allocate real output + if (alias_indices.count(out_i) == 0) { outputs.push_back( - inferAndAlloc(output, {}, expr_eval, {}, options_, false)); + inferAndAllocOutput(output, expr_eval, options_, false)); } else { - // Allocate a real output + // aliasing to inputs, no need to allocate real output outputs.push_back( - inferAndAllocOutput(output, expr_eval, options_, false)); + inferAndAlloc(output, {}, expr_eval, options_, false)); } } } @@ -904,7 +863,7 @@ std::vector FusionExecutor::runFusion( allocated_outputs[entry.first] = inputs[entry.second].toTensor(); } } else { - // TODO: Update for aliasing, validate the outputs are the right sizes. + // TODO: Update this as well; executor_utils::validateKernelOutputs( fusion_, allocated_outputs, options_.device); } diff --git a/torch/csrc/jit/codegen/cuda/executor_utils.cpp b/torch/csrc/jit/codegen/cuda/executor_utils.cpp index 38f5ba31a63e..b48c7dfc3773 100644 --- a/torch/csrc/jit/codegen/cuda/executor_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/executor_utils.cpp @@ -582,13 +582,11 @@ void validateAlignedVectorizedFusionInputOutput( bool still_rightmost = true; for (auto i = aten_tensor.ndimension() - 1; i >= 0; --i) { const auto stride = aten_tensor.strides().at(i); - const auto size = aten_tensor.sizes().at(i); - // If this domain is contiguous or size == 1, then not necessary to check - // the stride. Otherwise, stride must be 1 if it's rightmost or - // divisible by word_size + // If this domain is contiguous, then not necessary to check the + // stride. Otherwise, stride must be 1 if it's rightmost or + // divisible by word_size. TORCH_INTERNAL_ASSERT( - stride == cur_contig_stride || size == 1 || - (still_rightmost && stride == 1) || + stride == cur_contig_stride || (still_rightmost && stride == 1) || (!still_rightmost && stride % word_size == 0), "Vectorization of ", tv->toString(), @@ -601,12 +599,9 @@ void validateAlignedVectorizedFusionInputOutput( stride) // If the domain is size-1, the next domain is still considered // rightmost. + const auto size = aten_tensor.sizes().at(i); still_rightmost = still_rightmost && size == 1; - // We do not update cur_contig_stride for size==1 dimensions, - // since we have specialized vectorization stride check for them - if (size != 1) { - cur_contig_stride = stride * size; - } + cur_contig_stride = stride * size; } } diff --git a/torch/csrc/jit/codegen/cuda/index_compute.cpp b/torch/csrc/jit/codegen/cuda/index_compute.cpp index a0c67ad4ef27..a000dca87a15 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/index_compute.cpp @@ -1313,7 +1313,8 @@ std::vector Index::getGlobalProducerStridedIndices( { int stride_i = 0; for (const auto i : c10::irange(root_dom.size())) { - if (root_dom[i]->isReduction()) { + if (root_dom[i]->isReduction() || + root_dom[i]->getIterType() == IterType::BroadcastWithoutStride) { strides[i] = GpuLower::current()->kernel()->oneVal(); continue; } @@ -1332,12 +1333,15 @@ std::vector Index::getGlobalProducerStridedIndices( if (root_dom[dim]->isReduction()) { continue; } + if (root_dom[dim]->getIterType() == IterType::BroadcastWithoutStride) { + continue; + } Val* root_ind = nullptr; if (producer_indexing.indexMap().find(root_dom[dim]) != producer_indexing.indexMap().end()) { root_ind = producer_indexing.indexMap().at(root_dom[dim]); - } else if (root_dom[dim]->isBroadcast()) { + } else if (root_dom[dim]->getIterType() == IterType::BroadcastWithStride) { root_ind = GpuLower::current()->kernel()->zeroVal(); } @@ -1381,7 +1385,9 @@ std::vector Index::getGlobalProducerStridedIndices( for (const auto i : c10::irange(root_dom.size())) { // If the domain is derived from a trivial reduction, no indexing // to create. - if (root_dom[i]->isReduction() || root_dom[i]->isBroadcast() || + if (root_dom[i]->isReduction() || + root_dom[i]->getIterType() == IterType::BroadcastWithoutStride || + root_dom[i]->getIterType() == IterType::BroadcastWithStride || gpu_lower->trivialReductionInfo().isDerived(root_dom[i])) { continue; } @@ -1859,7 +1865,9 @@ std::vector Index::getGlobalConsumerStridedIndices( { int stride_i = 0; for (const auto i : c10::irange(root_dom.size())) { - if (root_dom[i]->isReduction() || root_dom[i]->isStride()) { + if (root_dom[i]->isReduction() || + root_dom[i]->getIterType() == IterType::BroadcastWithoutStride || + root_dom[i]->isStride()) { strides[i] = GpuLower::current()->kernel()->oneVal(); continue; } @@ -1878,12 +1886,15 @@ std::vector Index::getGlobalConsumerStridedIndices( if (root_dom[dim]->isReduction() || root_dom[dim]->isStride()) { continue; } + if (root_dom[dim]->getIterType() == IterType::BroadcastWithoutStride) { + continue; + } Val* root_ind = nullptr; if (consumer_indexing.indexMap().find(root_dom[dim]) != consumer_indexing.indexMap().end()) { root_ind = consumer_indexing.indexMap().at(root_dom[dim]); - } else if (root_dom[dim]->isBroadcast()) { + } else if (root_dom[dim]->getIterType() == IterType::BroadcastWithStride) { root_ind = GpuLower::current()->kernel()->zeroVal(); } @@ -1925,7 +1936,9 @@ std::vector Index::getGlobalConsumerStridedIndices( root_dom.size(), GpuLower::current()->kernel()->zeroVal()); for (const auto i : c10::irange(root_dom.size())) { // See a comment in indexing to root domains in getGlobalProducerIndex. - if (root_dom[i]->isReduction() || root_dom[i]->isBroadcast() || + if (root_dom[i]->isReduction() || + root_dom[i]->getIterType() == IterType::BroadcastWithoutStride || + root_dom[i]->getIterType() == IterType::BroadcastWithStride || gpu_lower->trivialReductionInfo().isDerived(root_dom[i]) || root_dom[i]->isStride()) { continue; diff --git a/torch/csrc/jit/codegen/cuda/index_reference_replay.cpp b/torch/csrc/jit/codegen/cuda/index_reference_replay.cpp index 63b6f793d8b7..bdb334ab044a 100644 --- a/torch/csrc/jit/codegen/cuda/index_reference_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/index_reference_replay.cpp @@ -41,9 +41,8 @@ IterDomain* IndexReferenceReplay::idCopy(IterDomain* id) { // reduction. All we care about are the transformations, and trying to make // sure we track correctly a replaying with consistent reduction/broadcast // domains is challenging and unnecessary. - auto copied_id = IterDomainBuilder(id->start(), id->extent()) - .parallel_type(id->getParallelType()) - .build(); + auto copied_id = SimplifyingIrBuilder::create( + id->container(), id->start(), id->extent(), id->getParallelType()); replayed_ids_.emplace_back(copied_id); return copied_id; } diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp index dcf0b054ca35..0d67f780886b 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp @@ -1,5 +1,4 @@ #include -#include #include #include #include @@ -130,11 +129,6 @@ class ConstCheck : private OptOutConstDispatch { private: bool is_const_ = true; - // Returns true if all Val's in the hisotry of provided Val is an Int. Since - // our expression evaluator doesn't support any type besides int, it's - // important to check it is one. - bool is_int_ = true; - void handle(const Bool* b) final { is_const_ = is_const_ && b->isConst(); } @@ -158,10 +152,6 @@ class ConstCheck : private OptOutConstDispatch { } void handle(const Val* val) final { - if (!val->isAnInt()) { - is_int_ = false; - } - if (val->definition() != nullptr) { handle(val->definition()); } else { @@ -175,12 +165,6 @@ class ConstCheck : private OptOutConstDispatch { cc.handle(val); return cc.is_const_; } - - static bool isConstInt(const Val* val) { - ConstCheck cc; - cc.handle(val); - return cc.is_const_ && cc.is_int_; - } }; } // namespace @@ -192,30 +176,6 @@ bool Val::isConstScalar() const { return ConstCheck::isConst(this); } -bool Val::isConstInt() const { - if (!isAnInt()) { - return false; - } - return ConstCheck::isConstInt(this); -} - -int64_t Val::evaluateInt() { - TORCH_INTERNAL_ASSERT( - ConstCheck::isConstInt(this), - "Cannot get Int of not const integers through IR nodes, must use runtime ExpressionEvaluator."); - - if (this->as()->value().has_value()) { - return this->as()->value().value(); - } - - ExpressionEvaluator ee(fusion()); - auto evaluated_val = ee.evaluate(this); - TORCH_INTERNAL_ASSERT( - evaluated_val.has_value(), - "Detected a const integer but failed to infer its value."); - return evaluated_val.value(); -} - c10::optional Val::getInt() const { if (isConstScalar() && isAnInt()) { if (this->getValType() == ValType::Scalar) { diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h index 2f57141432bb..70f0b8f80fe5 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h @@ -243,24 +243,12 @@ class TORCH_CUDA_CU_API Val : public Statement { // Returns if all dependencies are constant scalars bool isConstScalar() const; - // Returns if all dependencies are constant integers - bool isConstInt() const; - bool isAnInt() const { return isScalar() && dtype_ == DataType::Int; } - // If this Val is an integer with a direct constant value associated with it, - // will return the value of that constant integer. If this integer has - // defining expressions it will return a c10::nullopt. Those values should be - // infered using evaluateInt. c10::optional getInt() const; - // If this Val is a constant integer, and its history is comprised only of - // constant integers, will return the value of that constant integer. Cannot - // make constant as expression evaluator takes non-constant Vals. - int64_t evaluateInt(); - // Returns if no dependencies and is a constant scalar. virtual bool isConst() const { return false; diff --git a/torch/csrc/jit/codegen/cuda/ir_builder.cpp b/torch/csrc/jit/codegen/cuda/ir_builder.cpp index 29843676b2b6..be9f4c63b97d 100644 --- a/torch/csrc/jit/codegen/cuda/ir_builder.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_builder.cpp @@ -54,7 +54,6 @@ IR_BUILDER_INSTANTIATE(NamedScalar) IR_BUILDER_INSTANTIATE(Split) IR_BUILDER_INSTANTIATE(Merge) IR_BUILDER_INSTANTIATE(TransposeOp) -IR_BUILDER_INSTANTIATE(ExpandOp) IR_BUILDER_INSTANTIATE(ShiftOp) IR_BUILDER_INSTANTIATE(GatherOp) IR_BUILDER_INSTANTIATE(ViewAsScalar) diff --git a/torch/csrc/jit/codegen/cuda/ir_cloner.cpp b/torch/csrc/jit/codegen/cuda/ir_cloner.cpp index 1da8b1e23a5a..01c65b9dea13 100644 --- a/torch/csrc/jit/codegen/cuda/ir_cloner.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_cloner.cpp @@ -128,10 +128,6 @@ void IrCloner::handle(const TransposeOp* op) { clone_ = IrBuilder::clone(op, this); } -void IrCloner::handle(const ExpandOp* op) { - clone_ = IrBuilder::clone(op, this); -} - void IrCloner::handle(const ShiftOp* op) { clone_ = IrBuilder::clone(op, this); } diff --git a/torch/csrc/jit/codegen/cuda/ir_cloner.h b/torch/csrc/jit/codegen/cuda/ir_cloner.h index 349fd6842223..05c156f31b07 100644 --- a/torch/csrc/jit/codegen/cuda/ir_cloner.h +++ b/torch/csrc/jit/codegen/cuda/ir_cloner.h @@ -78,7 +78,6 @@ class TORCH_CUDA_CU_API IrCloner : private OptInConstDispatch { void handle(const LoadStoreOp*) override; void handle(const MmaOp*) override; void handle(const TransposeOp*) override; - void handle(const ExpandOp*) override; void handle(const ShiftOp*) override; void handle(const GatherOp*) override; void handle(const ViewAsScalar*) override; diff --git a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h index 2da0c8de6d4a..b83b2b794605 100644 --- a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h @@ -393,7 +393,7 @@ class TORCH_CUDA_CU_API TransposeOp : public Expr { IrBuilderPasskey, TensorView* out, TensorView* in, - std::vector new2old); + std::vector new2old); TransposeOp(const TransposeOp* src, IrCloner* ir_cloner); @@ -405,44 +405,14 @@ class TORCH_CUDA_CU_API TransposeOp : public Expr { return in_; } - const std::vector& new2old() const { + const std::vector& new2old() const { return new2old_; } - std::vector old2new() const; - private: TensorView* const out_ = nullptr; TensorView* const in_ = nullptr; - const std::vector new2old_; -}; - -class TORCH_CUDA_CU_API ExpandOp : public Expr { - public: - ExpandOp( - IrBuilderPasskey, - TensorView* out, - TensorView* in, - std::vector _expanded_extents); - - ExpandOp(const ExpandOp* src, IrCloner* ir_cloner); - - TensorView* out() const { - return out_; - } - - TensorView* in() const { - return in_; - } - - const std::vector& expanded_extents() const { - return expanded_extents_; - } - - private: - TensorView* const out_ = nullptr; - TensorView* const in_ = nullptr; - std::vector expanded_extents_; + const std::vector new2old_; }; class TORCH_CUDA_CU_API TernaryOp : public Expr { @@ -671,54 +641,6 @@ class TORCH_CUDA_CU_API LoadStoreOp : public Expr { Val* const in_ = nullptr; }; -// Convenience utility to initialize IterDomain's without having to sort through -// all the default values. Intended to be used with -// IterDomain::IterDomain(IrBuilderPasskey IterDomainBuildArgs) -class TORCH_CUDA_CU_API IterDomainBuilder { - public: - // Match legacy constructor - IterDomainBuilder(Val* _start, Val* _extent); - - // Grab all the parameters from id to set the IterDomainBuilder - IterDomainBuilder(const IterDomain* id); - - // Resets defaults for rfactor, is padded dim, padded to size, and is mma - // swizzle which should only be set during scheduling. - IterDomainBuilder& resetSchedulingParams(); - - // Resets is_rfactor_domain - IterDomainBuilder& resetRfactor(); - - IterDomainBuilder& start(Val* _start); - IterDomainBuilder& extent(Val* _extent); - IterDomainBuilder& expanded_extent(Val* _expanded_extent); - IterDomainBuilder& stop_offset(Val* _stop_offset); - IterDomainBuilder& parallel_type(ParallelType _parallel_type); - IterDomainBuilder& iter_type(IterType _iter_type); - IterDomainBuilder& is_rfactor_domain(bool _is_rfactor_domain); - IterDomainBuilder& is_padded_dimension(bool _is_padded_dimension); - IterDomainBuilder& padded_to_size(c10::optional _padded_to_size); - IterDomainBuilder& is_mma_swizzled(bool _is_mma_swizzled); - - IterDomain* build() const; - - // Must have start and extent at least - IterDomainBuilder() = delete; - - Val* start_ = nullptr; - Val* extent_ = nullptr; - Val* expanded_extent_ = nullptr; - Val* stop_offset_ = nullptr; - ParallelType parallel_type_ = ParallelType::Serial; - IterType iter_type_ = IterType::Iteration; - - // Only relevant at scheduling time or compile time. - bool is_rfactor_domain_ = false; - bool is_padded_dimension_ = false; - c10::optional padded_to_size_ = c10::nullopt; - bool is_mma_swizzled_ = false; -}; - // Friends for direct access to split class TensorDomain; class ReplayTransformations; @@ -729,22 +651,29 @@ class IndexReferenceReplay; //! on IterDomains. class TORCH_CUDA_CU_API IterDomain : public Val { public: - IterDomain(IrBuilderPasskey, const IterDomainBuilder& args); - - // Legacy constructor, TODO: should start moving to use IterDomainBuildArgs - // constructor Same as the above but can set the offset of the stop point IterDomain( IrBuilderPasskey, Val* start, Val* extent, - Val* expanded_extent, + ParallelType parallel_type = ParallelType::Serial, + IterType iter_type = IterType::Iteration, + bool is_rfactor_domain = false, + bool is_padded_dimension = false, + c10::optional padded_to_size_ = c10::nullopt, + bool is_mma_swizzled = false); + + // Same as the above but can set the offset of the stop point + IterDomain( + IrBuilderPasskey, + Val* start, + Val* extent, Val* stop_offset, - ParallelType parallel_type, - IterType iter_type, - bool is_rfactor_domain, - bool is_padded_dimension, - c10::optional padded_to_size_, - bool is_mma_swizzled); + ParallelType parallel_type = ParallelType::Serial, + IterType iter_type = IterType::Iteration, + bool is_rfactor_domain = false, + bool is_padded_dimension = false, + c10::optional padded_to_size_ = c10::nullopt, + bool is_mma_swizzled = false); IterDomain(const IterDomain* src, IrCloner* ir_cloner); @@ -791,7 +720,8 @@ class TORCH_CUDA_CU_API IterDomain : public Val { } bool isBroadcast() const { - return getIterType() == IterType::Broadcast; + return getIterType() == IterType::BroadcastWithStride || + getIterType() == IterType::BroadcastWithoutStride; } bool isGather() const { @@ -825,6 +755,25 @@ class TORCH_CUDA_CU_API IterDomain : public Val { return (isBlockDim() || isThreadDim()); } + //! Convert to strided broadcast, used for supporting broadcast on output + void toStridedBroadcast() { + TORCH_INTERNAL_ASSERT( + isBroadcast(), + "toStridedBroadCast: converting an non-broadcast iterdomain", + this); + iter_type_ = IterType::BroadcastWithStride; + } + + // Convert a serial iterdomain to broadcast, used for implicit broadcast + void convertToBroadcast() { + TORCH_INTERNAL_ASSERT( + !isBroadcast() && !isReduction(), + "convertToBroadcast: converting an non-serial iterdomain", + this); + + iter_type_ = IterType::BroadcastWithStride; + } + void parallelize(ParallelType t); ParallelType getParallelType() const { @@ -848,22 +797,6 @@ class TORCH_CUDA_CU_API IterDomain : public Val { return extent_; } - bool hasExpandedExtent() const { - TORCH_INTERNAL_ASSERT( - expanded_extent_ == nullptr || isBroadcast(), - "Expanded extent is only relevant for strided broadcast dimensions", - " yet found an expanded extent without a strided broadcast iter type."); - return expanded_extent_ != nullptr; - } - - // Returns the expanded extent of a strided broadcast entry. - Val* expandedExtent() const { - TORCH_INTERNAL_ASSERT( - isBroadcast(), - "Expanded extent is only relevant for strided broadcast dimensions."); - return expanded_extent_; - } - //! Dimension padding interface: //! 2 modes are currently supported: //! @@ -988,20 +921,6 @@ class TORCH_CUDA_CU_API IterDomain : public Val { //! Valid range is defined as [start:-stop_offset] Val* const start_ = nullptr; Val* const extent_ = nullptr; - - // Broadcast dimensions are assumed to be size 1 for the sake of code - // generation. If a user though calls `expand` on a tensor that dimension is - // still considered a broadcast dimension. However if we ever output that - // dimension it should be a size dictated by the `expand` operation, and have - // a stride of zero. Since this extent is important to track, but not - // necessarily generate code for (still want loops on broadcast to be of size - // 0), we simply store it separately from extent_. Having an expanded_extent_ - // is only allowed with broadcasted dimsneions. Only in this instance does it - // make sense to have an expanded_extent_, because it's used when users are - // expecting return tensors to have a physical domain. If a user simply - // "broadcasts" an operation - Val* const expanded_extent_ = nullptr; - //! Distance of stop from the end Val* const stop_offset_ = nullptr; ParallelType parallel_type_ = ParallelType::Serial; diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp index 0e8290fffb47..9fafe0ef628b 100644 --- a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp @@ -122,11 +122,7 @@ void IrPrinter::handle(const IterDomain* id) { print_inline(id->stop()); os_ << " : "; } - if (id->isBroadcast() && id->hasExpandedExtent()) { - print_inline(id->expandedExtent()); - } else { - print_inline(id->extent()); - } + print_inline(id->extent()); os_ << "}"; if (id->isRFactorProduct()) os_ << "rf"; @@ -476,18 +472,6 @@ void IrPrinter::handle(const TransposeOp* top) { indent() << top->out() << " = transpose( " << top->in() << " )\n"; } -void IrPrinter::handle(const ExpandOp* eop) { - indent() << eop->out() << " = expand( " << eop->in() << ", {"; - std::stringstream ss; - for (auto expanded_extent : eop->expanded_extents()) { - if (ss.tellp()) { - ss << ", "; - } - ss << expanded_extent; - } - os_ << ss.str() << "} )\n"; -} - void IrPrinter::handle(const ShiftOp* sop) { indent() << sop->out() << " = shift( " << sop->in() << ", {" << sop->offsets() << "}, {" << sop->padWidth() << "} )\n"; diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.h b/torch/csrc/jit/codegen/cuda/ir_iostream.h index c960d3975965..4eda32572a31 100644 --- a/torch/csrc/jit/codegen/cuda/ir_iostream.h +++ b/torch/csrc/jit/codegen/cuda/ir_iostream.h @@ -92,7 +92,6 @@ class TORCH_CUDA_CU_API IrPrinter : public OptInConstDispatch { void handle(const MmaOp*) final; void handle(const BroadcastOp*) final; void handle(const TransposeOp*) final; - void handle(const ExpandOp*) final; void handle(const ShiftOp*) final; void handle(const GatherOp*) final; void handle(const ViewAsScalar*) final; diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp index af27c93a124a..f2c366e24c50 100644 --- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp @@ -682,7 +682,7 @@ TransposeOp::TransposeOp( IrBuilderPasskey passkey, TensorView* out, TensorView* in, - std::vector new2old) + std::vector new2old) : Expr(passkey, ExprType::TransposeOp), out_(out), in_(in), @@ -691,14 +691,17 @@ TransposeOp::TransposeOp( // should be checked at function transpose. TORCH_INTERNAL_ASSERT( - TensorDomain::noReductions(in->getMaybeRFactorDomain()).size() == - out->getMaybeRFactorDomain().size()); + !in->hasRFactor(), "Transposing rFactor tensors is not supported."); - TORCH_INTERNAL_ASSERT(new2old_.size() == out->getMaybeRFactorDomain().size()); + TORCH_INTERNAL_ASSERT( + TensorDomain::noReductions(in->getRootDomain()).size() == + out->getRootDomain().size()); + + TORCH_INTERNAL_ASSERT(new2old_.size() == out->getRootDomain().size()); // Make sure the entries of new2old are unique and range from 0 to // N-1, where N == new2old.size(). - std::set old_positions(new2old_.begin(), new2old_.end()); + std::set old_positions(new2old_.begin(), new2old_.end()); TORCH_INTERNAL_ASSERT(old_positions.size() == new2old_.size()); // old_positions is sorted, so the first entry must be 0. TORCH_INTERNAL_ASSERT( @@ -722,45 +725,6 @@ TransposeOp::TransposeOp(const TransposeOp* src, IrCloner* ir_cloner) in_(ir_cloner->clone(src->in_)), new2old_(src->new2old_) {} -std::vector TransposeOp::old2new() const { - std::vector old2new(new2old_.size()); - for (auto new_axis : c10::irange(new2old_.size())) { - auto old_axis = new2old_.at(new_axis); - old2new[old_axis] = new_axis; - } - return old2new; -} - -ExpandOp::ExpandOp( - IrBuilderPasskey passkey, - TensorView* out, - TensorView* in, - std::vector _expanded_extents) - : Expr(passkey, ExprType::ExpandOp), - out_(out), - in_(in), - expanded_extents_(std::move(_expanded_extents)) { - addOutput(out); - addInput(in); - for (auto expanded_extent : expanded_extents_) { - TORCH_INTERNAL_ASSERT(expanded_extent != nullptr); - TORCH_INTERNAL_ASSERT( - expanded_extent->dtype() == DataType::Int, - "Expanded extents must be of Int type."); - addInput(expanded_extent); - } -} - -ExpandOp::ExpandOp(const ExpandOp* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), - out_(ir_cloner->clone(src->out_)), - in_(ir_cloner->clone(src->in_)) { - expanded_extents_.reserve(src->expanded_extents_.size()); - for (const auto expanded_extent : src->expanded_extents_) { - expanded_extents_.push_back(ir_cloner->clone(expanded_extent)); - } -} - ShiftOp::ShiftOp( IrBuilderPasskey passkey, Val* out, @@ -946,105 +910,32 @@ LoadStoreOp::LoadStoreOp(const LoadStoreOp* src, IrCloner* ir_cloner) out_(ir_cloner->clone(src->out_)), in_(ir_cloner->clone(src->in_)) {} -IterDomainBuilder::IterDomainBuilder(Val* _start, Val* _extent) - : start_(_start), extent_(_extent) { - TORCH_INTERNAL_ASSERT( - start_ != nullptr && extent_ != nullptr, - "Start and extent are required to build an iter domain."); -} - -IterDomainBuilder::IterDomainBuilder(const IterDomain* id) - : start_(id->start()), - extent_(id->extent()), - expanded_extent_( - id->hasExpandedExtent() ? id->expandedExtent() : nullptr), - stop_offset_(id->stopOffset()), - parallel_type_(id->getParallelType()), - iter_type_(id->getIterType()), - is_rfactor_domain_(id->isRFactorProduct()), - is_padded_dimension_(id->hasPaddingToMultipleOfWarp()), - padded_to_size_(id->getMaybeSizeAfterPadding()), - is_mma_swizzled_(id->isMmaSwizzled()) {} - -IterDomainBuilder& IterDomainBuilder::resetSchedulingParams() { - parallel_type_ = ParallelType::Serial; - is_rfactor_domain_ = false; - is_padded_dimension_ = false; - padded_to_size_ = c10::nullopt; - is_mma_swizzled_ = false; - return *this; -} - -IterDomainBuilder& IterDomainBuilder::resetRfactor() { - return is_rfactor_domain(false); -} - -IterDomainBuilder& IterDomainBuilder::start(Val* _start) { - start_ = _start; - return *this; -} - -IterDomainBuilder& IterDomainBuilder::extent(Val* _extent) { - extent_ = _extent; - return *this; -} - -IterDomainBuilder& IterDomainBuilder::expanded_extent(Val* _expanded_extent) { - expanded_extent_ = _expanded_extent; - return *this; -} - -IterDomainBuilder& IterDomainBuilder::stop_offset(Val* _stop_offset) { - stop_offset_ = _stop_offset; - return *this; -} - -IterDomainBuilder& IterDomainBuilder::parallel_type( - ParallelType _parallel_type) { - parallel_type_ = _parallel_type; - return *this; -} - -IterDomainBuilder& IterDomainBuilder::iter_type(IterType _iter_type) { - iter_type_ = _iter_type; - return *this; -} - -IterDomainBuilder& IterDomainBuilder::is_rfactor_domain( - bool _is_rfactor_domain) { - is_rfactor_domain_ = _is_rfactor_domain; - return *this; -} - -IterDomainBuilder& IterDomainBuilder::is_padded_dimension( - bool _is_padded_dimension) { - is_padded_dimension_ = _is_padded_dimension; - return *this; -} - -IterDomainBuilder& IterDomainBuilder::padded_to_size( - c10::optional _padded_to_size) { - padded_to_size_ = _padded_to_size; - return *this; -} - -IterDomainBuilder& IterDomainBuilder::is_mma_swizzled(bool _is_mma_swizzled) { - is_mma_swizzled_ = _is_mma_swizzled; - return *this; -} - -IterDomain* IterDomainBuilder::build() const { - TORCH_INTERNAL_ASSERT( - start_ != nullptr && extent_ != nullptr, - "Start and extent are required to build an iter domain."); - return IrBuilder::create(start_->container(), *this); -} +IterDomain::IterDomain( + IrBuilderPasskey passkey, + Val* start, + Val* extent, + ParallelType parallel_type, + IterType iter_type, + bool is_rfactor_domain, + bool is_padded_dimension, + c10::optional padded_to_size, + bool is_mma_swizzled) + : IterDomain( + passkey, + start, + extent, + nullptr, + parallel_type, + iter_type, + is_rfactor_domain, + is_padded_dimension, + padded_to_size, + is_mma_swizzled) {} IterDomain::IterDomain( IrBuilderPasskey passkey, Val* start, Val* extent, - Val* expanded_extent, Val* stop_offset, ParallelType parallel_type, IterType iter_type, @@ -1055,7 +946,6 @@ IterDomain::IterDomain( : Val(passkey, ValType::IterDomain, DataType::Int), start_(start), extent_(extent), - expanded_extent_(expanded_extent), stop_offset_( stop_offset == nullptr ? passkey.ir_container_->zeroVal() : stop_offset), @@ -1082,28 +972,10 @@ IterDomain::IterDomain( " ."); } -IterDomain::IterDomain(IrBuilderPasskey passkey, const IterDomainBuilder& args) - - : IterDomain( - passkey, - args.start_, - args.extent_, - args.expanded_extent_, - args.stop_offset_, - args.parallel_type_, - args.iter_type_, - args.is_rfactor_domain_, - args.is_padded_dimension_, - args.padded_to_size_, - args.is_mma_swizzled_) {} - IterDomain::IterDomain(const IterDomain* src, IrCloner* ir_cloner) : Val(src, ir_cloner), start_(ir_cloner->clone(src->start_)), extent_(ir_cloner->clone(src->extent_)), - expanded_extent_( - src->hasExpandedExtent() ? ir_cloner->clone(src->expandedExtent()) - : nullptr), stop_offset_(ir_cloner->clone(src->stop_offset_)), parallel_type_(src->parallel_type_), iter_type_(src->iter_type_), @@ -1137,7 +1009,17 @@ bool IterDomain::sameAs(const Statement* other) const { // Returns a new IterDomain matching properties of this except for // is_rfactor_domain_ IterDomain* IterDomain::cloneWithoutRFactor() const { - auto cloned = IterDomainBuilder(this).resetRfactor().build(); + auto cloned = IrBuilder::create( + ir_container_, + start(), + extent(), + stopOffset(), + getParallelType(), + getIterType(), + false, + is_padded_dimension_, + padded_to_size_, + is_mma_swizzled_); return cloned; } @@ -1177,7 +1059,12 @@ IterDomain* IterDomain::merge(IterDomain* outer, IterDomain* inner) { IterType itype = outer->getIterType(); if (outer->isBroadcast() && inner->isBroadcast()) { - itype = IterType::Broadcast; + if (outer->getIterType() == IterType::BroadcastWithStride || + inner->getIterType() == IterType::BroadcastWithStride) { + itype = IterType::BroadcastWithStride; + } else { + itype = IterType::BroadcastWithoutStride; + } } else if (outer->isBroadcast() || inner->isBroadcast()) { itype = IterType::Iteration; } @@ -1189,12 +1076,12 @@ IterDomain* IterDomain::merge(IterDomain* outer, IterDomain* inner) { itype = IterType::Iteration; } - IterDomain* merged_id = - IterDomainBuilder( - outer->container()->zeroVal(), merged_id_size->as()) - .parallel_type(outer->getParallelType()) - .iter_type(itype) - .build(); + IterDomain* merged_id = IrBuilder::create( + outer->container(), + outer->container()->zeroVal(), + merged_id_size->as(), + outer->getParallelType(), + itype); IrBuilder::create(outer->container(), merged_id, outer, inner); @@ -1242,20 +1129,20 @@ std::pair IterDomain::split( "Partial split is only allowed with root domains"); } // outer loop IterDomain - IterDomain* ido = IterDomainBuilder( - in->container()->zeroVal(), - inner_split ? remainder->as() : factor) - .parallel_type(in->getParallelType()) - .iter_type(in->getIterType()) - .build(); + IterDomain* ido = IrBuilder::create( + in->container(), + in->container()->zeroVal(), + inner_split ? remainder->as() : factor, + in->getParallelType(), + in->getIterType()); // inner loop IterDomain - IterDomain* idi = IterDomainBuilder( - in->container()->zeroVal(), - inner_split ? factor : remainder->as()) - .parallel_type(in->getParallelType()) - .iter_type(in->getIterType()) - .build(); + IterDomain* idi = IrBuilder::create( + in->container(), + in->container()->zeroVal(), + inner_split ? factor : remainder->as(), + in->getParallelType(), + in->getIterType()); IrBuilder::create( in->container(), @@ -1848,12 +1735,13 @@ TensorDomain* TensorDomain::flatten(int64_t start_dim, int64_t end_dim) { IterDomain* merged_id = new_root_domain[start_dim]; for (auto i : c10::irange(start_dim + 1, end_dim + 1)) { - IterDomain* new_merged_id = - IterDomainBuilder( - merged_id->container()->zeroVal(), - mul(merged_id->extent(), new_root_domain[i]->extent())) - .is_rfactor_domain(true) - .build(); + IterDomain* new_merged_id = IrBuilder::create( + merged_id->container(), + merged_id->container()->zeroVal(), + mul(merged_id->extent(), new_root_domain[i]->extent()), + ParallelType::Serial, + IterType::Iteration, + true); IrBuilder::create(new_merged_id, merged_id, new_root_domain[i]); merged_id = new_merged_id; } diff --git a/torch/csrc/jit/codegen/cuda/ir_utils.cpp b/torch/csrc/jit/codegen/cuda/ir_utils.cpp index 98912c425c5a..12055873e88d 100644 --- a/torch/csrc/jit/codegen/cuda/ir_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_utils.cpp @@ -12,49 +12,6 @@ namespace fuser { namespace cuda { namespace ir_utils { -std::vector normalizeNew2Old( - const std::vector& new2old_in, - size_t ndims) { - TORCH_CHECK( - new2old_in.size() == ndims, - "There must be a transpose mapping for each dimension in domain"); - - // Canonicalize dimensions by wrapping each dim for the given ndims - std::vector new2old; - std::transform( - new2old_in.begin(), - new2old_in.end(), - std::inserter(new2old, new2old.begin()), - [ndims](int64_t entry) { return entry < 0 ? entry + ndims : entry; }); - - // Check if any adjusted values are < 0, or >= nDims, which are invalid - TORCH_CHECK( - std::none_of( - new2old.begin(), - new2old.end(), - [ndims](int64_t entry) { - return entry < 0 || (unsigned int)entry >= ndims; - }), - "New2Old axes are not within the number of dimensions of the provided domain.\t", - new2old); - - // Going to use sets, to see if any duplicate values are in the map. - std::set old_pos_set; - std::transform( - new2old.begin(), - new2old.end(), - std::inserter(old_pos_set, old_pos_set.begin()), - [](int64_t entry) { return entry; }); - - // Error out if duplicate values are found. - TORCH_CHECK( - new2old.size() == ndims && old_pos_set.size() == new2old.size(), - "Duplicate entries in transformation map."); - - // END VALIDATION CHECKS - return new2old; -} - std::vector normalizeOld2New( const std::unordered_map& old2new_in, size_t ndims) { @@ -299,26 +256,6 @@ struct SubstituteInExpr : public OptInDispatch { transpose_expr->container(), out, in, transpose_expr->new2old()); } - void handle(ExpandOp* expand_expr) final { - auto out = reference_->sameAs(expand_expr->out()) - ? substitute_->as() - : expand_expr->out(); - auto in = reference_->sameAs(expand_expr->in()) - ? substitute_->as() - : expand_expr->in(); - - auto expanded_extents = expand_expr->expanded_extents(); - if (substitute_->isA()) { - for (auto i : c10::irange(expanded_extents.size())) { - if (!expanded_extents[i]->sameAs(substitute_)) { - expanded_extents[i] = substitute_; - } - } - } - expr_ = IrBuilder::create( - expand_expr->container(), out, in, expanded_extents); - } - void handle(ShiftOp* shift_expr) final { auto out = reference_->sameAs(shift_expr->out()) ? substitute_ : shift_expr->out(); diff --git a/torch/csrc/jit/codegen/cuda/ir_utils.h b/torch/csrc/jit/codegen/cuda/ir_utils.h index b4c96ae14787..0b05b6fb5e86 100644 --- a/torch/csrc/jit/codegen/cuda/ir_utils.h +++ b/torch/csrc/jit/codegen/cuda/ir_utils.h @@ -124,14 +124,6 @@ auto filterByType(const ContainerType& inputs) { return filterByType(inputs.cbegin(), inputs.cend()); } -//! Returns a list of new-to-old mappings. -//! -//! This funcion canonicalizes the dimensions and validates that multiple old -//! dimension are mapped to the same new dimension. -std::vector normalizeNew2Old( - const std::vector& new2old_in, - size_t ndims); - //! Returns a list of new-to-old mappings. //! //! The input map does not need to be complete. Missing axes are diff --git a/torch/csrc/jit/codegen/cuda/kernel.cpp b/torch/csrc/jit/codegen/cuda/kernel.cpp index 322aa61614b6..5ffcf7c4b3de 100644 --- a/torch/csrc/jit/codegen/cuda/kernel.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel.cpp @@ -132,6 +132,8 @@ class KernelIrScanner : private IrVisitor { void handle(GroupedGridReduction* grid_reduction) final { summary_.has_grid_reductions = true; + const auto dom = ir_utils::getTvOutput(grid_reduction)->domain(); + updateGridReductionInLoop(dom); if (grid_reduction->isAllreduce()) { summary_.has_cooperative_grid_reduction = true; } @@ -158,6 +160,18 @@ class KernelIrScanner : private IrVisitor { private: size_t max_smem_type_size_ = 0; KernelSummary summary_; + + private: + void updateGridReductionInLoop(TensorDomain* dom) { + for (const auto i : c10::irange(dom->nDims())) { + const auto id = GpuLower::current()->caMap()->getConcreteMappedID( + dom->domain()[i], IdMappingMode::LOOP); + + summary_.has_cooperative_grid_reduction = + summary_.has_cooperative_grid_reduction || + !(id->isThread() || id->extent()->isOneInt()); + } + } }; //! Make sure tensors have valid allocations even when parallelized diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp index 5ced5a75d541..969fd53d0426 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp @@ -470,21 +470,17 @@ GroupedGridReduction::GroupedGridReduction( std::vector inputs, std::vector reduction_buffers, Allocate* sync_buffer, - Val* entrance_index, - Val* entrances, - bool is_allreduce) + bool is_fused) : GroupedReductionOp( passkey, std::move(reduction_op_types), std::move(init_vals), std::move(outputs), std::move(inputs), - is_allreduce, + is_fused, ExprType::GroupedGridReduction), reduction_buffers_(std::move(reduction_buffers)), - sync_buffer_(sync_buffer), - entrance_index_(entrance_index), - entrances_(entrances) { + sync_buffer_(sync_buffer) { TORCH_INTERNAL_ASSERT( passkey.ir_container_->isA(), "IR type only valid for Kernel container."); diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.h b/torch/csrc/jit/codegen/cuda/kernel_ir.h index a04debdd4912..d2525965d2e4 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.h +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.h @@ -524,7 +524,7 @@ class TORCH_CUDA_CU_API GridReduction final : public ReductionOp { Allocate* sync_buffer, Val* entrance_index, Val* entrances, - bool is_allreduce = false); + bool is_fused = false); Allocate* reduction_buffer() const { return reduction_buffer_; @@ -573,8 +573,6 @@ class TORCH_CUDA_CU_API GroupedGridReduction final : public GroupedReductionOp { std::vector in, std::vector reduction_buffers, Allocate* sync_buffer, - Val* entrance_index, - Val* entrances, bool is_allreduce = false); const std::vector& reduction_buffers() const { @@ -589,16 +587,6 @@ class TORCH_CUDA_CU_API GroupedGridReduction final : public GroupedReductionOp { return sync_buffer_; } - // Which instance of entering this grid reduction is this iteration? - Val* entrance_index() const { - return entrance_index_; - } - - // How many times will this grid reduction be entered - Val* entrances() const { - return entrances_; - } - const ParallelTypeBitmap& threadPredicate() const { return thread_predicate_; } @@ -614,8 +602,6 @@ class TORCH_CUDA_CU_API GroupedGridReduction final : public GroupedReductionOp { // use them, the thread predicate is held here separately from // Expr::predicate_. ParallelTypeBitmap thread_predicate_; - Val* entrance_index_ = nullptr; - Val* entrances_ = nullptr; }; //! Grid broadcast operation diff --git a/torch/csrc/jit/codegen/cuda/lower_allocation.cpp b/torch/csrc/jit/codegen/cuda/lower_allocation.cpp index afd386ce4130..85d09e4ca080 100644 --- a/torch/csrc/jit/codegen/cuda/lower_allocation.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_allocation.cpp @@ -160,7 +160,8 @@ class AllocationInserter : public kir::ExprMutator { std::vector alloc_dims; for (const auto id : maybe_rfactor_domain) { - if (id->isReduction() || id->isStride() || id->isBroadcast()) { + if (id->isReduction() || id->isStride() || + id->getIterType() == IterType::BroadcastWithoutStride) { continue; } auto extent = id->extent(); diff --git a/torch/csrc/jit/codegen/cuda/lower_fusion_simplifier.cpp b/torch/csrc/jit/codegen/cuda/lower_fusion_simplifier.cpp index a82ef0ae52f6..b3e9b1776acf 100644 --- a/torch/csrc/jit/codegen/cuda/lower_fusion_simplifier.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_fusion_simplifier.cpp @@ -105,14 +105,6 @@ class UnaryOpInserter : private kir::ExprMutator { top, IrBuilder::create(container, UnaryOpType::Set, out, in)); } - void handle(ExpandOp* eop) final { - auto out = eop->out(); - auto in = eop->in(); - auto container = out->container(); - registerReplace( - eop, IrBuilder::create(container, UnaryOpType::Set, out, in)); - } - void handle(ShiftOp* sop) final { auto out = sop->out(); auto in = sop->in(); diff --git a/torch/csrc/jit/codegen/cuda/lower_index.cpp b/torch/csrc/jit/codegen/cuda/lower_index.cpp index 66bbb54dbb21..f3163135e5dd 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_index.cpp @@ -335,8 +335,7 @@ void IndexLowering::handleGridReduction( "Found a reduction stage that has both a non-parallelized ", "reduction and a grid reduction. This is not supported, ", "please use rfactor to do the serialized reduction first, ", - "then the grid reduction. ", - rop->toString()); + "then the grid reduction."); // When using the fused reduction in a loop, the global work buffer // is double buffered to save global synchronizations. @@ -503,8 +502,6 @@ void IndexLowering::handleGridReduction( out_domain->domain().end(), [](IterDomain* id) { return !isTrivialIterDomain(id); }); - const bool privatize_buffer = !grouped_rop->isAllreduce(); - std::vector reduce_buffers; std::transform( outputs.begin(), @@ -514,25 +511,14 @@ void IndexLowering::handleGridReduction( return ir_utils::allocGlobalBufferForGridComm( getGridCommWorkBufferSize( out_domain, - privatize_buffer ? for_loops_ : std::vector(), + for_loops_, (grouped_rop->isAllreduce() && is_within_a_loop ? 2 : 1)), output->dtype(), false); }); const auto sync_buffer = ir_utils::allocGlobalBufferForGridComm( - getGridSyncBufferSize( - out_domain, - privatize_buffer ? for_loops_ : std::vector()), - DataType::Int, - true); - - const auto entrance_ind = privatize_buffer - ? getEntranceLinIndGridReduce(for_loops_) - : GpuLower::current()->kernel()->zeroVal(); - const auto n_entrances = privatize_buffer - ? getEntranceCountGridReduce(for_loops_) - : GpuLower::current()->kernel()->oneVal(); + getGridSyncBufferSize(out_domain, for_loops_), DataType::Int, true); // The thread predicate for GridReduction needs to be set // separately from the main predicate. Do not combine them like @@ -547,8 +533,6 @@ void IndexLowering::handleGridReduction( inputs, reduce_buffers, sync_buffer, - entrance_ind, - n_entrances, grouped_rop->isAllreduce()); grid_reduction->setThreadPredicate(thread_pred); diff --git a/torch/csrc/jit/codegen/cuda/lower_instrument.cpp b/torch/csrc/jit/codegen/cuda/lower_instrument.cpp index 56fb8cda783e..894de470c09c 100644 --- a/torch/csrc/jit/codegen/cuda/lower_instrument.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_instrument.cpp @@ -65,13 +65,12 @@ class Instrumentor : private kir::IrVisitor { // Allocate two integers for each entry. One is used for accumulating // cycles, and another for couting the number of hits const std::vector new_buffer_ids = { - IterDomainBuilder( + IrBuilder::create( GpuLower::current()->kernel()->zeroVal(), - IrBuilder::create(num_profile_entries)) - .build(), - IterDomainBuilder( - GpuLower::current()->kernel()->zeroVal(), IrBuilder::create(2)) - .build()}; + IrBuilder::create(num_profile_entries)), + IrBuilder::create( + GpuLower::current()->kernel()->zeroVal(), + IrBuilder::create(2))}; const auto buffer_domain = IrBuilder::create(new_buffer_ids); diff --git a/torch/csrc/jit/codegen/cuda/lower_replace_size.cpp b/torch/csrc/jit/codegen/cuda/lower_replace_size.cpp index 02b2e9a70edc..beec550e537f 100644 --- a/torch/csrc/jit/codegen/cuda/lower_replace_size.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_replace_size.cpp @@ -1,6 +1,5 @@ #include #include -#include #include #include #include @@ -181,11 +180,17 @@ void replaceSymbolicSizes(Fusion* fusion) { size_t dim = 0; for (auto id : root_td) { Val* orig_size = id->extent(); + // Output sizes could have reduction axes, which isn't what gets output. // NOLINTNEXTLINE(bugprone-branch-clone) - if (id->isReduction()) { + if (id->isReduction() || + (id->getIterType() == IterType::BroadcastWithoutStride)) { continue; - } else if (orig_size->isConstScalar()) { + } else if ( + id->isRFactorProduct() || + // NOLINTNEXTLINE(bugprone-branch-clone) + (id->getIterType() == IterType::BroadcastWithStride) || + orig_size->isConstScalar()) { dim++; continue; } diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.cpp b/torch/csrc/jit/codegen/cuda/lower_utils.cpp index fae6a5dfb1f7..daabe6e28f15 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_utils.cpp @@ -98,7 +98,6 @@ bool isTvOp(const Expr* expr) { expr->getExprType().value() == ExprType::MmaOp || expr->getExprType().value() == ExprType::BroadcastOp || expr->getExprType().value() == ExprType::TransposeOp || - expr->getExprType().value() == ExprType::ExpandOp || expr->getExprType().value() == ExprType::ShiftOp || expr->getExprType().value() == ExprType::GatherOp || expr->getExprType().value() == ExprType::ViewAsScalar || @@ -264,8 +263,8 @@ c10::optional getMaybeWarpReductionDim( return c10::optional(reduction_on_xdim); } - if (reduction_on_xdim->extent()->isConstInt()) { - auto extent_value = reduction_on_xdim->extent()->evaluateInt(); + if (reduction_on_xdim->extent()->isConst()) { + auto extent_value = reduction_on_xdim->extent()->getInt().value(); if (extent_value % at::cuda::warp_size() == 0) { return c10::optional(reduction_on_xdim); } @@ -366,8 +365,8 @@ kir::Allocate* allocGlobalBufferForGridComm( DataType dtype, bool zero_init) { const std::vector new_buffer_ids = { - IrBuilder::create(IterDomainBuilder( - GpuLower::current()->kernel()->zeroVal(), buffer_size))}; + IrBuilder::create( + GpuLower::current()->kernel()->zeroVal(), buffer_size)}; const auto buffer_domain = IrBuilder::create(new_buffer_ids); const auto buffer_tv = IrBuilder::create(buffer_domain, dtype, MemoryType::Global); diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.cpp b/torch/csrc/jit/codegen/cuda/lower_validation.cpp index 6f1c266c385a..bf282fec0753 100644 --- a/torch/csrc/jit/codegen/cuda/lower_validation.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_validation.cpp @@ -177,6 +177,24 @@ void validateIr(Fusion* fusion) { fusion->validateInputs(); + // Convert all input broadcast iterdomains to strided + for (auto tv : ir_utils::filterByType(fusion->inputs())) { + for (auto id : tv->getMaybeRFactorDomain()) { + if (id->isBroadcast()) { + id->toStridedBroadcast(); + } + } + } + + // Convert all output broadcast iterdomains to strided + for (auto tv : ir_utils::filterByType(fusion->outputs())) { + for (auto id : tv->getMaybeRFactorDomain()) { + if (id->isBroadcast()) { + id->toStridedBroadcast(); + } + } + } + // Validate Parallelization ValidateSiblings::validate(fusion); @@ -888,8 +906,8 @@ void validateMmaTensors(MmaOp* mma) { GpuLower::current()->parallelDimensionMap(); TORCH_INTERNAL_ASSERT( paralel_dim_map.isExact(ptype) && - paralel_dim_map.get(ptype)->isConstInt() && - paralel_dim_map.get(ptype)->evaluateInt() == + paralel_dim_map.get(ptype)->getInt().has_value() && + paralel_dim_map.get(ptype)->getInt().value() == at::cuda::warp_size(), "TIDx is reserved for lane id in mma kernels, and it needs to be exactly a warp"); tidx_validated = true; diff --git a/torch/csrc/jit/codegen/cuda/mutator.cpp b/torch/csrc/jit/codegen/cuda/mutator.cpp index e828f2fa866d..37ee063d2121 100644 --- a/torch/csrc/jit/codegen/cuda/mutator.cpp +++ b/torch/csrc/jit/codegen/cuda/mutator.cpp @@ -63,13 +63,20 @@ void OptOutMutator::mutate(IterDomain* id) { stop_offset->sameAs(id->stopOffset())) { return; } - registerMutation( - id, - IterDomainBuilder(id) - .start(start) - .extent(extent) - .stop_offset(stop_offset) - .build()); + + Val* mutated_val = IrBuilder::create( + id->container(), + start, + extent, + stop_offset, + id->getParallelType(), + id->getIterType(), + id->isRFactorProduct()); + if (id->hasPaddingToMultipleOfWarp()) { + mutated_val->as()->padToMultipleOfWarp( + id->getMaybeSizeAfterPadding()); + } + registerMutation(id, mutated_val); } void OptOutMutator::mutate(TensorDomain* td) { @@ -325,32 +332,6 @@ void OptOutMutator::mutate(TransposeOp* top) { IrBuilder::create(container, out, in, new2old); } -void OptOutMutator::mutate(ExpandOp* eop) { - bool is_same = true; - - TensorView* out = maybeMutated(eop->out())->as(); - is_same = is_same && out->sameAs(eop->out()); - TensorView* in = maybeMutated(eop->in())->as(); - is_same = is_same && in->sameAs(eop->in()); - - std::vector expanded_extents; - expanded_extents.reserve(eop->expanded_extents().size()); - for (auto expanded_extent : eop->expanded_extents()) { - expanded_extents.push_back(maybeMutated(expanded_extent)); - if (!expanded_extents.back()->sameAs(expanded_extent)) { - is_same = false; - } - } - - if (is_same) { - return; - } - - auto container = eop->container(); - container->removeExpr(eop); - IrBuilder::create(container, out, in, expanded_extents); -} - void OptOutMutator::mutate(ShiftOp* sop) { Val* out = maybeMutated(sop->out())->asVal(); Val* in = maybeMutated(sop->in())->asVal(); diff --git a/torch/csrc/jit/codegen/cuda/ops/alias.cpp b/torch/csrc/jit/codegen/cuda/ops/alias.cpp index f6e12533fa6d..d5bbd4878828 100644 --- a/torch/csrc/jit/codegen/cuda/ops/alias.cpp +++ b/torch/csrc/jit/codegen/cuda/ops/alias.cpp @@ -1,6 +1,5 @@ #include #include -#include #include #include #include @@ -194,73 +193,6 @@ TensorView* unsqueeze(TensorView* x, int dim) { return broadcast(x, broadcast_axes); } -TensorView* permute(TensorView* x, const std::vector& new2old) { - auto inp_domain = TensorDomain::noReductions(x->getMaybeRFactorDomain()); - std::vector out_domain(inp_domain.size()); - - auto normalized_new2old = - ir_utils::normalizeNew2Old(new2old, inp_domain.size()); - - for (const auto i : c10::irange(out_domain.size())) { - auto in_id = inp_domain[new2old[i]]; - out_domain[i] = in_id->cloneWithoutRFactor(); - } - - TensorView* out_tensor = IrBuilder::create( - IrBuilder::create( - out_domain, std::vector(out_domain.size(), true)), - x->getDataType().value()); - IrBuilder::create(out_tensor, x, normalized_new2old); - return out_tensor; -} - -TensorView* transpose(TensorView* x, int64_t dim0, int64_t dim1) { - const auto ndims = static_cast(x->domain()->noReductions().size()); - - if (dim0 < 0) { - dim0 = ndims + dim0; - } - - if (dim1 < 0) { - dim1 = ndims + dim1; - } - - TORCH_CHECK( - dim0 >= 0 && dim0 <= ndims, "Invalid transpose dimension 0: ", dim0); - - TORCH_CHECK( - dim1 >= 0 && dim1 <= ndims, "Invalid transpose dimension 1: ", dim1); - - std::vector new2old(ndims); - for (const auto i : c10::irange(ndims)) { - if (i == dim0) { - new2old[i] = dim1; - } else if (i == dim1) { - new2old[i] = dim0; - } else { - new2old[i] = i; - } - } - return permute(x, new2old); -} - -TensorView* transpose(TensorView* x) { - const auto ndims = static_cast(x->domain()->noReductions().size()); - - TORCH_CHECK( - ndims <= 2, - "Expected a tensor with <= 2 dimensions, but it has ", - ndims, - "D."); - - // short-circuit: return original tensorview if less than 2 dimensions - if (ndims < 2) { - return x; - } - - return transpose(x, 0, 1); -} - } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/ops/alias.h b/torch/csrc/jit/codegen/cuda/ops/alias.h index f363f01bb409..f33a5a745a89 100644 --- a/torch/csrc/jit/codegen/cuda/ops/alias.h +++ b/torch/csrc/jit/codegen/cuda/ops/alias.h @@ -39,27 +39,6 @@ TORCH_CUDA_CU_API TensorView* squeeze( TORCH_CUDA_CU_API TensorView* unsqueeze(TensorView* x, int dim); -//! Permute a tensor as specified by axis mappings. -//! -//! The transposition mapping is specified with a list of pairs from -//! new to old positions. Positions are relative to the noReduction -//! domain. -//! -//! \param inp Tensor to transpose -//! \param new2old vector mapping from new to old positions. -TORCH_CUDA_CU_API TensorView* permute( - TensorView* x, - const std::vector& new2old); - -//! Transpose a tensor by swapping the two dimensions. -TORCH_CUDA_CU_API TensorView* transpose( - TensorView* x, - int64_t dim0, - int64_t dim1); - -//! Transpose a 2D tensor. -TORCH_CUDA_CU_API TensorView* transpose(TensorView* x); - } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/parser.cpp b/torch/csrc/jit/codegen/cuda/parser.cpp index 8d37c2186b24..169d41bb875f 100644 --- a/torch/csrc/jit/codegen/cuda/parser.cpp +++ b/torch/csrc/jit/codegen/cuda/parser.cpp @@ -324,8 +324,38 @@ struct MemoryFormat { } // returns transpose map to achieve permutation on non-permuted tensor - // note: used for aten::permute API and codegen tranpose API - std::vector apply() const { + // note: used for codegen transpose API + std::unordered_map apply() const { + std::unordered_map permute; + if (hasPermutation()) { + int rank = permuted_order_.size(); + for (const auto i : c10::irange(rank)) { + if (permuted_order_[i] != rank - 1 - i) { + permute[permuted_order_[i]] = rank - 1 - i; + } + } + } + return permute; + } + + // returns transpose map to restore back to non-permuted tensor + // note: used for codegen transpose API + std::unordered_map restore() const { + std::unordered_map permute; + if (hasPermutation()) { + int rank = permuted_order_.size(); + for (const auto i : c10::irange(rank)) { + if (permuted_order_[i] != rank - 1 - i) { + permute[rank - 1 - i] = permuted_order_[i]; + } + } + } + return permute; + } + + // returns transpose map to achieve permutation on non-permuted tensor + // note: used for aten::permute API + std::vector apply_vec() const { std::vector ret; if (hasPermutation()) { ret.resize(permuted_order_.size()); @@ -335,8 +365,8 @@ struct MemoryFormat { } // returns transpose map to restore back to non-permuted tensor - // note: used for aten::permute API and codegen transpose API - std::vector restore() const { + // note: used for aten::permute API + std::vector restore_vec() const { std::vector ret; if (hasPermutation()) { int rank = permuted_order_.size(); @@ -477,11 +507,11 @@ class ValueHolder { // restore source permutation if (format_s.hasPermutation()) { - tv = permute(tv, format_s.restore()); + tv = transpose(tv, format_s.restore()); } // apply destination permutation if (format_d.hasPermutation()) { - tv = permute(tv, format_d.apply()); + tv = transpose(tv, format_d.apply()); } return tv; } @@ -755,13 +785,13 @@ class IrParser { for (const auto& i : c10::irange(fusion->inputs().size())) { const auto& entry = permuted_tensors.find(fusion->inputs()[i]); if (entry != permuted_tensors.end()) { - fusion->setPermutationOnInput(i, entry->second.apply()); + fusion->setPermutationOnInput(i, entry->second.apply_vec()); } } for (const auto& i : c10::irange(fusion->outputs().size())) { const auto& entry = permuted_tensors.find(fusion->outputs()[i]); if (entry != permuted_tensors.end()) { - fusion->setPermutationOnOutput(i, entry->second.restore()); + fusion->setPermutationOnOutput(i, entry->second.restore_vec()); } } return fusion; @@ -3339,9 +3369,8 @@ class IrParser { std::vector s_vec = opt_s_vec.value(); // apply permutation auto permutation = format.apply(); - for (auto new_axis : c10::irange(permutation.size())) { - auto old_axis = permutation.at(new_axis); - s_vec[new_axis] = opt_s_vec.value()[old_axis]; + for (const auto& p : permutation) { + s_vec[p.second] = opt_s_vec.value()[p.first]; } // copying stride properties because we need to permute it diff --git a/torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.cpp b/torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.cpp index 08fa1d4e0411..c619b557fa12 100644 --- a/torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.cpp +++ b/torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.cpp @@ -193,11 +193,8 @@ void initNvFuserPythonBindings(PyObject* module) { py::arg("dtype") = torch::jit::fuser::cuda::DataType::Float, py::return_value_policy::reference) .def( - // TODO: Should the inernals of this function live more explicitly in - // TensorViewBuilder? "define_tensor", [](FusionDefinitionContextManager& self, - // TODO: This should come in as int64_t not int std::vector sizes, std::vector strides, torch::jit::fuser::cuda::DataType dtype = @@ -208,23 +205,17 @@ void initNvFuserPythonBindings(PyObject* module) { sizes.size(), strides.size()); - // TensorViewBuilder assumes any dim with a compile time constant - // size == 1 is a "maybe broadcast" axis, symbolic sizes are - // identified by -1, and size == 0 is not supported. - - // Translate to TensorViewBuilder's view of the world. - std::vector maybe_symbolic_sizes; - maybe_symbolic_sizes.reserve(sizes.size()); + std::vector domain_sizes; for (const auto i : c10::irange(sizes.size())) { - TORCH_INTERNAL_ASSERT( - sizes[i] > 0, - "Size of ", - sizes[i], - " is not supported in nvFuser. Expected size > 0."); if (sizes[i] == 1) { - maybe_symbolic_sizes.push_back(1); + domain_sizes.push_back(IrBuilder::create( + self.fusionPtr()->zeroVal(), + self.fusionPtr()->oneVal(), + ParallelType::Serial, + IterType::BroadcastWithStride)); } else { - maybe_symbolic_sizes.push_back(-1); + domain_sizes.push_back(IrBuilder::create( + self.fusionPtr()->zeroVal(), IrBuilder::create())); } } @@ -238,12 +229,9 @@ void initNvFuserPythonBindings(PyObject* module) { } } - return TensorViewBuilder() - .ndims(maybe_symbolic_sizes.size()) - .contiguity(contig_info) - .shape(maybe_symbolic_sizes) - .dtype(dtype) - .build(); + return IrBuilder::create( + IrBuilder::create(domain_sizes, contig_info), + dtype); }, py::arg("sizes"), py::arg("strides"), diff --git a/torch/csrc/jit/codegen/cuda/root_domain_map.h b/torch/csrc/jit/codegen/cuda/root_domain_map.h index a4a3b5a440e2..49644cb93f36 100644 --- a/torch/csrc/jit/codegen/cuda/root_domain_map.h +++ b/torch/csrc/jit/codegen/cuda/root_domain_map.h @@ -433,10 +433,6 @@ class TORCH_CUDA_CU_API ComputeAtRootDomainMapBuilder void handle(TransposeOp* op) override; - void handle(ExpandOp* op) override { - mapPointwiseOrReductionOp(op); - } - void handle(GatherOp* op) override; void handle(TensorView* tv) override; diff --git a/torch/csrc/jit/codegen/cuda/runtime/fused_reduction.cu b/torch/csrc/jit/codegen/cuda/runtime/fused_reduction.cu index 749941a846ce..3cc6f586b265 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/fused_reduction.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/fused_reduction.cu @@ -258,13 +258,10 @@ class ParallelReduce { index_utils:: maskedSize( gridDim) * - index_utils:: - maskedSize( - blockDim) * grid_red_size; global_work_buffer += global_buffer_size; } - flip = !flip; + flip = ~flip; // How many grid reductions have to be performed, in the grid dimension const auto num_block_iters = index_utils:: @@ -453,7 +450,9 @@ class ParallelReduce { // need of an additional grid_sync. Since we flip back and forth between // sections of the buffer, the one grid sync protects the other part of // the buffer. + } else { + // Forward protect the smem used in this reduction if (grid_reduce_participate) { if (last_block && has_block_result && block_reduce_participate && write_pred) { @@ -461,9 +460,8 @@ class ParallelReduce { out, shared_buf, block_reduction_idx * block_reduction_size_2); } } + block_sync::sync(); } - // Forward protect the smem used in this reduction - block_sync::sync(); } } @@ -632,14 +630,11 @@ class ParallelReduce { index_utils:: maskedSize( gridDim) * - index_utils:: - maskedSize( - blockDim) * grid_red_size; global_work_buffer1 += global_buffer_size; global_work_buffer2 += global_buffer_size; } - flip = !flip; + flip = ~flip; // Per-block partial reduction to global work buffer { @@ -658,7 +653,6 @@ class ParallelReduce { copyTuple(global_work_buffer1, work_buf_offset, block_result); } } - block_sync::sync(); } { const auto block_result = reduceBlock( @@ -1091,7 +1085,9 @@ class ParallelReduce { // need of an additional grid_sync. Since we flip back and forth between // sections of the buffer, the one grid sync protects the other part of // the buffer. + } else { + // Forward protect the smem used in this reduction if (grid_reduce_participate) { if (last_block && has_block_result && block_reduce_participate && write_pred) { @@ -1099,9 +1095,8 @@ class ParallelReduce { out, shared_buf, block_reduction_idx * block_reduction_size_2); } } + block_sync::sync(); } - // Forward protect the smem used in this reduction - block_sync::sync(); } } diff --git a/torch/csrc/jit/codegen/cuda/runtime/grid_broadcast.cu b/torch/csrc/jit/codegen/cuda/runtime/grid_broadcast.cu index 7772f159bcf6..8de1d7c32e0d 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/grid_broadcast.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/grid_broadcast.cu @@ -58,6 +58,7 @@ __device__ void broadcast( __threadfence(); } + bool null = false; grid_sync::sync( sync_flags[grid_seg_idx], grid_seg_size); diff --git a/torch/csrc/jit/codegen/cuda/runtime/grid_reduction.cu b/torch/csrc/jit/codegen/cuda/runtime/grid_reduction.cu index e44671d4f5fe..70e6b8b675aa 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/grid_reduction.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/grid_reduction.cu @@ -204,6 +204,10 @@ __device__ void gridReduce( const nvfuser_index_t n_entrances) { T block_reduction_val = init_val; + // entrance index only matters for non-persistent re-entrant grid reductions. + const nvfuser_index_t entrance_ind_ = PERSISTENT_REDUCTION ? 0 : entrance_ind; + const nvfuser_index_t n_entrances_ = PERSISTENT_REDUCTION ? 1 : n_entrances; + // Do block reduction when required if (X_THREAD || Y_THREAD || Z_THREAD) { blockReduce( @@ -263,7 +267,7 @@ __device__ void gridReduce( } else { // Use a different sync flag for each call grid_sync::sync( - sync_flags[entrance_ind * grid_segment_size + idx_in_grid_segment], + sync_flags[entrance_ind_ * grid_segment_size + idx_in_grid_segment], grid_reduction_segment_size); } @@ -434,9 +438,7 @@ __device__ void gridReduceGroup( int64_t* sync_flags, void* shared_buf, bool read_pred, - bool write_pred, - const nvfuser_index_t entrance_ind, - const nvfuser_index_t n_entrances) { + bool write_pred) { // Number of values to reduce in the reduction segment const auto grid_reduction_segment_size = index_utils::maskedSize(gridDim); @@ -452,18 +454,13 @@ __device__ void gridReduceGroup( const auto block_reduction_segment_size = index_utils::maskedSize(blockDim); - // Number of reductions in the grid - const nvfuser_index_t grid_segment_size = PERSISTENT_REDUCTION - ? 1 - : index_utils::maskedSize(gridDim); - // advance to the offset for this segment // index of reduction * size of the reduction * size of threads - work_buf1 += (entrance_ind * grid_segment_size + idx_in_grid_segment) * - grid_reduction_segment_size * block_reduction_segment_size; + work_buf1 += idx_in_grid_segment * grid_reduction_segment_size * + block_reduction_segment_size; - work_buf2 += (entrance_ind * grid_segment_size + idx_in_grid_segment) * - grid_reduction_segment_size * block_reduction_segment_size; + work_buf2 += idx_in_grid_segment * grid_reduction_segment_size * + block_reduction_segment_size; gridReduce2PartialReduction< X_BLOCK, @@ -499,14 +496,8 @@ __device__ void gridReduceGroup( idx_in_grid_segment, block_reduction_segment_size); - if (PERSISTENT_REDUCTION) { - grid_sync::sync( - sync_flags[idx_in_grid_segment], grid_reduction_segment_size); - } else { - grid_sync::sync( - sync_flags[entrance_ind * grid_segment_size + idx_in_grid_segment], - grid_reduction_segment_size); - } + grid_sync::sync( + sync_flags[idx_in_grid_segment], grid_reduction_segment_size); bool last_block = index_utils::maskedIsLast(blockIdx, gridDim); @@ -569,8 +560,6 @@ __device__ void gridReduceGroup( void* shared_buf, bool read_pred, bool write_pred, - const nvfuser_index_t entrance_ind, - const nvfuser_index_t n_entrances, int64_t& cycles, int64_t& count) { int64_t start_counter = 0; @@ -605,9 +594,7 @@ __device__ void gridReduceGroup( sync_flags, shared_buf, read_pred, - write_pred, - entrance_ind, - n_entrances); + write_pred); if (index_utils::maskedIsLast(blockIdx, gridDim) && index_utils::maskedIsZero(threadIdx)) { diff --git a/torch/csrc/jit/codegen/cuda/tensor_view.cpp b/torch/csrc/jit/codegen/cuda/tensor_view.cpp index 71983b1f162c..d4d13a6a1fd7 100644 --- a/torch/csrc/jit/codegen/cuda/tensor_view.cpp +++ b/torch/csrc/jit/codegen/cuda/tensor_view.cpp @@ -35,7 +35,18 @@ TensorView::TensorView( MemoryType mtype) : Val(passkey, ValType::TensorView, dtype), domain_(domain), - memory_type_(mtype) {} + memory_type_(mtype) { + // Don't do this after transforms + if (domain_->domain() == domain_->getRootDomain()) { + // Mark the size-1 axes as broadcast to support implicit broadcast semantic + for (auto* id : domain_->domain()) { + if (!id->isBroadcast() && !id->isReduction() && !id->isGather() && + id->extent()->isOneInt()) { + id->convertToBroadcast(); + } + } + } +} TensorView::TensorView( IrBuilderPasskey passkey, @@ -55,16 +66,14 @@ TensorView::TensorView( if (tensor_type->sizes()[i].has_value() && tensor_type->sizes()[i].value() == 1) { // If size is known to be 1, assuem it needs to be broadcasted. - sizes.push_back( - IterDomainBuilder( - passkey.ir_container_->zeroVal(), passkey.ir_container_->oneVal()) - .iter_type(IterType::Broadcast) - .build()); + sizes.push_back(IrBuilder::create( + passkey.ir_container_->zeroVal(), + passkey.ir_container_->oneVal(), + ParallelType::Serial, + IterType::BroadcastWithStride)); } else { - sizes.push_back( - IterDomainBuilder( - passkey.ir_container_->zeroVal(), IrBuilder::create()) - .build()); + sizes.push_back(IrBuilder::create( + passkey.ir_container_->zeroVal(), IrBuilder::create())); } } // [ Note -- stride_properties in tensor type ] @@ -164,10 +173,13 @@ void TensorView::convertRfactorToRootDomain() { getMaybeRFactorDomain().size()); for (const auto& id : getMaybeRFactorDomain()) { if (replacement_extents[idx] != nullptr) { - new_root_domain[idx] = IterDomainBuilder(id) - .extent(replacement_extents[idx]) - .resetSchedulingParams() - .build(); + new_root_domain[idx] = IrBuilder::create( + container(), + id->start(), + replacement_extents[idx], + id->stopOffset(), + id->getParallelType(), + id->getIterType()); ++idx; } else { TORCH_INTERNAL_ASSERT(!id->isRFactorProduct()); @@ -1093,10 +1105,8 @@ TensorView* TensorViewBuilder::build() const { std::vector domain(ndims_, nullptr); for (const auto i : c10::irange(ndims_)) { if (shape_.empty() || shape_[i] == -1) { - domain[i] = - IterDomainBuilder( - FusionGuard::getCurFusion()->zeroVal(), IrBuilder::create()) - .build(); + domain[i] = IrBuilder::create( + FusionGuard::getCurFusion()->zeroVal(), IrBuilder::create()); } else { TORCH_CHECK( shape_[i] >= 0, @@ -1104,16 +1114,15 @@ TensorView* TensorViewBuilder::build() const { "For a tensor representing a single scalar use ndims = 0 with no sizes set."); if (shape_[i] == 1) { // If size is known to be 1, assume it needs to be broadcasted. - domain[i] = IterDomainBuilder( - FusionGuard::getCurFusion()->zeroVal(), - FusionGuard::getCurFusion()->oneVal()) - .iter_type(IterType::Broadcast) - .build(); + domain[i] = IrBuilder::create( + FusionGuard::getCurFusion()->zeroVal(), + FusionGuard::getCurFusion()->oneVal(), + ParallelType::Serial, + IterType::BroadcastWithStride); } else { - domain[i] = IterDomainBuilder( - FusionGuard::getCurFusion()->zeroVal(), - IrBuilder::create(shape_[i])) - .build(); + domain[i] = IrBuilder::create( + FusionGuard::getCurFusion()->zeroVal(), + IrBuilder::create(shape_[i])); } } } diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index 4fdb04e5760f..9c77d1b4d397 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -5158,9 +5158,19 @@ TEST_F(NVFuserTest, FusionSimpleBCast3_CUDA) { Fusion fusion; FusionGuard fg(&fusion); - // Set up input tensor views + // Set up your input tensor views + std::vector dom; + dom.push_back(IrBuilder::create( + IrBuilder::create(0), IrBuilder::create())); + dom.push_back(IrBuilder::create( + IrBuilder::create(0), + IrBuilder::create(1), + ParallelType::Serial, + IterType::BroadcastWithStride)); + // tv0[I1, B{1}] - TensorView* tv0 = makeConcreteTensor({-1, 1}); + TensorView* tv0 = IrBuilder::create( + IrBuilder::create(dom), DataType::Float); fusion.addInput(tv0); // tv1[I0, I1, I2] @@ -5203,7 +5213,16 @@ TEST_F(NVFuserTest, FusionSimpleBCast4_CUDA) { FusionGuard fg(&fusion); // Set up your input tensor views - TensorView* tv0 = makeConcreteTensor({1, -1}); + std::vector dom; + dom.push_back(IrBuilder::create( + IrBuilder::create(0), + IrBuilder::create(1), + ParallelType::Serial, + IterType::BroadcastWithStride)); + dom.push_back(IrBuilder::create( + IrBuilder::create(0), IrBuilder::create())); + TensorView* tv0 = IrBuilder::create( + IrBuilder::create(dom), DataType::Float); TensorView* tv1 = makeSymbolicTensor(3); fusion.addInput(tv0); @@ -5251,8 +5270,23 @@ TEST_F(NVFuserTest, FusionSimpleBCast5_CUDA) { FusionGuard fg(&fusion); constexpr int m = 2, k = 3, n = 4; - auto tv0 = makeConcreteTensor({m, k}); - auto tv1 = makeConcreteTensor({k, n}); + + auto zero = IrBuilder::create(0); + auto M = IrBuilder::create(zero, IrBuilder::create(m)); + auto K = IrBuilder::create(zero, IrBuilder::create(k)); + auto N = IrBuilder::create(zero, IrBuilder::create(n)); + + // Set up your input tensor views + TensorView* tv0 = IrBuilder::create( + IrBuilder::create( + std::vector({M, K}), std::vector({true, true})), + DataType::Float); + // Note: IterDomain must not be reused, so K needs to be cloned. + TensorView* tv1 = IrBuilder::create( + IrBuilder::create( + std::vector({K->cloneWithoutRFactor(), N}), + std::vector({true, true})), + DataType::Float); fusion.addInput(tv0); fusion.addInput(tv1); @@ -13017,7 +13051,7 @@ TEST_F(NVFuserTest, FusionTranspose1_CUDA) { constexpr int N = 20; auto tv0 = makeSymbolicTensor(2); - auto tv1 = transpose(tv0); + auto tv1 = transpose(tv0, {{0, 1}}); fusion.addInput(tv0); fusion.addOutput(tv1); @@ -13047,7 +13081,7 @@ TEST_F(NVFuserTest, FusionTranspose2_CUDA) { constexpr int N = 20; auto tv0 = makeSymbolicTensor(2); - auto tv1 = transpose(tv0); + auto tv1 = transpose(tv0, {{0, 1}}); fusion.addInput(tv0); fusion.addOutput(tv1); @@ -13083,8 +13117,8 @@ TEST_F(NVFuserTest, FusionSimpleGemmTransposed_CUDA) { fusion.addInput(tv0); fusion.addInput(tv1); - TensorView* tv0_t = transpose(tv0); - TensorView* tv1_t = transpose(tv1); + TensorView* tv0_t = transpose(tv0, {{0, 1}}); + TensorView* tv1_t = transpose(tv1, {{0, 1}}); TensorView* tv2 = broadcast(tv0_t, {false, false, true}); // tv2[I0, I1, B] = tv0[I0, I1] @@ -13170,7 +13204,7 @@ TEST_F(NVFuserTest, FusionSoftmax3DTransposed_CUDA) { TensorView* input_tv0 = makeSymbolicTensor(3); fusion.addInput(input_tv0); - TensorView* input_t = transpose(input_tv0, 1, 2); + TensorView* input_t = transpose(input_tv0, {{1, 2}}); TensorView* exp_tv1 = unaryOp(UnaryOpType::Exp, input_t); TensorView* sum_exp_tv2 = sum(exp_tv1, {-1}); @@ -13178,7 +13212,7 @@ TEST_F(NVFuserTest, FusionSoftmax3DTransposed_CUDA) { // Replicate exp_tv4 as exp_tv4_copy because exp_tv4 is going to be // computed at sum_exp_rf_tv8. - TensorView* input_t_copy = transpose(input_tv0, 1, 2); + TensorView* input_t_copy = transpose(input_tv0, {{1, 2}}); TensorView* exp_tv1_copy = unaryOp(UnaryOpType::Exp, input_t_copy); TensorView* output_tv4 = div(exp_tv1_copy, bcast_sum_tv3); @@ -13235,7 +13269,7 @@ TEST_F(NVFuserTest, FusionAdvancedComputeAtTransposed1_CUDA) { TensorView* tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - tv0 = transpose(tv0); + tv0 = transpose(tv0, {{0, 1}}); TensorView* tv1 = mul(tv0, IrBuilder::create(0.5)); TensorView* tv2 = mul(tv1, IrBuilder::create(-1.0)); @@ -13317,7 +13351,7 @@ TEST_F(NVFuserTest, FusionAdvancedComputeAtTransposed2_CUDA) { TensorView* tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - tv0 = transpose(tv0); + tv0 = transpose(tv0, {{0, 1}}); TensorView* tv1 = mul(tv0, IrBuilder::create(-1.0)); TensorView* tv2 = add(tv0, IrBuilder::create(3.0)); @@ -13379,12 +13413,12 @@ TEST_F(NVFuserTest, FusionAdvancedComputeAtTransposed3_CUDA) { TensorView* tv0 = makeSymbolicTensor(4); fusion.addInput(tv0); - tv0 = permute(tv0, {3, 0, 1, 2}); + tv0 = transpose(tv0, {{0, 1}, {1, 2}, {2, 3}, {3, 0}}); TensorView* tv1 = makeSymbolicTensor(4); fusion.addInput(tv1); - tv1 = permute(tv1, {3, 0, 1, 2}); + tv1 = transpose(tv1, {{0, 1}, {1, 2}, {2, 3}, {3, 0}}); TensorView* tv2 = mul(tv1, IrBuilder::create(.979361)); TensorView* tv3 = mul(tv2, tv0); @@ -13442,22 +13476,22 @@ TEST_F(NVFuserTest, FusionAdvancedComputeAtTransposed4_CUDA) { TensorView* tv0 = makeSymbolicTensor(4); fusion.addInput(tv0); - tv0 = permute(tv0, {3, 0, 1, 2}); + tv0 = transpose(tv0, {{0, 1}, {1, 2}, {2, 3}, {3, 0}}); TensorView* tv1 = makeSymbolicTensor(4); fusion.addInput(tv1); - tv1 = permute(tv1, {3, 0, 1, 2}); + tv1 = transpose(tv1, {{0, 1}, {1, 2}, {2, 3}, {3, 0}}); TensorView* tv2 = makeSymbolicTensor(4); fusion.addInput(tv2); - tv2 = permute(tv2, {3, 0, 1, 2}); + tv2 = transpose(tv2, {{0, 1}, {1, 2}, {2, 3}, {3, 0}}); TensorView* tv3 = makeSymbolicTensor(4); fusion.addInput(tv3); - tv3 = permute(tv3, {3, 0, 1, 2}); + tv3 = transpose(tv3, {{0, 1}, {1, 2}, {2, 3}, {3, 0}}); TensorView* tv4 = sub(tv2, tv3); TensorView* tv5 = add(tv1, tv4); @@ -13522,10 +13556,10 @@ TEST_F(NVFuserTest, FusionAdvancedComputeAtTransposed5_CUDA) { // Set up your input tensor views TensorView* tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - tv0 = transpose(tv0); + tv0 = transpose(tv0, {{0, 1}}); TensorView* tv1 = makeSymbolicTensor(2); fusion.addInput(tv1); - tv1 = transpose(tv1); + tv1 = transpose(tv1, {{0, 1}}); TensorView* tv2 = add(tv0, IrBuilder::create(2.0)); TensorView* tv3 = mul(tv1, tv2); fusion.addOutput(tv3); @@ -13561,10 +13595,10 @@ TEST_F(NVFuserTest, FusionAdvancedComputeAtTransposed6_CUDA) { TensorView* tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - tv0 = transpose(tv0); + tv0 = transpose(tv0, {{0, 1}}); TensorView* tv1 = makeSymbolicTensor(2); fusion.addInput(tv1); - tv1 = transpose(tv1); + tv1 = transpose(tv1, {{0, 1}}); TensorView* tv2 = add(tv0, IrBuilder::create(2.0)); TensorView* tv3 = mul(tv1, tv2); fusion.addOutput(tv3); @@ -13955,7 +13989,7 @@ TEST_F(NVFuserTest, FusionTransposeWithSwizzle_CUDA) { auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = transpose(tv0); + auto tv1 = transpose(tv0, {{0, 1}}); fusion.addOutput(tv1); // tv0: [I0, I1] @@ -14017,7 +14051,7 @@ TEST_F(NVFuserTest, FusionTransposeWithSwizzle1DThreadBlock_CUDA) { auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = transpose(tv0); + auto tv1 = transpose(tv0, {{0, 1}}); fusion.addOutput(tv1); // tv0: [I0, I1] @@ -22877,7 +22911,7 @@ TEST_F(NVFuserTest, FusionExactRootDomainMap_CUDA) { fusion.addInput(tv1); auto tv2 = broadcast(tv0, {false, true}); - auto tv3 = transpose(tv2); + auto tv3 = transpose(tv2, {{0, 1}}); auto tv4 = add(tv2, tv1); auto tv5 = add(tv2, tv3); auto tv6 = add(tv3, tv1); @@ -23360,220 +23394,6 @@ TEST_F(NVFuserTest, FusionRepro1713_CUDA) { __FILE__); } -TEST_F(NVFuserTest, FusionExpand_CUDA) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - auto w = 2, x = 3, y = 4, z = 5; - - // Test - // a simple expand - // Expand that's propagated - // expand_as - // symbolic expand - - // x - auto tv0 = makeSymbolicTensor(1); - fusion->addInput(tv0); - - auto tv1 = broadcast(tv0, {false, true}); - auto tv2 = expand(tv1, {tv0->axis(0)->extent(), IrBuilder::create(y)}); - - // x - auto tv3 = makeSymbolicTensor(1); - fusion->addInput(tv3); - auto tv4 = broadcast(tv3, {false, true}); - auto tv5 = add(tv4, tv2); - // [x, e_y] - - // [x, y, z] - auto tv6 = makeSymbolicTensor(3); - fusion->addInput(tv6); - - // Disjoint set op will cause a segmentation for just this op. - auto tmp_7 = set(tv6); - fusion->addOutput(tmp_7); - - auto tv7 = broadcast(tv5, {false, false, true}); - - auto tv8 = expand_as(tv7, tv6); - // [x, e_y, e_z] - - auto w_symbolic = IrBuilder::create(); - fusion->addInput(w_symbolic); - - auto tv9 = broadcast(tv8, {true, false, false, false}); - //[1, x, e_y, e_z] - - auto tv10 = expand( - tv9, - {w_symbolic, - tv9->axis(1)->extent(), - tv9->axis(2)->expandedExtent(), - tv9->axis(3)->expandedExtent()}); - - fusion->addOutput(tv10); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({x}, options); - at::Tensor t3 = at::randn({x}, options); - at::Tensor t6 = at::randn({x, y, z}, options); - - FusionExecutorCache executor_cache(std::move(fusion)); - - auto cg_outputs = executor_cache.runFusionWithInputs({t0, t3, t6, w}); - auto cg_out = cg_outputs[1]; - - TORCH_INTERNAL_ASSERT(cg_out.size(0) == w); - TORCH_INTERNAL_ASSERT(cg_out.size(1) == x); - TORCH_INTERNAL_ASSERT(cg_out.size(2) == y); - TORCH_INTERNAL_ASSERT(cg_out.size(3) == z); - TORCH_INTERNAL_ASSERT(cg_out.stride(0) == 0); - TORCH_INTERNAL_ASSERT(cg_out.stride(1) == 1); - TORCH_INTERNAL_ASSERT(cg_out.stride(2) == 0); - TORCH_INTERNAL_ASSERT(cg_out.stride(3) == 0); - - auto t10 = t0.unsqueeze(-1) - .expand({x, y}) - .add(t3.unsqueeze(-1)) - .unsqueeze(-1) - .expand_as(t6) - .unsqueeze(0) - .expand({w, x, y, z}); - - testValidate( - executor_cache.fusion(), - cg_outputs, - {t0, t3, t6, w}, - {t6, t10}, - __LINE__, - __FILE__); -} - -TEST_F(NVFuserTest, FusionExpandIssue1751_CUDA) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - auto x = 3, y = 4, z = 5; - - // y, z - auto tv0 = makeSymbolicTensor(2); - fusion->addInput(tv0); - - auto tv1 = broadcast(tv0, {true, false, false}); - - // Two ways to propagate extents as is: use -1 or explicitly pass - // the extent vals. - - auto tv2 = expand( - tv1, - {IrBuilder::create(x), - IrBuilder::create(-1), - IrBuilder::create(-1)}); - - auto tv3 = expand( - tv1, - {IrBuilder::create(x), - tv0->axis(0)->extent(), - tv0->axis(1)->extent()}); - - fusion->addOutput(tv2); - fusion->addOutput(tv3); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({y, z}, options); - - FusionExecutorCache executor_cache(std::move(fusion)); - - auto cg_outputs = executor_cache.runFusionWithInputs({t0}); - - for (const auto& cg_out : cg_outputs) { - TORCH_INTERNAL_ASSERT(cg_out.size(0) == x); - TORCH_INTERNAL_ASSERT(cg_out.size(1) == y); - TORCH_INTERNAL_ASSERT(cg_out.size(2) == z); - } - - auto t2 = t0.expand({x, y, z}); - - testValidate( - executor_cache.fusion(), cg_outputs, {t0}, {t2, t2}, __LINE__, __FILE__); -} - -// TODO: Make sure the kernel uses the expanded concrete size instead -// of the symbolic size -TEST_F(NVFuserTest, FusionExpandToConcrete_CUDA) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - auto x = 3, y = 4; - - auto tv0 = makeSymbolicTensor(1); - fusion->addInput(tv0); - - auto tv1 = broadcast(tv0, {true, false}); - - auto tv2 = - expand(tv1, {IrBuilder::create(x), IrBuilder::create(y)}); - - fusion->addOutput(tv2); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({y}, options); - - FusionExecutorCache executor_cache(std::move(fusion)); - - auto cg_outputs = executor_cache.runFusionWithInputs({t0}); - - for (const auto& cg_out : cg_outputs) { - TORCH_INTERNAL_ASSERT(cg_out.size(0) == x); - TORCH_INTERNAL_ASSERT(cg_out.size(1) == y); - } - - auto t2 = t0.expand({x, y}); - - testValidate( - executor_cache.fusion(), cg_outputs, {t0}, {t2}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionReproNoncontigBroadcast_CUDA) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({4, 32, 16, 112, 112}, options).transpose(-1, -2); - at::Tensor t1 = at::randn({32, 1, 112, 1}, options).transpose(-1, -2); - - auto tv0 = TensorViewBuilder() - .ndims(5) - .contiguity({true, true, false, false, false}) // ttfff - .shape({-1, -1, -1, -1, -1}) - .dtype(DataType::Half) - .build(); - auto tv1 = TensorViewBuilder() - .ndims(4) - .contiguity({true, false, false, true}) // tfft - .shape({-1, 1, 1, -1}) - .dtype(DataType::Half) - .build(); - - fusion->addInput(tv0); - fusion->addInput(tv1); - - auto tv2 = add(tv0, tv1); - - fusion->addOutput(tv2); - - std::vector aten_inputs({t0, t1}); - - FusionExecutorCache executor_cache(std::move(fusion)); - auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); - - auto t2 = t0 + t1; - - testValidate( - executor_cache.fusion(), cg_outputs, {t0, t1}, {t2}, __LINE__, __FILE__); -} - } // namespace jit } // namespace torch #endif // #if defined(USE_CUDA) diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp index 154815927a55..87e665526001 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp @@ -114,7 +114,7 @@ void validateNoParallelBroadcastExist(kir::Kernel* kernel) { } // namespace -TEST_F(NVFuserTest, FusionGridAllreduce1_CUDA) { +TEST_F(NVFuserTest, FusionReduceAndBroadcast1_CUDA) { const int nx = 999; const int tidx = 128; @@ -157,7 +157,7 @@ TEST_F(NVFuserTest, FusionGridAllreduce1_CUDA) { testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); } -TEST_F(NVFuserTest, FusionGridAllreduce2_CUDA) { +TEST_F(NVFuserTest, FusionReduceAndBroadcast2_CUDA) { const int nx = 99; const int tidx = 32; @@ -208,7 +208,7 @@ TEST_F(NVFuserTest, FusionGridAllreduce2_CUDA) { // Grid reduction with serial non-reduction axis. The global work // buffer is double buffered. -TEST_F(NVFuserTest, FusionGridAllreduce3_CUDA) { +TEST_F(NVFuserTest, FusionReduceAndBroadcast3_CUDA) { const int nx = 100; const int ny = 5000; const int tidx = 128; @@ -255,7 +255,7 @@ TEST_F(NVFuserTest, FusionGridAllreduce3_CUDA) { } // Indirect reduction and broadcast -TEST_F(NVFuserTest, FusionGridAllreduce4_CUDA) { +TEST_F(NVFuserTest, FusionReduceAndBroadcast4_CUDA) { const int nx = 999; const int tidx = 128; @@ -300,7 +300,7 @@ TEST_F(NVFuserTest, FusionGridAllreduce4_CUDA) { } // Unused block dimension in the kernel -TEST_F(NVFuserTest, FusionGridAllreduce5_CUDA) { +TEST_F(NVFuserTest, FusionReduceAndBroadcast5_CUDA) { const int nx = 999; const int tidx = 128; const int iter = 2; @@ -361,60 +361,7 @@ TEST_F(NVFuserTest, FusionGridAllreduce5_CUDA) { testValidate(&fusion, cg_outputs, {t0, t5}, {ref, t5}, __LINE__, __FILE__); } -TEST_F(NVFuserTest, FusionGridAllreduce6_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - std::vector shape({99, 200}); - - const int vec = 4; - const int tidx = 32; - const int tidy = 8; - const int bdimx = ceilDiv(shape[1], vec * tidx); - const int bdimy = ceilDiv(shape[0], tidy); - - if (bdimx * bdimy > deviceSMCount()) { - GTEST_SKIP() << "Not enough SMs to run this test"; - } - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - auto tv1 = set(tv0); - auto tv2 = sum(tv1, {0}); - auto tv3 = broadcast(tv2, {true, false}); - auto tv4 = add(tv0, tv3); - fusion.addOutput(tv4); - - tv1->split(1, vec); - tv1->split(1, tidx); - tv1->split(0, tidy); - TransformPropagator::from(tv1); - - tv1->axis(0)->parallelize(ParallelType::BIDy); - tv1->axis(1)->parallelize(ParallelType::TIDy); - tv1->axis(2)->parallelize(ParallelType::BIDx); - tv1->axis(3)->parallelize(ParallelType::TIDx); - - scheduler_utils::parallelizeAllLike(tv1, ir_utils::allTvs(&fusion)); - - tv1->axis(4)->parallelize(ParallelType::Vectorize); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); - auto t0 = at::randn(shape, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0}); - auto outputs = fe.runFusion({t0}); - - auto t0_double = t0.to(at::kDouble); - auto ref = t0_double + t0_double.sum({0}).unsqueeze(0); - - testValidate(fe.kernel(), outputs, {t0}, {ref}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionGridAllreduceWelford1_CUDA) { +TEST_F(NVFuserTest, FusionWelfordAndBroadcast1_CUDA) { const int nx = 999; const int tidx = 128; @@ -462,7 +409,7 @@ TEST_F(NVFuserTest, FusionGridAllreduceWelford1_CUDA) { // Grid welford reduction with serial non-reduction axis. The global // work buffer is double buffered. -TEST_F(NVFuserTest, FusionGridAllreduceWelford2_CUDA) { +TEST_F(NVFuserTest, FusionWelfordAndBroadcast2_CUDA) { const int nx = 100; const int ny = 5000; const int tidx = 128; @@ -1284,8 +1231,6 @@ TEST_F(NVFuserTest, FusionPersistentBNBackwardAllreduce_CUDA) { reduction_axes.end(), std::back_inserter(at_reduction_axes)); - // MSVC bug on lambda non-capture of const integral type - // https://developercommunity.visualstudio.com/t/lambda-fails-to-implicitly-capture-constexpr-value/610504 auto at_bcast = [=](const auto& tensor) { if (channels_last) { tensor.unsqueeze(0).unsqueeze(0).unsqueeze(0); @@ -1327,320 +1272,6 @@ TEST_F(NVFuserTest, FusionPersistentBNBackwardAllreduce_CUDA) { fe.kernel(), outputs, aten_inputs, {at_grad_input}, __LINE__, __FILE__); } -TEST_F(NVFuserTest, FusionGroupedReductionReEntrant1_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = sum(tv1, {0}); - - auto tv3 = add(tv0, IrBuilder::create(2)); - auto tv4 = sum(tv3, {0}); - - auto tv5 = add(tv2, tv4); - fusion.addOutput(tv5); - - groupReductions({tv2, tv4}); - - auto tv0_cache = tv0->cacheAfter(); - - const int vec = 2; - const int tidx = 64; - const int tidy = 8; - - tv2->split(1, vec); - tv2->split(1, tidx); - - tv2->split(0, tidy); - TransformPropagator::from(tv2); - - tv0_cache->axis(-1)->parallelize(ParallelType::Vectorize); - - tv0->computeAt(tv4, -1, ComputeAtMode::MostInlined); - - tv2->axis(0)->parallelize(ParallelType::BIDy); - tv2->axis(1)->parallelize(ParallelType::TIDy); - tv2->axis(2)->parallelize(ParallelType::BIDx); - tv2->axis(3)->parallelize(ParallelType::TIDx); - - scheduler_utils::parallelizeAllLike(tv2, ir_utils::allTvs(&fusion)); - - std::vector shape({99, 999}); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); - - auto t0 = at::randn(shape, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0}); - auto outputs = fe.runFusion({t0}); - - auto t0_double = t0.to(at::kDouble); - auto ref = (t0_double + 1).sum({0}) + (t0_double + 2).sum({0}); - - testValidate(fe.kernel(), outputs, {t0}, {ref}, __LINE__, __FILE__); -} - -// Channels-last batch norm with vectorization. Relies on re-entrant -// GroupedGridReduction -TEST_F(NVFuserTest, FusionGroupedReductionChannelsLastBatchNormLike_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - const std::vector shape({64, 14, 14, 32}); - - auto tv0 = makeContigTensor(4, DataType::Half); - fusion.addInput(tv0); - auto tv1 = makeContigTensor(4, DataType::Half); - fusion.addInput(tv1); - auto tv2 = makeContigTensor(1); - fusion.addInput(tv2); - - std::vector reduction_axes({0, 1, 2}); - std::vector broadcast_mask({true, true, true, false}); - - auto tv3 = castOp(DataType::Float, tv0); - auto tv4 = castOp(DataType::Float, tv1); - - auto tv5 = sum(tv3, reduction_axes); - - auto tv6 = broadcast(tv2, broadcast_mask); - auto tv7 = sub(tv4, tv6); - auto tv8 = mul(tv3, tv7); - auto tv9 = sum(tv8, reduction_axes); - - auto tv10 = castOp(DataType::Half, tv5); - auto tv11 = castOp(DataType::Half, tv9); - - fusion.addOutput(tv10); - fusion.addOutput(tv11); - - groupReductions({tv5, tv9}); - - // Applies the outer-reduction schedule - const int64_t num_channels = shape.back(); - const int64_t vector = 2; - TORCH_CHECK(num_channels % vector == 0); - // Use at most 32 TIDx threads - const int64_t tidx = std::min(32l, num_channels / vector); - const auto bidx = ceilDiv(num_channels, tidx * vector); - - const int64_t tidy = 8; - const auto bidy = ceilDiv( - at::cuda::getCurrentDeviceProperties()->multiProcessorCount * 4, bidx); - - auto tv0_cache = tv0->cacheAfter(); - auto tv1_cache = tv1->cacheAfter(); - - auto ref = tv5; - - // Move the reduction domains inner positions - ref->reorder({{0, 1}, {1, 2}, {2, 3}, {3, 0}}); - - // Parallelizing the reduction domains - ref->merge(2, 3); - ref->merge(1, 2); - ref->split(1, tidy); - ref->split(1, bidy, false); - - // Parallelizing the iteration domains - ref->split(0, vector); - ref->split(0, tidx); - - // Move the vector axis to the innermost position - ref->reorder({{2, 5}, {3, 2}, {4, 3}, {5, 4}}); - // Move the serial reduction to the right of the vector axis - ref->reorder({{3, 4}, {4, 3}}); - - TransformPropagator::from(ref); - - auto rf_tvs = tv5->rFactor({-2}, {tv5, tv9}); - auto tv5_rf = rf_tvs.at(0); - auto tv9_rf = rf_tvs.at(1); - - tv0->computeAt(tv5_rf, -2, ComputeAtMode::BestEffort); - tv1->computeAt(tv9_rf, -2, ComputeAtMode::BestEffort); - tv3->computeAt(tv5_rf, -1, ComputeAtMode::BestEffort); - tv4->computeAt(tv9_rf, -1, ComputeAtMode::BestEffort); - - ref = tv5_rf; - - ref->axis(0)->parallelize(ParallelType::BIDx); - ref->axis(1)->parallelize(ParallelType::TIDx); - ref->axis(2)->parallelize(ParallelType::BIDy); - ref->axis(3)->parallelize(ParallelType::TIDy); - ref->axis(4)->parallelize(ParallelType::Serial); - ref->axis(5)->parallelize(ParallelType::Serial); - - scheduler_utils::parallelizeAllLike(ref, ir_utils::allTvs(&fusion)); - - tv0_cache->axis(-1)->parallelize(ParallelType::Vectorize); - tv1_cache->axis(-1)->parallelize(ParallelType::Vectorize); - - auto options_half = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto options_float = - at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto t0 = at::randn(shape, options_half); - auto t1 = at::randn(shape, options_half); - auto t2 = at::randn({shape.back()}, options_float); - std::vector aten_inputs({t0, t1, t2}); - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto outputs = fe.runFusion(aten_inputs); - - auto t0_double = t0.to(at::kDouble); - auto t1_double = t1.to(at::kDouble); - auto t2_double = t2.to(at::kDouble); - - std::vector at_reduction_axes( - {reduction_axes.begin(), reduction_axes.end()}); - auto t5 = t0_double.sum(at_reduction_axes); - auto t8 = t0_double * - (t1_double - t2_double.unsqueeze(0).unsqueeze(0).unsqueeze(0)); - auto t9 = t8.sum(at_reduction_axes); - - testValidate(fe.kernel(), outputs, aten_inputs, {t5, t9}, __LINE__, __FILE__); -} - -// Test the grouped grid allreduce with BN-like outer reductions -TEST_F( - NVFuserTest, - FusionGroupedReductionPersistentChannelsLastBatchNormLike_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - const std::vector shape({64, 14, 14, 32}); - - auto tv0 = makeContigTensor(4, DataType::Half); - fusion.addInput(tv0); - auto tv1 = makeContigTensor(4, DataType::Half); - fusion.addInput(tv1); - auto tv2 = makeContigTensor(1); - fusion.addInput(tv2); - - std::vector reduction_axes({0, 1, 2}); - std::vector broadcast_mask({true, true, true, false}); - - auto tv3 = castOp(DataType::Float, tv0); - auto tv4 = castOp(DataType::Float, tv1); - - auto tv5 = sum(tv3, reduction_axes); - - auto tv6 = broadcast(tv2, broadcast_mask); - auto tv7 = sub(tv4, tv6); - auto tv8 = mul(tv3, tv7); - auto tv9 = sum(tv8, reduction_axes); - - auto tv10 = broadcast(tv5, broadcast_mask); - auto tv11 = add(tv3, tv10); - - auto tv12 = broadcast(tv9, broadcast_mask); - auto tv13 = add(tv4, tv12); - - auto tv14 = castOp(DataType::Half, tv11); - auto tv15 = castOp(DataType::Half, tv13); - - fusion.addOutput(tv14); - fusion.addOutput(tv15); - - groupReductions({tv5, tv9}); - - // Applies the outer-reduction schedule - const int64_t num_channels = shape.back(); - const int64_t vector = 2; - TORCH_CHECK(num_channels % vector == 0); - // Use at most 32 TIDx threads - const int64_t tidx = std::min(32l, num_channels / vector); - const auto bidx = ceilDiv(num_channels, tidx * vector); - - const int64_t tidy = 8; - const int64_t reduction_work_per_thread = 8; - - auto tv0_cache = tv0->cacheAfter(); - auto tv1_cache = tv1->cacheAfter(); - - auto ref = tv5; - - // Move the reduction domains inner positions - ref->reorder({{0, 1}, {1, 2}, {2, 3}, {3, 0}}); - - // Parallelizing the reduction domains - ref->merge(2, 3); - ref->merge(1, 2); - ref->split(1, tidy); - ref->split(1, reduction_work_per_thread); - - // Parallelizing the iteration domains - ref->split(0, vector); - ref->split(0, tidx); - - // Move the vector axis to the innermost position - ref->reorder({{2, 5}, {3, 2}, {4, 3}, {5, 4}}); - // Move the serial reduction to the right of the vector axis - ref->reorder({{3, 4}, {4, 3}}); - - TransformPropagator::from(ref); - - auto rf_tvs = tv5->rFactor({-2}, {tv5, tv9}); - auto tv5_rf = rf_tvs.at(0); - auto tv9_rf = rf_tvs.at(1); - - tv0->computeAt(tv5_rf, -2, ComputeAtMode::BestEffort); - tv1->computeAt(tv9_rf, -2, ComputeAtMode::BestEffort); - tv3->computeAt(tv5_rf, -1, ComputeAtMode::BestEffort); - tv4->computeAt(tv9_rf, -1, ComputeAtMode::BestEffort); - - ref = tv5_rf; - - ref->axis(0)->parallelize(ParallelType::BIDx); - ref->axis(1)->parallelize(ParallelType::TIDx); - ref->axis(2)->parallelize(ParallelType::BIDy); - ref->axis(3)->parallelize(ParallelType::TIDy); - ref->axis(4)->parallelize(ParallelType::Serial); - ref->axis(5)->parallelize(ParallelType::Serial); - - scheduler_utils::parallelizeAllLike(ref, ir_utils::allTvs(&fusion)); - - tv0_cache->axis(-1)->parallelize(ParallelType::Vectorize); - tv1_cache->axis(-1)->parallelize(ParallelType::Vectorize); - - auto options_half = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto options_float = - at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto t0 = at::randn(shape, options_half); - auto t1 = at::randn(shape, options_half); - auto t2 = at::randn({shape.back()}, options_float); - std::vector aten_inputs({t0, t1, t2}); - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto outputs = fe.runFusion(aten_inputs); - - auto t0_double = t0.to(at::kDouble); - auto t1_double = t1.to(at::kDouble); - auto t2_double = t2.to(at::kDouble); - - std::vector at_reduction_axes( - {reduction_axes.begin(), reduction_axes.end()}); - auto t5 = t0_double.sum(at_reduction_axes); - auto t8 = t0_double * - (t1_double - t2_double.unsqueeze(0).unsqueeze(0).unsqueeze(0)); - auto t9 = t8.sum(at_reduction_axes); - - auto t10 = t5.unsqueeze(0).unsqueeze(0).unsqueeze(0); - auto t11 = t0_double + t10; - auto t12 = t9.unsqueeze(0).unsqueeze(0).unsqueeze(0); - auto t13 = t1_double + t12; - - testValidate( - fe.kernel(), outputs, aten_inputs, {t11, t13}, __LINE__, __FILE__); -} - } // namespace jit } // namespace torch #endif // #if defined(USE_CUDA) diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_shift.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_shift.cpp index d4ebe51c893e..6e9225d29237 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_shift.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_shift.cpp @@ -21,7 +21,6 @@ #include #include #include -#include #include #include #include @@ -3722,7 +3721,7 @@ TEST_F(NVFuserTest, FusionIm2Col_CUDA) { auto inp_tile = gather(inp, {1, 1, 3, 3}, {{0, 0}, {0, 0}, {1, 1}, {1, 1}}); // inp_tile: [N, C, H, W, 1, 1, 3, 3] - auto inp_col = permute(inp_tile, {0, 2, 3, 1, 4, 5, 6, 7}); + auto inp_col = transpose(inp_tile, {{1, 3}, {2, 1}, {3, 2}}); // inp_col: [N, H, W, C, 1, 1, 3, 3] fusion.addOutput(inp_col); diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_validator.h b/torch/csrc/jit/codegen/cuda/test/test_gpu_validator.h index 0247c33c8a72..6708248bf730 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_validator.h +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_validator.h @@ -5,8 +5,6 @@ #include #include -#include - #include namespace torch { @@ -454,17 +452,6 @@ inline void testValidate( } } -inline void clearL2Cache() { - torch::NoGradGuard no_grad; - auto l2_cache_size = at::cuda::getCurrentDeviceProperties()->l2CacheSize; - auto options = - torch::TensorOptions().dtype(torch::kFloat32).device(at::kCUDA, 0); - - auto l2_elems = l2_cache_size / 4; - torch::Tensor t0 = torch::empty(l2_elems, options); - torch::Tensor t1 = torch::clone(t0); -}; - } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp index d71efd00ed0f..6ab0df7b47cb 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp @@ -51,18 +51,28 @@ class ReplaySelf : public ReplayTransformations { // Manually replay the split, following the output of the operations. // This is so rfactor ops are replayed correctly. - IterDomain* ido = - IterDomainBuilder(s->outer()) - .start(s->container()->zeroVal()) - .extent(s->innerSplit() ? remainder->as() : s->factor()) - .build(); + IterDomain* ido = IrBuilder::create( + s->container(), + s->container()->zeroVal(), + s->innerSplit() ? remainder->as() : s->factor(), + s->outer()->getParallelType(), + s->outer()->getIterType(), + s->outer()->isRFactorProduct(), + s->outer()->hasPaddingToMultipleOfWarp(), + s->outer()->getMaybeSizeAfterPadding(), + s->outer()->isMmaSwizzled()); // inner IterDomain - IterDomain* idi = - IterDomainBuilder(s->inner()) - .start(s->container()->zeroVal()) - .extent(s->innerSplit() ? s->factor() : remainder->as()) - .build(); + IterDomain* idi = IrBuilder::create( + s->container(), + s->container()->zeroVal(), + s->innerSplit() ? s->factor() : remainder->as(), + s->inner()->getParallelType(), + s->inner()->getIterType(), + s->inner()->isRFactorProduct(), + s->outer()->hasPaddingToMultipleOfWarp(), + s->outer()->getMaybeSizeAfterPadding(), + s->outer()->isMmaSwizzled()); // Generate the split node IrBuilder::create( @@ -113,10 +123,16 @@ class ReplaySelf : public ReplayTransformations { Val* merged_id_size = mul(id_outer_mapped->extent(), id_inner_mapped->extent()); - IterDomain* merged_id = IterDomainBuilder(m->out()) - .start(m->container()->zeroVal()) - .extent(merged_id_size->as()) - .build(); + IterDomain* merged_id = IrBuilder::create( + m->container(), + m->container()->zeroVal(), + merged_id_size->as(), + m->out()->getParallelType(), + m->outer()->getIterType(), + m->out()->isRFactorProduct(), + m->out()->hasPaddingToMultipleOfWarp(), + m->out()->getMaybeSizeAfterPadding(), + m->out()->isMmaSwizzled()); IrBuilder::create( m->container(), merged_id, id_outer_mapped, id_inner_mapped); diff --git a/torch/csrc/jit/codegen/cuda/transform_rfactor.cpp b/torch/csrc/jit/codegen/cuda/transform_rfactor.cpp index dc5973c0ecd6..51478ae909c2 100644 --- a/torch/csrc/jit/codegen/cuda/transform_rfactor.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_rfactor.cpp @@ -117,26 +117,24 @@ class ReplayRFactor : public ReplayTransformations { // Manually replay the split, making reduction = false and rfactor = true // outer IterDomain - IterDomain* ido = - IterDomainBuilder( - s->container()->zeroVal(), - s->innerSplit() ? remainder->as() : s->factor()) - .iter_type( - rfactor_axes_.count(s->outer()) ? IterType::Reduction - : IterType::Iteration) - .is_rfactor_domain(static_rfactor_outputs) - .build(); + IterDomain* ido = IrBuilder::create( + s->container(), + IrBuilder::create(s->container(), 0), + s->innerSplit() ? remainder->as() : s->factor(), + ParallelType::Serial, + rfactor_axes_.count(s->outer()) ? IterType::Reduction + : IterType::Iteration, + static_rfactor_outputs); // inner IterDomain - IterDomain* idi = - IterDomainBuilder( - s->container()->zeroVal(), - s->innerSplit() ? s->factor() : remainder->as()) - .iter_type( - rfactor_axes_.count(s->inner()) ? IterType::Reduction - : IterType::Iteration) - .is_rfactor_domain(static_rfactor_outputs) - .build(); + IterDomain* idi = IrBuilder::create( + s->container(), + IrBuilder::create(s->container(), 0), + s->innerSplit() ? s->factor() : remainder->as(), + ParallelType::Serial, + rfactor_axes_.count(s->inner()) ? IterType::Reduction + : IterType::Iteration, + static_rfactor_outputs); // Generate the split node IrBuilder::create( @@ -181,13 +179,14 @@ class ReplayRFactor : public ReplayTransformations { Val* merged_id_size = mul(id_outer_mapped->extent(), id_inner_mapped->extent()); - IterDomain* merged_id = - IterDomainBuilder(m->container()->zeroVal(), merged_id_size->as()) - .iter_type( - rfactor_axes_.count(m->out()) ? IterType::Reduction - : IterType::Iteration) - .is_rfactor_domain(static_rfactor_ids_.count(m->out())) - .build(); + IterDomain* merged_id = IrBuilder::create( + m->container(), + IrBuilder::create(m->container(), 0), + merged_id_size->as(), + ParallelType::Serial, + rfactor_axes_.count(m->out()) ? IterType::Reduction + : IterType::Iteration, + static_rfactor_ids_.count(m->out())); IrBuilder::create( m->container(), merged_id, id_outer_mapped, id_inner_mapped); @@ -330,17 +329,25 @@ std::pair TransformRFactor::runReplay( auto id = original_td_root[i]; // If this is an rfactor root, it will be a reduction in this stage if (rfactor_root_axes.find(id) != rfactor_root_axes.end()) { - new_producer_root[i] = IterDomainBuilder(id->start(), id->extent()) - .stop_offset(id->stopOffset()) - .iter_type(IterType::Reduction) - .is_rfactor_domain(true) - .build(); + new_producer_root[i] = IrBuilder::create( + id->container(), + id->start(), + id->extent(), + id->stopOffset(), + ParallelType::Serial, + IterType::Reduction, + true); // If this is not an rfactor root, but a reduction root, it should be // turned into an iteration domain } else if (id->isReduction()) { - new_producer_root[i] = IterDomainBuilder(id->start(), id->extent()) - .stop_offset(id->stopOffset()) - .build(); + new_producer_root[i] = IrBuilder::create( + id->container(), + id->start(), + id->extent(), + id->stopOffset(), + ParallelType::Serial, + IterType::Iteration, + false); } else { new_producer_root[i] = id->cloneWithoutRFactor(); } @@ -432,11 +439,13 @@ std::pair TransformRFactor::runReplay( p2o_it != producer_to_original_map.end(), "Missing mapping from original tensor domain to producer tensor domain."); auto original_id = p2o_it->second; - auto new_consumer_root = - IterDomainBuilder(original_id->start(), original_id->extent()) - .stop_offset(original_id->stopOffset()) - .iter_type(original_id->getIterType()) - .build(); + auto new_consumer_root = IrBuilder::create( + original_id->container(), + original_id->start(), + original_id->extent(), + original_id->stopOffset(), + ParallelType::Serial, + original_id->getIterType()); new_consumer_root_domain.push_back(new_consumer_root); original_to_consumer_root_map[original_id] = new_consumer_root; } diff --git a/torch/csrc/jit/codegen/cuda/transform_view.cpp b/torch/csrc/jit/codegen/cuda/transform_view.cpp index e6c7f381a1a8..290875c9c6d9 100644 --- a/torch/csrc/jit/codegen/cuda/transform_view.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_view.cpp @@ -133,10 +133,12 @@ class MergeTransform final : public ViewTransform { auto merged_extent = mul(merged_id->extent(), new_root_domain[index_ + 1]->extent()); - auto new_merged_id = - IterDomainBuilder(FusionGuard::getCurFusion()->zeroVal(), merged_extent) - .is_rfactor_domain(true) - .build(); + auto new_merged_id = IrBuilder::create( + FusionGuard::getCurFusion()->zeroVal(), + merged_extent, + ParallelType::Serial, + IterType::Iteration, + true); IrBuilder::create( new_merged_id, merged_id, new_root_domain[index_ + 1]); @@ -192,19 +194,20 @@ class SplitTransform final : public ViewTransform { Val* remainder = ceilDiv(id->extent(), factor); // outer loop IterDomain - IterDomain* factor_id = - IterDomainBuilder(FusionGuard::getCurFusion()->zeroVal(), factor) - .parallel_type(id->getParallelType()) - .iter_type(id->getIterType()) - .is_rfactor_domain(true) - .build(); + IterDomain* factor_id = IrBuilder::create( + FusionGuard::getCurFusion()->zeroVal(), + factor, + id->getParallelType(), + id->getIterType(), + true); // inner loop IterDomain - IterDomain* remainder_id = - IterDomainBuilder( - FusionGuard::getCurFusion()->zeroVal(), remainder->as()) - .is_rfactor_domain(true) - .build(); + IterDomain* remainder_id = IrBuilder::create( + FusionGuard::getCurFusion()->zeroVal(), + remainder->as(), + ParallelType::Serial, + IterType::Iteration, + true); IrBuilder::create(factor_id, remainder_id, id, factor, false); diff --git a/torch/csrc/jit/codegen/cuda/type.cpp b/torch/csrc/jit/codegen/cuda/type.cpp index 5c06287f90cb..c05a630871eb 100644 --- a/torch/csrc/jit/codegen/cuda/type.cpp +++ b/torch/csrc/jit/codegen/cuda/type.cpp @@ -310,8 +310,6 @@ static const char* expr_type2string(ExprType t) { return "MmaOp"; case ExprType::TransposeOp: return "TransposeOp"; - case ExprType::ExpandOp: - return "ExpandOp"; case ExprType::ShiftOp: return "ShiftOp"; case ExprType::GatherOp: @@ -719,7 +717,9 @@ static const char* iter_type2string(IterType t) { return "i"; case IterType::Reduction: return "r"; - case IterType::Broadcast: + case IterType::BroadcastWithStride: + return "sb"; + case IterType::BroadcastWithoutStride: return "b"; case IterType::Gather: return "g"; diff --git a/torch/csrc/jit/codegen/cuda/type.h b/torch/csrc/jit/codegen/cuda/type.h index 38f22c308ce2..8a72f48cd9eb 100644 --- a/torch/csrc/jit/codegen/cuda/type.h +++ b/torch/csrc/jit/codegen/cuda/type.h @@ -111,7 +111,6 @@ enum class ExprType { WelfordOp, MmaOp, TransposeOp, - ExpandOp, ShiftOp, GatherOp, ViewOp, @@ -287,7 +286,8 @@ enum class MemoryType { Local, Shared, Global }; enum class IterType { Iteration, Reduction, - Broadcast, + BroadcastWithStride, + BroadcastWithoutStride, Gather, Stride, VectorComponent