Skip to content

Commit

Permalink
Merge pull request #42 from jmxpearson/scalmats
Browse files Browse the repository at this point in the history
generalizing scalmat
  • Loading branch information
andreasnoack committed Mar 28, 2016
2 parents 8fa302c + 4534ad4 commit 7bceeb6
Showing 1 changed file with 16 additions and 16 deletions.
32 changes: 16 additions & 16 deletions src/scalmat.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# Scaling matrix

immutable ScalMat{T<:AbstractFloat} <: AbstractPDMat{T}
immutable ScalMat{T<:Real} <: AbstractPDMat{T}
dim::Int
value::T
inv_value::T
end

ScalMat(d::Int,v::AbstractFloat) = ScalMat{typeof(v)}(d, v, one(v) / v)
ScalMat(d::Int,v::Real) = ScalMat{typeof(v)}(d, v, one(v) / v)

### Basics

Expand All @@ -17,7 +17,7 @@ diag(a::ScalMat) = fill(a.value, a.dim)

### Arithmetics

function pdadd!{T<:AbstractFloat}(r::Matrix{T}, a::Matrix{T}, b::ScalMat{T}, c::T)
function pdadd!(r::Matrix, a::Matrix, b::ScalMat, c)
@check_argdims size(r) == size(a) == size(b)
if is(r, a)
_adddiag!(r, b.value * c)
Expand All @@ -27,10 +27,10 @@ function pdadd!{T<:AbstractFloat}(r::Matrix{T}, a::Matrix{T}, b::ScalMat{T}, c::
return r
end

*{T<:AbstractFloat}(a::ScalMat{T}, c::T) = ScalMat(a.dim, a.value * c)
*{T<:AbstractFloat}(a::ScalMat, c::T) = ScalMat(a.dim, a.value * c)
/{T<:AbstractFloat}(a::ScalMat{T}, c::T) = ScalMat(a.dim, a.value / c)
*{T<:AbstractFloat}(a::ScalMat{T}, x::StridedVecOrMat{T}) = a.value * x
\{T<:AbstractFloat}(a::ScalMat{T}, x::StridedVecOrMat{T}) = a.inv_value * x
*(a::ScalMat, x::StridedVecOrMat) = a.value * x
\(a::ScalMat, x::StridedVecOrMat) = a.inv_value * x


### Algebra
Expand All @@ -43,7 +43,7 @@ eigmin(a::ScalMat) = a.value

### whiten and unwhiten

function whiten!{T<:AbstractFloat}(r::StridedVecOrMat{T}, a::ScalMat{T}, x::StridedVecOrMat{T})
function whiten!(r::StridedVecOrMat, a::ScalMat, x::StridedVecOrMat)
@check_argdims dim(a) == size(x, 1)
c = sqrt(a.inv_value)
for i = 1:length(x)
Expand All @@ -52,7 +52,7 @@ function whiten!{T<:AbstractFloat}(r::StridedVecOrMat{T}, a::ScalMat{T}, x::Stri
return r
end

function unwhiten!{T<:AbstractFloat}(r::StridedVecOrMat{T}, a::ScalMat{T}, x::StridedVecOrMat{T})
function unwhiten!(r::StridedVecOrMat, a::ScalMat, x::StridedVecOrMat)
@check_argdims dim(a) == size(x, 1)
c = sqrt(a.value)
for i = 1:length(x)
Expand All @@ -64,31 +64,31 @@ end

### quadratic forms

quad{T<:AbstractFloat}(a::ScalMat, x::Vector{T}) = sumabs2(x) * a.value
invquad{T<:AbstractFloat}(a::ScalMat, x::Vector{T}) = sumabs2(x) * a.inv_value
quad(a::ScalMat, x::Vector) = sumabs2(x) * a.value
invquad(a::ScalMat, x::Vector) = sumabs2(x) * a.inv_value

quad!{T<:AbstractFloat}(r::AbstractArray{T}, a::ScalMat{T}, x::Matrix{T}) = colwise_sumsq!(r, x, a.value)
invquad!{T<:AbstractFloat}(r::AbstractArray{T}, a::ScalMat{T}, x::Matrix{T}) = colwise_sumsq!(r, x, a.inv_value)
quad!(r::AbstractArray, a::ScalMat, x::Matrix) = colwise_sumsq!(r, x, a.value)
invquad!(r::AbstractArray, a::ScalMat, x::Matrix) = colwise_sumsq!(r, x, a.inv_value)


### tri products

function X_A_Xt{T<:AbstractFloat}(a::ScalMat{T}, x::StridedMatrix{T})
function X_A_Xt(a::ScalMat, x::StridedMatrix)
@check_argdims dim(a) == size(x, 2)
gemm('N', 'T', a.value, x, x)
end

function Xt_A_X{T<:AbstractFloat}(a::ScalMat{T}, x::StridedMatrix{T})
function Xt_A_X(a::ScalMat, x::StridedMatrix)
@check_argdims dim(a) == size(x, 1)
gemm('T', 'N', a.value, x, x)
end

function X_invA_Xt{T<:AbstractFloat}(a::ScalMat{T}, x::StridedMatrix{T})
function X_invA_Xt(a::ScalMat, x::StridedMatrix)
@check_argdims dim(a) == size(x, 2)
gemm('N', 'T', a.inv_value, x, x)
end

function Xt_invA_X{T<:AbstractFloat}(a::ScalMat{T}, x::StridedMatrix{T})
function Xt_invA_X(a::ScalMat, x::StridedMatrix)
@check_argdims dim(a) == size(x, 1)
gemm('T', 'N', a.inv_value, x, x)
end

0 comments on commit 7bceeb6

Please sign in to comment.