Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

completely type stable code cleared by JET allocates when it shouldn't #51955

Open
nsajko opened this issue Oct 31, 2023 · 2 comments
Open

completely type stable code cleared by JET allocates when it shouldn't #51955

nsajko opened this issue Oct 31, 2023 · 2 comments

Comments

@nsajko
Copy link
Contributor

nsajko commented Oct 31, 2023

A stripped-down version of my package:

EDIT: see a version that is further reduced in the comment below

# Copyright © 2023 Neven Sajko. All rights reserved.

module OptimalSortingNetworks

module TupleUtils

const NT = NTuple{n,Any} where {n}

rough_type_of(::T) where {n, T<:NT{n}} = NTuple{n,eltype(T)}

function copy_to_tuple(collection, ::Val{n}) where {n}
  f = let c = collection
    i -> c[begin + i - 1]
  end
  ntuple(f, Val(n))::NTuple{n,eltype(collection)}
end

end

module Sorted

function sorted end
function sorted! end
default_minmax_by(::Type{T}) where {T} = identity
default_minmax_less(::Type{T}) where {T} = <
new_minmax(
  ::Type{T} = Any;
  by::By = default_minmax_by(T),
  less::Less = default_minmax_less(T),
) where {T, By, Less} =
  let by = by, less = less
    (l, r) -> less(by(r), by(l)) ? (r, l) : (l, r)
  end

sorted(collection::C) where {C} = sorted(new_minmax(C), collection)
sorted!(collection::C) where {C} = sorted!(new_minmax(C), collection)

end

module SortingNetworks

import ..TupleUtils

const NT = TupleUtils.NT
const r = TupleUtils.rough_type_of

sort(::Any, t::Union{NT{0},NT{1}}) = t
sort(minmax::M, t::NT{2}) where {M} = minmax(t...)::r(t)

function sort(m::M, t::NT{3}) where {M}
  (a0, a1, a2) = t

  (b0, b2) = m(a0, a2)

  (c0, c1) = m(b0, a1)

  (d1, d2) = m(c1, b2)

  (c0, d1, d2)::r(t)
end

end

module SortedTuple

import ..TupleUtils, ..Sorted, ..SortingNetworks

const NT = TupleUtils.NT
const r = TupleUtils.rough_type_of
const sort = SortingNetworks.sort

to_nt_type(lengths) = Union{map((n -> NT{n}), lengths)...}

const jointly_optimal_lengths = (0:9..., 11)
const JointlyOptimal = to_nt_type(jointly_optimal_lengths)
sorted_impl(minmax, t::JointlyOptimal) = sort(minmax, t)::r(t)

const other_optimal_lengths = (10,)
const OtherOptimal = to_nt_type(other_optimal_lengths)
sorted_impl(minmax, t::OtherOptimal) = sort(minmax, t)::r(t)

Sorted.sorted(minmax, t::Tuple) = sorted_impl(minmax, t)::r(t)

end

module SortedVector

import ..TupleUtils, ..Sorted

const NT = TupleUtils.NT
const to_tuple = TupleUtils.copy_to_tuple

const MutableVector = Union{Vector{T},SubArray{T,1,<:Array{T}}} where {T}

struct LengthError <: Exception
  length::Int
end
throw_length_error(len::Int) = throw(LengthError(len))

to_sorted_tuple(mm, c, ::Val{n}) where {n} =
  Sorted.sorted(mm, to_tuple(c, Val(n)))::NTuple{n,eltype(c)}

function sort!(mm, v::MutableVector{T}, ::Val{n}) where {n, T}
  t = to_sorted_tuple(mm, v, Val(n))::NTuple{n,T}
  v .= t
  nothing
end

function Sorted.sorted!(minmax, v::MutableVector{T}) where {T}
  p = (minmax, v)
  len = length(v)

  if len == 0
    sort!(p..., Val(0))
  elseif len == 1
    sort!(p..., Val(1))
  elseif len == 2
    sort!(p..., Val(2))
  elseif len == 3
    sort!(p..., Val(3))
  else
    throw_length_error(len)
  end

  v
