Skip to content

Commit

Permalink
Merge pull request #62 from JuliaStats/aa/strided
Browse files Browse the repository at this point in the history
Widen (inv)quad(!) signatures to allow StridedArrays
  • Loading branch information
andreasnoack committed Apr 19, 2017
2 parents 4687667 + 0b04e3b commit a7871cb
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 8 deletions.
8 changes: 4 additions & 4 deletions src/pdiagmat.jl
Expand Up @@ -81,11 +81,11 @@ unwhiten!(r::StridedMatrix, a::PDiagMat, x::StridedMatrix) =

### quadratic forms

quad(a::PDiagMat, x::Vector) = wsumsq(a.diag, x)
invquad(a::PDiagMat, x::Vector) = wsumsq(a.inv_diag, x)
quad(a::PDiagMat, x::StridedVector) = wsumsq(a.diag, x)
invquad(a::PDiagMat, x::StridedVector) = wsumsq(a.inv_diag, x)

quad!(r::AbstractArray, a::PDiagMat, x::Matrix) = At_mul_B!(r, @compat(abs2.(x)), a.diag)
invquad!(r::AbstractArray, a::PDiagMat, x::Matrix) = At_mul_B!(r, @compat(abs2.(x)), a.inv_diag)
quad!(r::AbstractArray, a::PDiagMat, x::StridedMatrix) = At_mul_B!(r, @compat(abs2.(x)), a.diag)
invquad!(r::AbstractArray, a::PDiagMat, x::StridedMatrix) = At_mul_B!(r, @compat(abs2.(x)), a.inv_diag)


### tri products
Expand Down
8 changes: 4 additions & 4 deletions src/scalmat.jl
Expand Up @@ -68,11 +68,11 @@ end

### quadratic forms

quad(a::ScalMat, x::Vector) = sum(abs2, x) * a.value
invquad(a::ScalMat, x::Vector) = sum(abs2, x) * a.inv_value
quad(a::ScalMat, x::StridedVector) = sum(abs2, x) * a.value
invquad(a::ScalMat, x::StridedVector) = sum(abs2, x) * a.inv_value

quad!(r::AbstractArray, a::ScalMat, x::Matrix) = colwise_sumsq!(r, x, a.value)
invquad!(r::AbstractArray, a::ScalMat, x::Matrix) = colwise_sumsq!(r, x, a.inv_value)
quad!(r::AbstractArray, a::ScalMat, x::StridedMatrix) = colwise_sumsq!(r, x, a.value)
invquad!(r::AbstractArray, a::ScalMat, x::StridedMatrix) = colwise_sumsq!(r, x, a.inv_value)


### tri products
Expand Down
3 changes: 3 additions & 0 deletions src/testutils.jl
Expand Up @@ -5,6 +5,7 @@
#

import Base.Test: @test
using Compat: view

## driver function
function test_pdmat(C::AbstractPDMat, Cmat::Matrix;
Expand Down Expand Up @@ -182,6 +183,7 @@ function pdtest_quad(C::AbstractPDMat, Cmat::Matrix, Imat::Matrix, X::Matrix, ve
xi = vec(X[:,i])
r_quad[i] = dot(xi, Cmat * xi)
@test quad(C, xi) r_quad[i]
@test quad(C, view(X,:,i)) r_quad[i]
end
@test quad(C, X) r_quad

Expand All @@ -191,6 +193,7 @@ function pdtest_quad(C::AbstractPDMat, Cmat::Matrix, Imat::Matrix, X::Matrix, ve
xi = vec(X[:,i])
r_invquad[i] = dot(xi, Imat * xi)
@test invquad(C, xi) r_invquad[i]
@test invquad(C, view(X,:,i)) r_invquad[i]
end
@test invquad(C, X) r_invquad
end
Expand Down

0 comments on commit a7871cb

Please sign in to comment.