Skip to content

Commit

Permalink
refine and cleanup handling of range arithmetic
Browse files Browse the repository at this point in the history
Try to be more careful about which types we use for arguments and return
values and comparisons in intermediate computations. Not expected to
change nominal behaviors, but may improve some unusual ranges that
require some conversions or are near over/underflow.

And use convert(T,1) rather than oneunit(T) to support fewer types, as
we want the default step to be a unitless 1 (e.g., not Nanosecond(1)).

Replaces #43058
  • Loading branch information
vtjnash committed Dec 7, 2021
1 parent bdf9ead commit ff185b7
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 129 deletions.
121 changes: 64 additions & 57 deletions base/range.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

(:)(start::T, stop::T) where {T<:Real} = UnitRange{T}(start, stop)

(:)(start::T, stop::T) where {T} = (:)(start, oftype(stop-start, 1), stop)
(:)(start::T, stop::T) where {T} = (:)(start, oftype(stop >= start ? stop - start : start - stop, 1), stop)

# promote start and stop, leaving step alone
(:)(start::A, step, stop::C) where {A<:Real,C<:Real} =
Expand Down Expand Up @@ -164,7 +164,7 @@ _range(start::Any , step::Any , stop::Any , len::Any ) = range_error
range_length(len::Integer) = OneTo(len)

# Stop as the only argument
range_stop(stop) = range_start_stop(oneunit(stop), stop)
range_stop(stop) = range_start_stop(oftype(stop, 1), stop)
range_stop(stop::Integer) = range_length(stop)

# Stop and length as the only argument
Expand Down Expand Up @@ -200,10 +200,17 @@ function range_start_step_length(a::T, step, len::Integer) where {T}
_rangestyle(OrderStyle(T), ArithmeticStyle(T), a, step, len)
end

_rangestyle(::Ordered, ::ArithmeticWraps, a::T, step::S, len::Integer) where {T,S} =
StepRange{typeof(a+zero(step)),S}(a, step, a+step*(len-1))
_rangestyle(::Any, ::Any, a::T, step::S, len::Integer) where {T,S} =
StepRangeLen{typeof(a+zero(step)),T,S}(a, step, len)
function _rangestyle(::Ordered, ::ArithmeticWraps, a, step, len::Integer)
start = a + zero(step)
stop = a + step * (len - 1)
T = typeof(start)
return StepRange{T,typeof(step)}(start, step, convert(T, stop))
end
function _rangestyle(::Any, ::Any, a, step, len::Integer)
start = a + zero(step)
T = typeof(a)
return StepRangeLen{typeof(start),T,typeof(step)}(a, step, len)
end

range_start_step_stop(start, step, stop) = start:step:stop

Expand Down Expand Up @@ -306,19 +313,19 @@ struct StepRange{T,S} <: OrdinalRange{T,S}
stop::T

function StepRange{T,S}(start, step, stop) where {T,S}
sta = convert(T, start)
ste = convert(S, step)
sto = convert(T, stop)
new(sta, ste, steprange_last(sta,ste,sto))
start = convert(T, start)
step = convert(S, step)
stop = convert(T, stop)
return new(start, step, steprange_last(start, step, stop))
end
end

# to make StepRange constructor inlineable, so optimizer can see `step` value
function steprange_last(start::T, step, stop) where T
if isa(start,AbstractFloat) || isa(step,AbstractFloat)
function steprange_last(start, step, stop)::typeof(stop)
if isa(start, AbstractFloat) || isa(step, AbstractFloat)
throw(ArgumentError("StepRange should not be used with floating point"))
end
if isa(start,Integer) && !isinteger(start + step)
if isa(start, Integer) && !isinteger(start + step)
throw(ArgumentError("StepRange{<:Integer} cannot have non-integer step"))
end
z = zero(step)
Expand All @@ -335,30 +342,28 @@ function steprange_last(start::T, step, stop) where T
absdiff, absstep = stop > start ? (stop - start, step) : (start - stop, -step)

# Compute remainder as a nonnegative number:
if T <: Signed && absdiff < zero(absdiff)
# handle signed overflow with unsigned rem
remain = convert(T, unsigned(absdiff) % absstep)
if absdiff isa Signed && absdiff < zero(absdiff)
# unlikely, but handle the signed overflow case with unsigned rem
remain = convert(typeof(absdiff), unsigned(absdiff) % absstep)
else
remain = absdiff % absstep
remain = convert(typeof(absdiff), absdiff % absstep)
end
# Move `stop` closer to `start` if there is a remainder:
last = stop > start ? stop - remain : stop + remain
end
end
last
return last
end

