Skip to content

Commit

Permalink
add mvnormalcanon
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox committed May 31, 2022
1 parent 7883de3 commit e416cb6
Showing 1 changed file with 17 additions and 17 deletions.
34 changes: 17 additions & 17 deletions src/multivariate/mvnormalcanon.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ which is also a subtype of `AbstractMvNormal` to represent a multivariate normal
canonical parameters. Particularly, `MvNormalCanon` is defined as:
```julia
struct MvNormalCanon{T<:Real,P<:AbstractPDMat,V<:AbstractVector} <: AbstractMvNormal
struct MvNormalCanon{T<:Real,P,V<:AbstractVector} <: AbstractMvNormal
μ::V # the mean vector
h::V # potential vector, i.e. inv(Σ) * μ
J::P # precision matrix, i.e. inv(Σ)
Expand All @@ -40,10 +40,19 @@ const ZeroMeanIsoNormalCanon{Axes} = MvNormalCanon{Float64, ScalMat{Float64},
**Note:** `MvNormalCanon` share the same set of methods as `MvNormal`.
"""
struct MvNormalCanon{T<:Real,P<:AbstractPDMat,V<:AbstractVector} <: AbstractMvNormal
struct MvNormalCanon{T<:Real,P,V<:AbstractVector} <: AbstractMvNormal
μ::V # the mean vector
h::V # potential vector, i.e. inv(Σ) * μ
J::P # precision matrix, i.e. inv(Σ)

function MvNormalCanon{T,P,V}::V, h::AbstractVector, J::P) where {T<:Real, V<:AbstractVector{T}, P}
length(μ) == length(h) == dim(J) || throw(DimensionMismatch("Inconsistent argument dimensions"))
if typeof(μ) === typeof(h)
return new{T,typeof(J),typeof(μ)}(μ, h, J)
else
return new{T,typeof(J),Vector{T}}(collect(μ), collect(h), J)
end
end
end

const FullNormalCanon = MvNormalCanon{Float64,PDMat{Float64,Matrix{Float64}},Vector{Float64}}
Expand All @@ -56,26 +65,17 @@ const ZeroMeanIsoNormalCanon{Axes} = MvNormalCanon{Float64,ScalMat{Float64},Zer


### Constructors
function MvNormalCanon::AbstractVector{T}, h::AbstractVector{T}, J::AbstractPDMat{T}) where {T<:Real}
length(μ) == length(h) == dim(J) || throw(DimensionMismatch("Inconsistent argument dimensions"))
if typeof(μ) === typeof(h)
return MvNormalCanon{T,typeof(J),typeof(μ)}(μ, h, J)
else
return MvNormalCanon{T,typeof(J),Vector{T}}(collect(μ), collect(h), J)
end
end

function MvNormalCanon::AbstractVector{T}, h::AbstractVector{T}, J::AbstractPDMat) where {T<:Real}
function MvNormalCanon::AbstractVector{T}, h::AbstractVector{T}, J::P) where {T<:Real, P}
R = promote_type(T, eltype(J))
MvNormalCanon(convert(AbstractArray{R}, μ), convert(AbstractArray{R}, h), convert(AbstractArray{R}, J))
MvNormalCanon{T,P,typeof(μ)}(convert(AbstractArray{R}, μ), convert(AbstractArray{R}, h), convert(AbstractArray{R}, J))
end

function MvNormalCanon::AbstractVector{<:Real}, h::AbstractVector{<:Real}, J::AbstractPDMat)
function MvNormalCanon::AbstractVector{<:Real}, h::AbstractVector{<:Real}, J)
R = Base.promote_eltype(μ, h, J)
MvNormalCanon(convert(AbstractArray{R}, μ), convert(AbstractArray{R}, h), convert(AbstractArray{R}, J))
end

function MvNormalCanon(h::AbstractVector{<:Real}, J::AbstractPDMat)
function MvNormalCanon(h::AbstractVector{<:Real}, J)
length(h) == dim(J) || throw(DimensionMismatch("Inconsistent argument dimensions"))
R = Base.promote_eltype(h, J)
hh = convert(AbstractArray{R}, h)
Expand All @@ -89,7 +89,7 @@ end
Construct a multivariate normal distribution with potential vector `h` and precision matrix
`J`.
"""
MvNormalCanon(h::AbstractVector{<:Real}, J::AbstractMatrix{<:Real}) = MvNormalCanon(h, PDMat(J))
MvNormalCanon(h::AbstractVector{<:Real}, J::Matrix{<:Real}) = MvNormalCanon(h, PDMat(J))
MvNormalCanon(h::AbstractVector{<:Real}, J::Diagonal{<:Real}) = MvNormalCanon(h, PDiagMat(J.diag))
MvNormalCanon::AbstractVector{<:Real}, J::Union{Symmetric{<:Real,<:Diagonal{<:Real}},Hermitian{<:Real,<:Diagonal{<:Real}}}) = MvNormalCanon(μ, PDiagMat(J.data.diag))
function MvNormalCanon(h::AbstractVector{<:Real}, J::UniformScaling{<:Real})
Expand Down Expand Up @@ -170,7 +170,7 @@ sqmahal!(r::AbstractVector, d::MvNormalCanon, x::AbstractMatrix) = quad!(r, d.J,

# Sampling (for GenericMvNormal)

unwhiten_winv!(J::AbstractPDMat, x::AbstractVecOrMat) = unwhiten!(inv(J), x)
unwhiten_winv!(J::AbstractMatrix, x::AbstractVecOrMat) = unwhiten!(inv(J), x)
unwhiten_winv!(J::PDiagMat, x::AbstractVecOrMat) = whiten!(J, x)
unwhiten_winv!(J::ScalMat, x::AbstractVecOrMat) = whiten!(J, x)
if isdefined(PDMats, :PDSparseMat)
Expand Down

0 comments on commit e416cb6

Please sign in to comment.