Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Static"
uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
authors = ["chriselrod", "ChrisRackauckas", "Tokazama"]
version = "0.2.4"
version = "0.2.5"

[deps]
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
Expand Down
20 changes: 11 additions & 9 deletions src/Static.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@ else
end


include("static_implementation.jl")
include("int.jl")
include("bool.jl")
include("float.jl")
include("symbol.jl")
include("tuples.jl")

"""
Expand Down Expand Up @@ -85,14 +87,14 @@ Returns `True` if `T` is a static type.

See also: [`static`](@ref), [`known`](@ref)
"""
is_static
@aggressive_constprop is_static(x) = is_static(typeof(x))
is_static(::Type{T}) where {T<:StaticInt} = True()
is_static(::Type{T}) where {T<:StaticBool} = True()
is_static(::Type{T}) where {T<:StaticSymbol} = True()
is_static(::Type{T}) where {T<:Val} = True()
is_static(::Type{T}) where {T} = False()
is_static(::Type{T}) where {T<:StaticFloat64} = True()
is_static(@nospecialize(x)) = is_static(typeof(x))
is_static(@nospecialize(x::Type{<:StaticInt})) = True()
is_static(@nospecialize(x::Type{<:StaticBool})) = True()
is_static(@nospecialize(x::Type{<:StaticSymbol})) = True()
is_static(@nospecialize(x::Type{<:Val})) = True()
is_static(@nospecialize(x::Type{<:StaticFloat64})) = True()
is_static(x::Type{T}) where {T} = False()

Comment on lines +90 to +97
Copy link
Collaborator Author

@Tokazama Tokazama May 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@timholy , If I want to avoid generating a unique method for every variant of StaticInt is this the right approach or will it kill inference somewhere down the line? I used MethodAnalysis.methodinstances and it seems to avoid make extra methods but I'm never sure with this sort of stuff

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Presumably everything still infers and optimizes fine, because the behavior doesn't actually depend on the particular types?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's the goal. I was hoping to start strategically adding these to trait functions so that we can avoid some of the pitfalls of StaticArrays.

@aggressive_constprop _tuple_static(::Type{T}, i) where {T} = is_static(_get_tuple(T, i))
function is_static(::Type{T}) where {N,T<:Tuple{Vararg{Any,N}}}
if all(eachop(_tuple_static, nstatic(Val(N)), T))
Expand Down
130 changes: 130 additions & 0 deletions src/bool.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@

"""
StaticBool(x::Bool) -> True/False

A statically typed `Bool`.
"""
abstract type StaticBool{bool} <: Integer end

struct True <: StaticBool{true} end

struct False <: StaticBool{false} end

StaticBool{true}() = True()
StaticBool{false}() = False()
StaticBool(x::StaticBool) = x
function StaticBool(x::Bool)
if x
return True()
else
return False()
end
end

StaticInt(x::False) = Zero()
StaticInt(x::True) = One()
Base.Bool(::True) = true
Base.Bool(::False) = false

Base.:(~)(::True) = False()
Base.:(~)(::False) = True()
Base.:(!)(::True) = False()
Base.:(!)(::False) = True()

Base.:(==)(::True, ::True) = True()
Base.:(==)(::True, ::False) = False()
Base.:(==)(::False, ::True) = False()
Base.:(==)(::False, ::False) = True()

Base.:(|)(x::StaticBool, y::StaticBool) = _or(x, y)
_or(::True, ::False) = True()
_or(::False, ::True) = True()
_or(::True, ::True) = True()
_or(::False, ::False) = False()
Base.:(|)(x::Bool, y::True) = y
Base.:(|)(x::Bool, y::False) = x
Base.:(|)(x::True, y::Bool) = x
Base.:(|)(x::False, y::Bool) = y

Base.:(&)(x::StaticBool, y::StaticBool) = _and(x, y)
_and(::True, ::False) = False()
_and(::False, ::True) = False()
_and(::True, ::True) = True()
_and(::False, ::False) = False()
Base.:(&)(x::Bool, y::True) = x
Base.:(&)(x::Bool, y::False) = y
Base.:(&)(x::True, y::Bool) = y
Base.:(&)(x::False, y::Bool) = x

