Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refine and cleanup handling of range arithmetic #43360

Merged
merged 2 commits into from
Dec 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
121 changes: 64 additions & 57 deletions base/range.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

(:)(start::T, stop::T) where {T<:Real} = UnitRange{T}(start, stop)

(:)(start::T, stop::T) where {T} = (:)(start, oftype(stop-start, 1), stop)
(:)(start::T, stop::T) where {T} = (:)(start, oftype(stop >= start ? stop - start : start - stop, 1), stop)

# promote start and stop, leaving step alone
(:)(start::A, step, stop::C) where {A<:Real,C<:Real} =
Expand Down Expand Up @@ -164,7 +164,7 @@ _range(start::Any , step::Any , stop::Any , len::Any ) = range_error
range_length(len::Integer) = OneTo(len)

# Stop as the only argument
range_stop(stop) = range_start_stop(oneunit(stop), stop)
range_stop(stop) = range_start_stop(oftype(stop, 1), stop)
range_stop(stop::Integer) = range_length(stop)

# Stop and length as the only argument
Expand Down Expand Up @@ -200,10 +200,17 @@ function range_start_step_length(a::T, step, len::Integer) where {T}
_rangestyle(OrderStyle(T), ArithmeticStyle(T), a, step, len)
end

_rangestyle(::Ordered, ::ArithmeticWraps, a::T, step::S, len::Integer) where {T,S} =
StepRange{typeof(a+zero(step)),S}(a, step, a+step*(len-1))
_rangestyle(::Any, ::Any, a::T, step::S, len::Integer) where {T,S} =
StepRangeLen{typeof(a+zero(step)),T,S}(a, step, len)
function _rangestyle(::Ordered, ::ArithmeticWraps, a, step, len::Integer)
start = a + zero(step)
stop = a + step * (len - 1)
T = typeof(start)
return StepRange{T,typeof(step)}(start, step, convert(T, stop))
end
function _rangestyle(::Any, ::Any, a, step, len::Integer)
start = a + zero(step)
T = typeof(a)
return StepRangeLen{typeof(start),T,typeof(step)}(a, step, len)
end

range_start_step_stop(start, step, stop) = start:step:stop

Expand Down Expand Up @@ -306,19 +313,19 @@ struct StepRange{T,S} <: OrdinalRange{T,S}
stop::T

function StepRange{T,S}(start, step, stop) where {T,S}
sta = convert(T, start)
ste = convert(S, step)
sto = convert(T, stop)
new(sta, ste, steprange_last(sta,ste,sto))
start = convert(T, start)
step = convert(S, step)
stop = convert(T, stop)
return new(start, step, steprange_last(start, step, stop))
end
end

# to make StepRange constructor inlineable, so optimizer can see `step` value
function steprange_last(start::T, step, stop) where T
if isa(start,AbstractFloat) || isa(step,AbstractFloat)
function steprange_last(start, step, stop)::typeof(stop)
if isa(start, AbstractFloat) || isa(step, AbstractFloat)
throw(ArgumentError("StepRange should not be used with floating point"))
end
if isa(start,Integer) && !isinteger(start + step)
if isa(start, Integer) && !isinteger(start + step)
throw(ArgumentError("StepRange{<:Integer} cannot have non-integer step"))
end
z = zero(step)
Expand All @@ -335,30 +342,28 @@ function steprange_last(start::T, step, stop) where T
absdiff, absstep = stop > start ? (stop - start, step) : (start - stop, -step)

# Compute remainder as a nonnegative number:
if T <: Signed && absdiff < zero(absdiff)
# handle signed overflow with unsigned rem
remain = convert(T, unsigned(absdiff) % absstep)
if absdiff isa Signed && absdiff < zero(absdiff)
# unlikely, but handle the signed overflow case with unsigned rem
remain = convert(typeof(absdiff), unsigned(absdiff) % absstep)
else
remain = absdiff % absstep
remain = convert(typeof(absdiff), absdiff % absstep)
end
# Move `stop` closer to `start` if there is a remainder:
last = stop > start ? stop - remain : stop + remain
end
end
last
return last
end

function steprange_last_empty(start::Integer, step, stop)
# empty range has a special representation where stop = start-1
# this is needed to avoid the wrap-around that can happen computing
# start - step, which leads to a range that looks very large instead
# of empty.
function steprange_last_empty(start::Integer, step, stop)::typeof(stop)
# empty range has a special representation where stop = start-1,
# which simplifies arithmetic for Signed numbers
if step > zero(step)
last = start - oneunit(stop-start)
last = start - oneunit(step)
else
last = start + oneunit(stop-start)
last = start + oneunit(step)
end
last
return last
end
# For types where x+oneunit(x) may not be well-defined use the user-given value for stop
steprange_last_empty(start, step, stop) = stop
Expand Down Expand Up @@ -388,18 +393,21 @@ UnitRange{Int64}
struct UnitRange{T<:Real} <: AbstractUnitRange{T}
start::T
stop::T
UnitRange{T}(start, stop) where {T<:Real} = new(start, unitrange_last(start,stop))
UnitRange{T}(start::T, stop::T) where {T<:Real} = new(start, unitrange_last(start, stop))
end
UnitRange{T}(start, stop) where {T<:Real} = UnitRange{T}(convert(T, start), convert(T, stop))
UnitRange(start::T, stop::T) where {T<:Real} = UnitRange{T}(start, stop)
UnitRange(start, stop) = UnitRange(promote(start, stop)...)

