Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Static"
uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
authors = ["chriselrod", "ChrisRackauckas", "Tokazama"]
version = "0.6.2"
version = "0.6.3"

[deps]
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
Expand Down
76 changes: 34 additions & 42 deletions src/Static.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,17 @@ Returns the known value corresponding to a static type `T`. If `T` is not a stat

See also: [`static`](@ref), [`is_static`](@ref)
"""
known
@constprop :aggressive known(x) = known(typeof(x))
known(::Type{T}) where {T} = nothing
known(::Type{StaticInt{N}}) where {N} = N::Int
known(::Type{StaticFloat64{N}}) where {N} = N::Float64
known(::Type{StaticSymbol{S}}) where {S} = S::Symbol
known(@nospecialize(T::Type{<:StaticInt}))::Int = T.parameters[1]
known(@nospecialize(T::Type{<:StaticFloat64}))::Float64 = T.parameters[1]
known(@nospecialize(T::Type{<:StaticSymbol}))::Symbol = T.parameters[1]
known(::Type{Val{V}}) where {V} = V
known(::Type{True}) = true
known(::Type{False}) = false
known(::Type{NDIndex{N,I}}) where {N,I} = known(I)
known(@nospecialize(T::Type{<:NDIndex})) = known(T.parameters[2])
_get_known(::Type{T}, dim::StaticInt{D}) where {T,D} = known(field_type(T, dim))
function known(::Type{T}) where {N,T<:Tuple{Vararg{Any,N}}}
return eachop(_get_known, nstatic(Val(N)), T)
end
known(@nospecialize(T::Type{<:Tuple})) = eachop(_get_known, nstatic(Val(fieldcount(T))), T)
known(T::DataType) = nothing
known(@nospecialize(x)) = known(typeof(x))

"""
static(x)
Expand All @@ -67,8 +64,7 @@ static(:x)