function steprange_last_empty(start::Integer, step, stop)
# empty range has a special representation where stop = start-1
# this is needed to avoid the wrap-around that can happen computing
# start - step, which leads to a range that looks very large instead
# of empty.
function steprange_last_empty(start::Integer, step, stop)::typeof(stop)
# empty range has a special representation where stop = start-1,
# which simplifies arithmetic for Signed numbers
if step > zero(step)
last = start - oneunit(stop-start)
last = start - oneunit(step)
else
last = start + oneunit(stop-start)
last = start + oneunit(step)
end
last
return last
end
# For types where x+oneunit(x) may not be well-defined use the user-given value for stop
steprange_last_empty(start, step, stop) = stop
Expand Down Expand Up @@ -388,18 +393,21 @@ UnitRange{Int64}
struct UnitRange{T<:Real} <: AbstractUnitRange{T}
start::T
stop::T
UnitRange{T}(start, stop) where {T<:Real} = new(start, unitrange_last(start,stop))
UnitRange{T}(start::T, stop::T) where {T<:Real} = new(start, unitrange_last(start, stop))
end
UnitRange{T}(start, stop) where {T<:Real} = UnitRange{T}(convert(T, start), convert(T, stop))
UnitRange(start::T, stop::T) where {T<:Real} = UnitRange{T}(start, stop)
UnitRange(start, stop) = UnitRange(promote(start, stop)...)

unitrange_last(::Bool, stop::Bool) = stop
unitrange_last(start::T, stop::T) where {T<:Integer} =
stop >= start ? stop : convert(T,start-oneunit(start-stop))
unitrange_last(start::T, stop::T) where {T} =
stop >= start ? convert(T,start+floor(stop-start)) :
convert(T,start-oneunit(stop-start))
# if stop and start are integral, we know that their difference is a multiple of 1
unitrange_last(start::Integer, stop::Integer) =
stop >= start ? stop : convert(typeof(stop), start - oneunit(start - stop))
# otherwise, use `floor` as a more efficient way to compute modulus with step=1
unitrange_last(start, stop) =
stop >= start ? convert(typeof(stop), start + floor(stop - start)) :
convert(typeof(stop), start - oneunit(start - stop))

unitrange(x) = UnitRange(x)
unitrange(x::AbstractUnitRange) = UnitRange(x) # convenience conversion for promoting the range type

if isdefined(Main, :Base)
# Constant-fold-able indexing into tuples to functionally expose Base.tail and Base.front
Expand Down Expand Up @@ -556,7 +564,7 @@ function LinRange{T}(start, stop, len::Integer) where T
end

function LinRange(start, stop, len::Integer)
T = typeof((stop-start)/len)
T = typeof((zero(stop) - zero(start)) / oneunit(len))
LinRange{T}(start, stop, len)
end

Expand Down Expand Up @@ -642,7 +650,7 @@ length(r::AbstractRange) = error("length implementation missing") # catch mistak
size(r::AbstractRange) = (length(r),)

isempty(r::StepRange) =
# steprange_last_empty(r.start, r.step, r.stop) == r.stop
# steprange_last(r.start, r.step, r.stop) == r.stop
(r.start != r.stop) & ((r.step > zero(r.step)) != (r.stop > r.start))
isempty(r::AbstractUnitRange) = first(r) > last(r)
isempty(r::StepRangeLen) = length(r) == 0
Expand Down Expand Up @@ -689,9 +697,8 @@ firstindex(::LinRange) = 1
# defined between the relevant types
function checked_length(r::OrdinalRange{T}) where T
s = step(r)
# s != 0, by construction, but avoids the division error later
start = first(r)
if s == zero(s) || isempty(r)
if isempty(r)
return Integer(div(start - start, oneunit(s)))
end
stop = last(r)
Expand All @@ -716,9 +723,8 @@ end