Base.xor(y::StaticBool, x::StaticBool) = _xor(x, y)
_xor(::True, ::True) = False()
_xor(::True, ::False) = True()
_xor(::False, ::True) = True()
_xor(::False, ::False) = False()
Base.xor(x::Bool, y::StaticBool) = xor(x, Bool(y))
Base.xor(x::StaticBool, y::Bool) = xor(Bool(x), y)

Base.sign(x::StaticBool) = x
Base.abs(x::StaticBool) = x
Base.abs2(x::StaticBool) = x
Base.iszero(::True) = False()
Base.iszero(::False) = True()
Base.isone(::True) = True()
Base.isone(::False) = False()

Base.:(<)(x::StaticBool, y::StaticBool) = _lt(x, y)
_lt(::False, ::True) = True()
_lt(::True, ::True) = False()
_lt(::False, ::False) = False()
_lt(::True, ::False) = False()

Base.:(<=)(x::StaticBool, y::StaticBool) = _lteq(x, y)
_lteq(::False, ::True) = True()
_lteq(::True, ::True) = True()
_lteq(::False, ::False) = True()
_lteq(::True, ::False) = False()

Base.:(+)(x::True) = One()
Base.:(+)(x::False) = Zero()
Base.:(-)(x::True) = -One()
Base.:(-)(x::False) = Zero()

Base.:(+)(x::StaticBool, y::StaticBool) = StaticInt(x) + StaticInt(y)
Base.:(-)(x::StaticBool, y::StaticBool) = StaticInt(x) - StaticInt(y)
Base.:(*)(x::StaticBool, y::StaticBool) = x & y

# from `^(x::Bool, y::Bool) = x | !y`
Base.:(^)(x::StaticBool, y::False) = True()
Base.:(^)(x::StaticBool, y::True) = x
Base.:(^)(x::Integer, y::False) = one(x)
Base.:(^)(x::Integer, y::True) = x
Base.:(^)(x::BigInt, y::False) = one(x)
Base.:(^)(x::BigInt, y::True) = x

Base.div(x::StaticBool, y::False) = throw(DivideError())
Base.div(x::StaticBool, y::True) = x

Base.rem(x::StaticBool, y::False) = throw(DivideError())
Base.rem(x::StaticBool, y::True) = False()
Base.mod(x::StaticBool, y::StaticBool) = rem(x, y)

Base.promote_rule(::Type{<:StaticBool}, ::Type{<:StaticBool}) = StaticBool
Base.promote_rule(::Type{<:StaticBool}, ::Type{Bool}) = Bool
Base.promote_rule(::Type{Bool}, ::Type{<:StaticBool}) = Bool

@generated _get_tuple(::Type{T}, ::StaticInt{i}) where {T<:Tuple, i} = T.parameters[i]

Base.all(::Tuple{Vararg{True}}) = true
Base.all(::Tuple{Vararg{Union{True,False}}}) = false
Base.all(::Tuple{Vararg{False}}) = false

Base.any(::Tuple{Vararg{True}}) = true
Base.any(::Tuple{Vararg{Union{True,False}}}) = true
Base.any(::Tuple{Vararg{False}}) = false

ifelse(::True, x, y) = x

ifelse(::False, x, y) = y

Base.show(io::IO, ::StaticBool{bool}) where {bool} = print(io, "static($bool)")

150 changes: 0 additions & 150 deletions src/static_implementation.jl → src/int.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ end
@inline Base.:(+)(::Zero, ::StaticInt{M}) where {M} = StaticInt{M}()
@inline Base.:(+)(::StaticInt{M}, ::Zero) where {M} = StaticInt{M}()


@inline Base.:(-)(::StaticInt{M}) where {M} = StaticInt{-M}()
@inline Base.:(-)(::StaticInt{M}, ::Zero) where {M} = StaticInt{M}()

