Skip to content

Commit

Permalink
Matrix Gaussian Distribution (#1091)
Browse files Browse the repository at this point in the history
* Create matrixgaussian.jl

Implements the matrix Gaussian distribution as detailed in https://arxiv.org/pdf/1804.11010.pdf

* Added MatrixGaussian to exports

* Added MatrixGaussian

* Added matrixgaussian.jl to includes

* Created tests for matrixgaussian.jl

* Updated runtests.jl with matrixgaussian.jl tests

* Fixed typo in filename

* Incorporated johnczito's suggestion- wrap MvNormal

* Updated matrigaussian tests for MvNormal wrapping

* Fixed typo in constructor definition

* Fixed typo in constructor

* Incorporated johnczito's suggestion- wrap MvNormal

* Updated matrigaussian tests for MvNormal wrapping

* Fixed typo in constructor definition

* Fixed typo in constructor

* Added constructor for two matrices

* Fixed typo in constructor

* Fixed typo in constructor

* Check dimensions for insupport

* Check if matrix is real for insupport

* Added MatrixReshaped distribution; removed MatrixGaussian distribution.

* Removed individual MatrixVariates tests

* Update src/Distributions.jl to resolve merge errors

Co-Authored-By: John Zito <johnczito@users.noreply.github.com>

* Update docs/src/matrix.md to resolve merge errors

Co-Authored-By: John Zito <johnczito@users.noreply.github.com>

Co-authored-by: John Zito <johnczito@users.noreply.github.com>
  • Loading branch information
san-soucie and johnczito committed May 27, 2020
1 parent 4cbeba8 commit 5a6fc3c
Show file tree
Hide file tree
Showing 6 changed files with 271 additions and 4 deletions.
1 change: 1 addition & 0 deletions docs/src/matrix.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ vec(d::MatrixDistribution)
MatrixNormal
Wishart
InverseWishart
MatrixReshaped
MatrixTDist
MatrixBeta
MatrixFDist
Expand Down
6 changes: 4 additions & 2 deletions src/Distributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ export
MatrixBeta,
MatrixFDist,
MatrixNormal,
MatrixReshaped,
MatrixTDist,
MixtureModel,
Multinomial,
Expand Down Expand Up @@ -316,8 +317,9 @@ Supported distributions:
GeneralizedExtremeValue, Geometric, Gumbel, Hypergeometric,
InverseWishart, InverseGamma, InverseGaussian, IsoNormal,
IsoNormalCanon, Kolmogorov, KSDist, KSOneSided, Laplace, Levy, LKJ,
Logistic, LogNormal, MatrixBeta, MatrixFDist, MatrixNormal, MatrixTDist, MixtureModel,
Multinomial, MultivariateNormal, MvLogNormal, MvNormal, MvNormalCanon,
Logistic, LogNormal, MatrixBeta, MatrixFDist, MatrixNormal,
MatrixReshaped, MatrixTDist, MixtureModel, Multinomial,
MultivariateNormal, MvLogNormal, MvNormal, MvNormalCanon,
MvNormalKnownCov, MvTDist, NegativeBinomial, NoncentralBeta, NoncentralChisq,
NoncentralF, NoncentralHypergeometric, NoncentralT, Normal, NormalCanon,
NormalInverseGaussian, Pareto, PGeneralizedGaussian, Poisson, PoissonBinomial,
Expand Down
75 changes: 75 additions & 0 deletions src/matrix/matrixreshaped.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""
MatrixReshaped(D, n, p)
```julia
D::MultivariateDistribution base distribution
n::Integer number of rows
p::Integer number of columns
```
Reshapes a multivariate distribution into a matrix distribution with n rows and
p columns.
"""
struct MatrixReshaped{S<:ValueSupport,D<:MultivariateDistribution{S}} <:
MatrixDistribution{S}
d::D
num_rows::Int
num_cols::Int
function MatrixReshaped(
d::D,
n::N,
p::N,
) where {
D<:MultivariateDistribution{
S,
},
} where {S<:ValueSupport} where {N<:Integer}
(n > 0 && p > 0) || throw(ArgumentError("n and p should be positive"))
n * p == length(d) ||
throw(ArgumentError("Dimensions provided ($n x $p) do not match source distribution of length $(length(d))"))
return new{S,D}(d, n, p)
end
end

MatrixReshaped(D::MultivariateDistribution, n::Integer) =
MatrixReshaped(D, n, n)

show(io::IO, d::MatrixReshaped) =
show_multline(io, d, [(:num_rows, d.num_rows), (:num_cols, d.num_cols)])


# -----------------------------------------------------------------------------
# Properties
# -----------------------------------------------------------------------------

size(d::MatrixReshaped) = (d.num_rows, d.num_cols)

length(d::MatrixReshaped) = length(d.d)

rank(d::MatrixReshaped) = minimum(size(d))

function insupport(d::MatrixReshaped, X::AbstractMatrix)
return isreal(X) && size(d) == size(X) && insupport(d.d, vec(X))
end

mean(d::MatrixReshaped) = reshape(mean(d.d), size(d))
mode(d::MatrixReshaped) = reshape(mode(d.d), size(d))
cov(d::MatrixReshaped, ::Val{true} = Val(true)) =
reshape(cov(d.d), prod(size(d)), prod(size(d)))
cov(d::MatrixReshaped, ::Val{false}) =
((n, p) = size(d); reshape(cov(d), n, p, n, p))
var(d::MatrixReshaped) = reshape(var(d.d), size(d))

params(d::MatrixReshaped) = (d.d, d.num_rows, d.num_cols)

@inline partype(
d::MatrixReshaped{S,<:MultivariateDistribution{S}},
) where {S<:Real} = S

_logpdf(d::MatrixReshaped, X::AbstractMatrix) = logpdf(d.d, vec(X))

function _rand!(rng::AbstractRNG, d::MatrixReshaped, Y::AbstractMatrix)
rand!(rng, d.d, view(Y, :))
return Y
end

vec(d::MatrixReshaped) = d.d
4 changes: 2 additions & 2 deletions src/matrixvariates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ check_univariate(d::MatrixDistribution) = is_univariate(d) || throw(ArgumentErro
##### Specific distributions #####

for fname in ["wishart.jl", "inversewishart.jl", "matrixnormal.jl",
"matrixtdist.jl", "matrixbeta.jl", "matrixfdist.jl",
"lkj.jl"]
"matrixreshaped.jl", "matrixtdist.jl", "matrixbeta.jl",
"matrixfdist.jl", "lkj.jl"]
include(joinpath("matrix", fname))
end
188 changes: 188 additions & 0 deletions test/matrixreshaped.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
using Distributions, Test, Random, LinearAlgebra
using Distributions: MatrixReshaped

rng = MersenneTwister(123456)

σ = rand(rng, 16, 16)
μ = rand(rng, 16)
d1 = MvNormal(μ, σ * σ')
x1 = rand(rng, d1)

sizes = [(4, 4), (8, 2), (2, 8), (1, 16), (16, 1), (4,)]
ranks = [4, 2, 2, 1, 1, 4]

d1s = [MatrixReshaped(d1, s...) for s in sizes]


@testset "MatrixReshaped MvNormal tests" begin
@testset "MatrixReshaped constructor" begin
for d in d1s
@test d isa MatrixReshaped
end
end
@testset "MatrixReshaped constructor errors" begin
@test_throws ArgumentError MatrixReshaped(d1, 4, 3)
@test_throws ArgumentError MatrixReshaped(d1, 3)
@test_throws ArgumentError MatrixReshaped(d1, -4, -4)
end
@testset "MatrixReshaped size" begin
for (d, s) in zip(d1s[1:end-1], sizes[1:end-1])
@test size(d) == s
end
end
@testset "MatrixReshaped length" begin
for d in d1s
@test length(d) == length(μ)
end
end
@testset "MatrixReshaped rank" begin
for (d, r) in zip(d1s, ranks)
@test rank(d) == r
end
end
@testset "MatrixReshaped insupport" begin
for (i, d) in enumerate(d1s[1:end-1])
for (j, s) in enumerate(sizes[1:end-1])
@test (i == j) !insupport(d, reshape(x1, s))
end
end
end
@testset "MatrixReshaped mean" begin
for (d, s) in zip(d1s[1:end-1], sizes[1:end-1])
@test mean(d) == reshape(μ, s)
end
end
@testset "MatrixReshaped mode" begin
for (d, s) in zip(d1s[1:end-1], sizes[1:end-1])
@test mode(d) == reshape(mode(d1), s)
end
end
@testset "MatrixReshaped covariance" begin
for (d, (n, p)) in zip(d1s[1:end-1], sizes[1:end-1])
@test cov(d) == σ * σ'
@test cov(d, Val(false)) == reshape* σ', n, p, n, p)
end
end
@testset "MatrixReshaped variance" begin
for (d, s) in zip(d1s[1:end-1], sizes[1:end-1])
@test var(d) == reshape(var(d1), s)
end
end
@testset "MatrixReshaped params" begin
for (d, s) in zip(d1s[1:end-1], sizes[1:end-1])
@test params(d) == (d1, s...)
end
end
@testset "MatrixReshaped partype" begin
for d in d1s
@test partype(d) == Float64
end
end
@testset "MatrixReshaped logpdf" begin
for (d, s) in zip(d1s[1:end-1], sizes[1:end-1])
x = reshape(x1, s)
@test logpdf(d, x) == logpdf(d1, x1)
end
end
@testset "MatrixReshaped rand" begin
for d in d1s
x = rand(rng, d)
@test insupport(d, x)
@test insupport(d1, vec(x))
@test logpdf(d, x) == logpdf(d1, vec(x))
end
end
@testset "MatrixReshaped vec" begin
for d in d1s
@test vec(d) == d1
end
end
end

α = rand(rng, 36)
d1 = Dirichlet(α)
x1 = rand(rng, d1)

sizes = [(6, 6), (4, 9), (9, 4), (3, 12), (12, 3), (1, 36), (36, 1), (6,)]
ranks = [6, 4, 4, 3, 3, 1, 1, 6]

d1s = [MatrixReshaped(d1, s...) for s in sizes]

@testset "MatrixReshaped Dirichlet tests" begin
@testset "MatrixReshaped constructor" begin
for d in d1s
@test d isa MatrixReshaped
end
end
@testset "MatrixReshaped constructor errors" begin
@test_throws ArgumentError MatrixReshaped(d1, 4, 3)
@test_throws ArgumentError MatrixReshaped(d1, 3)
end
@testset "MatrixReshaped size" begin
for (d, s) in zip(d1s[1:end-1], sizes[1:end-1])
@test size(d) == s
end
end
@testset "MatrixReshaped length" begin
for d in d1s
@test length(d) == length(α)
end
end
@testset "MatrixReshaped rank" begin
for (d, r) in zip(d1s, ranks)
@test rank(d) == r
end
end
@testset "MatrixReshaped insupport" begin
for (i, d) in enumerate(d1s[1:end-1])
for (j, s) in enumerate(sizes[1:end-1])
@test (i == j) !insupport(d, reshape(x1, s))
end
end
end
@testset "MatrixReshaped mean" begin
for (d, s) in zip(d1s[1:end-1], sizes[1:end-1])
@test mean(d) == reshape(mean(d1), s)
end
end
@testset "MatrixReshaped covariance" begin
for (d, (n, p)) in zip(d1s[1:end-1], sizes[1:end-1])
@test cov(d) == cov(d1)
@test cov(d, Val(false)) == reshape(cov(d1), n, p, n, p)
end
end
@testset "MatrixReshaped variance" begin
for (d, s) in zip(d1s[1:end-1], sizes[1:end-1])
@test var(d) == reshape(var(d1), s)
end
end
@testset "MatrixReshaped params" begin
for (d, s) in zip(d1s[1:end-1], sizes[1:end-1])
@test params(d) == (d1, s...)
end
end
@testset "MatrixReshaped partype" begin
for d in d1s
@test partype(d) == Float64
end
end
@testset "MatrixReshaped logpdf" begin
for (d, s) in zip(d1s[1:end-1], sizes[1:end-1])
x = reshape(x1, s)
@test logpdf(d, x) == logpdf(d1, x1)
end
end
@testset "MatrixReshaped rand" begin
for d in d1s
x = rand(rng, d)
@test insupport(d, x)
@test insupport(d1, vec(x))
@test logpdf(d, x) == logpdf(d1, vec(x))
end
end
@testset "MatrixReshaped vec" begin
for d in d1s
@test vec(d) == d1
end
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ const tests = [
"mvtdist",
"kolmogorov",
"edgeworth",
"matrixreshaped",
"matrixvariates",
"vonmisesfisher",
"conversion",
Expand Down

0 comments on commit 5a6fc3c

Please sign in to comment.