Skip to content

Commit

Permalink
ATen scheduler for the new Matmul/LinearOp IR nodes (#2209)
Browse files Browse the repository at this point in the history
1. Adds a new scheduler -- `ExprEvalScheduler` that accepts the MatmulOp
and LinearOp (next PR) for ATen evaluation.
2. Modify the matmul input generator to test for all cases supported by
Thunder.
3. The `eagerMatmul` API is renamed and replaces the existing `matmul`
API. `fd.ops.matmul` now creates a `MatmulOp` (except in a few special
cases such as scalar dot product, for eg: `[K] x [K]`.

Issue #2149, #2092.
  • Loading branch information
Priya2698 committed May 14, 2024
1 parent 03716fc commit dfba77a
Show file tree
Hide file tree
Showing 15 changed files with 211 additions and 139 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ list(APPEND NVFUSER_SRCS
${NVFUSER_SRCS_DIR}/scheduler/transpose.cpp
${NVFUSER_SRCS_DIR}/scheduler/utils.cpp
${NVFUSER_SRCS_DIR}/scheduler/vectorize_helper.cpp
${NVFUSER_SRCS_DIR}/scheduler/expr_eval_sched.cpp
${NVFUSER_SRCS_DIR}/serde/polymorphic_value.cpp
${NVFUSER_SRCS_DIR}/serde/utils.cpp
${NVFUSER_SRCS_DIR}/swizzle.cpp
Expand Down
41 changes: 1 addition & 40 deletions csrc/ops/composite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,42 +54,6 @@ TensorView* dropout_backward(TensorView* dy, TensorView* mask, Val* scale) {
return dx;
}

// This function will add a castOp to the output of the matrix multiplication
// The implementation of linear can use this but will skip the cast (set cast
// flag as false) and add the bias.
TensorView* matmul(TensorView* a, TensorView* b, bool cast_output_to_input) {
NVF_CHECK(
a->nDims() == b->nDims(),
"The number of dimension of A and B do not match");
// TODO: We'll need to suppor nDims == 3 for bmm.
NVF_CHECK(
a->nDims() == 2,
"Only 2-D Tensors are supported, in the future we'll support 3-D as well!");

std::vector<bool> bcast_dims(a->nDims() + 1, false);
// A: [M, K, Bcast]
// B: [Bcast, K, N]
bcast_dims.at(bcast_dims.size() - 1) = true;
auto* tv0b = broadcast(a, bcast_dims);
bcast_dims.at(bcast_dims.size() - 1) = false;
bcast_dims.at(bcast_dims.size() - 3) = true;
auto* tv1b = broadcast(b, bcast_dims);

NVF_CHECK(
a->getDataType().value() == b->getDataType().value(),
"data types of inputs to matmul don't match");
auto* output = fusedMultiplySum(tv0b, tv1b, {-2});
if (cast_output_to_input) {
// For matmul, the output dtype should match input.
return maybeCastOp(a->getDataType().value(), output);
}
return output;
}

TensorView* matmul(TensorView* a, TensorView* b) {
return matmul(a, b, true /* cast output to input dtype */);
}

TensorView* linear(TensorView* a, TensorView* b, TensorView* bias) {
// TODO: Support 1+ dimensional A.
NVF_CHECK(
Expand Down Expand Up @@ -348,10 +312,7 @@ static TensorView* newForMatmul(TensorView* tv_a, TensorView* tv_b) {

} // namespace

// TODO (Priya): This will be renamed to matmul once we are ready to modify the
// python API backend. Keeping separate for now, to avoid breaking tests in
// Thunder.
TensorView* eagerMatmul(TensorView* tv_a, TensorView* tv_b) {
TensorView* matmul(TensorView* tv_a, TensorView* tv_b) {
NVF_CHECK(
tv_a->nDims() > 0 && tv_b->nDims() > 0,
"Expected inputs to be atleast 1D, got: ",
Expand Down
15 changes: 4 additions & 11 deletions csrc/ops/composite.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,6 @@ NVF_API LstmResult lstm(
TensorView* cell_x,
TensorView* out_x);

// Matmul function which takes in tensors with the shapes
// A[M,K] B[K,N], but the tensors may have different layouts
// via strides. All restrictions from the matmul APIs also
// apply here.
TensorView* matmul(TensorView* a, TensorView* b);
// This second matmul function is not exposed via
// the Python interface, but it does the guts of the work and
// can be used to create mamtuls without a cast operation following it.
TensorView* matmul(TensorView* a, TensorView* b, bool cast_output_to_input);

// Linear functions which takes in two tensors of shapes A[M,K] and
// B[N,K]. Takes in a options bias of shape [N] and performs
// out = A * B_Transpose + bias. The output dtype matches the dtype
Expand All @@ -81,6 +71,9 @@ TensorView* leaky_relu(TensorView* x, Val* negative_slope);

NVF_API TensorView* view_as_real(TensorView* x);

TensorView* eagerMatmul(TensorView* tv_a, TensorView* tv_b);
// Matmul function which takes in tensors with the shapes
// A[*, M, K] / A[K] and B[*, K, N] / B[K], but the tensors may have different
// layouts via strides. This has the same functionality as torch.matmul
TensorView* matmul(TensorView* tv_a, TensorView* tv_b);

} // namespace nvfuser
103 changes: 53 additions & 50 deletions csrc/root_domain_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,23 +123,55 @@ std::unordered_map<IterDomain*, IterDomain*> PairwiseRootDomainMap::map(
TensorDomain::noReductions(producer->maybeRFactor());
const auto& consumer_root = consumer->root();

// Add key-value iterdomain pair to the map.
auto updatePairwiseRootDomainMap =
[&root_dims_to_map, producer_to_consumer, &dom_map](
IterDomain* map_key_id, IterDomain* map_value_id) {
if (!producer_to_consumer) {
std::swap(map_key_id, map_value_id);
}
if (root_dims_to_map.find(map_key_id) != root_dims_to_map.end()) {
dom_map.insert(std::make_pair(map_key_id, map_value_id));
}
};
// Check following conditions and add key-value iterdomain pair to domain map:
// 1. Do not map broadcast ID to non-broadcast ID unless map_broadcast_ =
// true.
// 2. Do not map Symbolic ID if the extents are not identical unless
// map_symbolic_ = true.
auto updatePairwiseRootDomainMap = [&](IterDomain* producer_id,
IterDomain* consumer_id) {
if (!map_broadcast_ &&
producer_id->isBroadcast() != consumer_id->isBroadcast()) {
return;
}

// Condition: At least one ID is symbolic.
//
// If map_symbolic_ is true:
// Map these IDs regardless of other considerations.
//
// If map_symbolic_ is false (default):
// Map these only if their extents are identical. IterType::Symbolic
// reflects that the extent might evaluate to 1 for some inputs, in which
// case it may be valid to use those domains in a broadcast op. If the
// extents are exactly the same between two aligned IterDomains, the
// Symbolic one will be concretized to the same IterType as the other, so
// they should be mapped with one another.
if (!map_symbolic_ &&
(producer_id->isSymbolic() || consumer_id->isSymbolic()) &&
(!producer_id->extent()->sameAs(consumer_id->extent()))) {
return;
}

IterDomain* map_key_id = producer_id;
IterDomain* map_value_id = consumer_id;

if (!producer_to_consumer) {
std::swap(map_key_id, map_value_id);
}

if (root_dims_to_map.find(map_key_id) != root_dims_to_map.end()) {
dom_map.insert(std::make_pair(map_key_id, map_value_id));
}
};

// For MatmulOp, use the corresponding mapped input iterdomains.
if (MatmulOp* op = dynamic_cast<MatmulOp*>(consumer_tv_->definition())) {
// Check if the producer is lhs/rhs input
MatmulRole input_role =
producer->sameAs(op->inA()) ? MatmulRole::INPUT_A : MatmulRole::INPUT_B;
producer->sameAs(op->inA()->as<TensorView>()->domain())
? MatmulRole::INPUT_A
: MatmulRole::INPUT_B;
auto out_size = consumer_root.size();

// For MatmulOp, the input iterdomains at a given index do not necessarily
Expand All @@ -150,14 +182,18 @@ std::unordered_map<IterDomain*, IterDomain*> PairwiseRootDomainMap::map(
// input and output for index=2
// 2. `B, M, K] x [K, N] -> [B, M, N]`: For input B, the second iterdomain
// maps to the third output iterdomain.
const std::vector<IterDomain*>& aligned_producer_id =
const std::vector<IterDomain*>& aligned_producer_ids =
ops::mapMatmulOpIterDomains(producer_root, input_role, out_size);

for (auto inx : c10::irange(out_size)) {
IterDomain* map_key_id = aligned_producer_id.at(inx);
IterDomain* map_value_id = consumer_root.at(inx);
updatePairwiseRootDomainMap(map_key_id, map_value_id);
IterDomain* producer_id = aligned_producer_ids.at(inx);
IterDomain* consumer_id = consumer_root.at(inx);
if (producer_id == nullptr) {
continue;
}
updatePairwiseRootDomainMap(producer_id, consumer_id);
}

return dom_map;
}

Expand All @@ -171,8 +207,6 @@ std::unordered_map<IterDomain*, IterDomain*> PairwiseRootDomainMap::map(
// 2. IDs that may have different extents (e.g., non indexed
// domains of torch_gather)
// 3. Squeeze and unsqueeze
// 4. Broadcast and non broadcast
// 5. Symbolic ID with different extent from other ID

// Condition 1: when the producer ID is the dim of a select-like op
if (producer_id == indexed_producer_id) {
Expand Down Expand Up @@ -217,38 +251,7 @@ std::unordered_map<IterDomain*, IterDomain*> PairwiseRootDomainMap::map(
continue;
}

// Condition 4
if (!map_broadcast_ &&
producer_id->isBroadcast() != consumer_id->isBroadcast()) {
itc++;
itp++;
continue;
}

// Condition 5
// At least one ID is symbolic.
//
// If map_symbolic_ is true:
// Map these IDs regardless of other considerations.
//
// If map_symbolic_ is false (default):
// Map these only if their extents are identical. IterType::Symbolic
// reflects that the extent might evaluate to 1 for some inputs, in which
// case it may be valid to use those domains in a broadcast op. If the
// extents are exactly the same between two aligned IterDomains, the
// Symbolic one will be concretized to the same IterType as the other, so
// they should be mapped with one another.
if (!map_symbolic_ &&
(producer_id->isSymbolic() || consumer_id->isSymbolic()) &&
(!producer_id->extent()->sameAs(consumer_id->extent()))) {
itc++;
itp++;
continue;
}

IterDomain* map_key_id = producer_id;
IterDomain* map_value_id = consumer_id;
updatePairwiseRootDomainMap(map_key_id, map_value_id);
updatePairwiseRootDomainMap(producer_id, consumer_id);

itc++;
itp++;
Expand Down
1 change: 1 addition & 0 deletions csrc/scheduler/all_schedulers.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
*/
// clang-format on
#pragma once
#include <scheduler/expr_eval_sched.h>
#include <scheduler/matmul.h>
#include <scheduler/no_op.h>
#include <scheduler/normalization_inner.h>
Expand Down
33 changes: 33 additions & 0 deletions csrc/scheduler/expr_eval_sched.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on

#include <ir/utils.h>
#include <scheduler/debug_utils.h>
#include <scheduler/expr_eval_sched.h>
#include <scheduler/registry_utils.h>

namespace nvfuser {

// Check if the fusion has a single MatmulOp node
bool ExprEvalScheduler::canScheduleCompileTime(Fusion* fusion) {
auto exprs = fusion->exprs();
if (exprs.size() == 1 && exprs.front()->isA<MatmulOp>()) {
return true;
}
scheduler_debug_utils::canScheduleRejectReason(
heuristicType(),
"Fusion must contain a single expression of type MatmulOp");
return false;
}

void ExprEvalScheduler::schedule(Fusion* fusion) {
fusion->aliasOutputToInput(
fusion->outputs()[0], /*input=*/nullptr, AllocationType::Evaluate);
}

} // namespace nvfuser
49 changes: 49 additions & 0 deletions csrc/scheduler/expr_eval_sched.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#pragma once

#include <scheduler/heuristic.h>
#include <scheduler/registry.h>

namespace nvfuser {

class Fusion;
class SchedulerRuntimeInfo;
class HeuristicSummary;

// ExprEval scheduler represents the case where we allocate outputs directly
// using EE. No code is generated.
class ExprEvalScheduler : public SchedulerEntry {
public:
explicit ExprEvalScheduler(
Fusion* fusion,
SchedulerRuntimeInfo& runtime_info,
HeuristicSummary* data_cache = nullptr)
: SchedulerEntry(heuristicType()) {
params_ =
std::make_shared<HeuristicParams>("", runtime_info.getIndexType());
}

// This scheduler only accepts MatmulOp.
static bool canScheduleCompileTime(Fusion* fusion);

static bool canScheduleRunTime(
Fusion* fusion,
SchedulerRuntimeInfo& runtime_info,
HeuristicSummary* data_cache) {
return true;
}

constexpr static ScheduleHeuristic heuristicType() {
return ScheduleHeuristic::ExprEval;
}

void schedule(Fusion* fusion) override;
};

} // namespace nvfuser
17 changes: 13 additions & 4 deletions csrc/scheduler/heuristic.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,20 @@ class HeuristicParams : public PolymorphicBase {
return "Undefined Heuristic Params";
}

virtual size_t hash() const = 0;

virtual bool sameAs(const std::shared_ptr<HeuristicParams>& other) const = 0;
virtual size_t hash() const {
return 0;
};

virtual bool sameAs(const std::shared_ptr<HeuristicParams>& other) const {
if (!other->isStrictlyA<HeuristicParams>()) {
return false;
}
return other->cparams == cparams;
}

virtual std::shared_ptr<HeuristicParams> clone() const = 0;
virtual std::shared_ptr<HeuristicParams> clone() const {
return std::make_shared<HeuristicParams>();
}

HeuristicParams() = default;
HeuristicParams(std::string tag, PrimDataType index_type)
Expand Down
2 changes: 2 additions & 0 deletions csrc/scheduler/heuristic_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ std::string toString(ScheduleHeuristic sh) {
return "transpose";
case ScheduleHeuristic::Matmul:
return "matmul";
case ScheduleHeuristic::ExprEval:
return "expr_eval";
case ScheduleHeuristic::None:
return "none";
default:
Expand Down
6 changes: 4 additions & 2 deletions csrc/scheduler/heuristic_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,13 @@ enum class ScheduleHeuristic {
InnerPersistent,
InnerOuterPersistent,
OuterPersistent,
Transpose
Transpose,
ExprEval
};

//! Define a schedule table to loop over all the heuristics in priority order.
constexpr std::array<ScheduleHeuristic, 8> all_heuristics_in_priority_order = {
constexpr std::array<ScheduleHeuristic, 9> all_heuristics_in_priority_order = {
ScheduleHeuristic::ExprEval,
ScheduleHeuristic::NoOp,
ScheduleHeuristic::Matmul,
ScheduleHeuristic::Reduction,
Expand Down
Loading

0 comments on commit dfba77a

Please sign in to comment.