Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove constraint on AbstractPDMat #1552

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ChainRulesCore = "1"
DensityInterface = "0.4"
FillArrays = "0.9, 0.10, 0.11, 0.12, 0.13"
PDMats = "0.10, 0.11"
PDMats = "0.11.11"
QuadGK = "2"
SpecialFunctions = "1.2, 2"
StatsBase = "0.32, 0.33"
StatsFuns = "0.9.15, 1"
julia = "1.3"

[extras]
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Expand All @@ -43,4 +44,4 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["StableRNGs", "Calculus", "ChainRulesTestUtils", "Distributed", "FiniteDifferences", "ForwardDiff", "JSON", "StaticArrays", "Test", "OffsetArrays"]
test = ["BlockDiagonals", "StableRNGs", "Calculus", "ChainRulesTestUtils", "Distributed", "FiniteDifferences", "ForwardDiff", "JSON", "StaticArrays", "Test", "OffsetArrays"]
32 changes: 21 additions & 11 deletions src/multivariate/mvnormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,15 @@ struct MvNormal{T<:Real,Cov<:AbstractPDMat,Mean<:AbstractVector} <: AbstractMvNo
end
```

Here, the mean vector can be an instance of any `AbstractVector`. The covariance can be
of any subtype of `AbstractPDMat`. Particularly, one can use `PDMat` for full covariance,
`PDiagMat` for diagonal covariance, and `ScalMat` for the isotropic covariance -- those
in the form of ``\\sigma^2 \\mathbf{I}``. (See the Julia package
Here, the mean vector can be an instance of any `AbstractVector`.

Special handling is included if the covariance is a subtype of `AbstractPDMat`.
Particularly, one can use `PDMat` for full covariance, `PDiagMat` for diagonal covariance,
and `ScalMat` for the isotropic covariance
-- those in the form of ``\\sigma^2 \\mathbf{I}``. (See the Julia package
[PDMats](https://github.com/JuliaStats/PDMats.jl/) for details).
If you pass a dense `Matrix` for the covariance, it is automatically converted to a `PDMat`.
For other matrix types, you have to convert them yourself.

We also define a set of aliases for the types using different combinations of mean vectors and covariance:

Expand Down Expand Up @@ -166,9 +170,14 @@ Generally, users don't have to worry about these internal details.
We provide a common constructor `MvNormal`, which will construct a distribution of
appropriate type depending on the input arguments.
"""
struct MvNormal{T<:Real,Cov<:AbstractPDMat,Mean<:AbstractVector} <: AbstractMvNormal
struct MvNormal{T<:Real,Cov<:AbstractMatrix,Mean<:AbstractVector} <: AbstractMvNormal
μ::Mean
Σ::Cov

function MvNormal{T, Cov, Mean}(μ::Mean, Σ::Cov) where {T, Mean, Cov}
size(Σ, 1) == size(Σ, 2) == length(μ) || throw(DimensionMismatch("The dimensions of mu and Sigma are inconsistent."))
return new{T, Cov, Mean}(μ, Σ)
end
end

const MultivariateNormal = MvNormal # for the purpose of backward compatibility
Expand All @@ -182,14 +191,15 @@ const ZeroMeanDiagNormal{Axes} = MvNormal{Float64,PDiagMat{Float64,Vector{Float6
const ZeroMeanFullNormal{Axes} = MvNormal{Float64,PDMat{Float64,Matrix{Float64}},Zeros{Float64,1,Axes}}

### Construction
function MvNormal(μ::AbstractVector{T}, Σ::AbstractPDMat{T}) where {T<:Real}
dim(Σ) == length(μ) || throw(DimensionMismatch("The dimensions of mu and Sigma are inconsistent."))
MvNormal{T,typeof(Σ), typeof(μ)}(μ, Σ)
end

function MvNormal(μ::AbstractVector{T}, Σ::AbstractMatrix{T}) where {T<:Real}
MvNormal{T, typeof(Σ), typeof(μ)}(μ, Σ)
end
function MvNormal(μ::AbstractVector{<:Real}, Σ::AbstractPDMat{<:Real})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess here we might want to change it to

Suggested change
function MvNormal::AbstractVector{<:Real}, Σ::AbstractPDMat{<:Real})
function MvNormal::AbstractVector{<:Real}, Σ::AbstractMatrix{<:Real})

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests pass, right now.
It took some effort to make sure the constructors cover the right cases and don't stackoverflow, or ambiguity error.
I could go and give them another pass over now that it is working, but I wouldn't want to just go and relax the ones there right now whily-nilly

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think without this relaxation you can provoke a test error when you use something like MvNormal(::Vector{Float32}, ::BlockDiagonal{Float64}).

R = Base.promote_eltype(μ, Σ)
MvNormal(convert(AbstractArray{R}, μ), convert(AbstractArray{R}, Σ))
μc = convert(AbstractArray{R}, μ)
Σc = convert(AbstractArray{R}, Σ)
MvNormal{R, typeof(Σc), typeof(μc)}(μc, Σc)
Comment on lines +200 to +202
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is not needed it seems, is it?

Suggested change
μc = convert(AbstractArray{R}, μ)
Σc = convert(AbstractArray{R}, Σ)
MvNormal{R, typeof(Σc), typeof(μc)}(μc, Σc)
MvNormal(convert(AbstractArray{R}, μ), convert(AbstractArray{R}, Σ))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is so we can call the parameterized constructor.
Otherwise we get a stackoverflow.

end

# constructor with general covariance matrix
Expand All @@ -198,7 +208,7 @@ end

Construct a multivariate normal distribution with mean `μ` and covariance matrix `Σ`.
"""
MvNormal(μ::AbstractVector{<:Real}, Σ::AbstractMatrix{<:Real}) = MvNormal(μ, PDMat(Σ))
MvNormal(μ::AbstractVector{<:Real}, Σ::Matrix{<:Real}) = MvNormal(μ, PDMat(Σ))
MvNormal(μ::AbstractVector{<:Real}, Σ::Diagonal{<:Real}) = MvNormal(μ, PDiagMat(Σ.diag))
MvNormal(μ::AbstractVector{<:Real}, Σ::Union{Symmetric{<:Real,<:Diagonal{<:Real}},Hermitian{<:Real,<:Diagonal{<:Real}}}) = MvNormal(μ, PDiagMat(Σ.data.diag))
MvNormal(μ::AbstractVector{<:Real}, Σ::UniformScaling{<:Real}) =
Expand Down
34 changes: 17 additions & 17 deletions src/multivariate/mvnormalcanon.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ which is also a subtype of `AbstractMvNormal` to represent a multivariate normal
canonical parameters. Particularly, `MvNormalCanon` is defined as:

```julia
struct MvNormalCanon{T<:Real,P<:AbstractPDMat,V<:AbstractVector} <: AbstractMvNormal
struct MvNormalCanon{T<:Real,P<:AbstractMatrix,V<:AbstractVector} <: AbstractMvNormal
μ::V # the mean vector
h::V # potential vector, i.e. inv(Σ) * μ
J::P # precision matrix, i.e. inv(Σ)
Expand All @@ -40,10 +40,19 @@ const ZeroMeanIsoNormalCanon{Axes} = MvNormalCanon{Float64, ScalMat{Float64},

**Note:** `MvNormalCanon` share the same set of methods as `MvNormal`.
"""
struct MvNormalCanon{T<:Real,P<:AbstractPDMat,V<:AbstractVector} <: AbstractMvNormal
struct MvNormalCanon{T<:Real,P<:AbstractMatrix,V<:AbstractVector} <: AbstractMvNormal
μ::V # the mean vector
h::V # potential vector, i.e. inv(Σ) * μ
J::P # precision matrix, i.e. inv(Σ)

function MvNormalCanon{T,P,V}(μ::V, h::AbstractVector, J::P) where {T<:Real, V<:AbstractVector{T}, P}
length(μ) == length(h) == dim(J) || throw(DimensionMismatch("Inconsistent argument dimensions"))
if typeof(μ) === typeof(h)
return new{T,typeof(J),typeof(μ)}(μ, h, J)
else
return new{T,typeof(J),Vector{T}}(collect(μ), collect(h), J)
end
Comment on lines +50 to +54
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be handled by an outer constructor shouldn't it? It seems a bit weird to lie about the type of the constructed distribution and get something else than MvNormalCanon{T,P,V} here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that does seem odd.

end
end

const FullNormalCanon = MvNormalCanon{Float64,PDMat{Float64,Matrix{Float64}},Vector{Float64}}
Expand All @@ -56,26 +65,17 @@ const ZeroMeanIsoNormalCanon{Axes} = MvNormalCanon{Float64,ScalMat{Float64},Zer


### Constructors
function MvNormalCanon(μ::AbstractVector{T}, h::AbstractVector{T}, J::AbstractPDMat{T}) where {T<:Real}
length(μ) == length(h) == dim(J) || throw(DimensionMismatch("Inconsistent argument dimensions"))
if typeof(μ) === typeof(h)
return MvNormalCanon{T,typeof(J),typeof(μ)}(μ, h, J)
else
return MvNormalCanon{T,typeof(J),Vector{T}}(collect(μ), collect(h), J)
end
end

function MvNormalCanon(μ::AbstractVector{T}, h::AbstractVector{T}, J::AbstractPDMat) where {T<:Real}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we want to keep this function but change J::AbstractPDMat to J::AbstractMatrix{<:Real}.

function MvNormalCanon(μ::AbstractVector{T}, h::AbstractVector{T}, J::P) where {T<:Real, P}
Copy link
Member

@devmotion devmotion May 31, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function MvNormalCanon::AbstractVector{T}, h::AbstractVector{T}, J::P) where {T<:Real, P}
function MvNormalCanon::AbstractVector{T}, h::AbstractVector{T}, J::AbstractMatrix{T}) where {T<:Real}

R = promote_type(T, eltype(J))
MvNormalCanon(convert(AbstractArray{R}, μ), convert(AbstractArray{R}, h), convert(AbstractArray{R}, J))
MvNormalCanon{T,P,typeof(μ)}(convert(AbstractArray{R}, μ), convert(AbstractArray{R}, h), convert(AbstractArray{R}, J))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the old function here (relaxed to AbstractMatrix, see above):

Suggested change
MvNormalCanon{T,P,typeof(μ)}(convert(AbstractArray{R}, μ), convert(AbstractArray{R}, h), convert(AbstractArray{R}, J))
MvNormalCanon(convert(AbstractArray{R}, μ), convert(AbstractArray{R}, h), convert(AbstractArray{R}, J))

end

function MvNormalCanon(μ::AbstractVector{<:Real}, h::AbstractVector{<:Real}, J::AbstractPDMat)
function MvNormalCanon(μ::AbstractVector{<:Real}, h::AbstractVector{<:Real}, J::AbstractMatrix{<:Real})
R = Base.promote_eltype(μ, h, J)
MvNormalCanon(convert(AbstractArray{R}, μ), convert(AbstractArray{R}, h), convert(AbstractArray{R}, J))
end

function MvNormalCanon(h::AbstractVector{<:Real}, J::AbstractPDMat)
function MvNormalCanon(h::AbstractVector{<:Real}, J::AbstractMatrix{<:Real})
length(h) == dim(J) || throw(DimensionMismatch("Inconsistent argument dimensions"))
R = Base.promote_eltype(h, J)
hh = convert(AbstractArray{R}, h)
Expand All @@ -89,7 +89,7 @@ end
Construct a multivariate normal distribution with potential vector `h` and precision matrix
`J`.
"""
MvNormalCanon(h::AbstractVector{<:Real}, J::AbstractMatrix{<:Real}) = MvNormalCanon(h, PDMat(J))
MvNormalCanon(h::AbstractVector{<:Real}, J::Matrix{<:Real}) = MvNormalCanon(h, PDMat(J))
MvNormalCanon(h::AbstractVector{<:Real}, J::Diagonal{<:Real}) = MvNormalCanon(h, PDiagMat(J.diag))
MvNormalCanon(μ::AbstractVector{<:Real}, J::Union{Symmetric{<:Real,<:Diagonal{<:Real}},Hermitian{<:Real,<:Diagonal{<:Real}}}) = MvNormalCanon(μ, PDiagMat(J.data.diag))
function MvNormalCanon(h::AbstractVector{<:Real}, J::UniformScaling{<:Real})
Expand Down Expand Up @@ -170,7 +170,7 @@ sqmahal!(r::AbstractVector, d::MvNormalCanon, x::AbstractMatrix) = quad!(r, d.J,

# Sampling (for GenericMvNormal)

unwhiten_winv!(J::AbstractPDMat, x::AbstractVecOrMat) = unwhiten!(inv(J), x)
unwhiten_winv!(J::AbstractMatrix, x::AbstractVecOrMat) = unwhiten!(inv(J), x)
unwhiten_winv!(J::PDiagMat, x::AbstractVecOrMat) = whiten!(J, x)
unwhiten_winv!(J::ScalMat, x::AbstractVecOrMat) = whiten!(J, x)
if isdefined(PDMats, :PDSparseMat)
Expand Down
7 changes: 6 additions & 1 deletion test/mvnormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using Distributions
using LinearAlgebra, Random, Test
using SparseArrays
using FillArrays
using BlockDiagonals

###### General Testing

Expand Down Expand Up @@ -51,7 +52,11 @@ using FillArrays
(MvNormal(mu, Diagonal(dv)), mu, Matrix(Diagonal(dv))),
(MvNormal(mu, Symmetric(Diagonal(dv))), mu, Matrix(Diagonal(dv))),
(MvNormal(mu, Hermitian(Diagonal(dv))), mu, Matrix(Diagonal(dv))),
(MvNormal(mu_r, Diagonal(dv)), mu_r, Matrix(Diagonal(dv))) ]
(MvNormal(mu_r, Diagonal(dv)), mu_r, Matrix(Diagonal(dv))),
(MvNormal([mu_r; mu_r], BlockDiagonal([C, C])), [mu_r; mu_r], Matrix(BlockDiagonal([C, C]))),
(MvNormal([mu_r; mu_r], BlockDiagonal([PDMat(C), PDMat(C)])), [mu_r; mu_r], Matrix(BlockDiagonal([C, C]))),
(MvNormalCanon([mu_r; mu_r], BlockDiagonal([C, C])), BlockDiagonal([C, C]) \ [mu_r; mu_r], inv(BlockDiagonal([C, C]))),
]

@test mean(g) ≈ μ
@test cov(g) ≈ Σ
Expand Down