Expand Down Expand Up @@ -139,132 +138,6 @@ Base.UnitRange(start::StaticInt, stop) = UnitRange(Int(start), stop)
Base.UnitRange(start, stop::StaticInt) = UnitRange(start, Int(stop))
Base.UnitRange(start::StaticInt, stop::StaticInt) = UnitRange(Int(start), Int(stop))

"""
StaticBool(x::Bool) -> True/False

A statically typed `Bool`.
"""
abstract type StaticBool <: Integer end

StaticBool(x::StaticBool) = x

struct True <: StaticBool end

struct False <: StaticBool end

function StaticBool(x::Bool)
if x
return True()
else
return False()
end
end

StaticInt(x::False) = Zero()
StaticInt(x::True) = One()
Base.Bool(::True) = true
Base.Bool(::False) = false

Base.:(~)(::True) = False()
Base.:(~)(::False) = True()
Base.:(!)(::True) = False()
Base.:(!)(::False) = True()

Base.:(==)(::True, ::True) = True()
Base.:(==)(::True, ::False) = False()
Base.:(==)(::False, ::True) = False()
Base.:(==)(::False, ::False) = True()

Base.:(|)(x::StaticBool, y::StaticBool) = _or(x, y)
_or(::True, ::False) = True()
_or(::False, ::True) = True()
_or(::True, ::True) = True()
_or(::False, ::False) = False()
Base.:(|)(x::Bool, y::True) = y
Base.:(|)(x::Bool, y::False) = x
Base.:(|)(x::True, y::Bool) = x
Base.:(|)(x::False, y::Bool) = y

Base.:(&)(x::StaticBool, y::StaticBool) = _and(x, y)
_and(::True, ::False) = False()
_and(::False, ::True) = False()
_and(::True, ::True) = True()
_and(::False, ::False) = False()
Base.:(&)(x::Bool, y::True) = x
Base.:(&)(x::Bool, y::False) = y
Base.:(&)(x::True, y::Bool) = y
Base.:(&)(x::False, y::Bool) = x

Base.xor(y::StaticBool, x::StaticBool) = _xor(x, y)
_xor(::True, ::True) = False()
_xor(::True, ::False) = True()
_xor(::False, ::True) = True()
_xor(::False, ::False) = False()
Base.xor(x::Bool, y::StaticBool) = xor(x, Bool(y))
Base.xor(x::StaticBool, y::Bool) = xor(Bool(x), y)

Base.sign(x::StaticBool) = x
Base.abs(x::StaticBool) = x
Base.abs2(x::StaticBool) = x
Base.iszero(::True) = False()
Base.iszero(::False) = True()
Base.isone(::True) = True()
Base.isone(::False) = False()

Base.:(<)(x::StaticBool, y::StaticBool) = _lt(x, y)
_lt(::False, ::True) = True()
_lt(::True, ::True) = False()
_lt(::False, ::False) = False()
_lt(::True, ::False) = False()

Base.:(<=)(x::StaticBool, y::StaticBool) = _lteq(x, y)
_lteq(::False, ::True) = True()
_lteq(::True, ::True) = True()
_lteq(::False, ::False) = True()
_lteq(::True, ::False) = False()

Base.:(+)(x::True) = One()
Base.:(+)(x::False) = Zero()
Base.:(-)(x::True) = -One()
Base.:(-)(x::False) = Zero()

Base.:(+)(x::StaticBool, y::StaticBool) = StaticInt(x) + StaticInt(y)
Base.:(-)(x::StaticBool, y::StaticBool) = StaticInt(x) - StaticInt(y)
Base.:(*)(x::StaticBool, y::StaticBool) = x & y

# from `^(x::Bool, y::Bool) = x | !y`
Base.:(^)(x::StaticBool, y::False) = True()
Base.:(^)(x::StaticBool, y::True) = x
Base.:(^)(x::Integer, y::False) = one(x)
Base.:(^)(x::Integer, y::True) = x
Base.:(^)(x::BigInt, y::False) = one(x)
Base.:(^)(x::BigInt, y::True) = x

Base.div(x::StaticBool, y::False) = throw(DivideError())
Base.div(x::StaticBool, y::True) = x

