Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReactantCore = "a3311ec8-5e00-46d5-b541-4f83e724a433"
Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0"
ScopedSettings = "6ffd3f19-5aa5-475d-9277-f0318686a530"
ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63"
Scratch = "6c6a2e73-6563-6170-7368-637461726353"
Sockets = "6462fe0b-24de-5631-8697-dd941f90decc"
Expand Down Expand Up @@ -49,6 +50,7 @@ YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"

[sources]
ReactantCore = {path = "lib/ReactantCore"}
ScopedSettings = {url = "https://github.com/avik-pal/ScopedSettings.jl", rev = "ap/union_types"}

[extensions]
ReactantAbstractFFTsExt = "AbstractFFTs"
Expand Down Expand Up @@ -101,6 +103,7 @@ Random = "1.10"
Random123 = "1.7"
ReactantCore = "0.1.15"
Reactant_jll = "0.0.237"
ScopedSettings = "0.1.1"
ScopedValues = "1.3.0"
Scratch = "1.2"
Sockets = "1.10"
Expand Down
5 changes: 0 additions & 5 deletions docs/src/api/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,6 @@ CollapsedDocStrings = true

## Scoped Values

!!! warning

Currently options are scattered in the form of global variables and scoped values. We
are in the process of migrating all of them into scoped values.

```@docs
Reactant.with_config
```
Expand Down
54 changes: 36 additions & 18 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,42 @@ import Reactant: OptimizeCommunicationOptions, ShardyPropagationOptions, Compile

import ..ReactantCore: correct_maybe_bcast_call

const DEBUG_PRINT_CODEGEN = Ref(false)
const DEBUG_DISABLE_RESHARDING = Ref(false)
const DEBUG_ALIASED_BUFFER_ASSIGNMENT_ERROR = Ref(false)
using ScopedSettings: ScopedSetting, GetPreference

const DEBUG_PRINT_CODEGEN = ScopedSetting(
GetPreference(Reactant, "debug_print_codegen", false)
)
const DEBUG_DISABLE_RESHARDING = ScopedSetting(
GetPreference(Reactant, "debug_disable_resharding", false)
)
const DEBUG_ALIASED_BUFFER_ASSIGNMENT_ERROR = ScopedSetting(
GetPreference(Reactant, "debug_aliased_buffer_assignment_error", false)
)
const DEBUG_KERNEL = ScopedSetting(GetPreference(Reactant, "debug_kernel", false))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these compile time attributes really don't make sense to be scoped, since they're set once per the entire compilation

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just searched through and replaced them 😓. will go through all of them once

const DUMP_LLVMIR = ScopedSetting(GetPreference(Reactant, "debug_dump_llvmir", false))
const DUMP_FAILED_LOCKSTEP = ScopedSetting(
GetPreference(Reactant, "debug_dump_failed_lockstep", false)
)
const SROA_ATTRIBUTOR = ScopedSetting(GetPreference(Reactant, "sroa_attributor", false))

const WHILE_CONCAT = ScopedSetting(GetPreference(Reactant, "while_concat_passes", false))
const DUS_TO_CONCAT = ScopedSetting(GetPreference(Reactant, "dus_to_concat_passes", false))
const SUM_TO_REDUCEWINDOW = ScopedSetting(
GetPreference(Reactant, "sum_to_reducewindow_passes", false)
)
const SUM_TO_CONV = ScopedSetting(GetPreference(Reactant, "sum_to_conv_passes", false))
const AGGRESSIVE_SUM_TO_CONV = ScopedSetting(
GetPreference(Reactant, "aggressive_sum_to_conv_passes", false)
)
const AGGRESSIVE_PROPAGATION = ScopedSetting(
GetPreference(Reactant, "aggressive_propagation_passes", false)
)
const DUS_SLICE_SIMPLIFY = ScopedSetting(
GetPreference(Reactant, "dus_slice_simplify_passes", true)
)
const CONCATS_TO_DUS = ScopedSetting(GetPreference(Reactant, "concats_to_dus_passes", true))

const OpenMP = ScopedSetting(GetPreference(Reactant, "lower_jit_to_openmp", true))

const DEBUG_BUFFER_POINTERS_STORE_DICT = Base.IdDict()