unitrange_last(::Bool, stop::Bool) = stop
unitrange_last(start::T, stop::T) where {T<:Integer} =
stop >= start ? stop : convert(T,start-oneunit(start-stop))
unitrange_last(start::T, stop::T) where {T} =
stop >= start ? convert(T,start+floor(stop-start)) :
convert(T,start-oneunit(stop-start))
# if stop and start are integral, we know that their difference is a multiple of 1
unitrange_last(start::Integer, stop::Integer) =
stop >= start ? stop : convert(typeof(stop), start - oneunit(start - stop))
# otherwise, use `floor` as a more efficient way to compute modulus with step=1
unitrange_last(start, stop) =
stop >= start ? convert(typeof(stop), start + floor(stop - start)) :
convert(typeof(stop), start - oneunit(start - stop))

unitrange(x) = UnitRange(x)
unitrange(x::AbstractUnitRange) = UnitRange(x) # convenience conversion for promoting the range type

if isdefined(Main, :Base)
# Constant-fold-able indexing into tuples to functionally expose Base.tail and Base.front
Expand Down Expand Up @@ -556,7 +564,7 @@ function LinRange{T}(start, stop, len::Integer) where T
end

function LinRange(start, stop, len::Integer)
T = typeof((stop-start)/len)
T = typeof((zero(stop) - zero(start)) / oneunit(len))
LinRange{T}(start, stop, len)
end

Expand Down Expand Up @@ -642,7 +650,7 @@ length(r::AbstractRange) = error("length implementation missing") # catch mistak
size(r::AbstractRange) = (length(r),)

isempty(r::StepRange) =
# steprange_last_empty(r.start, r.step, r.stop) == r.stop
# steprange_last(r.start, r.step, r.stop) == r.stop
(r.start != r.stop) & ((r.step > zero(r.step)) != (r.stop > r.start))
isempty(r::AbstractUnitRange) = first(r) > last(r)
isempty(r::StepRangeLen) = length(r) == 0
Expand Down Expand Up @@ -689,9 +697,8 @@ firstindex(::LinRange) = 1
# defined between the relevant types
function checked_length(r::OrdinalRange{T}) where T
s = step(r)
# s != 0, by construction, but avoids the division error later
start = first(r)
if s == zero(s) || isempty(r)
if isempty(r)
return Integer(div(start - start, oneunit(s)))
end
stop = last(r)
Expand All @@ -716,9 +723,8 @@ end

