From 8c18701e252f279292f9ae1c926b7e919b3ebd17 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Tue, 14 May 2024 01:03:59 -0700 Subject: [PATCH] Allocation order refactor (#2168) refactored allocation order inference pass: * Instead of per operation propagation rule, we are now using IdModel mapping to map allocation domain of reference tensor to rfactor domain of target tensor. * Updated the inference API to allow specified sources and destinations for the propagation. ``` void inferenceAllocationOrder( Fusion* fusion, const std::vector& srcs, const std::vector& dsts); ``` * The propagation tried to keep the memory format of `dsts` closer to the `srcs` to simplify scheduling as well as facilitate vectorization. It works roughly as: * For each entry `dst`, among all its producers in `srcs`, we'll find the one with the most loop iter domain in its allocation domain as the reference `ref` * We try to map each iter domain in `dst`'s rfactor domain to `ref`'s allocation order domain and push those as the inner dimension in `dst`'s new allocation domain, while pushing unmapped iter domains as outer dimensions. * I have to put in a WAR for the mapping logic for now, since reduction scheduler is struggling with permuted output. See issue #2202. The WAR is simply to preserve the existing position of reduction iter domain in rfactor the same as it would be in its new allocation domain. This WAR is supposed to be removed at a later point once we fixed reduction scheduler. I kept both code path in the PR for easier future cleanup. --------- Co-authored-by: Naoya Maruyama Co-authored-by: Jingyue Wu --- .../allocation_order_inference.cpp | 649 +++++++----------- .../allocation_order_inference.h | 21 +- tests/cpp/test_allocation_order_inference.cpp | 162 ++--- tests/cpp/test_gather.cpp | 3 + tests/cpp/test_gpu_transpose.cpp | 10 +- 5 files changed, 328 insertions(+), 517 deletions(-) diff --git a/csrc/preseg_passes/allocation_order_inference.cpp b/csrc/preseg_passes/allocation_order_inference.cpp index f6a2eb99792..df4f5415368 100644 --- a/csrc/preseg_passes/allocation_order_inference.cpp +++ b/csrc/preseg_passes/allocation_order_inference.cpp @@ -5,6 +5,7 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on +#include #include #include #include @@ -15,442 +16,292 @@ namespace nvfuser::preseg_passes { namespace { -// performs permutation by `alloc_order` on `tv`'s rfactor_domain. -std::vector constructAllocationDomain( - TensorView* tv, - const AllocationOrder& alloc_order) { - auto rfactor_dom = tv->getMaybeRFactorDomain(); - auto rank = rfactor_dom.size(); - - std::vector allocation_domain(rank, nullptr); - // specify allocation domain with dimension per allocation order. - for (auto i : c10::irange(rank)) { - allocation_domain[i] = rfactor_dom.at(alloc_order.at(i)); - } - - return allocation_domain; +// counting the number of non-broadcast & non-reduction iter domains in tv's +// allocation domain. +int64_t countNonTrivialIterDomains(const TensorView* tv) { + return std::count_if( + tv->getMaybeAllocationDomain().begin(), + tv->getMaybeAllocationDomain().end(), + [&](auto* ptr_id) { + return !ptr_id->isBroadcast() && !ptr_id->isReduction(); + }); } -// NOTE: [Allocation Order Inference] +// Note [ Allocation Order Mapping ] // -// AllocationOrderInferencer ctor takes a map of allocation order for inputs as -// `unordered_map`. It propagates -// AllocationOrder on a fusion and updates the the map with allocation order for -// other TensorView in the fusion. +// Map allocation domain from ref to target's rfactor domain to construct a new +// allocation domain for target. The objective is to have target in a similar +// memory format as with ref. // -// e.g. -// std::unordered_map alloc_order_map; -// // ... update alloc_order_map with AllocationOrder for tensors -// // (i.e. usually inputs) +// The propagation rule explained in an example, given inputs: +// ref's allocation domain +// {iS0[i0], ir1[i1], iS2[i2]} +// target's rfactor domain +// {iS3[i3], iS4[i4], ir5[i1], iS6[i5], iS7[i2], ir8[1]} // -// // create AllocationOrderInferencer -// AllocationOrderInferencer infer(alloc_order_map); -// // propagates AllocationOrder from entries already in alloc_order_map -// infer.traverse(fusion); -// // all tensor that's propagated successfully will have their allocation -// // order in alloc_order_map +// 1. we project iter domains from targets' rfactor domain which has an exact +// map to ref's allocation domain. (sharp-edge 0: we exclude mapping from +// iteration id on ref to reduction id on target to avoid unnecessary +// re-ordering which exposes #2202). +// mapped_ids {ir5[i1], iS7[i2]} +// 2. remove all projected ids and reduction iter domains from target's rfactor +// domain: +// unmapped_ids {iS3[i3], iS4[i4], iS6[i5]} +// 3. iterating through unmodified target's rfactor domain to construct target +// allocation domain: +// (sharp-edge 1: if target_rfactor_domain[i] is a reduction and is not +// mapped, we keep the reduction iter domain in the original position.) Push +// the front of unmapped_id_vec to the end of target allocation domain, if +// unmapped_id_vec isn't empty yet; Otherwise, push the frnot of mapped_ids at +// the end of target allocation domain. // -// The protocol for AllocationOrder in alloc_order_map_ has three states. For -// each `tv`, its corresponding allocation order `alloc_order_map_[tv]`: -// 1. The allocation order has the same size as the `tv`'s rfactor domain; -// This means it has a preferred allocation order and the entry should -// participate in propagation. -// 2. The allocation order is an empty array; -// This means it's a wild card and shouldn't dictate output allocation -// order. But it marks that propagation is successful for `tv`. -// i.e. This currently happens for TensorViews that's created by factory -// methods and its consumers. -// 3. alloc_order_map_ does not have an entry for `tv`. -// This is the case where propagation has not reach the `tv`, likely due to -// lack of allocation order on inputs or certain operation not yet supported -// by propagation rule. -// -// Identify the difference between case 2. and 3. above allows us to better -// handle `resolveAllocationOrder` among multiple candidates. -// i. We do not want to ignore candidates where propagation has failed and -// aggressively propagates allocatoin order through unresolved candidates. So we -// would want to identify case 3. ii. Tensors created by factory methods should -// carry a wild-card and should not actively participate propagation. Because -// those tensors are not going to affect vectorization. Hence we need to -// identify case 2. -class AllocationOrderInferencer : public IterVisitor { - public: - // Note: alloc_order_map_ is a reference to the ground truth of - // alloc_order_map. The pass here tries to propagate the allocation order from - // the ground truth. - AllocationOrderInferencer( - std::unordered_map& alloc_order_map) - : alloc_order_map_(alloc_order_map) {} - - protected: - using IterVisitor::handle; - - void handle(FullOp*) override; - void handle(UnaryOp*) override; - void handle(BroadcastOp*) override; - void handle(BinaryOp*) override; - void handle(TernaryOp*) override; - void handle(PadOp*) override; - void handle(ReductionOp*) override; - // TODO: Add more propagation rules - // void handle(LoadStoreOp*) override; - // void handle(SqueezeOp*) override; - // void handle(ExpandOp*) override; - - private: - // mapping allocation domain from producer to consumer without reduction - // - // e.g. - // producer rfactor dom [r0', i0', i1', i2'] @ allocation order {0, 1, 3, 2} - // | alloc dom [r0', i0', i2', i1'] - // | - // Operation - // | - // v - // consumer rfactor dom [..., i0, ..., i1, ..., i2, ...] - // - // we construct allocation domain on producer, filtering out reduction, apply - // root domain map from producer to consumer. - // [r0', i0', i2', i1'] -> [i0', i2', i1'] -> [i0, i2, i1] - // so the function would return [i0, i2, i1] - std::vector propagateAllocationDomain( - TensorView* producer, - TensorView* consumer) { - // constructing alloc_domain for producer from its root domain, while - // filtering out reduction because they won't appear in consumer's domain. - std::vector alloc_domain = TensorDomain::noReductions( - constructAllocationDomain(producer, alloc_order_map_.at(producer))); - // creating producer to consumer root domain map - std::unordered_map p2c_map = - PairwiseRootDomainMap(producer, consumer).mapProducerToConsumer(); - // map alloc_domain to consumer - std::transform( - alloc_domain.cbegin(), - alloc_domain.cend(), - alloc_domain.begin(), - [&p2c_map](IterDomain* id) { return p2c_map.at(id); }); - return alloc_domain; +// Note: we could be using a simplified logic below, +// See issue https://github.com/NVIDIA/Fuser/issues/2202 +// 1. we project iter domains from targets' rfactor domain which has an exact +// map to ref's allocation domain. +// mapped_ids {ir5[i1], iS7[i2]} +// 2. remove all projected iter domains from target's rfactor +// domain: +// unmapped_ids {iS3[i3], iS4[i4], iS6[i5], ir8[1]} +// 3. append mapped_ids at the end of unmapped_id_vec. +// target_alloc_domain +// {iS3[i3], iS4[i4], iS6[i5], ir8[1], ir5[i1], iS7[i2]} +void mapAllocationDomain( + const IdModel& id_model, + const TensorView* ref, + TensorView* target) { + const ValGraph& val_graph = id_model.idGraph(IdMappingMode::EXACT); + + std::vector ref_alloc_domain = ref->getMaybeAllocationDomain(); + const std::vector& target_rfactor_domain = + target->getMaybeRFactorDomain(); + + // map target rfactor domain into ref's allocation domain + nvfuser::VectorOfUniqueEntries mapped_ids; + + std::unordered_map vg_id_map; + for (auto* id : target_rfactor_domain) { + if (val_graph.hasGroup(id)) { + vg_id_map[val_graph.toGroup(id)] = id; + } } - // Propagate allocation order from producer to consumer via: - // 1. Constructs producer allocation_domain with its allocation order; - // 2. Mapping it to consumer's root domain to create alloc_domain; - // 3. Compute allocation order of consumer as the permutation between - // alloc_domain and `permutation_ref`. - // - // Returns true when producer has a recorded allocation order, false - // otherwise. This function assumes that all root domain in consumer can be - // mapped to producer. - bool propagateAllocationOrder( - TensorView* producer, - TensorView* consumer, - const std::vector& permutation_ref) { - auto iter = alloc_order_map_.find(producer); - // early return is producer doesn't have an entry in alloc_order_map_ - if (iter == alloc_order_map_.end()) { - return false; + // logic to preserve reduction iter domain in target to WAR #2202 +#if true + // mapping id between ref's allocation domain to target's rfactor domain + for (auto* ref_id : ref_alloc_domain) { + // skip when no ValGroup for ref_id to map. + if (!val_graph.hasGroup(ref_id)) { + continue; } - - // early termination to propagate empty allocation order - if (iter->second.empty()) { - alloc_order_map_[consumer] = {}; - return true; + const ValGroup& vg = val_graph.toGroup(ref_id); + // skip when no mapping ValGroup found in target_rfactor_domain. + if (vg_id_map.count(vg) == 0) { + continue; } - - std::vector alloc_domain = - propagateAllocationDomain(producer, consumer); - // compute allocation order - std::optional permutation = - ir_utils::computePermutation(permutation_ref, alloc_domain); - - NVF_ERROR( - permutation.has_value(), - "allocation order propagation from ", - producer->toString(0), - " to ", - consumer->toString(0), - " failed!"); - alloc_order_map_[consumer] = permutation.value(); - return true; - } - - // Propagate allocation order from producer to consumer's rfactor_domain - bool propagateAllocationOrder(TensorView* producer, TensorView* consumer) { - return propagateAllocationOrder( - producer, consumer, consumer->getMaybeRFactorDomain()); + IterDomain* id = vg_id_map[vg]; + // sharp-edges 0 + // avoid mapping a reduced dimension. + if (!ref_id->isReduction() && id->isReduction()) { + continue; + } + mapped_ids.pushBack(id); } - // Returns the candidate operand that dominates the allocation order. - // - // It scans through each candidate to find the first one that: - // 1. is a TensorView - // 2. has the most non_broadcast IterDomains - // - // The function returns a nullptr when it encounters a TensorView that does - // not have an entry in alloc_order_map_, since this means we failed to - // propagate memory format for an entry, we do NOT want to aggressively insert - // output memory format. - // - // The function is used to resolve allocation order propagation for operator - // with multiple operands. The operand with the most number of - // non-broadcast IterDomain will be dominating the output allocation order. - // The motivation behind it to avoid breaking allocation order propagation - // from operands produced by broadcast. e.g. When a binary operator could take - // in a channels_last 4d tensor and an unsqueezed bias vector. We'll want to - // propagate the channels_last allocation order to output. - // - // Pre-condition: `candidates` must be the input operands of the same Expr. - TensorView* resolveAllocationOrder(const std::vector& candidates); - - // alloc_order_map_ records the allocation order of each TensorView. - // Since it only handles permutation from a rfactor domain to allocation - // domain, it can be interpreted as: - // - // e.g. TV0 rfactor domain [i0, i1, i2] - // alloc domain [i0, i2, i1] - // allocation order 0, 2, 1 - std::unordered_map& alloc_order_map_; -}; - -TensorView* AllocationOrderInferencer::resolveAllocationOrder( - const std::vector& candidates) { - TensorView* src = nullptr; - size_t non_bc_high_water_mark = 0; - - // helper utils to count the number of non broadcast / non reduction - // iterdomain - auto countLoopIterDomains = [](const TensorView* tv) -> size_t { - return std::count_if( - tv->getMaybeRFactorDomain().begin(), - tv->getMaybeRFactorDomain().end(), - [&](auto ptr_id) { - return !ptr_id->isBroadcast() && !ptr_id->isReduction(); - }); - }; - - for (auto* val : candidates) { - auto* tv = dynamic_cast(val); - // skip non TensorView entry - if (tv == nullptr) { + // removing mapped ids and reduction ids to create unmapped_ids. + // This means for the rest of ids in target_rfactor_domain that's not in + // mapped_ids, they are either 1. a reduction domain, or; 2. in + // [unmapped_ids.begin(), unmapped_ids_vec_end) This ensures that sharp-edges + // 1's loop would reconstruct a permutation of the target_rfactor_domain, + // hence a valid allocation domain for target. + std::vector unmapped_ids = target_rfactor_domain; + auto unmapped_ids_vec_end = std::remove_if( + unmapped_ids.begin(), unmapped_ids.end(), [&mapped_ids](IterDomain* it) { + return mapped_ids.has(it) || it->isReduction(); + }); + + auto mapped_id_iter = mapped_ids.begin(); + auto unmapped_id_iter = unmapped_ids.begin(); + // initialize new target allocation domain with nullptr + std::vector target_alloc_domain( + target_rfactor_domain.size(), nullptr); + for (auto i : c10::irange(target_rfactor_domain.size())) { + // sharp-edges 1 + // preserves non-mapped reduction id in its original position + if (target_rfactor_domain[i]->isReduction() && + !mapped_ids.has(target_rfactor_domain[i])) { + target_alloc_domain[i] = target_rfactor_domain[i]; continue; } - - auto iter = alloc_order_map_.find(tv); - // stopping propagation when we encounter an entry that does not have an - // allocation order. See NOTE: [Allocation Order Inference] - if (iter == alloc_order_map_.end()) { - return nullptr; + // push unmapped ids to outer dimension until it's fully consumed + if (unmapped_id_iter != unmapped_ids_vec_end) { + target_alloc_domain[i] = *unmapped_id_iter++; + } else { + // push mapped ids to inner dimension + target_alloc_domain[i] = *mapped_id_iter++; } - - // skip entry that has an empty allocation order - if (iter->second.empty()) { - // We still want to ensure that we propagate empty allocation order if - // there's no candidate with a non-empty allocation order - if (src == nullptr) { - src = tv; - } - - // skip if unspecified + } +#else + // mapping id between ref's allocation domain to target's rfactor domain + for (auto* ref_id : ref_alloc_domain) { + // skip when no ValGroup for ref_id to map. + if (!val_graph.hasGroup(ref_id)) { continue; } - - // check if current entry sets new record for num of non broadcast / non - // reduction iterdomain - if (size_t non_bc_count = countLoopIterDomains(tv); - non_bc_count > non_bc_high_water_mark) { - non_bc_high_water_mark = non_bc_count; - src = tv; + const ValGroup& vg = val_graph.toGroup(ref_id); + // skip when no mapping ValGroup found in target_rfactor_domain. + if (vg_id_map.count(vg) == 0) { + continue; } + IterDomain* id = vg_id_map[vg]; + mapped_ids.pushBack(id); } - - return src; -} - -// FullOp set empty allocation order to output -void AllocationOrderInferencer::handle(FullOp* op) { - auto* out = static_cast(op->output(0)); - alloc_order_map_[out] = {}; -} - -// UnaryOp propagation forward allocation order from input to output -void AllocationOrderInferencer::handle(UnaryOp* op) { - auto* out = dynamic_cast(op->out()); - if (out == nullptr) { - return; + std::vector target_alloc_domain = target_rfactor_domain; + // removing mapped ids. + auto unmapped_ids_vec_end = std::remove_if( + target_alloc_domain.begin(), + target_alloc_domain.end(), + [&mapped_ids](IterDomain* it) { return mapped_ids.has(it); }); + // appending mapped ids at the end of target_alloc_domain. + std::copy(mapped_ids.begin(), mapped_ids.end(), unmapped_ids_vec_end); +#endif + + // skip trivial allocation domain + if (target_alloc_domain != target_rfactor_domain) { + target->setAllocationDomain(target_alloc_domain, true); } - auto* in = op->in()->as(); - propagateAllocationOrder(in, out); } -// BroadcastOp propagation: -// 1. preserves all allocation order of input iterdomain; -// 2. stacks all added broadcast iter domain on outputs as outer dimensions in -// their natural position -// -// e.g. -// TV0 rfactor dom [i0', i1', i2'] @ allocation order {0, 2, 1} -// | alloc dom [i0', i2', i1'] -// | -// | -// BroadcastOp -// | -// v -// TV1 rfactor dom [i0, b3, i1, i2, b4] -// -// step 0: -// scan through all iterdomain in output TV1's rfactor domain -// insert all broadcast domain to alloc_domain[b3, b4]; +} // namespace + +// Note [ Allocation Order Propagation ] // -// step 1: -// computing iterdomain mapping from input to output; -// [i0', i2', i1'] -> [i0, i2, i1] +// The propagation tries to populate allocation domain from srcs to dsts. // -// step 2: -// follow allocation order on input, insert the mapped iter domain on -// output to alloc_domain[b3, b4, i0, i2, i1]; +// For each TensorView in dsts, it iterate through all TensorView in srcs +// looking for a reference TensorView to propagate its allocation domain. +// 1. It only propagate to TensorView in dsts when it's safe to manipulate its +// allocation domain: +// 1.1 It doesn't have an allocation domain set; +// 1.2 It is not an aliase to another TensorView; +// 1.3 It does not have self mapping; +// 2. Among all entries in srcs, we pick reference that: +// 2.1 It has a dependency towards dst; +// 2.2 It has the highest no. of non-trivial (non-broadcast/non-reduction) +// iter domains in allocation domain. +// Note0: The reason to count behind this is that, we could have binary +// operation on a full-sized tensor with a broadcast vector tensor. In +// which case, we would want to propagate the layout of the full-sized +// tensor to the output, even though both candidates have the same rank. +// Note1: when we have multiple candidates with the same count of +// non-trivial iter domains, we require there's no ambiguity by +// checking both candidates having the same iter domain mapping. +// Otherwise we'll stop the propagation by leaving ref as nullptr. +// 2.3 It does not have self mapping; +// 3. Propagate memory format from selected reference in `srcs` to its +// corresponding target in `dsts`. // -// step 3: -// compute permutation from alloc_domain to TV1's rfactor domain; -// so output TV1 will have allocation order {1, 4, 0, 3, 2} -void AllocationOrderInferencer::handle(BroadcastOp* op) { - auto* out = dynamic_cast(op->out()); - if (out == nullptr) { - return; - } - auto* in = op->in()->as(); - - auto iter = alloc_order_map_.find(in); - // early return when there's no recorded allocation order for `in` - if (iter == alloc_order_map_.end()) { - return; - } - - // propagate empty allocation order; - if (iter->second.empty()) { - alloc_order_map_[out] = {}; - return; - } - - size_t out_rank = out->nDims(); - std::vector alloc_domain; - alloc_domain.reserve(out_rank); - - // step 0: insert all broadcast iterdomain in output - for (auto i : c10::irange(out_rank)) { - if (op->isBroadcastDim(i)) { - alloc_domain.push_back(out->getMaybeRFactorDomain()[i]); +// propagation rule: +// Given a reference TensorView `ref` and a target TensorView `target`, we try +// to map iter domain in `ref->getMaybeAllocationDomain()` to +// `target->getMaybeRFactorDomain()`, which would gives `target` to a similar +// memory layout as `ref`. For details on the propagation rule see Note [ +// Allocation Order Mapping ] +void inferenceAllocationOrder( + Fusion* fusion, + const std::vector& srcs, + const std::vector& dsts) { + // build IdModel, setting allow_self_mapping to avoid assert + // even though we do NOT populate allocation order where self_mapping is + // present + auto id_model = + IdModel(fusion, /*build_graphs=*/true, /*allow_self_mapping=*/true); + const ValGraph& exact_graph = id_model.idGraph(IdMappingMode::EXACT); + const DisjointSets& val_sets = exact_graph.disjointValSets(); + + // populate the number of non-trivial iter domains on srcs + std::unordered_map non_trivial_iter_count; + for (auto* tv : srcs) { + // skip entry with self mapping. + if (!hasSelfMapping(tv, exact_graph).has_value()) { + non_trivial_iter_count[tv] = countNonTrivialIterDomains(tv); } } - // step 1: computing iterdomain mapping from input to output - std::vector mapped_alloc_dom = - propagateAllocationDomain(in, out); - - // step 2: push each mapped iterdomain - std::copy( - mapped_alloc_dom.begin(), - mapped_alloc_dom.end(), - std::back_inserter(alloc_domain)); - - // step 3: compute permutation - std::optional permutation = - ir_utils::computePermutation(out->getMaybeRFactorDomain(), alloc_domain); - - NVF_ERROR( - permutation.has_value(), - "allocation order propagation on broadcast op failed to compute valid permutation"); - alloc_order_map_[out] = permutation.value(); -} - -void AllocationOrderInferencer::handle(BinaryOp* op) { - auto* out = dynamic_cast(op->out()); - if (out == nullptr) { - return; - } - propagateAllocationOrder(resolveAllocationOrder(op->inputs()), out); -} - -void AllocationOrderInferencer::handle(TernaryOp* op) { - auto* out = dynamic_cast(op->out()); - if (out == nullptr) { - return; - } - propagateAllocationOrder(resolveAllocationOrder(op->inputs()), out); -} - -void AllocationOrderInferencer::handle(PadOp* op) { - auto* out = dynamic_cast(op->out()); - auto* in = dynamic_cast(op->in()); - // Note: `out` from pad has rfactor domain that cannot be mapped back to - // `in`'s root domain. Hence we use `out`'s root domain to match permutation. - propagateAllocationOrder(in, out, out->getRootDomain()); -} + // propagate new allocation domain on dsts + for (TensorView* dst : dsts) { + // safe check when allocation domain on the entry cannot be safely mutated. + if (dst == nullptr || dst->hasAllocation() || + fusion->getOutputAlias(dst).type != AllocationType::New) { + continue; + } -void AllocationOrderInferencer::handle(ReductionOp* op) { - auto* out = dynamic_cast(op->out()); - auto* in = dynamic_cast(op->in()); - propagateAllocationOrder(in, out); -} + // skip entry with self mapping. + if (hasSelfMapping(dst, exact_graph).has_value()) { + continue; + } -} // namespace + // find a ref among srcs to be propagated to given dst + TensorView* ref = nullptr; -// Note [ Allocation Order Propagation ] -// -// The propagation tries to propagate allocation order from inputs to the entire -// fusion: -// 1. Iterates through all inputs, looking for TensorView with allocation -// domain that's a permutation of its corresponding rfactor domain and record -// it as the allocation order of the tensor; -// 2. Traverse the fusion IR, propagate allocation order and record results in -// alloc_order_map. -std::unordered_map inferenceAllocationOrder( - Fusion* fusion) { - std::unordered_map alloc_order_map; + // high water mark for candidate of ref. + int64_t non_bc_high_water_mark = 0; + for (auto* tv : srcs) { + // skip when non-trivial iter domain count is missing. + if (non_trivial_iter_count.count(tv) == 0) { + continue; + } + // discard srcs for propagation which dst has no dependency on. + if (!DependencyCheck::isDependencyOf(tv, dst)) { + continue; + } + // discard srcs with lower iterdomain count than ref. + if (non_trivial_iter_count[tv] < non_bc_high_water_mark) { + continue; + } + // new candidate found, update ref and high water mark. + if (non_trivial_iter_count[tv] > non_bc_high_water_mark) { + ref = tv; + non_bc_high_water_mark = non_trivial_iter_count[tv]; + continue; + } + // found multiple candidate with the same iterdomain count + if (non_trivial_iter_count[tv] == non_bc_high_water_mark && + ref != nullptr) { + // ensure that there's no ambiguity on permutation mapping from multiple + // references. we need both ref candidates to have the same mapping on + // allocation domain + for (auto i : c10::irange(ref->nDims())) { + if (!val_sets.permissiveAreMapped( + ref->getMaybeAllocationDomain()[i], + tv->getMaybeAllocationDomain()[i])) { + // reset ref to nullptr, while keeping the iterdomain count high + // water mark. No propagation will occur unless we found another ref + // candidate with a higher iterdomain count. + ref = nullptr; + break; + } + } + continue; + } + } - // Note: we only consider simple permutation of allocation domain to rfactor - // domain. - for (auto tv : ir_utils::filterByType(fusion->inputs())) { - std::optional permutation = ir_utils::computePermutation( - TensorDomain::noReductions(tv->getMaybeRFactorDomain()), - TensorDomain::noReductions(tv->getMaybeAllocationDomain())); - if (permutation.has_value()) { - alloc_order_map[tv] = permutation.value(); + // propagate allocation domain if we still have a candidate. + if (ref) { + mapAllocationDomain(id_model, ref, dst); } } - - // Initialize AllocationOrderInferencer with allocation order of input tensor - // views - AllocationOrderInferencer infer(alloc_order_map); - infer.traverse(fusion); - - // return the propagated map - return alloc_order_map; } void AllocationDomainPass::runPass(Fusion* fusion) { - std::unordered_map stride_mapping = - inferenceAllocationOrder(fusion); - - for (Val* out_val : fusion->outputs()) { - auto* out_tv = dynamic_cast(out_val); - // skip: - // 1. non-tensor output; - // 2. tensor output with allocation specified, assuming everything is - // semantical - // 3. tensor output that's aliasing (Does aliased src matter?) - if (out_tv == nullptr || out_tv->hasAllocation() || - fusion->getOutputAlias(out_val).type != AllocationType::New) { - continue; - } - - auto mapped_entry = stride_mapping.find(out_tv); - if (mapped_entry == stride_mapping.end() || mapped_entry->second.empty()) { - continue; - } - - out_tv->setAllocationDomain( - constructAllocationDomain(out_tv, mapped_entry->second), true); - } + // mark input TensorViews as propagation sources + auto input_tvs = ir_utils::filterByType(fusion->inputs()); + std::vector srcs(input_tvs.begin(), input_tvs.end()); + // mark output TensorViews as propagation destinations + auto output_tvs = ir_utils::filterByType(fusion->outputs()); + std::vector dsts(output_tvs.begin(), output_tvs.end()); + // propagate allocation domain from sources to destinations + inferenceAllocationOrder(fusion, srcs, dsts); } } // namespace nvfuser::preseg_passes diff --git a/csrc/preseg_passes/allocation_order_inference.h b/csrc/preseg_passes/allocation_order_inference.h index 1eadc6facbb..9650e750f74 100644 --- a/csrc/preseg_passes/allocation_order_inference.h +++ b/csrc/preseg_passes/allocation_order_inference.h @@ -12,23 +12,14 @@ namespace nvfuser::preseg_passes { -// allocation order is the permutation to apply on a tensor view's rfactor -// domain to its allocation domain. -// -// i.e. For a channels last 4d tensor, we mark it as (0, 2, 3, 1). This is -// trying to present it more consistently with how we construct it with c++ API. -// std::vector tv0_nhwc = { -// tv0->axis(0), tv0->axis(2), tv0->axis(3), tv0->axis(1)}; -// tv0->setAllocationDomain(tv0_nhwc, true); -using AllocationOrder = std::vector; - -// Propagate allocation order from input to the entire fusion. It does NOT -// modify any fusion IR, but instead stores the propagated allocation order as -// an unordered_map from TensorView to permutation. +// Propagate allocation domain from srcs to dsts. +// The pass update allocation domain on dsts tensor views. // // See details in Note [ Allocation Order Propagation ] -std::unordered_map inferenceAllocationOrder( - Fusion* fusion); +void inferenceAllocationOrder( + Fusion* fusion, + const std::vector& srcs, + const std::vector& dsts); // Realize allocation order propagation on fusion inputs to optimize allocation // domain of output tensor. This optimization pass currently only applies to diff --git a/tests/cpp/test_allocation_order_inference.cpp b/tests/cpp/test_allocation_order_inference.cpp index 62b93f5107b..d86e2410a7e 100644 --- a/tests/cpp/test_allocation_order_inference.cpp +++ b/tests/cpp/test_allocation_order_inference.cpp @@ -25,6 +25,13 @@ using testing::ElementsAre; using AllocationOrderInferenceTest = NVFuserTest; +std::vector getAllocationDomainPermutation(TensorView* tv) { + std::optional> permutation = + ir_utils::computePermutation( + tv->getMaybeRFactorDomain(), tv->getMaybeAllocationDomain()); + return permutation.value(); +} + TEST_F(AllocationOrderInferenceTest, BroadcastOpPropagation) { auto fusion_ptr = std::make_unique(); Fusion& fusion = *fusion_ptr.get(); @@ -44,9 +51,10 @@ TEST_F(AllocationOrderInferenceTest, BroadcastOpPropagation) { tv0->axis(0), tv0->axis(2), tv0->axis(3), tv0->axis(1)}; tv0->setAllocationDomain(tv0_nhwc, true); - auto updated_layout = preseg_passes::inferenceAllocationOrder(&fusion); - EXPECT_THAT(updated_layout[tv2], ElementsAre(0, 3, 5, 7, 1, 4, 6, 2)); - EXPECT_THAT(updated_layout[tv3], ElementsAre(0, 2, 3, 1)); + preseg_passes::inferenceAllocationOrder(&fusion, {tv0, tv1}, {tv2, tv3}); + EXPECT_THAT( + getAllocationDomainPermutation(tv2), ElementsAre(0, 3, 5, 7, 1, 4, 6, 2)); + EXPECT_THAT(getAllocationDomainPermutation(tv3), ElementsAre(0, 2, 3, 1)); } TEST_F(AllocationOrderInferenceTest, UnaryOpPropagation) { @@ -63,8 +71,8 @@ TEST_F(AllocationOrderInferenceTest, UnaryOpPropagation) { tv0->axis(0), tv0->axis(2), tv0->axis(3), tv0->axis(1)}; tv0->setAllocationDomain(tv0_nhwc, true); - const auto inferred_layout = preseg_passes::inferenceAllocationOrder(&fusion); - EXPECT_THAT(inferred_layout.at(tv1), ElementsAre(0, 2, 3, 1)); + preseg_passes::inferenceAllocationOrder(&fusion, {tv0}, {tv1}); + EXPECT_THAT(getAllocationDomainPermutation(tv1), ElementsAre(0, 2, 3, 1)); } TEST_F(AllocationOrderInferenceTest, BinaryOpPropagation) { @@ -94,12 +102,12 @@ TEST_F(AllocationOrderInferenceTest, BinaryOpPropagation) { tv0->axis(0), tv0->axis(2), tv0->axis(3), tv0->axis(1)}; tv0->setAllocationDomain(tv0_nhwc, true); - const auto inferred_layout = - preseg_passes::inferenceAllocationOrder(&fusion); - EXPECT_THAT(inferred_layout.at(tv2), ElementsAre(0, 2, 3, 1)); - EXPECT_THAT(inferred_layout.at(tv3), ElementsAre(0, 2, 3, 1)); - EXPECT_THAT(inferred_layout.at(tv6), ElementsAre(0, 2, 3, 1)); - EXPECT_THAT(inferred_layout.at(tv7), ElementsAre(0, 2, 3, 1)); + preseg_passes::inferenceAllocationOrder( + &fusion, {tv0}, {tv2, tv3, tv6, tv7}); + EXPECT_THAT(getAllocationDomainPermutation(tv2), ElementsAre(0, 2, 3, 1)); + EXPECT_THAT(getAllocationDomainPermutation(tv3), ElementsAre(0, 2, 3, 1)); + EXPECT_THAT(getAllocationDomainPermutation(tv6), ElementsAre(0, 2, 3, 1)); + EXPECT_THAT(getAllocationDomainPermutation(tv7), ElementsAre(0, 2, 3, 1)); } { auto fusion_ptr = std::make_unique(); @@ -124,82 +132,14 @@ TEST_F(AllocationOrderInferenceTest, BinaryOpPropagation) { tv1->axis(1), tv1->axis(0), tv1->axis(2), tv1->axis(3)}; tv1->setAllocationDomain(tv1_format, true); - const auto inferred_layout = - preseg_passes::inferenceAllocationOrder(&fusion); - EXPECT_THAT(inferred_layout.at(tv2), ElementsAre(1, 0, 2, 3)); - EXPECT_THAT(inferred_layout.at(tv3), ElementsAre(1, 0, 2, 3)); - } - { - auto fusion_ptr = std::make_unique(); - Fusion& fusion = *fusion_ptr.get(); - FusionGuard fg(&fusion); - - // Testing propagation between two tensors - // tv0 and tv1 has the same number of non-broadcast iter domains, so lhs - // operand would propagate its allocation order. - auto tv0 = makeSymbolicTensor({-1, -1, 1, 1}); - fusion.addInput(tv0); - auto tv1 = makeSymbolicTensor({-1, -1, 1, 1}); - fusion.addInput(tv1); - // tv2 should have allocation order from tv0 - auto tv2 = add(tv0, tv1); - fusion.addOutput(tv2); - // tv3 should have allocation order from tv1 - auto tv3 = add(tv1, tv0); - fusion.addOutput(tv3); - - std::vector tv0_format = { - tv0->axis(0), tv0->axis(2), tv0->axis(1), tv0->axis(3)}; - tv0->setAllocationDomain(tv0_format, true); - std::vector tv1_format = { - tv1->axis(1), tv1->axis(0), tv1->axis(2), tv1->axis(3)}; - tv1->setAllocationDomain(tv1_format, true); - - const auto inferred_layout = - preseg_passes::inferenceAllocationOrder(&fusion); - EXPECT_THAT(inferred_layout.at(tv2), ElementsAre(0, 2, 1, 3)); - EXPECT_THAT(inferred_layout.at(tv3), ElementsAre(1, 0, 2, 3)); - } - { - auto fusion_ptr = std::make_unique(); - Fusion& fusion = *fusion_ptr.get(); - FusionGuard fg(&fusion); - - // Testing propagation between two tensors - // tv0 and tv1 has the same number of non-broadcast iter domains, so lhs - // operand would propagate its allocation order. - auto tv0 = makeSymbolicTensor({-1, -1, 1, 1}); - fusion.addInput(tv0); - auto tv1 = makeSymbolicTensor({-1, -1, 1, 1}); - fusion.addInput(tv1); - // tv2 should have allocation order from tv0 - auto tv2 = add(tv0, tv1); - fusion.addOutput(tv2); - - // reshape propagation is not supported yet - auto tv3 = reshape( - tv1, - { - tv0->axis(0)->extent(), - tv0->axis(1)->extent(), - tv0->axis(2)->extent(), - tv0->axis(3)->extent(), - }); - auto tv4 = add(tv0, tv3); - fusion.addOutput(tv4); - - std::vector tv0_format = { - tv0->axis(0), tv0->axis(2), tv0->axis(1), tv0->axis(3)}; - tv0->setAllocationDomain(tv0_format, true); - std::vector tv1_format = { - tv1->axis(1), tv1->axis(0), tv1->axis(2), tv1->axis(3)}; - tv1->setAllocationDomain(tv1_format, true); - - const auto inferred_layout = - preseg_passes::inferenceAllocationOrder(&fusion); - EXPECT_THAT(inferred_layout.at(tv2), ElementsAre(0, 2, 1, 3)); - EXPECT_TRUE(inferred_layout.count(tv3) == 0); - EXPECT_TRUE(inferred_layout.count(tv4) == 0); + preseg_passes::inferenceAllocationOrder(&fusion, {tv0, tv1}, {tv2, tv3}); + // tv1 dominates output allocation order, which has a permutation {1, 0, 2, + // 3}. But since tv1->axis(3) is a broadcast dimension, it did not map to + // tv2->axis(3)/tv3->axis(3). Propagated permutation would push the unmapped + // axis(3) first in the allocation domain while keeping mapped ids in its + // original order {1, 0, 2} as inner entries in its allocation domain. + EXPECT_THAT(getAllocationDomainPermutation(tv2), ElementsAre(3, 1, 0, 2)); + EXPECT_THAT(getAllocationDomainPermutation(tv3), ElementsAre(3, 1, 0, 2)); } } @@ -228,9 +168,9 @@ TEST_F(AllocationOrderInferenceTest, TensorFactoryBinaryOpPropagation) { std::vector tv1_c_last = {tv1->axis(0), tv1->axis(1)}; tv1->setAllocationDomain(tv1_c_last, true); - const auto inferred_layout = preseg_passes::inferenceAllocationOrder(&fusion); - EXPECT_THAT(inferred_layout.at(tv2), ElementsAre(1, 0)); - EXPECT_THAT(inferred_layout.at(tv3), ElementsAre(1, 0)); + preseg_passes::inferenceAllocationOrder(&fusion, {tv0}, {tv2, tv3}); + EXPECT_THAT(getAllocationDomainPermutation(tv2), ElementsAre(1, 0)); + EXPECT_THAT(getAllocationDomainPermutation(tv3), ElementsAre(1, 0)); } TEST_F(AllocationOrderInferenceTest, TensorEmptyAllocationOrderPropagation) { @@ -256,8 +196,8 @@ TEST_F(AllocationOrderInferenceTest, TensorEmptyAllocationOrderPropagation) { std::vector tv0_c_last = {tv0->axis(1), tv0->axis(0)}; tv0->setAllocationDomain(tv0_c_last, true); - const auto inferred_layout = preseg_passes::inferenceAllocationOrder(&fusion); - EXPECT_THAT(inferred_layout.at(tv4), ElementsAre(1, 0)); + preseg_passes::inferenceAllocationOrder(&fusion, {tv0}, {tv4}); + EXPECT_THAT(getAllocationDomainPermutation(tv4), ElementsAre(1, 0)); } TEST_F(AllocationOrderInferenceTest, TernaryOpPropagation) { @@ -272,6 +212,7 @@ TEST_F(AllocationOrderInferenceTest, TernaryOpPropagation) { auto tv2 = makeSymbolicTensor({-1, -1, -1, -1}); fusion.addInput(tv2); auto tv3 = gt(tv0, IrBuilder::create(0.0)); + fusion.addOutput(tv3); auto tv4 = where(tv3, tv1, tv2); fusion.addOutput(tv4); @@ -285,9 +226,9 @@ TEST_F(AllocationOrderInferenceTest, TernaryOpPropagation) { tv2->axis(0), tv2->axis(2), tv2->axis(3), tv2->axis(1)}; tv2->setAllocationDomain(tv2_nhwc, true); - const auto inferred_layout = preseg_passes::inferenceAllocationOrder(&fusion); - EXPECT_THAT(inferred_layout.at(tv3), ElementsAre(0, 2, 3, 1)); - EXPECT_THAT(inferred_layout.at(tv4), ElementsAre(0, 2, 3, 1)); + preseg_passes::inferenceAllocationOrder(&fusion, {tv0, tv1, tv2}, {tv3, tv4}); + EXPECT_THAT(getAllocationDomainPermutation(tv3), ElementsAre(0, 2, 3, 1)); + EXPECT_THAT(getAllocationDomainPermutation(tv4), ElementsAre(0, 2, 3, 1)); } TEST_F(AllocationOrderInferenceTest, ReductionOpPropagation) { @@ -302,8 +243,16 @@ TEST_F(AllocationOrderInferenceTest, ReductionOpPropagation) { fusion.addInput(tv0); auto tv1 = makeSymbolicTensor({-1, 1}); // stride order: {0, 1} fusion.addInput(tv1); - auto tv2 = sum(tv0, {1}); // stride order: {1, 2, 3, 0} - auto tv3 = sum(tv2, {1}); // stride order: {1, 2, 0} + // Instead of propagating stride order: {1, 2, 3, 0} + // The end result is {2, 1, 3, 0} because we skip mapping from Iteration id to + // reduction id. See Note [ Allocation Order Mapping ] sharp-edge 0 for + // details. + // TODO: restore behavior after issue: + // https://github.com/NVIDIA/Fuser/issues/2202 + auto tv2 = sum(tv0, {1}); + fusion.addOutput(tv2); + // ditto. stride order here is {2, 1, 0} instead of {1, 2, 0} + auto tv3 = sum(tv2, {1}); fusion.addOutput(tv3); // tv3 dominates the propagation since it has more non-broadcast dimension auto tv4 = add(tv1, tv3); // stride order: {1, 0} @@ -314,11 +263,20 @@ TEST_F(AllocationOrderInferenceTest, ReductionOpPropagation) { auto tv5 = broadcast(tv3, {true, false, false, true}); fusion.addOutput(tv5); - const auto inferred_layout = preseg_passes::inferenceAllocationOrder(&fusion); - EXPECT_THAT(inferred_layout.at(tv2), ElementsAre(1, 2, 3, 0)); - EXPECT_THAT(inferred_layout.at(tv3), ElementsAre(1, 2, 0)); - EXPECT_THAT(inferred_layout.at(tv4), ElementsAre(1, 0)); - EXPECT_THAT(inferred_layout.at(tv5), ElementsAre(0, 3, 2, 1)); + preseg_passes::inferenceAllocationOrder( + &fusion, {tv0, tv1}, {tv2, tv3, tv4, tv5}); +#if true + // permutation here is strange because in propagation we are preserving + // reduction iter domain in its position in rfactor domain See issue: + // https://github.com/NVIDIA/Fuser/issues/2202 + EXPECT_THAT(getAllocationDomainPermutation(tv2), ElementsAre(2, 1, 3, 0)); + EXPECT_THAT(getAllocationDomainPermutation(tv3), ElementsAre(2, 1, 0)); +#else + EXPECT_THAT(getAllocationDomainPermutation(tv2), ElementsAre(1, 2, 3, 0)); + EXPECT_THAT(getAllocationDomainPermutation(tv3), ElementsAre(1, 2, 0)); +#endif + EXPECT_THAT(getAllocationDomainPermutation(tv4), ElementsAre(1, 0)); + EXPECT_THAT(getAllocationDomainPermutation(tv5), ElementsAre(0, 3, 2, 1)); } TEST_F(AllocationOrderInferenceTest, EnableInRuntime) { diff --git a/tests/cpp/test_gather.cpp b/tests/cpp/test_gather.cpp index 910d2122461..3bd5893ea90 100644 --- a/tests/cpp/test_gather.cpp +++ b/tests/cpp/test_gather.cpp @@ -1035,6 +1035,9 @@ TEST_F(IndexingOpTest, TakeAlongAxisIntermediateTensorTranspose1_CUDA) { auto tv4 = take_along_axis(tv2, tv3, 0); auto tv5 = transpose(tv4, 1, 2); fusion.addOutput(tv5); + // specify output allocation domain to avoid allocation order pass changing + // this to a pointwise kernel + tv5->setAllocationDomain(tv5->getMaybeRFactorDomain(), true); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); auto options_i = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0); diff --git a/tests/cpp/test_gpu_transpose.cpp b/tests/cpp/test_gpu_transpose.cpp index 8e0d2ac594d..92fb5e27a76 100644 --- a/tests/cpp/test_gpu_transpose.cpp +++ b/tests/cpp/test_gpu_transpose.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #include #include @@ -46,11 +47,18 @@ class TransposeTest : public NVFuserTest { // For convenience, disable MarkAliasesPreparePass. Many tests in this file // run a fusion that consists of `transpose` only. MarkAliasesPreparePass // would turn those fusions into a no-op, skipping the transpose scheduler. - TransposeTest() : optimization_guard_(false) {} + // + // Disable AllocationDomainPass. Fusion with permutation would otherwise run + // through pointwise scheduler with allocation order pass trying to match + // output with the same layout as with its inputs. + TransposeTest() + : optimization_guard_(false), allocation_order_guard_(false) {} private: preseg_passes::OptimizationPassGuard optimization_guard_; + preseg_passes::OptimizationPassGuard + allocation_order_guard_; }; // x->sin->transpose->cos->y