diff --git a/Project.toml b/Project.toml index d5e4eee..d4f9c20 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Static" uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" authors = ["chriselrod", "ChrisRackauckas", "Tokazama"] -version = "0.2.4" +version = "0.2.5" [deps] IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" diff --git a/src/Static.jl b/src/Static.jl index f6de921..653c9fc 100644 --- a/src/Static.jl +++ b/src/Static.jl @@ -14,8 +14,10 @@ else end -include("static_implementation.jl") +include("int.jl") +include("bool.jl") include("float.jl") +include("symbol.jl") include("tuples.jl") """ @@ -85,14 +87,14 @@ Returns `True` if `T` is a static type. See also: [`static`](@ref), [`known`](@ref) """ -is_static -@aggressive_constprop is_static(x) = is_static(typeof(x)) -is_static(::Type{T}) where {T<:StaticInt} = True() -is_static(::Type{T}) where {T<:StaticBool} = True() -is_static(::Type{T}) where {T<:StaticSymbol} = True() -is_static(::Type{T}) where {T<:Val} = True() -is_static(::Type{T}) where {T} = False() -is_static(::Type{T}) where {T<:StaticFloat64} = True() +is_static(@nospecialize(x)) = is_static(typeof(x)) +is_static(@nospecialize(x::Type{<:StaticInt})) = True() +is_static(@nospecialize(x::Type{<:StaticBool})) = True() +is_static(@nospecialize(x::Type{<:StaticSymbol})) = True() +is_static(@nospecialize(x::Type{<:Val})) = True() +is_static(@nospecialize(x::Type{<:StaticFloat64})) = True() +is_static(x::Type{T}) where {T} = False() + @aggressive_constprop _tuple_static(::Type{T}, i) where {T} = is_static(_get_tuple(T, i)) function is_static(::Type{T}) where {N,T<:Tuple{Vararg{Any,N}}} if all(eachop(_tuple_static, nstatic(Val(N)), T)) diff --git a/src/bool.jl b/src/bool.jl new file mode 100644 index 0000000..b578b1e --- /dev/null +++ b/src/bool.jl @@ -0,0 +1,130 @@ + +""" + StaticBool(x::Bool) -> True/False + +A statically typed `Bool`. +""" +abstract type StaticBool{bool} <: Integer end + +struct True <: StaticBool{true} end + +struct False <: StaticBool{false} end + +StaticBool{true}() = True() +StaticBool{false}() = False() +StaticBool(x::StaticBool) = x +function StaticBool(x::Bool) + if x + return True() + else + return False() + end +end + +StaticInt(x::False) = Zero() +StaticInt(x::True) = One() +Base.Bool(::True) = true +Base.Bool(::False) = false + +Base.:(~)(::True) = False() +Base.:(~)(::False) = True() +Base.:(!)(::True) = False() +Base.:(!)(::False) = True() + +Base.:(==)(::True, ::True) = True() +Base.:(==)(::True, ::False) = False() +Base.:(==)(::False, ::True) = False() +Base.:(==)(::False, ::False) = True() + +Base.:(|)(x::StaticBool, y::StaticBool) = _or(x, y) +_or(::True, ::False) = True() +_or(::False, ::True) = True() +_or(::True, ::True) = True() +_or(::False, ::False) = False() +Base.:(|)(x::Bool, y::True) = y +Base.:(|)(x::Bool, y::False) = x +Base.:(|)(x::True, y::Bool) = x +Base.:(|)(x::False, y::Bool) = y + +Base.:(&)(x::StaticBool, y::StaticBool) = _and(x, y) +_and(::True, ::False) = False() +_and(::False, ::True) = False() +_and(::True, ::True) = True() +_and(::False, ::False) = False() +Base.:(&)(x::Bool, y::True) = x +Base.:(&)(x::Bool, y::False) = y +Base.:(&)(x::True, y::Bool) = y +Base.:(&)(x::False, y::Bool) = x + +Base.xor(y::StaticBool, x::StaticBool) = _xor(x, y) +_xor(::True, ::True) = False() +_xor(::True, ::False) = True() +_xor(::False, ::True) = True() +_xor(::False, ::False) = False() +Base.xor(x::Bool, y::StaticBool) = xor(x, Bool(y)) +Base.xor(x::StaticBool, y::Bool) = xor(Bool(x), y) + +Base.sign(x::StaticBool) = x +Base.abs(x::StaticBool) = x +Base.abs2(x::StaticBool) = x +Base.iszero(::True) = False() +Base.iszero(::False) = True() +Base.isone(::True) = True() +Base.isone(::False) = False() + +Base.:(<)(x::StaticBool, y::StaticBool) = _lt(x, y) +_lt(::False, ::True) = True() +_lt(::True, ::True) = False() +_lt(::False, ::False) = False() +_lt(::True, ::False) = False() + +Base.:(<=)(x::StaticBool, y::StaticBool) = _lteq(x, y) +_lteq(::False, ::True) = True() +_lteq(::True, ::True) = True() +_lteq(::False, ::False) = True() +_lteq(::True, ::False) = False() + +Base.:(+)(x::True) = One() +Base.:(+)(x::False) = Zero() +Base.:(-)(x::True) = -One() +Base.:(-)(x::False) = Zero() + +Base.:(+)(x::StaticBool, y::StaticBool) = StaticInt(x) + StaticInt(y) +Base.:(-)(x::StaticBool, y::StaticBool) = StaticInt(x) - StaticInt(y) +Base.:(*)(x::StaticBool, y::StaticBool) = x & y + +# from `^(x::Bool, y::Bool) = x | !y` +Base.:(^)(x::StaticBool, y::False) = True() +Base.:(^)(x::StaticBool, y::True) = x +Base.:(^)(x::Integer, y::False) = one(x) +Base.:(^)(x::Integer, y::True) = x +Base.:(^)(x::BigInt, y::False) = one(x) +Base.:(^)(x::BigInt, y::True) = x + +Base.div(x::StaticBool, y::False) = throw(DivideError()) +Base.div(x::StaticBool, y::True) = x + +Base.rem(x::StaticBool, y::False) = throw(DivideError()) +Base.rem(x::StaticBool, y::True) = False() +Base.mod(x::StaticBool, y::StaticBool) = rem(x, y) + +Base.promote_rule(::Type{<:StaticBool}, ::Type{<:StaticBool}) = StaticBool +Base.promote_rule(::Type{<:StaticBool}, ::Type{Bool}) = Bool +Base.promote_rule(::Type{Bool}, ::Type{<:StaticBool}) = Bool + +@generated _get_tuple(::Type{T}, ::StaticInt{i}) where {T<:Tuple, i} = T.parameters[i] + +Base.all(::Tuple{Vararg{True}}) = true +Base.all(::Tuple{Vararg{Union{True,False}}}) = false +Base.all(::Tuple{Vararg{False}}) = false + +Base.any(::Tuple{Vararg{True}}) = true +Base.any(::Tuple{Vararg{Union{True,False}}}) = true +Base.any(::Tuple{Vararg{False}}) = false + +ifelse(::True, x, y) = x + +ifelse(::False, x, y) = y + +Base.show(io::IO, ::StaticBool{bool}) where {bool} = print(io, "static($bool)") + diff --git a/src/static_implementation.jl b/src/int.jl similarity index 62% rename from src/static_implementation.jl rename to src/int.jl index df08f6a..33d5b39 100644 --- a/src/static_implementation.jl +++ b/src/int.jl @@ -87,7 +87,6 @@ end @inline Base.:(+)(::Zero, ::StaticInt{M}) where {M} = StaticInt{M}() @inline Base.:(+)(::StaticInt{M}, ::Zero) where {M} = StaticInt{M}() - @inline Base.:(-)(::StaticInt{M}) where {M} = StaticInt{-M}() @inline Base.:(-)(::StaticInt{M}, ::Zero) where {M} = StaticInt{M}() @@ -139,132 +138,6 @@ Base.UnitRange(start::StaticInt, stop) = UnitRange(Int(start), stop) Base.UnitRange(start, stop::StaticInt) = UnitRange(start, Int(stop)) Base.UnitRange(start::StaticInt, stop::StaticInt) = UnitRange(Int(start), Int(stop)) -""" - StaticBool(x::Bool) -> True/False - -A statically typed `Bool`. -""" -abstract type StaticBool <: Integer end - -StaticBool(x::StaticBool) = x - -struct True <: StaticBool end - -struct False <: StaticBool end - -function StaticBool(x::Bool) - if x - return True() - else - return False() - end -end - -StaticInt(x::False) = Zero() -StaticInt(x::True) = One() -Base.Bool(::True) = true -Base.Bool(::False) = false - -Base.:(~)(::True) = False() -Base.:(~)(::False) = True() -Base.:(!)(::True) = False() -Base.:(!)(::False) = True() - -Base.:(==)(::True, ::True) = True() -Base.:(==)(::True, ::False) = False() -Base.:(==)(::False, ::True) = False() -Base.:(==)(::False, ::False) = True() - -Base.:(|)(x::StaticBool, y::StaticBool) = _or(x, y) -_or(::True, ::False) = True() -_or(::False, ::True) = True() -_or(::True, ::True) = True() -_or(::False, ::False) = False() -Base.:(|)(x::Bool, y::True) = y -Base.:(|)(x::Bool, y::False) = x -Base.:(|)(x::True, y::Bool) = x -Base.:(|)(x::False, y::Bool) = y - -Base.:(&)(x::StaticBool, y::StaticBool) = _and(x, y) -_and(::True, ::False) = False() -_and(::False, ::True) = False() -_and(::True, ::True) = True() -_and(::False, ::False) = False() -Base.:(&)(x::Bool, y::True) = x -Base.:(&)(x::Bool, y::False) = y -Base.:(&)(x::True, y::Bool) = y -Base.:(&)(x::False, y::Bool) = x - -Base.xor(y::StaticBool, x::StaticBool) = _xor(x, y) -_xor(::True, ::True) = False() -_xor(::True, ::False) = True() -_xor(::False, ::True) = True() -_xor(::False, ::False) = False() -Base.xor(x::Bool, y::StaticBool) = xor(x, Bool(y)) -Base.xor(x::StaticBool, y::Bool) = xor(Bool(x), y) - -Base.sign(x::StaticBool) = x -Base.abs(x::StaticBool) = x -Base.abs2(x::StaticBool) = x -Base.iszero(::True) = False() -Base.iszero(::False) = True() -Base.isone(::True) = True() -Base.isone(::False) = False() - -Base.:(<)(x::StaticBool, y::StaticBool) = _lt(x, y) -_lt(::False, ::True) = True() -_lt(::True, ::True) = False() -_lt(::False, ::False) = False() -_lt(::True, ::False) = False() - -Base.:(<=)(x::StaticBool, y::StaticBool) = _lteq(x, y) -_lteq(::False, ::True) = True() -_lteq(::True, ::True) = True() -_lteq(::False, ::False) = True() -_lteq(::True, ::False) = False() - -Base.:(+)(x::True) = One() -Base.:(+)(x::False) = Zero() -Base.:(-)(x::True) = -One() -Base.:(-)(x::False) = Zero() - -Base.:(+)(x::StaticBool, y::StaticBool) = StaticInt(x) + StaticInt(y) -Base.:(-)(x::StaticBool, y::StaticBool) = StaticInt(x) - StaticInt(y) -Base.:(*)(x::StaticBool, y::StaticBool) = x & y - -# from `^(x::Bool, y::Bool) = x | !y` -Base.:(^)(x::StaticBool, y::False) = True() -Base.:(^)(x::StaticBool, y::True) = x -Base.:(^)(x::Integer, y::False) = one(x) -Base.:(^)(x::Integer, y::True) = x -Base.:(^)(x::BigInt, y::False) = one(x) -Base.:(^)(x::BigInt, y::True) = x - -Base.div(x::StaticBool, y::False) = throw(DivideError()) -Base.div(x::StaticBool, y::True) = x - -Base.rem(x::StaticBool, y::False) = throw(DivideError()) -Base.rem(x::StaticBool, y::True) = False() -Base.mod(x::StaticBool, y::StaticBool) = rem(x, y) - -Base.promote_rule(::Type{<:StaticBool}, ::Type{<:StaticBool}) = StaticBool -Base.promote_rule(::Type{<:StaticBool}, ::Type{Bool}) = Bool -Base.promote_rule(::Type{Bool}, ::Type{<:StaticBool}) = Bool - -@generated _get_tuple(::Type{T}, ::StaticInt{i}) where {T<:Tuple, i} = T.parameters[i] - -Base.all(::Tuple{Vararg{True}}) = true -Base.all(::Tuple{Vararg{Union{True,False}}}) = false -Base.all(::Tuple{Vararg{False}}) = false - -Base.any(::Tuple{Vararg{True}}) = true -Base.any(::Tuple{Vararg{Union{True,False}}}) = true -Base.any(::Tuple{Vararg{False}}) = false - -ifelse(::True, x, y) = x - -ifelse(::False, x, y) = y - """ eq(x, y) @@ -313,26 +186,3 @@ Equivalent to `<` but if `x` and `y` are both static returns a `StaticBool. lt(x::X, y::Y) where {X,Y} = ifelse(is_static(X) & is_static(Y), static, identity)(x < y) lt(x::X) where {X} = Base.Fix2(lt, x) -""" - StaticSymbol - -A statically typed `Symbol`. -""" -struct StaticSymbol{s} - StaticSymbol{s}() where {s} = new{s::Symbol}() - StaticSymbol(s::Symbol) = new{s}() - StaticSymbol(x::StaticSymbol) = x - StaticSymbol(x) = StaticSymbol(Symbol(x)) -end -StaticSymbol(x, y) = StaticSymbol(Symbol(x, y)) -StaticSymbol(x::StaticSymbol, y::StaticSymbol) = _cat_syms(x, y) -@generated function _cat_syms(::StaticSymbol{x}, ::StaticSymbol{y}) where {x,y} - return :(StaticSymbol{$(QuoteNode(Symbol(x, y)))}()) -end -StaticSymbol(x, y, z...) = StaticSymbol(StaticSymbol(x, y), z...) - -Base.Symbol(::StaticSymbol{s}) where {s} = s::Symbol - -Base.show(io::IO, ::StaticSymbol{s}) where {s} = print(io, "static(:$s)") - - diff --git a/src/symbol.jl b/src/symbol.jl new file mode 100644 index 0000000..05f8dd4 --- /dev/null +++ b/src/symbol.jl @@ -0,0 +1,23 @@ + +""" + StaticSymbol + +A statically typed `Symbol`. +""" +struct StaticSymbol{s} + StaticSymbol{s}() where {s} = new{s::Symbol}() + StaticSymbol(s::Symbol) = new{s}() + StaticSymbol(x::StaticSymbol) = x + StaticSymbol(x) = StaticSymbol(Symbol(x)) +end +StaticSymbol(x, y) = StaticSymbol(Symbol(x, y)) +StaticSymbol(x::StaticSymbol, y::StaticSymbol) = _cat_syms(x, y) +@generated function _cat_syms(::StaticSymbol{x}, ::StaticSymbol{y}) where {x,y} + return :(StaticSymbol{$(QuoteNode(Symbol(x, y)))}()) +end +StaticSymbol(x, y, z...) = StaticSymbol(StaticSymbol(x, y), z...) + +Base.Symbol(::StaticSymbol{s}) where {s} = s::Symbol + +Base.show(io::IO, ::StaticSymbol{s}) where {s} = print(io, "static(:$s)") + diff --git a/test/runtests.jl b/test/runtests.jl index 4ab2bd0..4c34fbb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -86,6 +86,9 @@ using Test t = static(static(true)) f = StaticBool(static(false)) + @test StaticBool{true}() === t + @test StaticBool{false}() === f + @test @inferred(StaticInt(t)) === StaticInt(1) @test @inferred(StaticInt(f)) === StaticInt(0) @@ -274,6 +277,7 @@ using Test @test repr(static(float(1))) == "static($(float(1)))" @test repr(static(1)) == "static(1)" @test repr(static(:x)) == "static(:x)" + @test repr(static(true)) == "static(true)" end # for some reason this can't be inferred when in the "Static.jl" test set