-
Notifications
You must be signed in to change notification settings - Fork 12
/
Shrinkage.jl
125 lines (105 loc) · 4.52 KB
/
Shrinkage.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
export Shrinkage
# TODO: make sure we act correctly for `probabilities` and `allprobabilities`.
"""
Shrinkage{<:OutcomeSpace} <: ProbabilitiesEstimator
Shrinkage(; t = nothing, λ = nothing)
The `Shrinkage` estimator is used with [`probabilities`](@ref) and related functions
to estimate probabilities over the given `m`-element counting-based
[`OutcomeSpace`](@ref) using James-Stein-type shrinkage
[JamesStein1992](@cite), as presented in [Hausser2009](@citet).
## Description
The `Shrinkage` estimator estimates a cell probability ``\\theta_{k}^{\\text{Shrink}}`` as
```math
\\theta_{k}^{\\text{Shrink}} = \\lambda t_k + (1-\\lambda) \\hat{\\theta}_k^{RelativeAmount},
```
where ``\\lambda \\in [0, 1]`` is the shrinkage intensity (``\\lambda = 0`` means
no shrinkage, and ``\\lambda = 1`` means full shrinkage), and ``t_k`` is the shrinkage
target. [Hausser2009](@citet) picks ``t_k = 1/m``, i.e. the uniform
distribution.
If `t == nothing`, then ``t_k`` is set to ``1/m`` for all ``k``,
as in [Hausser2009](@citet).
If `λ == nothing` (the default), then the shrinkage intensity is optimized according
to [Hausser2009](@citet). Hence, you should probably not pick
`λ` nor `t` manually, unless you know what you are doing.
## Assumptions
The `Shrinkage` estimator assumes a fixed and known number of outcomes `m`. Thus, using
it with [`probabilities`](@ref) and [`allprobabilities`](@ref) will yield different results,
depending on whether all outcomes are observed in the input data or not.
For [`probabilities`](@ref), `m` is the number of *observed* outcomes.
For [`allprobabilities`](@ref), `m = total_outcomes(o, x)`, where `o` is the
[`OutcomeSpace`](@ref) and `x` is the input data.
!!! note
If used with [`allprobabilities`](@ref)/[`allprobabilities_and_outcomes`](@ref), then
outcomes which have not been observed may be assigned non-zero probabilities.
This might affect your results if using e.g. [`missing_outcomes`](@ref).
## Examples
```julia
using ComplexityMeasures
x = cumsum(randn(100))
ps_shrink = probabilities(Shrinkage(OrdinalPatterns(m = 3)), x)
```
See also: [`RelativeAmount`](@ref), [`BayesianRegularization`](@ref).
"""
struct Shrinkage{T <: Union{Nothing, Real, Vector{<:Real}}, L <: Union{Nothing, Real}} <: ProbabilitiesEstimator
t::T
λ::L
function Shrinkage(; t::T = nothing, λ::L = nothing) where {T, L}
new{T, L}(t, λ)
end
end
function probabilities(est::Shrinkage, outcomemodel::OutcomeSpace, x)
probs, Ω = probabilities_and_outcomes(RelativeAmount(), outcomemodel, x)
return probs_and_outs_from_histogram(est, outcomemodel, probs, Ω, x)
end
function allprobabilities(est::Shrinkage, outcomemodel::OutcomeSpace, x)
probs_all, Ω_all = allprobabilities_and_outcomes(outcomemodel, x)
return probs_and_outs_from_histogram(est, outcomemodel, probs_all, Ω_all, x)
end
function probs_and_outs_from_histogram(est::Shrinkage, outcomemodel::OutcomeSpace,
probs_observed, Ω_observed, x)
verify_counting_based(outcomemodel, "Shrinkage")
t = est.t
n = encoded_space_cardinality(outcomemodel, x) # Normalize based on *encoded* data.
m = length(Ω_observed)
Ω = outcomes(outcomemodel, x)
if t isa Vector{<:Real}
length(t) == M || throw(DimensionMismatch("If `t` is a vector, `length(t)` must equal the number of elements in the outcome space (got $M outcomes, but length(t)=$(length(t)))."))
end
λ = get_λ(est, n, probs_observed, t, m)
@assert 0 ≤ λ ≤ 1
probs = zeros(m)
for (k, ωₖ) in enumerate(Ω_observed)
tₖ = get_tₖ(t, k, m)
pₖ = θₖ_shrink(probs_observed[k], λ, tₖ)
idx = findfirst(x -> x == Ω_observed[k], Ω_observed)
probs[idx] = θₖ_shrink(probs_observed[k], λ, tₖ)
end
@assert sum(probs) ≈ 1.0
return Probabilities(probs, (x1 = Ω_observed,))
end
function get_λ(est, n, probs_observed, t, m)
# Optimal shrinkage intensity (eq. 5 in Hausser and Strimmer, 2009).
if est.λ === nothing
densum = 0.0
for k = 1:m
tₖ = get_tₖ(t, k, m)
densum += (tₖ - probs_observed[k]) ^ 2
end
λ = (1 - sum(probs_observed .^ 2)) / (n - 1)*densum
# User-picked shrinkage intensity.
else
λ = est.λ
end
# Truncate, so that 0 ≤ λ ≤ 1.
return max(0.0, min(1.0, λ))
end
function get_tₖ(t, k::Int, m::Int)
if t isa Real
return t
elseif t === nothing
return 1/m
else
return t[k]
end
end
θₖ_shrink(θₖML, λ, tₖ) = λ*tₖ + (1 - λ)*θₖML