Skip to content

Commit

Permalink
Merge pull request #103 from johnczito/more_kron
Browse files Browse the repository at this point in the history
more kron
  • Loading branch information
andreasnoack committed Oct 15, 2019
2 parents 83b4f8f + 257835e commit 9f97462
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 18 deletions.
3 changes: 2 additions & 1 deletion src/generics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pdadd(a::Matrix{T}, b::AbstractPDMat{S}) where {T<:Real, S<:Real} = pdadd!(simil
*(a::AbstractPDMat, c::T) where {T<:Real} = a * c
*(c::T, a::AbstractPDMat) where {T<:Real} = a * c
/(a::AbstractPDMat, c::T) where {T<:Real} = a * inv(c)
Base.kron(A::AbstractPDMat, B::AbstractPDMat) = PDMat(kron(Matrix(A), Matrix(B)))


## whiten and unwhiten
Expand Down Expand Up @@ -58,7 +59,7 @@ PDMat{Float64,Array{Float64,2}}(2, [4.0 10.0; 10.0 30.0], Cholesky{Float64,Array
julia> W = whiten(a, X)
2×4 Array{Float64,2}:
0.5 0.5 0.5 0.5
0.5 0.5 0.5 0.5
-0.67082 -0.223607 0.223607 0.67082
julia> W * W'
Expand Down
2 changes: 1 addition & 1 deletion src/pdiagmat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ end
*(a::PDiagMat, c::T) where {T<:Real} = PDiagMat(a.diag * c)
*(a::PDiagMat, x::StridedVecOrMat) = a.diag .* x
\(a::PDiagMat, x::StridedVecOrMat) = a.inv_diag .* x

Base.kron(A::PDiagMat, B::PDiagMat) = PDiagMat( vcat([A.diag[i] * B.diag for i in 1:dim(A)]...) )

### Algebra

Expand Down
2 changes: 1 addition & 1 deletion src/scalmat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ end
/(a::ScalMat{T}, c::T) where {T<:Real} = ScalMat(a.dim, a.value / c)
*(a::ScalMat, x::StridedVecOrMat) = a.value * x
\(a::ScalMat, x::StridedVecOrMat) = a.inv_value * x

Base.kron(A::ScalMat, B::ScalMat) = ScalMat( dim(A) * dim(B), A.value * B.value )

### Algebra

Expand Down
42 changes: 27 additions & 15 deletions test/kron.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,34 @@
using PDMats
using Test

_randPDMat(T, n) = (X = randn(T, n, n); PDMat(X * X'))
_randPDiagMat(T, n) = PDiagMat(rand(T, n))
_randScalMat(T, n) = ScalMat(n, rand(T))
function _pd_compare(A::AbstractPDMat, B::AbstractPDMat)
@test dim(A) == dim(B)
@test Matrix(A) Matrix(B)
@test cholesky(A).L cholesky(B).L
@test cholesky(A).U cholesky(B).U
nothing
end
function _pd_kron_compare(A::AbstractPDMat, B::AbstractPDMat)
PDAkB1 = kron(A, B)
PDAkB2 = PDMat( kron( Matrix(A), Matrix(B) ) )
_pd_compare(PDAkB1, PDAkB2)
nothing
end

n = 4
m = 7

for T in [Float64, Float32]
X = randn(T, n, n)
Y = randn(T, m, m)
A = X * X'
B = Y * Y'
AkB = kron(A, B)
PDA = PDMat(A)
PDB = PDMat(B)
PDAkB1 = PDMat(AkB)
PDAkB2 = kron(PDA, PDB)
@test PDAkB1.dim == PDAkB2.dim
@test PDAkB1.mat PDAkB2.mat
@test PDAkB1.chol.L PDAkB2.chol.L
@test PDAkB1.chol.U PDAkB2.chol.U
@test Matrix(PDAkB2.chol) PDAkB2.mat
for T in [Float32, Float64]
_pd_kron_compare( _randPDMat(T, n), _randPDMat(T, m) )
_pd_kron_compare( _randPDiagMat(T, n), _randPDiagMat(T, m) )
_pd_kron_compare( _randScalMat(T, n), _randScalMat(T, m) )
_pd_kron_compare( _randPDMat(T, n), _randPDiagMat(T, m) )
_pd_kron_compare( _randPDiagMat(T, m), _randPDMat(T, n) )
_pd_kron_compare( _randPDMat(T, n), _randScalMat(T, m) )
_pd_kron_compare( _randScalMat(T, m), _randPDMat(T, n) )
_pd_kron_compare( _randPDiagMat(T, n), _randScalMat(T, m) )
_pd_kron_compare( _randScalMat(T, m), _randPDiagMat(T, n) )
end

0 comments on commit 9f97462

Please sign in to comment.