Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benchmarks/cpp/heuristic_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ static void NvFuserScheduler_LayerNormBackward_HeuristicCache(
auto runtime = getLayerBackwardNormRuntime(
std::move(fusion_ptr), executor_cache, args, shape, norm_shape);

NVF_ERROR(runtime->getMaybeHeuristicsFor(args).has_value());
NVF_ERROR(runtime->getMaybeHeuristicsFor(args) != nullptr);

for (auto _ : benchmark_state) {
// Setup (not included in the measurement)
Expand All @@ -62,7 +62,7 @@ static void NvFuserScheduler_LayerNormForward_HeuristicCache(
auto runtime = getLayerForwardNormRuntime(
std::move(fusion_ptr), executor_cache, args, shape, norm_shape);

NVF_ERROR(runtime->getMaybeHeuristicsFor(args).has_value());
NVF_ERROR(runtime->getMaybeHeuristicsFor(args) != nullptr);

for (auto _ : benchmark_state) {
// Setup (not included in the measurement)
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/cpp/heuristic_lookup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ static void NvFuserScheduler_LayerNormBackward_HeuristicLookup(
auto runtime = getLayerBackwardNormRuntime(
std::move(fusion_ptr), executor_cache, args, shape, norm_shape);

NVF_ERROR(runtime->getMaybeHeuristicsFor(args).has_value());
NVF_ERROR(runtime->getMaybeHeuristicsFor(args) != nullptr);

for (auto _ : benchmark_state) {
// Setup (not included in the measurement)
Expand All @@ -62,7 +62,7 @@ static void NvFuserScheduler_LayerNormForward_HeuristicLookup(
auto runtime = getLayerForwardNormRuntime(
std::move(fusion_ptr), executor_cache, args, shape, norm_shape);

NVF_ERROR(runtime->getMaybeHeuristicsFor(args).has_value());
NVF_ERROR(runtime->getMaybeHeuristicsFor(args) != nullptr);

for (auto _ : benchmark_state) {
// Setup (not included in the measurement)
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/cpp/shape_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ void LayerNormBackward_ShapeInference_Base(
auto runtime = getLayerBackwardNormRuntime(
std::move(fusion_ptr), executor_cache, args, shape, norm_shape);

NVF_ERROR(runtime->getMaybeHeuristicsFor(args).has_value());
NVF_ERROR(runtime->getMaybeHeuristicsFor(args) != nullptr);

executor_cache->profile(true);
executor_cache->disableKernelLaunch();
Expand Down Expand Up @@ -81,7 +81,7 @@ void LayerNormForward_ShapeInferenceBase(
auto runtime = getLayerForwardNormRuntime(
std::move(fusion_ptr), executor_cache, args, shape, norm_shape);

NVF_ERROR(runtime->getMaybeHeuristicsFor(args).has_value());
NVF_ERROR(runtime->getMaybeHeuristicsFor(args) != nullptr);

executor_cache->profile(true);
executor_cache->disableKernelLaunch();
Expand Down
23 changes: 22 additions & 1 deletion csrc/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
#include <C++23/utility>

//! IR header hierarchy
//! 1. ** utils.h ** - PolymorphicBase and NonCopyable
//! 1. ** base.h ** - PolymorphicBase and NonCopyable
//! 2. ir/base_nodes.h - Statement, Expr, and Val
//! 3. ir/internal_base_nodes.h - IterDomain and TensorDomain
//! 4. ir/interface_nodes.h - TensorView and Scalar
Expand Down Expand Up @@ -113,6 +113,27 @@ constexpr int64_t alignSharedMemoryBytes(int64_t unaligned_bytes) {
return (unaligned_bytes + (alignment - 1)) & (~(alignment - 1));
}

//! Returns the value of an optional, or throws via NVF_ERROR if nullopt. This
//! is to satisfy clang-tidy bugprone-unchecked-optional-access. Use this when
//! you have already ensured that the optional is engaged.
template <typename T>
const T& valueOrError(const std::optional<T>& opt) {
NVF_ERROR(opt.has_value());
return *opt;
}
template <typename T>
T& valueOrError(std::optional<T>& opt) {
NVF_ERROR(opt.has_value());
return *opt;
}
template <typename T>
T valueOrError(std::optional<T>&& opt) {
NVF_ERROR(opt.has_value());
// Function arguments are lvalues, so `*opt` is an lvalue and we need to
// std::move it.
return std::move(*opt);
}

//! Simple mixin for suppressing copy & move operations, ex:
//!
//! class Foo : public NonCopyable {
Expand Down
7 changes: 2 additions & 5 deletions csrc/device_lower/pass/allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -441,9 +441,7 @@ class AllocationDomainSetup : private kir::IrVisitor {
const std::optional<bool> contig_flag = contiguity.at(dim);
// Broadcast doesn't have contig flag but it must have been
// already filtered out
NVF_ERROR(contig_flag.has_value());

if (contig_flag.value()) {
if (valueOrError(contig_flag)) {
strides[dim] = cur_contig_stride;
cur_contig_stride = SimplifyingIrBuilder::mulExpr(
cur_contig_stride, promotion_domain->extent());
Expand Down Expand Up @@ -472,8 +470,7 @@ class AllocationDomainSetup : private kir::IrVisitor {
actual_allocation_domains.push_back(promotion_domain);
actual_strides.push_back(stride);
auto contig = contiguity.at(i);
NVF_ERROR(contig.has_value());
actual_contiguity.push_back(contig.value());
actual_contiguity.push_back(valueOrError(contig));
}

NVF_ERROR(actual_allocation_domains.size() == actual_strides.size());
Expand Down
8 changes: 4 additions & 4 deletions csrc/device_lower/pass/grid_serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,18 +123,18 @@ class GridSerializationSyncInserter : kir::ExprMutator {

void insertSyncs() {
NVF_ERROR(cur_top_level_expr_ != nullptr);
NVF_ERROR(cur_expr_sync_pattern_.has_value());
auto sync_pattern = valueOrError(cur_expr_sync_pattern_);
kir::Allocate* alloc = lower_utils::allocGlobalBufferForGridComm(
lower_utils::getGridSyncBufferSize(cur_expr_sync_pattern_.value()),
lower_utils::getGridSyncBufferSize(sync_pattern),
DataType::Int,
/*zero_init=*/true,
/*resets_to_zero=*/true);
auto wait = IrBuilder::create<kir::BlockSerializeWait>(
cur_expr_sync_pattern_.value(), alloc->buffer());
sync_pattern, alloc->buffer());
registerInsertBefore(cur_top_level_expr_, alloc);
registerInsertBefore(cur_top_level_expr_, wait);
auto release = IrBuilder::create<kir::BlockSerializeRelease>(
cur_expr_sync_pattern_.value(), alloc->buffer());
sync_pattern, alloc->buffer());
registerInsertAfter(cur_top_level_expr_, release);
}

Expand Down
6 changes: 3 additions & 3 deletions csrc/fusion_segmenter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2628,8 +2628,8 @@ std::vector<Expr*> SegmentedGroup::stablyOrderedExprs() const {
return ordered_exprs;
}

std::optional<std::unique_ptr<HeuristicParams>> SegmentedGroup::
getMaybeHeuristicParams(SchedulerRuntimeInfo& runtime_info) {
std::unique_ptr<HeuristicParams> SegmentedGroup::getMaybeHeuristicParams(
SchedulerRuntimeInfo& runtime_info) {
FUSER_PERF_SCOPE("SegmentedFusion::getMaybeHeuristicParams");
auto heuristic_data_cache =
segmented_fusion_->getCachedHeuristicDataFor(this);
Expand All @@ -2639,7 +2639,7 @@ std::optional<std::unique_ptr<HeuristicParams>> SegmentedGroup::
runtime_info,
heuristic_data_cache,
/*skip_compile_time_checks=*/true)) {
return std::nullopt;
return nullptr;
}
return SchedulerEntry::makeSchedulerInstance(schedulerType())
->computeHeuristics(
Expand Down
4 changes: 2 additions & 2 deletions csrc/fusion_segmenter.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,9 @@ class SegmentedGroup {
//! Returns a new scheduler with the same heuristics
//! for this group if possible.
//! Note that the schedule params can be different.
//! Returns a nullopt if this group cannot be scheduled
//! Returns nullptr if this group cannot be scheduled
//! with the same heuristics.
std::optional<std::unique_ptr<HeuristicParams>> getMaybeHeuristicParams(
std::unique_ptr<HeuristicParams> getMaybeHeuristicParams(
SchedulerRuntimeInfo& runtime_info);

//! Get the SegmentedFusion this group belongs to
Expand Down
4 changes: 2 additions & 2 deletions csrc/host_ir/evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ KernelArgumentHolder HostIrEvaluator::runWithInputs(
FUSER_PERF_SCOPE("HostIrEvaluator::runWithInputs");
expr_evaluator_ = ExpressionEvaluator();
expr_evaluator_.bind("numberOfStreams", params_.number_of_streams);
NVF_ERROR(args.getCacheId().has_value());
expr_evaluator_.bind("cacheId", static_cast<int64_t>(*args.getCacheId()));
expr_evaluator_.bind(
"cacheId", static_cast<int64_t>(valueOrError(args.getCacheId())));

NVF_ERROR_EQ(std::ssize(container_->inputs()), args.size());
for (auto&& [in_val, arg] : zip(container_->inputs(), args)) {
Expand Down
7 changes: 2 additions & 5 deletions csrc/host_ir/jit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -934,7 +934,7 @@ KernelArgumentHolder HostIrJitImpl::runWithInputs(
const KernelArgumentHolder& args) {
FUSER_PERF_SCOPE("HostIrJitImpl::runWithInputs");
// Bind cache id to llvm global variable or align with main function inputs
NVF_ERROR(args.getCacheId().has_value(), "Cache ID is not set");
size_t cache_id = valueOrError(args.getCacheId());
NVF_ERROR_EQ(std::ssize(container_->inputs()), args.size());

std::unordered_set<const at::Tensor*> preserved_tensors;
Expand All @@ -959,10 +959,7 @@ KernelArgumentHolder HostIrJitImpl::runWithInputs(

// Run the main function
std::vector<void*> output_aten_tensors(container_->outputs().size());
main_func_(
args.getCacheId().value(),
input_aten_tensors.data(),
output_aten_tensors.data());
main_func_(cache_id, input_aten_tensors.data(), output_aten_tensors.data());

// Collect the outputs
KernelArgumentHolder outputs;
Expand Down
10 changes: 3 additions & 7 deletions csrc/host_ir/lower_to_communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -496,9 +496,7 @@ Layout getCommunicationLayout(
TensorView* tv,
const CommunicationType type,
IterDomain* sharded_id) {
std::optional<Layout> canonical_layout = canonicalizeLayout(tv);
NVF_ERROR(canonical_layout.has_value());
Layout layout = canonical_layout->contiguous();
Layout layout = valueOrError(canonicalizeLayout(tv)).contiguous();
// For the following communication types, the sharded_id does not have to be
// outermost in allocation domain. Nonetheless, `tv` still needs to be
// contiguous and therefore .contiguous() at the beginning of this function.
Expand Down Expand Up @@ -548,9 +546,8 @@ bool isCommunicationLayoutCompliant(Expr* e) {

auto* producer = e->inputs().at(0)->as<TensorView>();
std::optional<Layout> p_layout = canonicalizeLayout(producer);
NVF_ERROR(p_layout.has_value());
if (!isCompliantWith(
*p_layout,
valueOrError(p_layout),
getCommunicationLayout(
producer,
communication_info.type,
Expand All @@ -560,9 +557,8 @@ bool isCommunicationLayoutCompliant(Expr* e) {

auto* consumer = e->outputs().at(0)->as<TensorView>();
std::optional<Layout> c_layout = canonicalizeLayout(consumer);
NVF_ERROR(c_layout.has_value());
if (!isCompliantWith(
*c_layout,
valueOrError(c_layout),
getCommunicationLayout(
consumer,
communication_info.type,
Expand Down
5 changes: 2 additions & 3 deletions csrc/ir/internal_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1053,12 +1053,11 @@ ArrayConstruct::ArrayConstruct(
for (auto in : inputs) {
addInput(in);
auto in_dtype_opt = in->getDataType();
NVF_ERROR(in_dtype_opt.has_value());
if (input_dtype == DataType::Null) {
input_dtype = *in_dtype_opt;
input_dtype = valueOrError(in_dtype_opt);
} else {
NVF_CHECK(
input_dtype == *in_dtype_opt,
input_dtype == valueOrError(in_dtype_opt),
"All inputs to ArrayConstruct must have the same data type");
}
}
Expand Down
6 changes: 0 additions & 6 deletions csrc/ops/arith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,6 @@

namespace nvfuser {

template <typename T>
T valueOrError(std::optional<T> opt) {
NVF_CHECK(opt.has_value());
return *opt;
}

Val* castOp(DataType dtype, Val* v1) {
auto orig_dtype = valueOrError(v1->getDataType());
if (dtype == orig_dtype) {
Expand Down
3 changes: 1 addition & 2 deletions csrc/parallel_dimension_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,7 @@ bool ParallelDimensionMap::canUseElectSyncInAsyncWarp() const {
return true;
}
// Currently only support one warp specialized axis
NVF_ERROR(warp_specialized_parallel_type_.has_value());
ParallelType ws_pt = warp_specialized_parallel_type_.value();
ParallelType ws_pt = valueOrError(warp_specialized_parallel_type_);

// Check that BlockDim.x >= 32 active threads in AsyncWarp
if (ws_pt != ParallelType::TIDx) {
Expand Down
3 changes: 1 addition & 2 deletions csrc/runtime/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,12 +259,11 @@ void KernelExecutor::compile(
if (!args.empty()) {
auto expr_eval =
executor_utils::bindInputs(args, compiled_kernel_->lowered()->kernel());
NVF_ERROR(compile_params.index_type.has_value());
launch_params = computeLaunchParams(
launch_constraints,
expr_eval,
warp_size_,
compile_params.index_type.value());
valueOrError(compile_params.index_type));
block_size = launch_params.nThreads();
dynamic_smem = launch_params.smem();
NVF_ERROR_GT(*block_size, 0);
Expand Down
8 changes: 2 additions & 6 deletions csrc/runtime/fusion_executor_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -617,13 +617,9 @@ FusionKernelRuntime* FusionExecutorCache::getKernelRuntimeFor(
kernel_runtimes.begin(),
kernel_runtimes.end(),
[&args, &new_heuristics, &forced_index_type](auto& kernel_runtime) {
auto maybe_heuristics =
new_heuristics =
kernel_runtime->getMaybeHeuristicsFor(args, forced_index_type);
if (!maybe_heuristics.has_value()) {
return false;
}
new_heuristics = std::move(maybe_heuristics.value());
return true;
return new_heuristics != nullptr;
});
if (runtime_it != kernel_runtimes.end()) {
kernel_runtime = runtime_it->get();
Expand Down
28 changes: 12 additions & 16 deletions csrc/runtime/fusion_kernel_runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,8 @@ FusionKernelRuntime::FusionKernelRuntime(
is_segmented_ = segmented_fusion_->groups().size() > 1;

// Create Initial Heuristics for Segmented Fusion
auto maybe_heuristics = getMaybeHeuristicsFor(args, forced_index_type);
NVF_CHECK(maybe_heuristics.has_value());
heuristics_ = std::move(maybe_heuristics.value());
heuristics_ = getMaybeHeuristicsFor(args, forced_index_type);
NVF_ERROR(heuristics_ != nullptr);
}

void FusionKernelRuntime::evictCache(size_t input_id) {
Expand Down Expand Up @@ -277,8 +276,7 @@ PrimDataType FusionKernelRuntime::getIndexType() const {
return PrimDataType::Int;
}
auto index_type = schedulers().at(0).get()->cparams.index_type;
NVF_ERROR(index_type.has_value());
return index_type.value();
return valueOrError(index_type);
}

KernelArgumentHolder FusionKernelRuntime::runWithInputs(
Expand Down Expand Up @@ -535,10 +533,9 @@ const ExecutorLog& FusionKernelRuntime::getMostRecentExecutorLog() const {
return most_recent_executor_log_;
}

std::optional<std::unique_ptr<HeuristicParamsList>> FusionKernelRuntime::
getMaybeHeuristicsFor(
const KernelArgumentHolder& args,
std::optional<PrimDataType> forced_index_type) {
std::unique_ptr<HeuristicParamsList> FusionKernelRuntime::getMaybeHeuristicsFor(
const KernelArgumentHolder& args,
std::optional<PrimDataType> forced_index_type) {
FUSER_PERF_SCOPE("FusionKernelRuntime::getMaybeHeuristicsFor");

// The runtime group run order is different from the segmented_fusion group
Expand Down Expand Up @@ -606,18 +603,17 @@ std::optional<std::unique_ptr<HeuristicParamsList>> FusionKernelRuntime::
// canScheduleRuntime, but it is safe to skip canScheduleCompileTime. We
// skip it here to avoid performing expensive fusion traversals on the
// dynamic shape path.
auto maybe_heuristic_params =
auto heuristic_params =
group_to_run->getMaybeHeuristicParams(fusion_to_run_info);
// If unavailable, then return std::nullopt
if (!maybe_heuristic_params.has_value()) {
return std::nullopt;
// If unavailable, then return nullptr
if (!heuristic_params) {
return nullptr;
}
// Check if this scheduler entry matches the previous entry for this
// segmented group. If no match, then return std::nullptr
auto heuristic_params = std::move(maybe_heuristic_params.value());
// segmented group. If no match, then return nullptr
if (!heuristic_params->sameAs(
heuristics_->at(group_to_run->groupId()).get())) {
return std::nullopt;
return nullptr;
}
// Add new scheduler entry for this segmented group
heuristics->at(group_to_run->groupId()) = std::move(heuristic_params);
Expand Down
6 changes: 3 additions & 3 deletions csrc/runtime/fusion_kernel_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,11 @@ class FusionKernelRuntime {
const ExecutorLog& getMostRecentExecutorLog() const;

// Try to compute heuristics based on the SegmentedFusion managed
// in this kernel runtime, and will return a nullopt if either
// any segment cannot be scheduled or the parameters don't match
// in this kernel runtime, and will return nullptr if either
// any segment cannot be scheduled or the parameters don't match.
//
// Heuristics must use the index type of forced_index_type if given.
std::optional<std::unique_ptr<HeuristicParamsList>> getMaybeHeuristicsFor(
std::unique_ptr<HeuristicParamsList> getMaybeHeuristicsFor(
const KernelArgumentHolder& args,
std::optional<PrimDataType> forced_index_type = std::nullopt);

Expand Down
6 changes: 2 additions & 4 deletions csrc/scheduler/matmul_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -840,8 +840,7 @@ class VectorizationCalculator {
}
// Record contiguity of concrete dimensions
std::optional<bool> contig_opt = tv->getContiguity().at(i);
NVF_ERROR(contig_opt.has_value());
concrete_contig.push_back(contig_opt.value());
concrete_contig.push_back(valueOrError(contig_opt));

PolymorphicValue ext =
runtime_info_.expressionEvaluator().evaluate(id->extent());
Expand Down Expand Up @@ -897,8 +896,7 @@ class VectorizationCalculator {
remaining_inner_dims.pop_back();

std::optional<bool> c = tv->getContiguity().at(i);
NVF_ERROR(c.has_value());
if (!c.value()) {
if (!valueOrError(c)) {
// axis is marked discontiguous; can't vectorize
break;
} else {
Expand Down
Loading
Loading