/
distribution_wrappers.jl
71 lines (60 loc) · 2.42 KB
/
distribution_wrappers.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
using Distributions: Distributions
using Bijectors: Bijectors
using Distributions: Univariate, Multivariate, Matrixvariate
"""
A named distribution that carries the name of the random variable with it.
"""
struct NamedDist{variate,support,Td<:Distribution{variate,support},Tv<:VarName} <:
Distribution{variate,support}
dist::Td
name::Tv
end
NamedDist(dist::Distribution, name::Symbol) = NamedDist(dist, VarName{name}())
Base.length(dist::NamedDist) = Base.length(dist.dist)
Base.size(dist::NamedDist) = Base.size(dist.dist)
Distributions.logpdf(dist::NamedDist, x::Real) = Distributions.logpdf(dist.dist, x)
function Distributions.logpdf(dist::NamedDist, x::AbstractArray{<:Real})
return Distributions.logpdf(dist.dist, x)
end
function Distributions.loglikelihood(dist::NamedDist, x::Real)
return Distributions.loglikelihood(dist.dist, x)
end
function Distributions.loglikelihood(dist::NamedDist, x::AbstractArray{<:Real})
return Distributions.loglikelihood(dist.dist, x)
end
Bijectors.bijector(d::NamedDist) = Bijectors.bijector(d.dist)
struct NoDist{variate,support,Td<:Distribution{variate,support}} <:
Distribution{variate,support}
dist::Td
end
NoDist(dist::NamedDist) = NamedDist(NoDist(dist.dist), dist.name)
nodist(dist::Distribution) = NoDist(dist)
nodist(dists::AbstractArray) = nodist.(dists)
Base.length(dist::NoDist) = Base.length(dist.dist)
Base.size(dist::NoDist) = Base.size(dist.dist)
Distributions.rand(rng::Random.AbstractRNG, d::NoDist) = rand(rng, d.dist)
Distributions.logpdf(d::NoDist{<:Univariate}, ::Real) = 0
Distributions.logpdf(d::NoDist{<:Multivariate}, ::AbstractVector{<:Real}) = 0
function Distributions.logpdf(d::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real})
return zeros(Int, size(x, 2))
end
Distributions.logpdf(d::NoDist{<:Matrixvariate}, ::AbstractMatrix{<:Real}) = 0
Distributions.minimum(d::NoDist) = minimum(d.dist)
Distributions.maximum(d::NoDist) = maximum(d.dist)
Bijectors.logpdf_with_trans(d::NoDist{<:Univariate}, ::Real, ::Bool) = 0
function Bijectors.logpdf_with_trans(
d::NoDist{<:Multivariate}, ::AbstractVector{<:Real}, ::Bool
)
return 0
end
function Bijectors.logpdf_with_trans(
d::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real}, ::Bool
)
return zeros(Int, size(x, 2))
end
function Bijectors.logpdf_with_trans(
d::NoDist{<:Matrixvariate}, ::AbstractMatrix{<:Real}, ::Bool
)
return 0
end
Bijectors.bijector(d::NoDist) = Bijectors.bijector(d.dist)