/
bagcount.jl
74 lines (58 loc) · 1.86 KB
/
bagcount.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
"""
BagCount{T <: AbstractAggregation}
A wrapper type that when called applies the [`AbstractAggregation`](@ref) stored in it,
and appends one more element containing bag size after ``x ↦ \\log(x + 1)`` transformation to the result.
Used as a functor:
(bc::BagCount)(x, bags[, w])
where `x` is either `AbstractMatrix` or `missing`, `bags` is [`AbstractBags`](@ref) structure
and optionally `w` is an `AbstractVector` of weights.
# Examples
```jldoctest
julia> x = Float32[0 1 2; 3 4 5]
2×3 Matrix{Float32}:
0.0 1.0 2.0
3.0 4.0 5.0
julia> b = bags([1:1, 2:3])
AlignedBags{Int64}(UnitRange{Int64}[1:1, 2:3])
julia> a = vcat(SegmentedMean(2), SegmentedMax(2))
AggregationStack:
SegmentedMean(ψ = Float32[0.0, 0.0])
SegmentedMax(ψ = Float32[0.0, 0.0])
julia> a(x, b)
4×2 Matrix{Float32}:
0.0 1.5
3.0 4.5
0.0 2.0
3.0 5.0
julia> BagCount(a)(x, b)
5×2 Matrix{Float32}:
0.0 1.5
3.0 4.5
0.0 2.0
3.0 5.0
0.693147 1.09861
```
See also: [`AbstractAggregation`](@ref), [`AggregationStack`](@ref), [`SegmentedSum`](@ref),
[`SegmentedMax`](@ref), [`SegmentedMean`](@ref), [`SegmentedPNorm`](@ref), [`SegmentedLSE`](@ref).
"""
struct BagCount{T<:AbstractAggregation}
a::T
end
Flux.@functor BagCount
_bagcount(T, bags) = permutedims(log.(one(T) .+ length.(bags)))
ChainRulesCore.@non_differentiable _bagcount(T, bags)
function (bc::BagCount)(x::Maybe{AbstractArray}, bags::AbstractBags, args...)
o1 = bc.a(x, bags, args...)
o2 = _bagcount(eltype(o1), bags)
vcat(o1, o2)
end
function Base.show(io::IO, m::MIME"text/plain", @nospecialize(bc::BagCount{T})) where {T}
print(io, "BagCount(", repr(m, bc.a))
print(io, T <: AggregationStack ? "\n)" : ")")
end
function Base.show(io::IO, @nospecialize(bc::BagCount))
print(io, "BagCount")
if !get(io, :compact, false)
print(io, "(", bc.a, ")")
end
end