function length(r::OrdinalRange{T}) where T
s = step(r)
# s != 0, by construction, but avoids the division error later
start = first(r)
if s == zero(s) || isempty(r)
if isempty(r)
return Integer(div(start - start, oneunit(s)))
end
stop = last(r)
Expand Down Expand Up @@ -756,7 +762,6 @@ let bigints = Union{Int, UInt, Int64, UInt64, Int128, UInt128}
# (near typemax) for types with known `unsigned` functions
function length(r::OrdinalRange{T}) where T<:bigints
s = step(r)
s == zero(s) && return zero(T) # unreachable, by construction, but avoids the error case here later
isempty(r) && return zero(T)
diff = last(r) - first(r)
# if |s| > 1, diff might have overflowed, but unsigned(diff)÷s should
Expand All @@ -773,7 +778,6 @@ let bigints = Union{Int, UInt, Int64, UInt64, Int128, UInt128}
end
function checked_length(r::OrdinalRange{T}) where T<:bigints
s = step(r)
s == zero(s) && return zero(T) # unreachable, by construction, but avoids the error case here later
isempty(r) && return zero(T)
stop, start = last(r), first(r)
# n.b. !(s isa T)
Expand All @@ -800,7 +804,6 @@ let smallints = (Int === Int64 ?
# n.b. !(step isa T)
function length(r::OrdinalRange{<:smallints})
s = step(r)
s == zero(s) && return 0 # unreachable, by construction, but avoids the error case here later
isempty(r) && return 0
return div(Int(last(r)) - Int(first(r)), s) + 1
end
Expand Down Expand Up @@ -962,29 +965,30 @@ function getindex(r::AbstractUnitRange, s::AbstractUnitRange{T}) where {T<:Integ
@boundscheck checkbounds(r, s)

if T === Bool
range(first(s) ? first(r) : last(r), length = Integer(last(s)))
return range(first(s) ? first(r) : last(r), length = last(s))
else
f = first(r)
st = oftype(f, f + first(s)-firstindex(r))
return range(st, length=length(s))
start = oftype(f, f + first(s)-firstindex(r))
return range(start, length=length(s))
end
end

function getindex(r::OneTo{T}, s::OneTo) where T
@inline
@boundscheck checkbounds(r, s)
OneTo(T(s.stop))
return OneTo(T(s.stop))
end

function getindex(r::AbstractUnitRange, s::StepRange{T}) where {T<:Integer}
@inline
@boundscheck checkbounds(r, s)

if T === Bool
range(first(s) ? first(r) : last(r), step=oneunit(eltype(r)), length = Integer(last(s)))
return range(first(s) ? first(r) : last(r), step=oneunit(eltype(r)), length = last(s))
else
st = oftype(first(r), first(r) + s.start-firstindex(r))
return range(st, step=step(s), length=length(s))
f = first(r)
start = oftype(f, f + s.start-firstindex(r))
return range(start, step=step(s), length=length(s))
end
end

Expand All @@ -994,19 +998,22 @@ function getindex(r::StepRange, s::AbstractRange{T}) where {T<:Integer}

if T === Bool
if length(s) == 0
return range(first(r), step=step(r), length=0)
start, len = first(r), 0
elseif length(s) == 1
if first(s)
return range(first(r), step=step(r), length=1)
start, len = first(r), 1
else
return range(first(r), step=step(r), length=0)
start, len = first(r), 0
end
else # length(s) == 2
return range(last(r), step=step(r), length=1)
start, len = last(r), 1
end
return range(start, step=step(r); length=len)
else
st = oftype(r.start, r.start + (first(s)-1)*step(r))
return range(st, step=step(r)*step(s), length=length(s))
f = r.start
st = r.step
start = oftype(f, f + (first(s)-oneunit(first(s)))*st)
return range(start; step=st*step(s), length=length(s))
end
end

Expand Down Expand Up @@ -1235,7 +1242,7 @@ end
issubset(r::OneTo, s::OneTo) = r.stop <= s.stop

issubset(r::AbstractUnitRange{<:Integer}, s::AbstractUnitRange{<:Integer}) =
isempty(r) || first(r) >= first(s) && last(r) <= last(s)
isempty(r) || (first(r) >= first(s) && last(r) <= last(s))

## linear operations on ranges ##

Expand Down
7 changes: 4 additions & 3 deletions stdlib/Dates/src/ranges.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@ end
Base.length(r::StepRange{<:TimeType}) = isempty(r) ? Int64(0) : len(r.start, r.stop, r.step) + 1
# Period ranges hook into Int64 overflow detection
Base.length(r::StepRange{<:Period}) = length(StepRange(value(r.start), value(r.step), value(r.stop)))
Base.checked_length(r::StepRange{<:Period}) = Base.checked_length(StepRange(value(r.start), value(r.step), value(r.stop)))

# Overload Base.steprange_last because `rem` is not overloaded for `TimeType`s
# Overload Base.steprange_last because `step::Period` may be a variable amount of time (e.g. for Month and Year)
function Base.steprange_last(start::T, step, stop) where T<:TimeType
if isa(step,AbstractFloat)
if isa(step, AbstractFloat)
throw(ArgumentError("StepRange should not be used with floating point"))
end
z = zero(step)
Expand All @@ -47,7 +48,7 @@ function Base.steprange_last(start::T, step, stop) where T<:TimeType
last = stop - remain
end
end
last
return last
end

import Base.in
Expand Down

0 comments on commit ff185b7

Please sign in to comment.