-
Notifications
You must be signed in to change notification settings - Fork 6
/
sparse_cat.jl
115 lines (96 loc) · 3.08 KB
/
sparse_cat.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""
SparseCat(values, probabilities)
Create a sparse categorical distribution.
`values` is an iterable object containing the possible values (can be of any type) in the distribution that have nonzero probability. `probabilities` is an iterable object that contains the associated probabilities.
This is optimized for value iteration with a fast implementation of `weighted_iterator`. Both `pdf` and `rand` are order n.
"""
struct SparseCat{V, P}
vals::V
probs::P
end
function rand(rng::AbstractRNG, d::SparseCat)
r = sum(d.probs)*rand(rng)
tot = zero(eltype(d.probs))
for (v, p) in d
tot += p
if r < tot
return v
end
end
if sum(d.probs) <= 0.0
error("""
Tried to sample from a SparseCat distribution with probabilities that sum to $(sum(d.probs)).
vals = $(d.vals)
probs = $(d.probs)
""")
end
error("Error sampling from SparseCat distribution with vals $(d.vals) and probs $(d.probs)") # try to help with type stability
end
# slow linear search :(
function pdf(d::SparseCat, s)
for (v, p) in d
if v == s
return p
end
end
return zero(eltype(d.probs))
end
function pdf(d::SparseCat{V,P}, s) where {V<:AbstractArray, P<:AbstractArray}
for (i,v) in enumerate(d.vals)
if v == s
return d.probs[i]
end
end
return zero(eltype(d.probs))
end
support(d::SparseCat) = d.vals
weighted_iterator(d::SparseCat) = d
# iterator for general SparseCat
# this has some type stability problems
function Base.iterate(d::SparseCat)
val, vstate = iterate(d.vals)
prob, pstate = iterate(d.probs)
return ((val=>prob), (vstate, pstate))
end
function Base.iterate(d::SparseCat, dstate::Tuple)
vstate, pstate = dstate
vnext = iterate(d.vals, vstate)
pnext = iterate(d.probs, pstate)
if vnext == nothing || pnext == nothing
return nothing
end
val, vstate_next = vnext
prob, pstate_next = pnext
return ((val=>prob), (vstate_next, pstate_next))
end
# iterator for SparseCat with indexed members
const Indexed = Union{AbstractArray, Tuple, NamedTuple}
function Base.iterate(d::SparseCat{V,P}, state::Integer=1) where {V<:Indexed, P<:Indexed}
if state > length(d)
return nothing
end
return (d.vals[state]=>d.probs[state], state+1)
end
Base.length(d::SparseCat) = min(length(d.vals), length(d.probs))
Base.eltype(D::Type{SparseCat{V,P}}) where {V, P} = Pair{eltype(V), eltype(P)}
sampletype(D::Type{SparseCat{V,P}}) where {V, P} = eltype(V)
Random.gentype(D::Type{SparseCat{V,P}}) where {V, P} = eltype(V)
function mean(d::SparseCat)
vsum = zero(eltype(d.vals))
for (v, p) in d
vsum += v*p
end
return vsum/sum(d.probs)
end
function mode(d::SparseCat)
bestp = zero(eltype(d.probs))
bestv = first(d.vals)
for (v, p) in d
if p >= bestp
bestp = p
bestv = v
end
end
return bestv
end
Base.show(io::IO, m::MIME"text/plain", d::SparseCat) = showdistribution(io, m, d, title="SparseCat distribution")