end

end

using .Sorted: sorted!

end

JET reports no errors when calling sorted! on a three-element-long Vector{Int}, but Julia seems to allocate something on the heap anyway for some reason.

There are 64 bytes allocated on all of v1.9, v1.10 (beta) and v1.11 (nightly):

julia> include("/tmp/OSN.jl")
Main.OptimalSortingNetworks

julia> const sorted! = OptimalSortingNetworks.sorted!
sorted! (generic function with 2 methods)

julia> const v = [10, 3, 7]
3-element Vector{Int64}:
 10
  3
  7

julia> sorted!(v)
3-element Vector{Int64}:
  3
  7
 10

julia> @allocated sorted!(v)
64

julia> versioninfo()
Julia Version 1.11.0-DEV.777
Commit 96147bbe334 (2023-10-30 18:40 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 8 × AMD Ryzen 3 5300U with Radeon Graphics
  WORD_SIZE: 64
  LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
  Threads: 1 on 8 virtual cores
Environment:
  JULIA_NUM_PRECOMPILE_TASKS = 3
  JULIA_PKG_PRECOMPILE_AUTO = 0

@profview_allocs sorted!(v) sample_rate=1 indicates that the allocation happens in sorted_impl, in line 76, even though I think everything should be on the stack or even just in registers.

@nsajko
Copy link
Contributor Author

nsajko commented Oct 31, 2023

More minimal:

module OptimalSortingNetworks

module TupleUtils
  const NT = NTuple{n,Any} where {n}
  rough_type_of(::T) where {n, T<:NT{n}} = NTuple{n,eltype(T)}
  function copy_to_tuple(collection, ::Val{n}) where {n}
    f = let c = collection
      i -> c[begin + i - 1]
    end
    ntuple(f, Val(n))::NTuple{n,eltype(collection)}
  end
end

module Sorted
  function sorted end
  function sorted! end
end

module SortedTuple
  import ..TupleUtils, ..Sorted
  const NT = TupleUtils.NT
  const r = TupleUtils.rough_type_of
  sort(::Any, t::Union{NT{0},NT{1}}) = t
  sort(minmax::M, t::NT{2}) where {M} = minmax(t...)::r(t)
  Sorted.sorted(minmax, t::T) where {T<:Tuple} = sort(minmax, t)::r(t)
end

module SortedVector
  import ..TupleUtils, ..Sorted
  const NT = TupleUtils.NT
  const to_tuple = TupleUtils.copy_to_tuple
  const MutableVector = Union{Vector{T},SubArray{T,1,<:Array{T}}} where {T}
  function to_sorted_tuple(mm, c, ::Val{n}) where {n}
    C = NTuple{n,eltype(c)}
    Sorted.sorted(mm, to_tuple(c, Val(n))::C)::C
  end
  function sort!(mm, v::MutableVector{T}, ::Val{n}) where {n, T}
    t = to_sorted_tuple(mm, v, Val(n))::NTuple{n,T}
    v .= t
    nothing
  end
end

end

const v = [10, 7]
sorted!(v) = OptimalSortingNetworks.SortedVector.sort!(minmax, v, Val(2))
sorted!(v)
@allocated sorted!(v)
@allocated sorted!(v)

@LilithHafner
Copy link
Member

I love a MWE that I can paste directly into a REPL; thanks!

Adding a type parameter for the type of minmax gives a hint to the compiler to specialize on that type. This diff applied to your second example removes the allocations:

- function Sorted.sorted!(minmax, v::MutableVector{T}) where {T}
+ function Sorted.sorted!(minmax::M, v::MutableVector{T}) where {M, T}

https://docs.julialang.org/en/v1/manual/performance-tips/#Be-aware-of-when-Julia-avoids-specializing

It would be lovely if JET had a way of opting in to detecting this sort of thing, but that's a question for the JET folks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants