-
Notifications
You must be signed in to change notification settings - Fork 414
/
bernoulli.jl
154 lines (116 loc) · 3.98 KB
/
bernoulli.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
"""
Bernoulli(p)
A *Bernoulli distribution* is parameterized by a success rate `p`, which takes value 1
with probability `p` and 0 with probability `1-p`.
```math
P(X = k) = \\begin{cases}
1 - p & \\quad \\text{for } k = 0, \\\\
p & \\quad \\text{for } k = 1.
\\end{cases}
```
```julia
Bernoulli() # Bernoulli distribution with p = 0.5
Bernoulli(p) # Bernoulli distribution with success rate p
params(d) # Get the parameters, i.e. (p,)
succprob(d) # Get the success rate, i.e. p
failprob(d) # Get the failure rate, i.e. 1 - p
```
External links:
* [Bernoulli distribution on Wikipedia](http://en.wikipedia.org/wiki/Bernoulli_distribution)
"""
struct Bernoulli{T<:Real} <: DiscreteUnivariateDistribution
p::T
function Bernoulli{T}(p::T) where T
@check_args(Bernoulli, zero(p) <= p <= one(p))
new{T}(p)
end
end
Bernoulli(p::T) where {T<:Real} = Bernoulli{T}(p)
Bernoulli(p::Integer) = Bernoulli(Float64(p))
Bernoulli() = Bernoulli(0.5)
@distr_support Bernoulli 0 1
#### Conversions
convert(::Type{Bernoulli{T}}, p::Real) where {T<:Real} = Bernoulli(T(p))
convert(::Type{Bernoulli{T}}, d::Bernoulli{S}) where {T <: Real, S <: Real} = Bernoulli(T(d.p))
#### Parameters
succprob(d::Bernoulli) = d.p
failprob(d::Bernoulli) = 1 - d.p
params(d::Bernoulli) = (d.p,)
@inline partype(d::Bernoulli{T}) where {T<:Real} = T
#### Properties
mean(d::Bernoulli) = succprob(d)
var(d::Bernoulli) = succprob(d) * failprob(d)
skewness(d::Bernoulli) = (p0 = failprob(d); p1 = succprob(d); (p0 - p1) / sqrt(p0 * p1))
kurtosis(d::Bernoulli) = 1 / var(d) - 6
mode(d::Bernoulli) = ifelse(succprob(d) > 1/2, 1, 0)
function modes(d::Bernoulli)
p = succprob(d)
p < 1/2 ? [0] :
p > 1/2 ? [1] : [0, 1]
end
median(d::Bernoulli) = ifelse(succprob(d) <= 1/2, 0, 1)
function entropy(d::Bernoulli)
p0 = failprob(d)
p1 = succprob(d)
(p0 == 0 || p0 == 1) ? zero(d.p) : -(p0 * log(p0) + p1 * log(p1))
end
#### Evaluation
pdf(d::Bernoulli, x::Bool) = x ? succprob(d) : failprob(d)
pdf(d::Bernoulli, x::Int) = x == 0 ? failprob(d) :
x == 1 ? succprob(d) : zero(d.p)
cdf(d::Bernoulli, x::Bool) = x ? failprob(d) : one(d.p)
cdf(d::Bernoulli, x::Int) = x < 0 ? zero(d.p) :
x < 1 ? failprob(d) : one(d.p)
ccdf(d::Bernoulli, x::Bool) = x ? succprob(d) : one(d.p)
ccdf(d::Bernoulli, x::Int) = x < 0 ? one(d.p) :
x < 1 ? succprob(d) : zero(d.p)
function quantile(d::Bernoulli{T}, p::Real) where T<:Real
0 <= p <= 1 ? (p <= failprob(d) ? zero(T) : one(T)) : T(NaN)
end
function cquantile(d::Bernoulli{T}, p::Real) where T<:Real
0 <= p <= 1 ? (p >= succprob(d) ? zero(T) : one(T)) : T(NaN)
end
mgf(d::Bernoulli, t::Real) = failprob(d) + succprob(d) * exp(t)
cf(d::Bernoulli, t::Real) = failprob(d) + succprob(d) * cis(t)
#### Sampling
rand(d::Bernoulli) = rand(GLOBAL_RNG, d)
rand(rng::AbstractRNG, d::Bernoulli) = round(Int, rand(rng) <= succprob(d))
#### MLE fitting
struct BernoulliStats <: SufficientStats
cnt0::Float64
cnt1::Float64
BernoulliStats(c0::Real, c1::Real) = new(Float64(c0), Float64(c1))
end
fit_mle(::Type{Bernoulli}, ss::BernoulliStats) = Bernoulli(ss.cnt1 / (ss.cnt0 + ss.cnt1))
function suffstats(::Type{Bernoulli}, x::AbstractArray{T}) where T<:Integer
n = length(x)
c0 = c1 = 0
for i = 1:n
@inbounds xi = x[i]
if xi == 0
c0 += 1
elseif xi == 1
c1 += 1
else
throw(DomainError())
end
end
BernoulliStats(c0, c1)
end
function suffstats(::Type{Bernoulli}, x::AbstractArray{T}, w::AbstractArray{Float64}) where T<:Integer
n = length(x)
length(w) == n || throw(DimensionMismatch("Inconsistent argument dimensions."))
c0 = c1 = 0
for i = 1:n
@inbounds xi = x[i]
@inbounds wi = w[i]
if xi == 0
c0 += wi
elseif xi == 1
c1 += wi
else
throw(DomainError())
end
end
BernoulliStats(c0, c1)
end