/
reshaped.jl
163 lines (135 loc) · 5.65 KB
/
reshaped.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
"""
ReshapedDistribution(d::Distribution{<:ArrayLikeVariate}, dims::Dims{N})
Distribution of the `N`-dimensional random variable `reshape(X, dims)` where `X` is a random
variable with distribution `d`.
It is recommended to not use `reshape` instead of the constructor of `ReshapedDistribution`
directly since `reshape` can return more optimized distributions for specific types of `d`
and number of dimensions `N`.
"""
struct ReshapedDistribution{N,S<:ValueSupport,D<:Distribution{<:ArrayLikeVariate,S}} <: Distribution{ArrayLikeVariate{N},S}
dist::D
dims::Dims{N}
function ReshapedDistribution(dist::Distribution{<:ArrayLikeVariate,S}, dims::Dims{N}) where {N,S<:ValueSupport}
_reshape_check_dims(dist, dims)
return new{N,S,typeof(dist)}(dist, dims)
end
end
function _reshape_check_dims(dist::Distribution{<:ArrayLikeVariate}, dims::Dims)
(all(d > 0 for d in dims) && length(dist) == prod(dims)) ||
throw(ArgumentError("dimensions $(dims) do not match size of source distribution $(size(dist))"))
end
Base.size(d::ReshapedDistribution) = d.dims
Base.eltype(::Type{ReshapedDistribution{<:Any,<:ValueSupport,D}}) where {D} = eltype(D)
partype(d::ReshapedDistribution) = partype(d.dist)
params(d::ReshapedDistribution) = (d.dist, d.dims)
function insupport(d::ReshapedDistribution{N}, x::AbstractArray{<:Real,N}) where {N}
return size(d) == size(x) && insupport(d.dist, reshape(x, size(d.dist)))
end
mean(d::ReshapedDistribution) = reshape(mean(d.dist), size(d))
var(d::ReshapedDistribution) = reshape(var(d.dist), size(d))
cov(d::ReshapedDistribution) = reshape(cov(d.dist), length(d), length(d))
function cov(d::ReshapedDistribution{2}, ::Val{false})
n, p = size(d)
return reshape(cov(d), n, p, n, p)
end
mode(d::ReshapedDistribution) = reshape(mode(d.dist), size(d))
# TODO: remove?
rank(d::ReshapedDistribution{2}) = minimum(size(d))
# logpdf evaluation
# have to fix method ambiguity due to default fallback for `MatrixDistribution`...
_logpdf(d::ReshapedDistribution{N}, x::AbstractArray{<:Real,N}) where {N} = __logpdf(d, x)
_logpdf(d::ReshapedDistribution{2}, x::AbstractMatrix{<:Real}) = __logpdf(d, x)
function __logpdf(d::ReshapedDistribution{N}, x::AbstractArray{<:Real,N}) where {N}
dist = d.dist
return @inbounds logpdf(dist, reshape(x, size(dist)))
end
# loglikelihood
# useful if the original distribution defined more optimized methods
@inline function loglikelihood(
d::ReshapedDistribution{N},
x::AbstractArray{<:Real,N},
) where {N}
@boundscheck begin
size(x) == size(d) ||
throw(DimensionMismatch("inconsistent array dimensions"))
end
dist = d.dist
return @inbounds loglikelihood(dist, reshape(x, size(dist)))
end
@inline function loglikelihood(
d::ReshapedDistribution{N},
x::AbstractArray{<:Real,M},
) where {N,M}
@boundscheck begin
M > N ||
throw(DimensionMismatch(
"number of dimensions of `x` ($M) must be greater than number of dimensions of `d` ($N)"
))
ntuple(i -> size(x, i), Val(N)) == size(d) ||
throw(DimensionMismatch("inconsistent array dimensions"))
end
dist = d.dist
trailingsize = ntuple(i -> size(x, N + i), Val(M - N))
return @inbounds loglikelihood(dist, reshape(x, size(dist)..., trailingsize...))
end
# sampling
function _rand!(
rng::AbstractRNG,
d::ReshapedDistribution{N},
x::AbstractArray{<:Real,N}
) where {N}
dist = d.dist
@inbounds rand!(rng, dist, reshape(x, size(dist)))
return x
end
"""
reshape(d::Distribution{<:ArrayLikeVariate}, dims::Int...)
reshape(d::Distribution{<:ArrayLikeVariate}, dims::Dims)
Return a [`Distribution`](@ref) of `reshape(X, dims)` where `X` is a random variable with
distribution `d`.
The default implementation returns a [`ReshapedDistribution`](@ref). However, it can return
more optimized distributions for specific types of distributions and numbers of dimensions.
Therefore it is recommended to use `reshape` instead of the constructor of
`ReshapedDistribution`.
# Implementation
Since `reshape(d, dims::Int...)` calls `reshape(d, dims::Dims)`, one should implement
`reshape(d, ::Dims)` for desired distributions `d`.
See also: [`vec`](@ref)
"""
function Base.reshape(dist::Distribution{<:ArrayLikeVariate}, dims::Dims)
return ReshapedDistribution(dist, dims)
end
function Base.reshape(dist::Distribution{<:ArrayLikeVariate}, dims1::Int, dims::Int...)
return reshape(dist, (dims1, dims...))
end
"""
vec(d::Distribution{<:ArrayLikeVariate})
Return a [`MultivariateDistribution`](@ref) of `vec(X)` where `X` is a random variable with
distribution `d`.
The default implementation returns a [`ReshapedDistribution`](@ref). However, it can return
more optimized distributions for specific types of distributions and numbers of dimensions.
Therefore it is recommended to use `vec` instead of the constructor of
`ReshapedDistribution`.
# Implementation
Since `vec(d)` is defined as `reshape(d, length(d))` one should implement
`reshape(d, ::Tuple{Int})` rather than `vec`.
See also: [`reshape`](@ref)
"""
Base.vec(dist::Distribution{<:ArrayLikeVariate}) = reshape(dist, length(dist))
# avoid unnecessary wrappers
function Base.reshape(
dist::ReshapedDistribution{<:Any,<:ValueSupport,<:MultivariateDistribution},
dims::Tuple{Int},
)
_reshape_check_dims(dist, dims)
return dist.dist
end
function Base.reshape(dist::MultivariateDistribution, dims::Tuple{Int})
_reshape_check_dims(dist, dims)
return dist
end
# specialization for flattened `MatrixNormal`
function Base.reshape(dist::MatrixNormal, dims::Tuple{Int})
_reshape_check_dims(dist, dims)
return MvNormal(vec(dist.M), kron(dist.V, dist.U))
end