Base.rem(x::StaticBool, y::False) = throw(DivideError())
Base.rem(x::StaticBool, y::True) = False()
Base.mod(x::StaticBool, y::StaticBool) = rem(x, y)

Base.promote_rule(::Type{<:StaticBool}, ::Type{<:StaticBool}) = StaticBool
Base.promote_rule(::Type{<:StaticBool}, ::Type{Bool}) = Bool
Base.promote_rule(::Type{Bool}, ::Type{<:StaticBool}) = Bool

@generated _get_tuple(::Type{T}, ::StaticInt{i}) where {T<:Tuple, i} = T.parameters[i]

Base.all(::Tuple{Vararg{True}}) = true
Base.all(::Tuple{Vararg{Union{True,False}}}) = false
Base.all(::Tuple{Vararg{False}}) = false

Base.any(::Tuple{Vararg{True}}) = true
Base.any(::Tuple{Vararg{Union{True,False}}}) = true
Base.any(::Tuple{Vararg{False}}) = false

ifelse(::True, x, y) = x

ifelse(::False, x, y) = y

"""
eq(x, y)

Expand Down Expand Up @@ -313,26 +186,3 @@ Equivalent to `<` but if `x` and `y` are both static returns a `StaticBool.
lt(x::X, y::Y) where {X,Y} = ifelse(is_static(X) & is_static(Y), static, identity)(x < y)
lt(x::X) where {X} = Base.Fix2(lt, x)

"""
StaticSymbol

A statically typed `Symbol`.
"""
struct StaticSymbol{s}
StaticSymbol{s}() where {s} = new{s::Symbol}()
StaticSymbol(s::Symbol) = new{s}()
StaticSymbol(x::StaticSymbol) = x
StaticSymbol(x) = StaticSymbol(Symbol(x))
end
StaticSymbol(x, y) = StaticSymbol(Symbol(x, y))
StaticSymbol(x::StaticSymbol, y::StaticSymbol) = _cat_syms(x, y)
@generated function _cat_syms(::StaticSymbol{x}, ::StaticSymbol{y}) where {x,y}
return :(StaticSymbol{$(QuoteNode(Symbol(x, y)))}())
end
StaticSymbol(x, y, z...) = StaticSymbol(StaticSymbol(x, y), z...)

Base.Symbol(::StaticSymbol{s}) where {s} = s::Symbol

Base.show(io::IO, ::StaticSymbol{s}) where {s} = print(io, "static(:$s)")


23 changes: 23 additions & 0 deletions src/symbol.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@

"""
StaticSymbol

A statically typed `Symbol`.
"""
struct StaticSymbol{s}
StaticSymbol{s}() where {s} = new{s::Symbol}()
StaticSymbol(s::Symbol) = new{s}()
StaticSymbol(x::StaticSymbol) = x
StaticSymbol(x) = StaticSymbol(Symbol(x))
end
StaticSymbol(x, y) = StaticSymbol(Symbol(x, y))
StaticSymbol(x::StaticSymbol, y::StaticSymbol) = _cat_syms(x, y)
@generated function _cat_syms(::StaticSymbol{x}, ::StaticSymbol{y}) where {x,y}
return :(StaticSymbol{$(QuoteNode(Symbol(x, y)))}())
end
StaticSymbol(x, y, z...) = StaticSymbol(StaticSymbol(x, y), z...)

Base.Symbol(::StaticSymbol{s}) where {s} = s::Symbol

Base.show(io::IO, ::StaticSymbol{s}) where {s} = print(io, "static(:$s)")

4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ using Test
t = static(static(true))
f = StaticBool(static(false))

@test StaticBool{true}() === t
@test StaticBool{false}() === f

@test @inferred(StaticInt(t)) === StaticInt(1)
@test @inferred(StaticInt(f)) === StaticInt(0)

Expand Down Expand Up @@ -274,6 +277,7 @@ using Test
@test repr(static(float(1))) == "static($(float(1)))"
@test repr(static(1)) == "static(1)"
@test repr(static(:x)) == "static(:x)"
@test repr(static(true)) == "static(true)"
end

# for some reason this can't be inferred when in the "Static.jl" test set
Expand Down