Permalink
Browse files

Add symmetric and hermitian matrix mutiplication

  • Loading branch information...
1 parent 8bf9bad commit 02fa0b3391988d49855fc10552e79f515d0477ff @Shimuuar committed Aug 30, 2012
Showing with 38 additions and 0 deletions.
  1. +38 −0 Numeric/BLAS/Expression.hs
View
@@ -172,6 +172,12 @@ data Expr m a where
-> Trans -> Expr MatD.Matrix a
-> Trans -> Expr MatD.Matrix a
-> Expr MatD.Matrix a
+ -- Matrix-matrix multiplication for symmetric and dense matrix
+ MultSymMM :: BLAS3 a
+ => () -> Side -> Expr MatS.Symmetric a -> Expr MatD.Matrix a -> Expr MatD.Matrix a
+ -- Matrix-matrix multiplication for symmetric and hermitian matrix
+ MultHerMM :: (BLAS3 a, MMatS.Conjugate a)
+ => () -> Side -> Expr MatS.Hermitian a -> Expr MatD.Matrix a -> Expr MatD.Matrix a
@@ -271,6 +277,12 @@ evalST cont (MulTMV q t m v) = evalTMV (cont q) 1 t m v
-- * op(Matrix) x op(Matrix)
evalST cont (Scale _ a (MulMM q tm m tn n)) = evalMM (cont q) a tm m tn n
evalST cont (MulMM q tm m tn n) = evalMM (cont q) 1 tm m tn n
+-- * Symmetric matrix x Matrix
+evalST cont (Scale _ α (MultSymMM q sd ma mb)) = evalSymMM (cont q) sd α ma mb
+evalST cont ( MultSymMM q sd ma mb) = evalSymMM (cont q) sd 1 ma mb
+-- * Hermitian matrix x Matrix
+evalST cont (Scale _ α (MultHerMM q sd ma mb)) = evalHerMM (cont q) sd α ma mb
+evalST cont ( MultHerMM q sd ma mb) = evalHerMM (cont q) sd 1 ma mb
--
-- * Scale data type in place
evalST cont (Scale q a x) = do
@@ -435,6 +447,32 @@ evalMM cont a tm m tn n = do
multMM a tm m_ tn n_ 0 r
return r
+evalSymMM :: (BLAS3 a)
+ => Cont s -> Side -> a
+ -> Expr MatS.Symmetric a
+ -> Expr MatD.Matrix a
+ -> ST s (Mutable MatD.Matrix s a)
+{-# INLINE evalSymMM #-}
+evalSymMM cont side α ma mb = do
+ ma_ <- pull cont ma
+ mb_ <- pull cont mb
+ mc_ <- cloneShape mb_
+ multSymMM side α ma_ mb_ 0 mc_
+ return mc_
+
+evalHerMM :: (BLAS3 a, MMatS.Conjugate a)
+ => Cont s -> Side -> a
+ -> Expr MatS.Hermitian a
+ -> Expr MatD.Matrix a
+ -> ST s (Mutable MatD.Matrix s a)
+{-# INLINE evalHerMM #-}
+evalHerMM cont side α ma mb = do
+ ma_ <- pull cont ma
+ mb_ <- pull cont mb
+ mc_ <- cloneShape mb_
+ multHerMM side α ma_ mb_ 0 mc_
+ return mc_
+
----------------------------------------------------------------

0 comments on commit 02fa0b3

Please sign in to comment.