Skip to content

Commit

Permalink
Widen (inv)quad signature to allow any AbstractVector
Browse files Browse the repository at this point in the history
  • Loading branch information
ararslan committed Apr 17, 2017
1 parent 4687667 commit cc84e1a
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 8 deletions.
4 changes: 2 additions & 2 deletions src/pdiagmat.jl
Expand Up @@ -81,8 +81,8 @@ unwhiten!(r::StridedMatrix, a::PDiagMat, x::StridedMatrix) =

### quadratic forms

quad(a::PDiagMat, x::Vector) = wsumsq(a.diag, x)
invquad(a::PDiagMat, x::Vector) = wsumsq(a.inv_diag, x)
quad(a::PDiagMat, x::AbstractVector) = wsumsq(a.diag, x)
invquad(a::PDiagMat, x::AbstractVector) = wsumsq(a.inv_diag, x)

quad!(r::AbstractArray, a::PDiagMat, x::Matrix) = At_mul_B!(r, @compat(abs2.(x)), a.diag)
invquad!(r::AbstractArray, a::PDiagMat, x::Matrix) = At_mul_B!(r, @compat(abs2.(x)), a.inv_diag)
Expand Down
4 changes: 2 additions & 2 deletions src/pdmat.jl
Expand Up @@ -65,8 +65,8 @@ end

### quadratic forms

quad(a::PDMat, x::StridedVector) = dot(x, a * x)
invquad(a::PDMat, x::StridedVector) = dot(x, a \ x)
quad(a::PDMat, x::AbstractVector) = dot(x, a * x)
invquad(a::PDMat, x::AbstractVector) = dot(x, a \ x)

quad!(r::AbstractArray, a::PDMat, x::StridedMatrix) = colwise_dot!(r, x, a.mat * x)
invquad!(r::AbstractArray, a::PDMat, x::StridedMatrix) = colwise_dot!(r, x, a.mat \ x)
Expand Down
4 changes: 2 additions & 2 deletions src/pdsparsemat.jl
Expand Up @@ -62,8 +62,8 @@ end

### quadratic forms

quad(a::PDSparseMat, x::StridedVector) = dot(x, a * x)
invquad(a::PDSparseMat, x::StridedVector) = dot(x, a \ x)
quad(a::PDSparseMat, x::AbstractVector) = dot(x, a * x)
invquad(a::PDSparseMat, x::AbstractVector) = dot(x, a \ x)

function quad!(r::AbstractArray, a::PDSparseMat, x::StridedMatrix)
for i in 1:size(x, 2)
Expand Down
4 changes: 2 additions & 2 deletions src/scalmat.jl
Expand Up @@ -68,8 +68,8 @@ end

### quadratic forms

quad(a::ScalMat, x::Vector) = sum(abs2, x) * a.value
invquad(a::ScalMat, x::Vector) = sum(abs2, x) * a.inv_value
quad(a::ScalMat, x::AbstractVector) = sum(abs2, x) * a.value
invquad(a::ScalMat, x::AbstractVector) = sum(abs2, x) * a.inv_value

quad!(r::AbstractArray, a::ScalMat, x::Matrix) = colwise_sumsq!(r, x, a.value)
invquad!(r::AbstractArray, a::ScalMat, x::Matrix) = colwise_sumsq!(r, x, a.inv_value)
Expand Down
3 changes: 3 additions & 0 deletions src/testutils.jl
Expand Up @@ -5,6 +5,7 @@
#

import Base.Test: @test
using Compat: view

## driver function
function test_pdmat(C::AbstractPDMat, Cmat::Matrix;
Expand Down Expand Up @@ -182,6 +183,7 @@ function pdtest_quad(C::AbstractPDMat, Cmat::Matrix, Imat::Matrix, X::Matrix, ve
xi = vec(X[:,i])
r_quad[i] = dot(xi, Cmat * xi)
@test quad(C, xi) r_quad[i]
@test quad(C, view(X,:,i)) r_quad[i]
end
@test quad(C, X) r_quad

Expand All @@ -191,6 +193,7 @@ function pdtest_quad(C::AbstractPDMat, Cmat::Matrix, Imat::Matrix, X::Matrix, ve
xi = vec(X[:,i])
r_invquad[i] = dot(xi, Imat * xi)
@test invquad(C, xi) r_invquad[i]
@test invquad(C, view(X,:,i)) r_invquad[i]
end
@test invquad(C, X) r_invquad
end
Expand Down

0 comments on commit cc84e1a

Please sign in to comment.