diff --git a/Project.toml b/Project.toml index 114835e..a60c6cd 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Static" uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" authors = ["chriselrod", "ChrisRackauckas", "Tokazama"] -version = "0.6.2" +version = "0.6.3" [deps] IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" diff --git a/src/Static.jl b/src/Static.jl index 9aae582..8597679 100644 --- a/src/Static.jl +++ b/src/Static.jl @@ -30,20 +30,17 @@ Returns the known value corresponding to a static type `T`. If `T` is not a stat See also: [`static`](@ref), [`is_static`](@ref) """ -known -@constprop :aggressive known(x) = known(typeof(x)) -known(::Type{T}) where {T} = nothing -known(::Type{StaticInt{N}}) where {N} = N::Int -known(::Type{StaticFloat64{N}}) where {N} = N::Float64 -known(::Type{StaticSymbol{S}}) where {S} = S::Symbol +known(@nospecialize(T::Type{<:StaticInt}))::Int = T.parameters[1] +known(@nospecialize(T::Type{<:StaticFloat64}))::Float64 = T.parameters[1] +known(@nospecialize(T::Type{<:StaticSymbol}))::Symbol = T.parameters[1] known(::Type{Val{V}}) where {V} = V known(::Type{True}) = true known(::Type{False}) = false -known(::Type{NDIndex{N,I}}) where {N,I} = known(I) +known(@nospecialize(T::Type{<:NDIndex})) = known(T.parameters[2]) _get_known(::Type{T}, dim::StaticInt{D}) where {T,D} = known(field_type(T, dim)) -function known(::Type{T}) where {N,T<:Tuple{Vararg{Any,N}}} - return eachop(_get_known, nstatic(Val(N)), T) -end +known(@nospecialize(T::Type{<:Tuple})) = eachop(_get_known, nstatic(Val(fieldcount(T))), T) +known(T::DataType) = nothing +known(@nospecialize(x)) = known(typeof(x)) """ static(x) @@ -67,8 +64,7 @@ static(:x) ``` """ -static -@constprop :aggressive static(x::X) where {X} = ifelse(is_static(X), identity, _no_static_type)(x) +static(@nospecialize(x::Union{StaticInt,StaticSymbol,StaticFloat64,True,False})) = x @constprop :aggressive static(x::Int) = StaticInt(x) @constprop :aggressive static(x::Union{Int8,UInt8,Int16,UInt16}) = StaticInt(x % Int) @static if sizeof(Int) == 8 @@ -81,12 +77,9 @@ end @constprop :aggressive static(x::Bool) = StaticBool(x) @constprop :aggressive static(x::Symbol) = StaticSymbol(x) @constprop :aggressive static(x::Tuple{Vararg{Any}}) = map(static, x) -@generated static(::Val{V}) where {V} = static(V) -function _no_static_type(@nospecialize(x)) - error("There is no static alternative for type $(typeof(x)).") -end -static(x::CartesianIndex) = NDIndex(static(Tuple(x))) - +static(::Val{V}) where {V} = static(V) +static(@nospecialize(x::CartesianIndex)) = NDIndex(static(Tuple(x))) +static(x) = error("There is no static alternative for type $(typeof(x)).") """ is_static(::Type{T}) -> StaticBool @@ -96,43 +89,42 @@ Returns `True` if `T` is a static type. See also: [`static`](@ref), [`known`](@ref) """ is_static(@nospecialize(x)) = is_static(typeof(x)) -is_static(@nospecialize(x::Type{<:StaticInt})) = True() -is_static(@nospecialize(x::Type{<:StaticBool})) = True() -is_static(@nospecialize(x::Type{<:StaticSymbol})) = True() +is_static(@nospecialize(x::Type{<:Union{StaticInt,StaticSymbol,StaticFloat64,True,False}})) = True() is_static(@nospecialize(x::Type{<:Val})) = True() -is_static(@nospecialize(x::Type{<:StaticFloat64})) = True() -is_static(x::Type{T}) where {T} = False() - @constprop :aggressive _tuple_static(::Type{T}, i) where {T} = is_static(field_type(T, i)) -function is_static(::Type{T}) where {N,T<:Tuple{Vararg{Any,N}}} - if all(eachop(_tuple_static, nstatic(Val(N)), T)) +@inline function is_static(@nospecialize(T::Type{<:Tuple})) + if all(eachop(_tuple_static, nstatic(Val(fieldcount(T))), T)) return True() else return False() end end +is_static(T::DataType) = False() """ dynamic(x) Returns the "dynamic" or non-static form of `x`. """ -dynamic(x::X) where {X} = _dynamic(is_static(X), x) -_dynamic(::True, x::X) where {X} = known(X) -_dynamic(::False, x::X) where {X} = x -@constprop :aggressive dynamic(x::Tuple) = map(dynamic, x) -dynamic(x::NDIndex) = CartesianIndex(dynamic(Tuple(x))) - - -# Base.string usually shouldn't be given unique methods but primitive types are processed in -# unique ways and we can avoid some needless specialization by defining these -for T in [StaticInt, StaticFloat64, StaticBool, StaticSymbol] - @eval begin - function Base.string(@nospecialize(x::$(T)); kwargs...) - string("static(" * repr(known(typeof(x))) * ")"; kwargs...) - end - Base.show(io::IO, @nospecialize(x::$(T))) = show(io, MIME"text/plain"(), x) - Base.show(io::IO, ::MIME"text/plain", @nospecialize(x::$(T))) = print(io, string(x)) +@inline dynamic(@nospecialize(x)) = ifelse(is_static(typeof(x)), known, identity)(x) +dynamic(@nospecialize(x::Tuple)) = map(dynamic, x) +dynamic(@nospecialize(x::NDIndex)) = CartesianIndex(dynamic(Tuple(x))) + +function Base.string(@nospecialize(x::Union{StaticInt,StaticSymbol,StaticFloat64,True,False}); kwargs...) + string("static(" * repr(known(typeof(x))) * ")"; kwargs...) +end +Base.show(io::IO, @nospecialize(x::Union{StaticInt,StaticSymbol,StaticFloat64,True,False})) = show(io, MIME"text/plain"(), x) +Base.show(io::IO, ::MIME"text/plain", @nospecialize(x::Union{StaticInt,StaticSymbol,StaticFloat64,True,False})) = print(io, string(x)) + +# This method assumes that `f` uetrieves compile time information and `g` is the fall back +# for the corresponding dynamic method. If the `f(x)` doesn't return `nothing` that means +# the value is known and compile time and returns `static(f(x))`. +@inline function maybe_static(f::F, g::G, x) where {F,G} + L = f(x) + if L === nothing + return g(x) + else + return static(L) end end diff --git a/src/float.jl b/src/float.jl index 94ebe49..deea86f 100644 --- a/src/float.jl +++ b/src/float.jl @@ -13,27 +13,25 @@ end (::Type{T})(x::Integer) where {T<:StaticFloat64} = StaticFloat64(x) (::Type{T})(x::AbstractFloat) where {T<:StaticFloat64} = StaticFloat64(x) -@generated function Base.AbstractFloat(::StaticInt{N}) where {N} - Expr(:call, Expr(:curly, :StaticFloat64, Float64(N))) -end -StaticFloat64(x::StaticInt{N}) where {N} = float(x) +Base.AbstractFloat(::StaticInt{N}) where {N} = StaticFloat64{Float64(N)}() +StaticFloat64(@nospecialize(x::StaticInt)) = float(x) const FloatOne = StaticFloat64{one(Float64)} const FloatZero = StaticFloat64{zero(Float64)} -Base.convert(::Type{T}, ::StaticFloat64{N}) where {N,T<:AbstractFloat} = T(N) -Base.promote_rule(::Type{StaticFloat64{N}}, ::Type{T}) where {N,T} = promote_type(T, Float64) -Base.promote_rule(::Type{StaticFloat64{N}}, ::Type{Float64}) where {N} = Float64 -Base.promote_rule(::Type{StaticFloat64{N}}, ::Type{Float32}) where {N} = Float32 -Base.promote_rule(::Type{StaticFloat64{N}}, ::Type{Float16}) where {N} = Float16 +Base.convert(::Type{T}, @nospecialize(x::StaticFloat64)) where {T<:AbstractFloat} = T(known(x)) +Base.promote_rule(@nospecialize(T1::Type{<:StaticFloat64}), ::Type{T2}) where {T2} = promote_type(T2, Float64) +Base.promote_rule(@nospecialize(T1::Type{<:StaticFloat64}), ::Type{Float64}) = Float64 +Base.promote_rule(@nospecialize(T1::Type{<:StaticFloat64}), ::Type{Float32}) = Float32 +Base.promote_rule(@nospecialize(T1::Type{<:StaticFloat64}), ::Type{Float16}) = Float16 -Base.eltype(::Type{T}) where {T<:StaticFloat64} = Float64 +Base.eltype(@nospecialize(T::Type{<:StaticFloat64})) = Float64 Base.iszero(::FloatZero) = true -Base.iszero(::StaticFloat64) = false +Base.iszero(@nospecialize(x::StaticFloat64)) = false Base.isone(::FloatOne) = true -Base.isone(::StaticFloat64) = false -Base.zero(::Type{T}) where {T<:StaticFloat64} = FloatZero() -Base.one(::Type{T}) where {T<:StaticFloat64} = FloatOne() +Base.isone(@nospecialize(x::StaticFloat64)) = false +Base.zero(@nospecialize(x::Type{<:StaticFloat64})) = FloatZero +Base.one(@nospecialize(x::Type{<:StaticFloat64})) = FloatOne() function fsub(::StaticFloat64{X}, ::StaticFloat64{Y}) where {X,Y} return StaticFloat64{Base.sub_float(X, Y)::Float64}() @@ -51,46 +49,24 @@ function fmul(::StaticFloat64{X}, ::StaticFloat64{Y}) where {X,Y} return StaticFloat64{Base.mul_float(X, Y)::Float64}() end -Base.:+(x::StaticFloat64{X}, y::StaticFloat64{Y}) where {X,Y} = fadd(x, y) -Base.:+(x::StaticFloat64{X}, y::StaticInt{Y}) where {X,Y} = +(x, float(y)) -Base.:+(x::StaticInt{X}, y::StaticFloat64{Y}) where {X,Y} = +(float(x), y) -Base.:+(x::FloatZero, ::FloatZero) = x -Base.:+(x::StaticFloat64{X}, ::FloatZero) where {X} = x -Base.:+(::FloatZero, y::StaticFloat64{Y}) where {Y} = y -Base.:+(x::StaticFloat64{X}, ::Zero) where {X} = x -Base.:+(::Zero, y::StaticFloat64{Y}) where {Y} = y +Base.:+(@nospecialize(x::StaticFloat64), @nospecialize(y::StaticFloat64)) = fadd(x, y) +@inline Base.:+(@nospecialize(x::StaticFloat64), @nospecialize(y::StaticInt)) = +(x, float(y)) +@inline Base.:+(@nospecialize(x::StaticInt), @nospecialize(y::StaticFloat64)) = +(float(x), y) Base.:-(::StaticFloat64{X}) where {X} = StaticFloat64{-X}() -Base.:-(x::StaticFloat64{X}, y::StaticFloat64{Y}) where {X,Y} = fsub(x, y) -Base.:-(x::StaticFloat64{X}, y::StaticInt{Y}) where {X,Y} = -(x, float(y)) -Base.:-(x::StaticInt{X}, y::StaticFloat64{Y}) where {X,Y} = -(float(x), y) -Base.:-(x::FloatZero, ::FloatZero) = x -Base.:-(x::StaticFloat64{X}, ::FloatZero) where {X} = x -Base.:-(x::StaticFloat64{X}, ::Zero) where {X} = x -Base.:-(::FloatZero, y::StaticFloat64{Y}) where {Y} = -y -Base.:-(::Zero, y::StaticFloat64{Y}) where {Y} = -y - -Base.:*(x::StaticFloat64{X}, y::StaticFloat64{Y}) where {X,Y} = fmul(x, y) -Base.:*(x::StaticFloat64{X}, y::StaticInt{Y}) where {X,Y} = *(x, float(y)) -Base.:*(::StaticFloat64{X}, ::Zero) where {X} = FloatZero() -Base.:*(::Zero, ::StaticFloat64{Y}, ) where {Y} = FloatZero() -Base.:*(x::StaticFloat64{X}, ::One) where {X} = x -Base.:*(x::StaticInt{X}, y::StaticFloat64{Y}) where {X,Y} = *(float(x), y) -Base.:*(::One, y::StaticFloat64{Y}) where {Y} = y -Base.:*(x::FloatZero, ::FloatZero) = x -Base.:*(::StaticFloat64{X}, y::FloatZero) where {X} = y -Base.:*(x::FloatZero, ::StaticFloat64{Y}) where {Y} = x -Base.:*(x::FloatZero, ::FloatOne) = x -Base.:*(x::FloatOne, ::FloatOne) = x -Base.:*(x::StaticFloat64{X}, ::FloatOne) where {X} = x -Base.:*(::FloatOne, y::StaticFloat64{Y}) where {Y} = y -Base.:*(::FloatOne, y::FloatZero) = y - -Base.:/(x::StaticFloat64{X}, y::StaticFloat64{Y}) where {X,Y} = fdiv(x, y) -Base.:/(x::StaticFloat64{X}, y::StaticInt{Y}) where {X,Y} = /(x, float(y)) -Base.:/(x::StaticInt{X}, y::StaticFloat64{Y}) where {X,Y} = /(float(x), y) - -@generated Base.sqrt(::StaticInt{M}) where {M} = Expr(:call, Expr(:curly, :StaticFloat64, sqrt(M))) +Base.:-(@nospecialize(x::StaticFloat64), @nospecialize(y::StaticFloat64)) = fsub(x, y) +@inline Base.:-(@nospecialize(x::StaticFloat64), @nospecialize(y::StaticInt)) = -(x, float(y)) +@inline Base.:-(@nospecialize(x::StaticInt), @nospecialize(y::StaticFloat64)) = -(float(x), y) + +Base.:*(@nospecialize(x::StaticFloat64), @nospecialize(y::StaticFloat64)) = fmul(x, y) +@inline Base.:*(@nospecialize(x::StaticFloat64), @nospecialize(y::StaticInt)) = *(x, float(y)) +@inline Base.:*(@nospecialize(x::StaticInt), @nospecialize(y::StaticFloat64)) = *(float(x), y) + +Base.:/(@nospecialize(x::StaticFloat64), @nospecialize(y::StaticFloat64)) = fdiv(x, y) +Base.:/(@nospecialize(x::StaticFloat64), @nospecialize(y::StaticInt)) = /(x, float(y)) +Base.:/(@nospecialize(x::StaticInt), @nospecialize(y::StaticFloat64)) = /(float(x), y) + +Base.sqrt(@nospecialize(x::StaticInt)) = sqrt(float(x)) @generated Base.sqrt(::StaticFloat64{M}) where {M} = Expr(:call, Expr(:curly, :StaticFloat64, sqrt(M))) @generated Base.round(::StaticFloat64{M}) where {M} = Expr(:call, Expr(:curly, :StaticFloat64, round(M))) @@ -103,13 +79,7 @@ Base.:(^)(::StaticFloat64{x}, y::Float64) where {x} = exp2(log2(x) * y) Base.inv(x::StaticFloat64{N}) where {N} = fdiv(one(x), x) -# @generated function Base.exponent(::StaticFloat64{M}) where {M} -# Expr(:call, Expr(:curly, :StaticInt, exponent(M))) -# end - -@inline function Base.exponent(::StaticFloat64{M}) where {M} - static(exponent(M)) -end +@inline Base.exponent(::StaticFloat64{M}) where {M} = static(exponent(M)) for f in (:rad2deg, :deg2rad, :cbrt, :mod2pi, :rem2pi, :sinpi, :cospi, @@ -122,7 +92,6 @@ for f in (:rad2deg, :deg2rad, :cbrt, :sinh, :cosh, :tanh, :sech, :csch, :coth, :asinh, :acosh, :atanh, :asech, :acsch, :acoth, ) - @eval @generated function (Base.$f)(::StaticFloat64{M}) where {M} Expr(:call, Expr(:curly, :StaticFloat64, $f(M))) end diff --git a/src/int.jl b/src/int.jl index fd9a118..9091c61 100644 --- a/src/int.jl +++ b/src/int.jl @@ -14,125 +14,81 @@ const One = StaticInt{1} StaticInt(N::Int) = StaticInt{N}() StaticInt(N::Integer) = StaticInt(convert(Int, N)) -StaticInt(::StaticInt{N}) where {N} = StaticInt{N}() +StaticInt(@nospecialize(N::StaticInt)) = N StaticInt(::Val{N}) where {N} = StaticInt{N}() -# Base.Val(::StaticInt{N}) where {N} = Val{N}() -Base.convert(::Type{T}, ::StaticInt{N}) where {T<:Number,N} = convert(T, N) +Base.convert(::Type{T}, @nospecialize(N::StaticInt)) where {T<:Number} = convert(T, Int(N)) Base.Bool(x::StaticInt{N}) where {N} = Bool(N) -Base.BigInt(x::StaticInt{N}) where {N} = BigInt(N) -Base.Integer(x::StaticInt{N}) where {N} = x -(::Type{T})(x::StaticInt{N}) where {T<:Integer,N} = T(N) -(::Type{T})(x::Int) where {T<:StaticInt} = StaticInt(x) + +Base.BigInt(@nospecialize(x::StaticInt)) = BigInt(Int(x)) +Base.Integer(@nospecialize(x::StaticInt)) = x +(::Type{T})(@nospecialize(x::StaticInt)) where {T<:Integer} = T(known(x)) +(::Type{T})(x::Int) where {T<:StaticInt} = StaticInt{x}() Base.convert(::Type{StaticInt{N}}, ::StaticInt{N}) where {N} = StaticInt{N}() -Base.promote_rule(::Type{<:StaticInt}, ::Type{T}) where {T<:Number} = promote_type(Int, T) -function Base.promote_rule(::Type{<:StaticInt}, ::Type{T}) where {T<:AbstractIrrational} - return promote_type(Int, T) -end -# Base.promote_rule(::Type{T}, ::Type{<:StaticInt}) where {T <: AbstractIrrational} = promote_rule(T, Int) +Base.promote_rule(@nospecialize(T1::Type{<:StaticInt}), ::Type{T2}) where {T2<:Number} = promote_type(Int, T2) +Base.promote_rule(@nospecialize(T1::Type{<:StaticInt}), ::Type{T2}) where {T2<:AbstractIrrational} = promote_type(Int, T2) for (S, T) in [(:Complex, :Real), (:Rational, :Integer), (:(Base.TwicePrecision), :Any)] - @eval function Base.promote_rule(::Type{$S{T}}, ::Type{<:StaticInt}) where {T<:$T} - return promote_type($S{T}, Int) + @eval function Base.promote_rule(::Type{$S{T}}, @nospecialize(SI::Type{<:StaticInt})) where {T<:$T} + promote_type($S{T}, Int) end end -function Base.promote_rule(::Type{Union{Nothing,Missing}}, ::Type{<:StaticInt}) - return Union{Nothing,Missing,Int} -end -function Base.promote_rule(::Type{T}, ::Type{<:StaticInt}) where {T>:Union{Missing,Nothing}} - return promote_type(T, Int) -end -Base.promote_rule(::Type{T}, ::Type{<:StaticInt}) where {T>:Nothing} = promote_type(T, Int) -Base.promote_rule(::Type{T}, ::Type{<:StaticInt}) where {T>:Missing} = promote_type(T, Int) + +Base.promote_rule(::Type{Union{Nothing,Missing}}, @nospecialize(T::Type{<:StaticInt})) = Union{Nothing,Missing,Int} +Base.promote_rule(::Type{T1}, @nospecialize(T2::Type{<:StaticInt})) where {T1>:Union{Missing,Nothing}} = promote_type(T1, Int) +Base.promote_rule(::Type{T1}, @nospecialize(T2::Type{<:StaticInt})) where {T1>:Nothing} = promote_type(T1, Int) +Base.promote_rule(::Type{T1}, @nospecialize(T2::Type{<:StaticInt})) where {T1>:Missing} = promote_type(T1, Int) for T in [:Bool, :Missing, :BigFloat, :BigInt, :Nothing, :Any] # let S = :Any @eval begin - function Base.promote_rule(::Type{S}, ::Type{$T}) where {S<:StaticInt} - return promote_type(Int, $T) - end - function Base.promote_rule(::Type{$T}, ::Type{S}) where {S<:StaticInt} - return promote_type($T, Int) - end + Base.promote_rule(@nospecialize(S::Type{<:StaticInt}), ::Type{$T}) = promote_type(Int, $T) + Base.promote_rule(::Type{$T}, @nospecialize(S::Type{<:StaticInt})) = promote_type($T, Int) end end -Base.promote_rule(::Type{<:StaticInt}, ::Type{<:StaticInt}) = Int -Base.:(%)(::StaticInt{N}, ::Type{Integer}) where {N} = N +Base.promote_rule(@nospecialize(T1::Type{<:StaticInt}), @nospecialize(T2::Type{<:StaticInt})) = Int -Base.eltype(::Type{T}) where {T<:StaticInt} = Int +Base.:(%)(@nospecialize(n::StaticInt), ::Type{Integer}) = Int(n) + +Base.eltype(@nospecialize(T::Type{<:StaticInt})) = Int Base.iszero(::Zero) = true -Base.iszero(::StaticInt) = false +Base.iszero(@nospecialize(x::StaticInt)) = false Base.isone(::One) = true -Base.isone(::StaticInt) = false -Base.zero(::Type{T}) where {T<:StaticInt} = Zero() -Base.one(::Type{T}) where {T<:StaticInt} = One() +Base.isone(@nospecialize(x::StaticInt)) = false +Base.zero(@nospecialize(x::Type{<:StaticInt})) = Zero() +Base.one(@nospecialize(x::Type{<:StaticInt})) = One() for T in [:Real, :Rational, :Integer] - @eval begin - @inline Base.:(+)(i::$T, ::Zero) = i - @inline Base.:(+)(i::$T, ::StaticInt{M}) where {M} = i + M - @inline Base.:(+)(::Zero, i::$T) = i - @inline Base.:(+)(::StaticInt{M}, i::$T) where {M} = M + i - @inline Base.:(-)(i::$T, ::Zero) = i - @inline Base.:(-)(i::$T, ::StaticInt{M}) where {M} = i - M - @inline Base.:(*)(i::$T, ::Zero) = Zero() - @inline Base.:(*)(i::$T, ::One) = i - @inline Base.:(*)(i::$T, ::StaticInt{M}) where {M} = i * M - @inline Base.:(*)(::Zero, i::$T) = Zero() - @inline Base.:(*)(::One, i::$T) = i - @inline Base.:(*)(::StaticInt{M}, i::$T) where {M} = M * i + for f in [:(-), :(+), :(*)] + @eval begin + Base.$(f)(x::$T, @nospecialize(y::StaticInt)) = $(f)(x, Int(y)) + Base.$(f)(@nospecialize(x::StaticInt), y::$T) = $(f)(Int(x), y) + end end end -@inline Base.:(+)(::Zero, ::Zero) = Zero() -@inline Base.:(+)(::Zero, ::StaticInt{M}) where {M} = StaticInt{M}() -@inline Base.:(+)(::StaticInt{M}, ::Zero) where {M} = StaticInt{M}() - @inline Base.:(-)(::StaticInt{M}) where {M} = StaticInt{-M}() -@inline Base.:(-)(::StaticInt{M}, ::Zero) where {M} = StaticInt{M}() -@inline Base.:(*)(::Zero, ::Zero) = Zero() -@inline Base.:(*)(::One, ::Zero) = Zero() -@inline Base.:(*)(::Zero, ::One) = Zero() -@inline Base.:(*)(::One, ::One) = One() -@inline Base.:(*)(::StaticInt{M}, ::Zero) where {M} = Zero() -@inline Base.:(*)(::Zero, ::StaticInt{M}) where {M} = Zero() -@inline Base.:(*)(::StaticInt{M}, ::One) where {M} = StaticInt{M}() -@inline Base.:(*)(::One, ::StaticInt{M}) where {M} = StaticInt{M}() -for f in [:(+), :(-), :(*), :(/), :(÷), :(%), :(<<), :(>>), :(>>>), :(&), :(|), :(⊻)] - @eval @generated function Base.$f(::StaticInt{M}, ::StaticInt{N}) where {M,N} - return Expr(:call, Expr(:curly, :StaticInt, $f(M, N))) - end +for f in [:(+), :(-), :(*), :(÷), :(%), :(<<), :(>>), :(>>>), :(&), :(|), :(⊻)] + eval(:(Base.$f(::StaticInt{M}, ::StaticInt{N}) where {M,N} = StaticInt{$f(M,N)}())) end for f in [:(<<), :(>>), :(>>>)] @eval begin - @inline Base.$f(::StaticInt{M}, x::UInt) where {M} = $f(M, x) - @inline Base.$f(x::Integer, ::StaticInt{M}) where {M} = $f(x, M) + Base.$f(@nospecialize(x::StaticInt), y::UInt) = $f(Int(x), y) + Base.$f(x::Integer, @nospecialize(y::StaticInt)) = $f(x, Int(y)) end end for f in [:(==), :(!=), :(<), :(≤), :(>), :(≥)] @eval begin - @inline Base.$f(::StaticInt{M}, ::StaticInt{N}) where {M,N} = $f(M, N) - @inline Base.$f(::StaticInt{M}, x::Int) where {M} = $f(M, x) - @inline Base.$f(x::Int, ::StaticInt{M}) where {M} = $f(x, M) - end -end - -@inline function maybe_static(f::F, g::G, x) where {F,G} - L = f(x) - if L === nothing - return g(x) - else - return static(L) + Base.$f(::StaticInt{M}, ::StaticInt{N}) where {M,N} = $f(M, N) + Base.$f(@nospecialize(x::StaticInt), y::Int) = $f(Int(x), y) + Base.$f(x::Int, @nospecialize(y::StaticInt)) = $f(x, Int(y)) end end -@inline Base.widen(::StaticInt{N}) where {N} = widen(N) - -Base.UnitRange{T}(start::StaticInt, stop) where {T<:Real} = UnitRange{T}(T(start), T(stop)) -Base.UnitRange{T}(start, stop::StaticInt) where {T<:Real} = UnitRange{T}(T(start), T(stop)) -function Base.UnitRange{T}(start::StaticInt, stop::StaticInt) where {T<:Real} - return UnitRange{T}(T(start), T(stop)) -end +Base.widen(@nospecialize(x::StaticInt)) = widen(Int(x)) -Base.UnitRange(start::StaticInt, stop) = UnitRange(Int(start), stop) -Base.UnitRange(start, stop::StaticInt) = UnitRange(start, Int(stop)) -Base.UnitRange(start::StaticInt, stop::StaticInt) = UnitRange(Int(start), Int(stop)) +Base.UnitRange{T}(@nospecialize(start::StaticInt), stop) where {T<:Real} = UnitRange{T}(T(start), T(stop)) +Base.UnitRange{T}(start, @nospecialize(stop::StaticInt)) where {T<:Real} = UnitRange{T}(T(start), T(stop)) +Base.UnitRange{T}(@nospecialize(start::StaticInt), @nospecialize(stop::StaticInt)) where {T<:Real} = UnitRange{T}(T(start), T(stop)) +Base.UnitRange(@nospecialize(start::StaticInt), stop) = UnitRange(Int(start), stop) +Base.UnitRange(start, @nospecialize(stop::StaticInt)) = UnitRange(start, Int(stop)) +Base.UnitRange(@nospecialize(start::StaticInt), @nospecialize(stop::StaticInt)) = UnitRange(Int(start), Int(stop)) diff --git a/src/ndindex.jl b/src/ndindex.jl index 59292b7..ade1dee 100644 --- a/src/ndindex.jl +++ b/src/ndindex.jl @@ -55,13 +55,17 @@ _flatten(i::Base.AbstractCartesianIndex) = _flatten(Tuple(i)...) @inline function _flatten(i::Base.AbstractCartesianIndex, I...) return (_flatten(Tuple(i)...)..., _flatten(I...)...) end -Base.Tuple(index::NDIndex) = index.index +Base.Tuple(@nospecialize(x::NDIndex)) = getfield(x, :index) -Base.show(io::IO, i::NDIndex) = (print(io, "NDIndex"); show(io, Tuple(i))) +Base.show(io::IO, @nospecialize(x::NDIndex)) = show(io, MIME"text/plain"(), x) +function Base.show(io::IO, m::MIME"text/plain", @nospecialize(x::NDIndex)) + print(io, "NDIndex") + show(io, m, Tuple(x)) +end # length -Base.length(::NDIndex{N}) where {N} = N -Base.length(::Type{NDIndex{N,I}}) where {N,I} = N +Base.length(@nospecialize(x::NDIndex))::Int = length(Tuple(x)) +Base.length(@nospecialize(T::Type{<:NDIndex}))::Int = @inbounds(T.parameters[1]) # indexing @propagate_inbounds function Base.getindex(x::NDIndex{N,T}, i::Int)::Int where {N,T} @@ -77,13 +81,13 @@ end Base.setindex(x::NDIndex, i, j) = NDIndex(Base.setindex(Tuple(x), i, j)) # equality -Base.:(==)(x::NDIndex{N}, y::NDIndex{N}) where N = Tuple(x) == Tuple(y) +Base.:(==)(@nospecialize(x::NDIndex), @nospecialize(y::NDIndex)) = ==(Tuple(x), Tuple(y)) # zeros and ones -Base.zero(::NDIndex{N}) where {N} = zero(NDIndex{N}) -Base.zero(::Type{NDIndex{N}}) where {N} = NDIndex(ntuple(_ -> static(0), Val(N))) -Base.oneunit(::NDIndex{N}) where {N} = oneunit(NDIndex{N}) -Base.oneunit(::Type{NDIndex{N}}) where {N} = NDIndex(ntuple(_ -> static(1), Val(N))) +Base.zero(@nospecialize(x::NDIndex)) = zero(typeof(x)) +Base.zero(@nospecialize(T::Type{<:NDIndex})) = NDIndex(ntuple(_ -> static(0), Val(length(T)))) +Base.oneunit(@nospecialize(x::NDIndex)) = oneunit(typeof(x)) +Base.oneunit(@nospecialize(T::Type{<:NDIndex})) = NDIndex(ntuple(_ -> static(1), Val(length(T)))) @inline function Base.IteratorsMD.split(i::NDIndex, V::Val) i, j = Base.IteratorsMD.split(Tuple(i), V) @@ -91,37 +95,37 @@ Base.oneunit(::Type{NDIndex{N}}) where {N} = NDIndex(ntuple(_ -> static(1), Val( end # arithmetic, min/max -@inline Base.:(-)(i::NDIndex{N}) where {N} = NDIndex{N}(map(-, Tuple(i))) -@inline function Base.:(+)(i1::NDIndex{N}, i2::NDIndex{N}) where {N} - return NDIndex(map(+, Tuple(i1), Tuple(i2))) +@inline Base.:(-)(@nospecialize(i::NDIndex)) = NDIndex(map(-, Tuple(i))) +@inline function Base.:(+)(@nospecialize(i1::NDIndex), @nospecialize(i2::NDIndex)) + NDIndex(map(+, Tuple(i1), Tuple(i2))) end -@inline function Base.:(-)(i1::NDIndex{N}, i2::NDIndex{N}) where {N} - return NDIndex(map(-, Tuple(i1), Tuple(i2))) +@inline function Base.:(-)(@nospecialize(i1::NDIndex), @nospecialize(i2::NDIndex)) + NDIndex(map(-, Tuple(i1), Tuple(i2))) end -@inline function Base.min(i1::NDIndex{N}, i2::NDIndex{N}) where {N} - return NDIndex(map(min, Tuple(i1), Tuple(i2))) +@inline function Base.min(@nospecialize(i1::NDIndex), @nospecialize(i2::NDIndex)) + NDIndex(map(min, Tuple(i1), Tuple(i2))) end -@inline function Base.max(i1::NDIndex{N}, i2::NDIndex{N}) where {N} - return NDIndex(map(max, Tuple(i1), Tuple(i2))) +@inline function Base.max(@nospecialize(i1::NDIndex), @nospecialize(i2::NDIndex)) + NDIndex(map(max, Tuple(i1), Tuple(i2))) end -@inline Base.:(*)(a::Integer, i::NDIndex{N}) where {N} = NDIndex(map(x->a*x, Tuple(i))) -@inline Base.:(*)(i::NDIndex, a::Integer) = *(a, i) +@inline Base.:(*)(a::Integer, @nospecialize(i::NDIndex)) = NDIndex(map(x->a*x, Tuple(i))) +@inline Base.:(*)(@nospecialize(i::NDIndex), a::Integer) = *(a, i) -Base.CartesianIndex(x::NDIndex) = dynamic(x) +Base.CartesianIndex(@nospecialize(x::NDIndex)) = dynamic(x) # comparison -@inline function Base.isless(x::NDIndex{N}, y::NDIndex{N}) where {N} - return Bool(_isless(static(0), Tuple(x), Tuple(y))) +@inline function Base.isless(@nospecialize(x::NDIndex), @nospecialize(y::NDIndex)) + Bool(_isless(static(0), Tuple(x), Tuple(y))) end -lt(x::NDIndex{N}, y::NDIndex{N}) where {N} = _isless(static(0), Tuple(x), Tuple(y)) +lt(@nospecialize(x::NDIndex), @nospecialize(y::NDIndex)) = _isless(static(0), Tuple(x), Tuple(y)) _final_isless(c::Int) = c === 1 _final_isless(::StaticInt{N}) where {N} = static(false) _final_isless(::StaticInt{1}) = static(true) _isless(c::C, x::Tuple{}, y::Tuple{}) where {C} = _final_isless(c) function _isless(c::C, x::Tuple, y::Tuple) where {C} - return _isless(icmp(c, x, y), Base.front(x), Base.front(y)) + _isless(icmp(c, x, y), Base.front(x), Base.front(y)) end icmp(::StaticInt{0}, x::Tuple, y::Tuple) = icmp(last(x), last(y)) icmp(::StaticInt{N}, x::Tuple, y::Tuple) where {N} = static(N) @@ -142,7 +146,7 @@ __icmp(x::Bool) = ifelse(x, 0, -1) # In simple cases, we know that we don't need to use axes(A). Optimize those # until Julia gets smart enough to elide the call on its own: @inline function Base.to_indices(A, inds, I::Tuple{NDIndex, Vararg{Any}}) - return to_indices(A, inds, (Tuple(I[1])..., Base.tail(I)...)) + to_indices(A, inds, (Tuple(I[1])..., Base.tail(I)...)) end # But for arrays of CartesianIndex, we just skip the appropriate number of inds @inline function Base.to_indices(A, inds, I::Tuple{AbstractArray{NDIndex{N,J}}, Vararg{Any}}) where {N,J} diff --git a/src/operators.jl b/src/operators.jl index 06c04d6..41a7c00 100644 --- a/src/operators.jl +++ b/src/operators.jl @@ -5,7 +5,7 @@ Equivalent to `!=` but if `x` and `y` are both static returns a `StaticBool. """ eq(x::X, y::Y) where {X,Y} = ifelse(is_static(X) & is_static(Y), static, identity)(x == y) -eq(x::X) where {X} = Fix2(eq, x) +eq(x) = Fix2(eq, x) """ ne(x, y) @@ -13,7 +13,7 @@ eq(x::X) where {X} = Fix2(eq, x) Equivalent to `!=` but if `x` and `y` are both static returns a `StaticBool. """ ne(x::X, y::Y) where {X,Y} = !eq(x, y) -ne(x::X) where {X} = Fix2(ne, x) +ne(x) = Fix2(ne, x) """ gt(x, y) @@ -21,7 +21,7 @@ ne(x::X) where {X} = Fix2(ne, x) Equivalent to `>` but if `x` and `y` are both static returns a `StaticBool. """ gt(x::X, y::Y) where {X,Y} = ifelse(is_static(X) & is_static(Y), static, identity)(x > y) -gt(x::X) where {X} = Fix2(gt, x) +gt(x) = Fix2(gt, x) """ ge(x, y) @@ -29,7 +29,7 @@ gt(x::X) where {X} = Fix2(gt, x) Equivalent to `>=` but if `x` and `y` are both static returns a `StaticBool. """ ge(x::X, y::Y) where {X,Y} = ifelse(is_static(X) & is_static(Y), static, identity)(x >= y) -ge(x::X) where {X} = Fix2(ge, x) +ge(x) = Fix2(ge, x) """ le(x, y) @@ -37,7 +37,7 @@ ge(x::X) where {X} = Fix2(ge, x) Equivalent to `<=` but if `x` and `y` are both static returns a `StaticBool. """ le(x::X, y::Y) where {X,Y} = ifelse(is_static(X) & is_static(Y), static, identity)(x <= y) -le(x::X) where {X} = Fix2(le, x) +le(x) = Fix2(le, x) """ lt(x, y) @@ -45,7 +45,7 @@ le(x::X) where {X} = Fix2(le, x) Equivalent to `<` but if `x` and `y` are both static returns a `StaticBool. """ lt(x::X, y::Y) where {X,Y} = ifelse(is_static(X) & is_static(Y), static, identity)(x < y) -lt(x::X) where {X} = Fix2(lt, x) +lt(x) = Fix2(lt, x) """ mul(x) -> Base.Fix2(*, x) diff --git a/src/symbol.jl b/src/symbol.jl index 3a83e67..6babe37 100644 --- a/src/symbol.jl +++ b/src/symbol.jl @@ -13,13 +13,13 @@ end StaticSymbol(x, y) = StaticSymbol(Symbol(x, y)) StaticSymbol(x::StaticSymbol, y::StaticSymbol) = _cat_syms(x, y) @generated function _cat_syms(::StaticSymbol{x}, ::StaticSymbol{y}) where {x,y} - return :(StaticSymbol{$(QuoteNode(Symbol(x, y)))}()) + :(StaticSymbol{$(QuoteNode(Symbol(x, y)))}()) end StaticSymbol(x, y, z...) = StaticSymbol(StaticSymbol(x, y), z...) -Base.Symbol(::StaticSymbol{s}) where {s} = s::Symbol +Base.Symbol(@nospecialize(s::StaticSymbol)) = known(s) -Base.:(==)(::StaticSymbol{X}, ::StaticSymbol{Y}) where {X,Y} = X === Y -Base.:(==)(@nospecialize(x::StaticSymbol), y::Symbol) = dynamic(x) === y -Base.:(==)(x::Symbol, @nospecialize(y::StaticSymbol)) = x === dynamic(y) +Base.:(==)(@nospecialize(x::StaticSymbol), @nospecialize(y::StaticSymbol)) = x === y +Base.:(==)(@nospecialize(x::StaticSymbol), y::Symbol) = known(typeof(x)) === y +Base.:(==)(x::Symbol, @nospecialize(y::StaticSymbol)) = x === known(typeof(y)) diff --git a/src/tuples.jl b/src/tuples.jl index bc87f85..ba85b72 100644 --- a/src/tuples.jl +++ b/src/tuples.jl @@ -10,19 +10,17 @@ Functionally equivalent to `fieldtype(T, f)` except `f` may be a static type. @inline nstatic(::Val{N}) where {N} = ntuple(StaticInt, Val(N)) -invariant_permutation(::Any, ::Any) = False() -function invariant_permutation(x::T, y::T) where {N,T<:Tuple{Vararg{StaticInt,N}}} - if x === nstatic(Val(N)) +@inline function invariant_permutation(@nospecialize(x::Tuple), @nospecialize(y::Tuple)) + if y === x === nstatic(Val(nfields(x))) return True() else return False() end end -permute(x::Tuple, perm::Val) = permute(x, static(perm)) -permute(x::Tuple{Vararg{Any}}, perm::Tuple{Vararg{StaticInt}}) = eachop(getindex, perm, x) -function permute(x::Tuple{Vararg{Any,K}}, perm::Tuple{Vararg{StaticInt,K}}) where {K} - if invariant_permutation(perm, perm) === False() +permute(@nospecialize(x::Tuple), @nospecialize(perm::Val)) = permute(x, static(perm)) +@inline function permute(@nospecialize(x::Tuple), @nospecialize(perm::Tuple)) + if invariant_permutation(nstatic(Val(nfields(x))), perm) === False() return eachop(getindex, perm, x) else return x @@ -35,7 +33,7 @@ end Produces a tuple of `(op(args..., iterator[1]), op(args..., iterator[2]),...)`. """ @inline function eachop(op::F, itr::Tuple{T,Vararg{Any}}, args::Vararg{Any}) where {F,T} - return (op(args..., first(itr)), eachop(op, Base.tail(itr), args...)...) + (op(args..., first(itr)), eachop(op, Base.tail(itr), args...)...) end eachop(::F, ::Tuple{}, args::Vararg{Any}) where {F} = () diff --git a/test/runtests.jl b/test/runtests.jl index 97d3af7..115deb2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -359,6 +359,8 @@ using Test @test deleteat!(Union{}[], Union{}[]) == Union{}[] end + f = static(float(2)) + repr(f) @test repr(static(float(1))) == "static($(float(1)))" @test repr(static(1)) == "static(1)" @test repr(static(:x)) == "static(:x)"