```
"""
static
@constprop :aggressive static(x::X) where {X} = ifelse(is_static(X), identity, _no_static_type)(x)
static(@nospecialize(x::Union{StaticInt,StaticSymbol,StaticFloat64,True,False})) = x
@constprop :aggressive static(x::Int) = StaticInt(x)
@constprop :aggressive static(x::Union{Int8,UInt8,Int16,UInt16}) = StaticInt(x % Int)
@static if sizeof(Int) == 8
Expand All @@ -81,12 +77,9 @@ end
@constprop :aggressive static(x::Bool) = StaticBool(x)
@constprop :aggressive static(x::Symbol) = StaticSymbol(x)
@constprop :aggressive static(x::Tuple{Vararg{Any}}) = map(static, x)
@generated static(::Val{V}) where {V} = static(V)
function _no_static_type(@nospecialize(x))
error("There is no static alternative for type $(typeof(x)).")
end
static(x::CartesianIndex) = NDIndex(static(Tuple(x)))

static(::Val{V}) where {V} = static(V)
static(@nospecialize(x::CartesianIndex)) = NDIndex(static(Tuple(x)))
static(x) = error("There is no static alternative for type $(typeof(x)).")

"""
is_static(::Type{T}) -> StaticBool
Expand All @@ -96,43 +89,42 @@ Returns `True` if `T` is a static type.
See also: [`static`](@ref), [`known`](@ref)
"""
is_static(@nospecialize(x)) = is_static(typeof(x))
is_static(@nospecialize(x::Type{<:StaticInt})) = True()
is_static(@nospecialize(x::Type{<:StaticBool})) = True()
is_static(@nospecialize(x::Type{<:StaticSymbol})) = True()
is_static(@nospecialize(x::Type{<:Union{StaticInt,StaticSymbol,StaticFloat64,True,False}})) = True()
is_static(@nospecialize(x::Type{<:Val})) = True()
is_static(@nospecialize(x::Type{<:StaticFloat64})) = True()
is_static(x::Type{T}) where {T} = False()

@constprop :aggressive _tuple_static(::Type{T}, i) where {T} = is_static(field_type(T, i))
function is_static(::Type{T}) where {N,T<:Tuple{Vararg{Any,N}}}
if all(eachop(_tuple_static, nstatic(Val(N)), T))
@inline function is_static(@nospecialize(T::Type{<:Tuple}))
if all(eachop(_tuple_static, nstatic(Val(fieldcount(T))), T))
return True()
else
return False()
end
end
is_static(T::DataType) = False()

"""
dynamic(x)

Returns the "dynamic" or non-static form of `x`.
"""
dynamic(x::X) where {X} = _dynamic(is_static(X), x)
_dynamic(::True, x::X) where {X} = known(X)
_dynamic(::False, x::X) where {X} = x
@constprop :aggressive dynamic(x::Tuple) = map(dynamic, x)
dynamic(x::NDIndex) = CartesianIndex(dynamic(Tuple(x)))


# Base.string usually shouldn't be given unique methods but primitive types are processed in
# unique ways and we can avoid some needless specialization by defining these
for T in [StaticInt, StaticFloat64, StaticBool, StaticSymbol]
@eval begin
function Base.string(@nospecialize(x::$(T)); kwargs...)
string("static(" * repr(known(typeof(x))) * ")"; kwargs...)
end
Base.show(io::IO, @nospecialize(x::$(T))) = show(io, MIME"text/plain"(), x)
Base.show(io::IO, ::MIME"text/plain", @nospecialize(x::$(T))) = print(io, string(x))
@inline dynamic(@nospecialize(x)) = ifelse(is_static(typeof(x)), known, identity)(x)
dynamic(@nospecialize(x::Tuple)) = map(dynamic, x)
dynamic(@nospecialize(x::NDIndex)) = CartesianIndex(dynamic(Tuple(x)))

function Base.string(@nospecialize(x::Union{StaticInt,StaticSymbol,StaticFloat64,True,False}); kwargs...)
string("static(" * repr(known(typeof(x))) * ")"; kwargs...)
end
Base.show(io::IO, @nospecialize(x::Union{StaticInt,StaticSymbol,StaticFloat64,True,False})) = show(io, MIME"text/plain"(), x)
Base.show(io::IO, ::MIME"text/plain", @nospecialize(x::Union{StaticInt,StaticSymbol,StaticFloat64,True,False})) = print(io, string(x))

# This method assumes that `f` uetrieves compile time information and `g` is the fall back
# for the corresponding dynamic method. If the `f(x)` doesn't return `nothing` that means
# the value is known and compile time and returns `static(f(x))`.
@inline function maybe_static(f::F, g::G, x) where {F,G}
L = f(x)
if L === nothing
return g(x)
else
return static(L)
end
end

Expand Down
89 changes: 29 additions & 60 deletions src/float.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,25 @@ end

(::Type{T})(x::Integer) where {T<:StaticFloat64} = StaticFloat64(x)
(::Type{T})(x::AbstractFloat) where {T<:StaticFloat64} = StaticFloat64(x)
@generated function Base.AbstractFloat(::StaticInt{N}) where {N}
Expr(:call, Expr(:curly, :StaticFloat64, Float64(N)))
end
StaticFloat64(x::StaticInt{N}) where {N} = float(x)
Base.AbstractFloat(::StaticInt{N}) where {N} = StaticFloat64{Float64(N)}()
StaticFloat64(@nospecialize(x::StaticInt)) = float(x)

const FloatOne = StaticFloat64{one(Float64)}
const FloatZero = StaticFloat64{zero(Float64)}

Base.convert(::Type{T}, ::StaticFloat64{N}) where {N,T<:AbstractFloat} = T(N)
Base.promote_rule(::Type{StaticFloat64{N}}, ::Type{T}) where {N,T} = promote_type(T, Float64)
Base.promote_rule(::Type{StaticFloat64{N}}, ::Type{Float64}) where {N} = Float64
Base.promote_rule(::Type{StaticFloat64{N}}, ::Type{Float32}) where {N} = Float32
Base.promote_rule(::Type{StaticFloat64{N}}, ::Type{Float16}) where {N} = Float16
Base.convert(::Type{T}, @nospecialize(x::StaticFloat64)) where {T<:AbstractFloat} = T(known(x))
Base.promote_rule(@nospecialize(T1::Type{<:StaticFloat64}), ::Type{T2}) where {T2} = promote_type(T2, Float64)
Base.promote_rule(@nospecialize(T1::Type{<:StaticFloat64}), ::Type{Float64}) = Float64
Base.promote_rule(@nospecialize(T1::Type{<:StaticFloat64}), ::Type{Float32}) = Float32
Base.promote_rule(@nospecialize(T1::Type{<:StaticFloat64}), ::Type{Float16}) = Float16

Base.eltype(::Type{T}) where {T<:StaticFloat64} = Float64
Base.eltype(@nospecialize(T::Type{<:StaticFloat64})) = Float64
Base.iszero(::FloatZero) = true
Base.iszero(::StaticFloat64) = false
Base.iszero(@nospecialize(x::StaticFloat64)) = false
Base.isone(::FloatOne) = true
Base.isone(::StaticFloat64) = false
Base.zero(::Type{T}) where {T<:StaticFloat64} = FloatZero()
Base.one(::Type{T}) where {T<:StaticFloat64} = FloatOne()
Base.isone(@nospecialize(x::StaticFloat64)) = false
Base.zero(@nospecialize(x::Type{<:StaticFloat64})) = FloatZero
Base.one(@nospecialize(x::Type{<:StaticFloat64})) = FloatOne()

function fsub(::StaticFloat64{X}, ::StaticFloat64{Y}) where {X,Y}
return StaticFloat64{Base.sub_float(X, Y)::Float64}()
Expand All @@ -51,46 +49,24 @@ function fmul(::StaticFloat64{X}, ::StaticFloat64{Y}) where {X,Y}
return StaticFloat64{Base.mul_float(X, Y)::Float64}()
end

Base.:+(x::StaticFloat64{X}, y::StaticFloat64{Y}) where {X,Y} = fadd(x, y)
Base.:+(x::StaticFloat64{X}, y::StaticInt{Y}) where {X,Y} = +(x, float(y))
Base.:+(x::StaticInt{X}, y::StaticFloat64{Y}) where {X,Y} = +(float(x), y)
Base.:+(x::FloatZero, ::FloatZero) = x
Base.:+(x::StaticFloat64{X}, ::FloatZero) where {X} = x
Base.:+(::FloatZero, y::StaticFloat64{Y}) where {Y} = y
Base.:+(x::StaticFloat64{X}, ::Zero) where {X} = x
Base.:+(::Zero, y::StaticFloat64{Y}) where {Y} = y
Base.:+(@nospecialize(x::StaticFloat64), @nospecialize(y::StaticFloat64)) = fadd(x, y)
@inline Base.:+(@nospecialize(x::StaticFloat64), @nospecialize(y::StaticInt)) = +(x, float(y))
@inline Base.:+(@nospecialize(x::StaticInt), @nospecialize(y::StaticFloat64)) = +(float(x), y)

Base.:-(::StaticFloat64{X}) where {X} = StaticFloat64{-X}()
Base.:-(x::StaticFloat64{X}, y::StaticFloat64{Y}) where {X,Y} = fsub(x, y)
Base.:-(x::StaticFloat64{X}, y::StaticInt{Y}) where {X,Y} = -(x, float(y))
Base.:-(x::StaticInt{X}, y::StaticFloat64{Y}) where {X,Y} = -(float(x), y)
Base.:-(x::FloatZero, ::FloatZero) = x
Base.:-(x::StaticFloat64{X}, ::FloatZero) where {X} = x
Base.:-(x::StaticFloat64{X}, ::Zero) where {X} = x
Base.:-(::FloatZero, y::StaticFloat64{Y}) where {Y} = -y
Base.:-(::Zero, y::StaticFloat64{Y}) where {Y} = -y

Base.:*(x::StaticFloat64{X}, y::StaticFloat64{Y}) where {X,Y} = fmul(x, y)
Base.:*(x::StaticFloat64{X}, y::StaticInt{Y}) where {X,Y} = *(x, float(y))
Base.:*(::StaticFloat64{X}, ::Zero) where {X} = FloatZero()
Base.:*(::Zero, ::StaticFloat64{Y}, ) where {Y} = FloatZero()
Base.:*(x::StaticFloat64{X}, ::One) where {X} = x
Base.:*(x::StaticInt{X}, y::StaticFloat64{Y}) where {X,Y} = *(float(x), y)
Base.:*(::One, y::StaticFloat64{Y}) where {Y} = y
Base.:*(x::FloatZero, ::FloatZero) = x
Base.:*(::StaticFloat64{X}, y::FloatZero) where {X} = y
Base.:*(x::FloatZero, ::StaticFloat64{Y}) where {Y} = x
Base.:*(x::FloatZero, ::FloatOne) = x
Base.:*(x::FloatOne, ::FloatOne) = x
Base.:*(x::StaticFloat64{X}, ::FloatOne) where {X} = x
Base.:*(::FloatOne, y::StaticFloat64{Y}) where {Y} = y
Base.:*(::FloatOne, y::FloatZero) = y

Base.:/(x::StaticFloat64{X}, y::StaticFloat64{Y}) where {X,Y} = fdiv(x, y)
Base.:/(x::StaticFloat64{X}, y::StaticInt{Y}) where {X,Y} = /(x, float(y))
Base.:/(x::StaticInt{X}, y::StaticFloat64{Y}) where {X,Y} = /(float(x), y)

@generated Base.sqrt(::StaticInt{M}) where {M} = Expr(:call, Expr(:curly, :StaticFloat64, sqrt(M)))
Base.:-(@nospecialize(x::StaticFloat64), @nospecialize(y::StaticFloat64)) = fsub(x, y)
@inline Base.:-(@nospecialize(x::StaticFloat64), @nospecialize(y::StaticInt)) = -(x, float(y))
@inline Base.:-(@nospecialize(x::StaticInt), @nospecialize(y::StaticFloat64)) = -(float(x), y)

Base.:*(@nospecialize(x::StaticFloat64), @nospecialize(y::StaticFloat64)) = fmul(x, y)
@inline Base.:*(@nospecialize(x::StaticFloat64), @nospecialize(y::StaticInt)) = *(x, float(y))
@inline Base.:*(@nospecialize(x::StaticInt), @nospecialize(y::StaticFloat64)) = *(float(x), y)

Base.:/(@nospecialize(x::StaticFloat64), @nospecialize(y::StaticFloat64)) = fdiv(x, y)
Base.:/(@nospecialize(x::StaticFloat64), @nospecialize(y::StaticInt)) = /(x, float(y))
Base.:/(@nospecialize(x::StaticInt), @nospecialize(y::StaticFloat64)) = /(float(x), y)

Base.sqrt(@nospecialize(x::StaticInt)) = sqrt(float(x))
@generated Base.sqrt(::StaticFloat64{M}) where {M} = Expr(:call, Expr(:curly, :StaticFloat64, sqrt(M)))

@generated Base.round(::StaticFloat64{M}) where {M} = Expr(:call, Expr(:curly, :StaticFloat64, round(M)))
Expand All @@ -103,13 +79,7 @@ Base.:(^)(::StaticFloat64{x}, y::Float64) where {x} = exp2(log2(x) * y)

Base.inv(x::StaticFloat64{N}) where {N} = fdiv(one(x), x)

# @generated function Base.exponent(::StaticFloat64{M}) where {M}
# Expr(:call, Expr(:curly, :StaticInt, exponent(M)))
# end

@inline function Base.exponent(::StaticFloat64{M}) where {M}
static(exponent(M))
end
@inline Base.exponent(::StaticFloat64{M}) where {M} = static(exponent(M))

for f in (:rad2deg, :deg2rad, :cbrt,
:mod2pi, :rem2pi, :sinpi, :cospi,
Expand All @@ -122,7 +92,6 @@ for f in (:rad2deg, :deg2rad, :cbrt,
:sinh, :cosh, :tanh, :sech, :csch, :coth,
:asinh, :acosh, :atanh, :asech, :acsch, :acoth,
)

@eval @generated function (Base.$f)(::StaticFloat64{M}) where {M}
Expr(:call, Expr(:curly, :StaticFloat64, $f(M)))
end
Expand Down
Loading