function length(r::OrdinalRange{T}) where T
s = step(r)
# s != 0, by construction, but avoids the division error later
start = first(r)
if s == zero(s) || isempty(r)
if isempty(r)
return Integer(div(start - start, oneunit(s)))
end
stop = last(r)
Expand Down Expand Up @@ -756,7 +762,6 @@ let bigints = Union{Int, UInt, Int64, UInt64, Int128, UInt128}
# (near typemax) for types with known `unsigned` functions
function length(r::OrdinalRange{T}) where T<:bigints
s = step(r)
s == zero(s) && return zero(T) # unreachable, by construction, but avoids the error case here later
isempty(r) && return zero(T)
diff = last(r) - first(r)
# if |s| > 1, diff might have overflowed, but unsigned(diff)÷s should
Expand All @@ -773,7 +778,6 @@ let bigints = Union{Int, UInt, Int64, UInt64, Int128, UInt128}
end
function checked_length(r::OrdinalRange{T}) where T<:bigints
s = step(r)
s == zero(s) && return zero(T) # unreachable, by construction, but avoids the error case here later
isempty(r) && return zero(T)
stop, start = last(r), first(r)
# n.b. !(s isa T)
Expand All @@ -800,7 +804,6 @@ let smallints = (Int === Int64 ?
# n.b. !(step isa T)
function length(r::OrdinalRange{<:smallints})
s = step(r)
s == zero(s) && return 0 # unreachable, by construction, but avoids the error case here later
isempty(r) && return 0
return div(Int(last(r)) - Int(first(r)), s) + 1
end
Expand Down Expand Up @@ -962,29 +965,30 @@ function getindex(r::AbstractUnitRange, s::AbstractUnitRange{T}) where {T<:Integ
@boundscheck checkbounds(r, s)

if T === Bool
range(first(s) ? first(r) : last(r), length = Integer(last(s)))
return range(first(s) ? first(r) : last(r), length = last(s))
else
f = first(r)
st = oftype(f, f + first(s)-firstindex(r))
return range(st, length=length(s))
start = oftype(f, f + first(s)-firstindex(r))
return range(start, length=length(s))
end
end

function getindex(r::OneTo{T}, s::OneTo) where T
@inline
@boundscheck checkbounds(r, s)
OneTo(T(s.stop))
return OneTo(T(s.stop))
end

function getindex(r::AbstractUnitRange, s::StepRange{T}) where {T<:Integer}
@inline
@boundscheck checkbounds(r, s)

if T === Bool
range(first(s) ? first(r) : last(r), step=oneunit(eltype(r)), length = Integer(last(s)))
return range(first(s) ? first(r) : last(r), step=oneunit(eltype(r)), length = last(s))
else
st = oftype(first(r), first(r) + s.start-firstindex(r))
return range(st, step=step(s), length=length(s))
f = first(r)
start = oftype(f, f + s.start-firstindex(r))
return range(start, step=step(s), length=length(s))
end
end

Expand All @@ -994,19 +998,22 @@ function getindex(r::StepRange, s::AbstractRange{T}) where {T<:Integer}

if T === Bool
if length(s) == 0
return range(first(r), step=step(r), length=0)
start, len = first(r), 0
elseif length(s) == 1
if first(s)
return range(first(r), step=step(r), length=1)
start, len = first(r), 1
else
return range(first(r), step=step(r), length=0)
start, len = first(r), 0
end
else # length(s) == 2
return range(last(r), step=step(r), length=1)
start, len = last(r), 1
end
return range(start, step=step(r); length=len)
else
st = oftype(r.start, r.start + (first(s)-1)*step(r))
return range(st, step=step(r)*step(s), length=length(s))
f = r.start
st = r.step
start = oftype(f, f + (first(s)-oneunit(first(s)))*st)
return range(start; step=st*step(s), length=length(s))
end
end

Expand Down Expand Up @@ -1235,7 +1242,7 @@ end
issubset(r::OneTo, s::OneTo) = r.stop <= s.stop

issubset(r::AbstractUnitRange{<:Integer}, s::AbstractUnitRange{<:Integer}) =
isempty(r) || first(r) >= first(s) && last(r) <= last(s)
isempty(r) || (first(r) >= first(s) && last(r) <= last(s))

## linear operations on ranges ##

Expand Down
2 changes: 1 addition & 1 deletion stdlib/Dates/src/Dates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ for more information.
"""
module Dates

import Base: ==, div, fld, mod, rem, gcd, lcm, +, -, *, /, %, broadcast
import Base: ==, isless, div, fld, mod, rem, gcd, lcm, +, -, *, /, %, broadcast
using Printf: @sprintf

using Base.Iterators
Expand Down
7 changes: 5 additions & 2 deletions stdlib/Dates/src/periods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,7 @@ default(p::Union{T,Type{T}}) where {T<:TimePeriod} = T(0)

(-)(x::P) where {P<:Period} = P(-value(x))
==(x::P, y::P) where {P<:Period} = value(x) == value(y)
==(x::Period, y::Period) = (==)(promote(x, y)...)
Base.isless(x::P, y::P) where {P<:Period} = isless(value(x), value(y))
Base.isless(x::Period, y::Period) = isless(promote(x, y)...)

# Period Arithmetic, grouped by dimensionality:
for op in (:+, :-, :lcm, :gcd)
Expand All @@ -97,6 +95,11 @@ end
(*)(A::Period, B::AbstractArray) = Broadcast.broadcast_preserving_zero_d(*, A, B)
(*)(A::AbstractArray, B::Period) = Broadcast.broadcast_preserving_zero_d(*, A, B)

for op in (:(==), :isless, :/, :rem, :mod, :lcm, :gcd)
@eval ($op)(x::Period, y::Period) = ($op)(promote(x, y)...)
end
div(x::Period, y::Period, r::RoundingMode) = div(promote(x, y)..., r)

# intfuncs
Base.gcdx(a::T, b::T) where {T<:Period} = ((g, x, y) = gcdx(value(a), value(b)); return T(g), x, y)
Base.abs(a::T) where {T<:Period} = T(abs(value(a)))
Expand Down
7 changes: 4 additions & 3 deletions stdlib/Dates/src/ranges.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@ end
Base.length(r::StepRange{<:TimeType}) = isempty(r) ? Int64(0) : len(r.start, r.stop, r.step) + 1
# Period ranges hook into Int64 overflow detection
Base.length(r::StepRange{<:Period}) = length(StepRange(value(r.start), value(r.step), value(r.stop)))
Base.checked_length(r::StepRange{<:Period}) = Base.checked_length(StepRange(value(r.start), value(r.step), value(r.stop)))

# Overload Base.steprange_last because `rem` is not overloaded for `TimeType`s
# Overload Base.steprange_last because `step::Period` may be a variable amount of time (e.g. for Month and Year)
function Base.steprange_last(start::T, step, stop) where T<:TimeType
if isa(step,AbstractFloat)
if isa(step, AbstractFloat)
throw(ArgumentError("StepRange should not be used with floating point"))
end
z = zero(step)
Expand All @@ -47,7 +48,7 @@ function Base.steprange_last(start::T, step, stop) where T<:TimeType
last = stop - remain
end
end
last
return last
end

import Base.in
Expand Down