/
aggregations.jl
94 lines (77 loc) · 3.24 KB
/
aggregations.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# We document types/constructors/functors in one docstring until
# https://github.com/JuliaDocs/Documenter.jl/issues/558 is resolved
"""
AbstractAggregation
Supertype for any aggregation operator.
See also: [`AggregationStack`](@ref), [`SegmentedSum`](@ref), [`SegmentedMax`](@ref),
[`SegmentedMean`](@ref), [`SegmentedPNorm`](@ref), [`SegmentedLSE`](@ref).
"""
abstract type AbstractAggregation end
@inline _bagnorm(w::Nothing, b) = length(b)
@inline _bagnorm(w::AbstractVector, b) = @views sum(w[b])
@inline _bagnorm(w::AbstractMatrix, b) = @views vec(sum(w[:, b], dims=2))
@inline _weight(w::Nothing, _, _, ::Type{T}) where T = one(T)
@inline _weight(w::AbstractVector, _, j, _) = w[j]
@inline _weight(w::AbstractMatrix, i, j, _) = w[i, j]
@inline _weightsum(ws::Real, _) = ws
@inline _weightsum(ws::AbstractVector, i) = ws[i]
# more stable definitions for r_map and p_map
ChainRulesCore.rrule(::typeof(softplus), x) = softplus.(x), Δ -> (NoTangent(), Δ .* σ.(x))
# our definition of type min for Maybe{...} types
_typemin(t::Type) = typemin(t)
_typemin(::Type{Missing}) = missing
_typemin(::Type{Maybe{T}}) where T = typemin(T)
function _check_agg(::AbstractAggregation, ::Missing) end
function _check_agg(a::AbstractAggregation, X::AbstractMatrix)
if size(X, 1) ≠ length(a.ψ)
DimensionMismatch(
"Different number of rows in input ($(size(X, 2))) and aggregation ($(length(a.ψ)))"
) |> throw
end
end
include("segmented_sum.jl")
include("segmented_mean.jl")
include("segmented_max.jl")
include("segmented_pnorm.jl")
include("segmented_lse.jl")
include("aggregation_stack.jl")
Base.vcat(as::AbstractAggregation...) = reduce(vcat, as |> collect)
function Base.reduce(::typeof(vcat), as::Vector{<:AbstractAggregation})
AggregationStack(tuple(as...))
end
include("bagcount.jl")
# definitions for mixed aggregations
const names = ["Sum", "Mean", "Max", "PNorm", "LSE"]
for p in filter(p -> length(p) > 1, collect(powerset(collect(1:length(names)))))
s = Symbol("Segmented", names[p]...)
@eval begin
"""
$($(s))([t::Type, ]d::Int)
Construct [`AggregationStack`](@ref) consisting of $($(
join("[`Segmented" .* names[p] .* "`](@ref)", ", ", " and ")
)) operator$($(length(p) > 1 ? "s" : "")).
$($(
all(in(["Sum", "Mean", "Max"]), names[p]) ? """
# Examples
```jldoctest
julia> $(s)(4)
AggregationStack:
$(join(" Segmented" .* names[p] .* "(ψ = Float32[0.0, 0.0, 0.0, 0.0])", "\n"))
julia> $(s)(Float64, 2)
AggregationStack:
$(join(" Segmented" .* names[p] .* "(ψ = [0.0, 0.0])", "\n"))
```
""" : ""
))
See also: [`AbstractAggregation`](@ref), [`AggregationStack`](@ref), [`SegmentedSum`](@ref),
[`SegmentedMax`](@ref), [`SegmentedMean`](@ref), [`SegmentedPNorm`](@ref), [`SegmentedLSE`](@ref).
"""
function $s(d::Int)
AggregationStack($((Expr(:call, Symbol("Segmented", n), :d) for n in names[p])...))
end
end
@eval function $s(::Type{T}, d::Int) where T
AggregationStack($((Expr(:call, Symbol("Segmented", n), :T, :d) for n in names[p])...))
end
@eval export $s
end