Expand Down Expand Up @@ -684,15 +717,6 @@ function create_result(
return Meta.quot(tocopy)
end

const WHILE_CONCAT = Ref(false)
const DUS_TO_CONCAT = Ref(false)
const SUM_TO_REDUCEWINDOW = Ref(false)
const SUM_TO_CONV = Ref(false)
const AGGRESSIVE_SUM_TO_CONV = Ref(false)
const AGGRESSIVE_PROPAGATION = Ref(false)
const DUS_SLICE_SIMPLIFY = Ref(true)
const CONCATS_TO_DUS = Ref(false)

# Optimization passes via transform dialect
function optimization_passes(
compile_options::CompileOptions;
Expand Down Expand Up @@ -1436,12 +1460,6 @@ function cubinFeatures()
return "+ptx$ptx"
end

const DEBUG_KERNEL = Ref{Bool}(false)
const DUMP_LLVMIR = Ref{Bool}(false)
const DUMP_FAILED_LOCKSTEP = Ref{Bool}(false)
const OpenMP = Ref{Bool}(true)
const SROA_ATTRIBUTOR = Ref{Bool}(true)

function activate_raising!(is_raising::Bool)
stack = get!(task_local_storage(), :reactant_is_raising) do
Bool[]
Expand Down
16 changes: 10 additions & 6 deletions src/Configuration.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using ScopedValues: ScopedValues, ScopedValue
using ScopedValues: ScopedValues

export with_config
export DotGeneralAlgorithmPreset, PrecisionConfig, DotGeneralAlgorithm
Expand Down Expand Up @@ -63,8 +63,12 @@ function with_config(
end

# Lower to ApproxTopK
const LOWER_PARTIALSORT_TO_APPROX_TOP_K = ScopedValue(false)
const FALLBACK_APPROX_TOP_K_LOWERING = ScopedValue(true)
const LOWER_PARTIALSORT_TO_APPROX_TOP_K = ScopedSetting(
GetPreference(Reactant, "lower_partialsort_to_approx_top_k", false)
)
const FALLBACK_APPROX_TOP_K_LOWERING = ScopedSetting(
GetPreference(Reactant, "fallback_approx_top_k_lowering", true)
)

# DotGeneral Attributes Configuration
"""
Expand All @@ -88,13 +92,13 @@ end

Base.@deprecate_binding DotGeneralPrecision PrecisionConfig

const DOT_GENERAL_PRECISION = ScopedValue{
const DOT_GENERAL_PRECISION = ScopedSetting{
Union{PrecisionConfig.T,Nothing,Tuple{PrecisionConfig.T,PrecisionConfig.T}}
}(
PrecisionConfig.DEFAULT
)

const CONVOLUTION_PRECISION = ScopedValue{
const CONVOLUTION_PRECISION = ScopedSetting{
Union{PrecisionConfig.T,Nothing,Tuple{PrecisionConfig.T,PrecisionConfig.T}}
}(
PrecisionConfig.DEFAULT
Expand Down Expand Up @@ -224,7 +228,7 @@ The following functions are available:
TF32_TF32_F32_X3
end

const DOT_GENERAL_ALGORITHM = ScopedValue{
const DOT_GENERAL_ALGORITHM = ScopedSetting{
Union{DotGeneralAlgorithmPreset.T,Nothing,DotGeneralAlgorithm}
}(
DotGeneralAlgorithmPreset.DEFAULT
Expand Down
14 changes: 10 additions & 4 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,16 @@ function mlir_type(::Type{<:MissingTracedValue})::MLIR.IR.Type
return MLIR.IR.TensorType(Int[], MLIR.IR.Type(Bool))
end

const DEBUG_MODE::Ref{Bool} = Ref(false)
const LARGE_CONSTANT_THRESHOLD = Ref(100 << 20) # 100 MiB
const LARGE_CONSTANT_RAISE_ERROR = Ref(true)
const GATHER_GETINDEX_DISABLED = Ref(false)
const DEBUG_MODE = ScopedSetting(false)
const LARGE_CONSTANT_THRESHOLD = ScopedSetting(
GetPreference(Reactant, "large_constant_threshold", 100 << 20) # 100 MiB
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
GetPreference(Reactant, "large_constant_threshold", 100 << 20) # 100 MiB
GetPreference(Reactant, "large_constant_threshold", 100 << 20), # 100 MiB

)
const LARGE_CONSTANT_RAISE_ERROR = ScopedSetting(
GetPreference(Reactant, "large_constant_raise_error", true)
)
const GATHER_GETINDEX_DISABLED = ScopedSetting(
GetPreference(Reactant, "gather_getindex_disabled", false)
)

function with_debug(f)
old = DEBUG_MODE[]
Expand Down
2 changes: 2 additions & 0 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ using Functors: Functors, @leaf
using Adapt: Adapt, WrappedArray
using GPUArraysCore: GPUArraysCore, @allowscalar, allowscalar # keep this import to allow users to do `Reactant.allowscalar(false)`

using ScopedSettings: ScopedSetting, GetPreference

export @allowscalar # re-exported from GPUArraysCore

is_extension_loaded(::Val) = false
Expand Down
2 changes: 2 additions & 0 deletions src/mlir/IR/IR.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ export @affinemap

using Random: randstring

using ScopedSettings: ScopedSetting, GetPreference

function mlirIsNull(val)
return val.ptr == C_NULL
end
Expand Down
6 changes: 4 additions & 2 deletions src/mlir/IR/Pass.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,11 @@ function enable_verifier!(pm, enable=true)
end

# Where to dump the MLIR modules
const DUMP_MLIR_DIR = Ref{Union{Nothing,String}}(nothing)
const DUMP_MLIR_DIR = ScopedSetting{Union{Nothing,String}}(
GetPreference(Reactant, "dump_mlir_dir", nothing)
)
# Whether to always dump MLIR, regardless of failure
const DUMP_MLIR_ALWAYS = Ref{Bool}(false)
const DUMP_MLIR_ALWAYS = ScopedSetting(GetPreference(Reactant, "dump_mlir_always", false))
# Counter for dumping MLIR modules
const MLIR_DUMP_COUNTER = Threads.Atomic{Int}(0)

Expand Down
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ function safe_print(name, x)
return ccall(:jl_, Cvoid, (Any,), name * " " * string(x))
end

const DEBUG_INTERP = Ref(false)
const DEBUG_INTERP = ScopedSetting(GetPreference(Reactant, "debug_interpreter", false))

# Rewrite type unstable calls to recurse into call_with_reactant to ensure
# they continue to use our interpreter. Reset the derived return type
Expand Down
Loading