/
bernoulli.jl
158 lines (119 loc) · 4.29 KB
/
bernoulli.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
"""
Bernoulli(p)
A *Bernoulli distribution* is parameterized by a success rate `p`, which takes value 1
with probability `p` and 0 with probability `1-p`.
```math
P(X = k) = \\begin{cases}
1 - p & \\quad \\text{for } k = 0, \\\\
p & \\quad \\text{for } k = 1.
\\end{cases}
```
```julia
Bernoulli() # Bernoulli distribution with p = 0.5
Bernoulli(p) # Bernoulli distribution with success rate p
params(d) # Get the parameters, i.e. (p,)
succprob(d) # Get the success rate, i.e. p
failprob(d) # Get the failure rate, i.e. 1 - p
```
External links:
* [Bernoulli distribution on Wikipedia](http://en.wikipedia.org/wiki/Bernoulli_distribution)
"""
struct Bernoulli{T<:Real} <: DiscreteUnivariateDistribution
p::T
Bernoulli{T}(p::Real) where {T <: Real} = new{T}(p)
end
function Bernoulli(p::Real; check_args::Bool=true)
@check_args Bernoulli (p, zero(p) <= p <= one(p))
return Bernoulli{typeof(p)}(p)
end
Bernoulli(p::Integer; check_args::Bool=true) = Bernoulli(float(p); check_args=check_args)
Bernoulli() = Bernoulli{Float64}(0.5)
@distr_support Bernoulli false true
Base.eltype(::Type{<:Bernoulli}) = Bool
#### Conversions
convert(::Type{Bernoulli{T}}, p::Real) where {T<:Real} = Bernoulli(T(p))
Base.convert(::Type{Bernoulli{T}}, d::Bernoulli) where {T<:Real} = Bernoulli{T}(T(d.p))
Base.convert(::Type{Bernoulli{T}}, d::Bernoulli{T}) where {T<:Real} = d
#### Parameters
succprob(d::Bernoulli) = d.p
failprob(d::Bernoulli) = 1 - d.p
params(d::Bernoulli) = (d.p,)
partype(::Bernoulli{T}) where {T} = T
#### Properties
mean(d::Bernoulli) = succprob(d)
var(d::Bernoulli) = succprob(d) * failprob(d)
skewness(d::Bernoulli) = (p0 = failprob(d); p1 = succprob(d); (p0 - p1) / sqrt(p0 * p1))
kurtosis(d::Bernoulli) = 1 / var(d) - 6
mode(d::Bernoulli) = ifelse(succprob(d) > 1/2, 1, 0)
function modes(d::Bernoulli)
p = succprob(d)
p < 1/2 ? [0] :
p > 1/2 ? [1] : [0, 1]
end
median(d::Bernoulli) = ifelse(succprob(d) <= 1/2, 0, 1)
function entropy(d::Bernoulli)
p0 = failprob(d)
p1 = succprob(d)
(p0 == 0 || p0 == 1) ? zero(d.p) : -(p0 * log(p0) + p1 * log(p1))
end
#### Evaluation
pdf(d::Bernoulli, x::Bool) = x ? succprob(d) : failprob(d)
pdf(d::Bernoulli, x::Real) = x == 0 ? failprob(d) :
x == 1 ? succprob(d) : zero(d.p)
logpdf(d::Bernoulli, x::Real) = log(pdf(d, x))
cdf(d::Bernoulli, x::Bool) = x ? one(d.p) : failprob(d)
cdf(d::Bernoulli, x::Int) = x < 0 ? zero(d.p) :
x < 1 ? failprob(d) : one(d.p)
ccdf(d::Bernoulli, x::Bool) = x ? zero(d.p) : succprob(d)
ccdf(d::Bernoulli, x::Int) = x < 0 ? one(d.p) :
x < 1 ? succprob(d) : zero(d.p)
function quantile(d::Bernoulli{T}, p::Real) where T<:Real
0 <= p <= 1 ? (p <= failprob(d) ? zero(T) : one(T)) : T(NaN)
end
function cquantile(d::Bernoulli{T}, p::Real) where T<:Real
0 <= p <= 1 ? (p >= succprob(d) ? zero(T) : one(T)) : T(NaN)
end
mgf(d::Bernoulli, t::Real) = failprob(d) + succprob(d) * exp(t)
function cgf(d::Bernoulli, t)
p, = params(d)
# log(1-p+p*exp(t))
logaddexp(log1p(-p), t+log(p))
end
cf(d::Bernoulli, t::Real) = failprob(d) + succprob(d) * cis(t)
#### Sampling
rand(rng::AbstractRNG, d::Bernoulli) = rand(rng) <= succprob(d)
#### MLE fitting
struct BernoulliStats{C<:Real} <: SufficientStats
cnt0::C
cnt1::C
end
BernoulliStats(c0::Real, c1::Real) = BernoulliStats(promote(c0, c1)...)
fit_mle(::Type{T}, ss::BernoulliStats) where {T<:Bernoulli} = T(ss.cnt1 / (ss.cnt0 + ss.cnt1))
function suffstats(::Type{<:Bernoulli}, x::AbstractArray{<:Integer})
c0 = c1 = 0
for xi in x
if xi == 0
c0 += 1
elseif xi == 1
c1 += 1
else
throw(DomainError(xi, "samples must be 0 or 1"))
end
end
BernoulliStats(c0, c1)
end
function suffstats(::Type{<:Bernoulli}, x::AbstractArray{<:Integer}, w::AbstractArray{<:Real})
length(x) == length(w) || throw(DimensionMismatch("inconsistent argument dimensions"))
z = zero(eltype(w))
c0 = c1 = z + z # possibly widened and different from `z`, e.g., if `z = true`
for (xi, wi) in zip(x, w)
if xi == 0
c0 += wi
elseif xi == 1
c1 += wi
else
throw(DomainError(xi, "samples must be 0 or 1"))
end
end
BernoulliStats(c0, c1)
end