-
Notifications
You must be signed in to change notification settings - Fork 42
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ATen scheduler for the new
Matmul/LinearOp
IR nodes (#2209)
1. Adds a new scheduler -- `ExprEvalScheduler` that accepts the MatmulOp and LinearOp (next PR) for ATen evaluation. 2. Modify the matmul input generator to test for all cases supported by Thunder. 3. The `eagerMatmul` API is renamed and replaces the existing `matmul` API. `fd.ops.matmul` now creates a `MatmulOp` (except in a few special cases such as scalar dot product, for eg: `[K] x [K]`. Issue #2149, #2092.
- Loading branch information
Showing
15 changed files
with
211 additions
and
139 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
// clang-format off | ||
/* | ||
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. | ||
* All rights reserved. | ||
* SPDX-License-Identifier: BSD-3-Clause | ||
*/ | ||
// clang-format on | ||
|
||
#include <ir/utils.h> | ||
#include <scheduler/debug_utils.h> | ||
#include <scheduler/expr_eval_sched.h> | ||
#include <scheduler/registry_utils.h> | ||
|
||
namespace nvfuser { | ||
|
||
// Check if the fusion has a single MatmulOp node | ||
bool ExprEvalScheduler::canScheduleCompileTime(Fusion* fusion) { | ||
auto exprs = fusion->exprs(); | ||
if (exprs.size() == 1 && exprs.front()->isA<MatmulOp>()) { | ||
return true; | ||
} | ||
scheduler_debug_utils::canScheduleRejectReason( | ||
heuristicType(), | ||
"Fusion must contain a single expression of type MatmulOp"); | ||
return false; | ||
} | ||
|
||
void ExprEvalScheduler::schedule(Fusion* fusion) { | ||
fusion->aliasOutputToInput( | ||
fusion->outputs()[0], /*input=*/nullptr, AllocationType::Evaluate); | ||
} | ||
|
||
} // namespace nvfuser |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
// clang-format off | ||
/* | ||
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. | ||
* All rights reserved. | ||
* SPDX-License-Identifier: BSD-3-Clause | ||
*/ | ||
// clang-format on | ||
#pragma once | ||
|
||
#include <scheduler/heuristic.h> | ||
#include <scheduler/registry.h> | ||
|
||
namespace nvfuser { | ||
|
||
class Fusion; | ||
class SchedulerRuntimeInfo; | ||
class HeuristicSummary; | ||
|
||
// ExprEval scheduler represents the case where we allocate outputs directly | ||
// using EE. No code is generated. | ||
class ExprEvalScheduler : public SchedulerEntry { | ||
public: | ||
explicit ExprEvalScheduler( | ||
Fusion* fusion, | ||
SchedulerRuntimeInfo& runtime_info, | ||
HeuristicSummary* data_cache = nullptr) | ||
: SchedulerEntry(heuristicType()) { | ||
params_ = | ||
std::make_shared<HeuristicParams>("", runtime_info.getIndexType()); | ||
} | ||
|
||
// This scheduler only accepts MatmulOp. | ||
static bool canScheduleCompileTime(Fusion* fusion); | ||
|
||
static bool canScheduleRunTime( | ||
Fusion* fusion, | ||
SchedulerRuntimeInfo& runtime_info, | ||
HeuristicSummary* data_cache) { | ||
return true; | ||
} | ||
|
||
constexpr static ScheduleHeuristic heuristicType() { | ||
return ScheduleHeuristic::ExprEval; | ||
} | ||
|
||
void schedule(Fusion* fusion) override; | ||
}; | ||
|
||
} // namespace nvfuser |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.