Permalink
Browse files

Add clever rules for matrix addition

  • Loading branch information...
1 parent 16cf4d4 commit 2a0d647924eb9355b26ccd776fbf6bfe915949ce @Shimuuar committed Aug 30, 2012
Showing with 74 additions and 8 deletions.
  1. +74 −8 Numeric/BLAS/Expression.hs
View
@@ -173,11 +173,11 @@ data Expr m a where
-> Trans -> Expr MatD.Matrix a
-> Expr MatD.Matrix a
-- Matrix-matrix multiplication for symmetric and dense matrix
- MultSymMM :: BLAS3 a
- => () -> Side -> Expr MatS.Symmetric a -> Expr MatD.Matrix a -> Expr MatD.Matrix a
+ MulSymMM :: BLAS3 a
+ => () -> Side -> Expr MatS.Symmetric a -> Expr MatD.Matrix a -> Expr MatD.Matrix a
-- Matrix-matrix multiplication for symmetric and hermitian matrix
- MultHerMM :: (BLAS3 a, MMatS.Conjugate a)
- => () -> Side -> Expr MatS.Hermitian a -> Expr MatD.Matrix a -> Expr MatD.Matrix a
+ MulHerMM :: (BLAS3 a, MMatS.Conjugate a)
+ => () -> Side -> Expr MatS.Hermitian a -> Expr MatD.Matrix a -> Expr MatD.Matrix a
@@ -230,6 +230,33 @@ evalST cont (Add _ (Scale _ β u) (Scale _ α (MulTMV q t m v))) | Just u_ <- mu
-- NOTE: C ← α·op(A)·op(B) + β·C | for subtraction coefficient α must be negated
evalST cont (Sub _ u (MulTMV q t m v)) | Just u_ <- mutable (cont q) u = inplaceEvalTMV (cont q) (-1) t m v 1 =<< u_
evalST cont (Sub _ (Scale _ β u) (MulTMV q t m v)) | Just u_ <- mutable (cont q) u = inplaceEvalTMV (cont q) (-1) t m v β =<< u_
+--
+-- * op(Matrix) x * op(Matrix)
+evalST cont (Add _ u (MulMM q tm m tn n)) | Just u_ <- mutable (cont q) u = inplaceEvalMM (cont q) 1 tm m tn n 1 =<< u_
+evalST cont (Add _ u (Scale _ α (MulMM q tm m tn n))) | Just u_ <- mutable (cont q) u = inplaceEvalMM (cont q) α tm m tn n 1 =<< u_
+evalST cont (Add _ (Scale _ β u) (MulMM q tm m tn n)) | Just u_ <- mutable (cont q) u = inplaceEvalMM (cont q) 1 tm m tn n β =<< u_
+evalST cont (Add _ (Scale _ β u) (Scale _ α (MulMM q tm m tn n))) | Just u_ <- mutable (cont q) u = inplaceEvalMM (cont q) α tm m tn n β =<< u_
+-- Subtraction
+evalST cont (Add _ u (MulMM q tm m tn n)) | Just u_ <- mutable (cont q) u = inplaceEvalMM (cont q) (-1) tm m tn n 1 =<< u_
+evalST cont (Add _ (Scale _ β u) (MulMM q tm m tn n)) | Just u_ <- mutable (cont q) u = inplaceEvalMM (cont q) (-1) tm m tn n β =<< u_
+--
+-- * op(Symmetric matrix) x * op(Matrix)
+evalST cont (Add _ u (MulSymMM q sd m n)) | Just u_ <- mutable (cont q) u = inplaceEvalSymMM (cont q) sd 1 m n 1 =<< u_
+evalST cont (Add _ u (Scale _ α (MulSymMM q sd m n))) | Just u_ <- mutable (cont q) u = inplaceEvalSymMM (cont q) sd α m n 1 =<< u_
+evalST cont (Add _ (Scale _ β u) (MulSymMM q sd m n)) | Just u_ <- mutable (cont q) u = inplaceEvalSymMM (cont q) sd 1 m n β =<< u_
+evalST cont (Add _ (Scale _ β u) (Scale _ α (MulSymMM q sd m n))) | Just u_ <- mutable (cont q) u = inplaceEvalSymMM (cont q) sd α m n β =<< u_
+-- Subtraction
+evalST cont (Add _ u (MulSymMM q sd m n)) | Just u_ <- mutable (cont q) u = inplaceEvalSymMM (cont q) sd (-1) m n 1 =<< u_
+evalST cont (Add _ (Scale _ β u) (MulSymMM q sd m n)) | Just u_ <- mutable (cont q) u = inplaceEvalSymMM (cont q) sd (-1) m n β =<< u_
+--
+-- * op(Hermitian matrix) x * op(Matrix)
+evalST cont (Add _ u (MulHerMM q sd m n)) | Just u_ <- mutable (cont q) u = inplaceEvalHerMM (cont q) sd 1 m n 1 =<< u_
+evalST cont (Add _ u (Scale _ α (MulHerMM q sd m n))) | Just u_ <- mutable (cont q) u = inplaceEvalHerMM (cont q) sd α m n 1 =<< u_
+evalST cont (Add _ (Scale _ β u) (MulHerMM q sd m n)) | Just u_ <- mutable (cont q) u = inplaceEvalHerMM (cont q) sd 1 m n β =<< u_
+evalST cont (Add _ (Scale _ β u) (Scale _ α (MulHerMM q sd m n))) | Just u_ <- mutable (cont q) u = inplaceEvalHerMM (cont q) sd α m n β =<< u_
+-- Subtraction
+evalST cont (Add _ u (MulHerMM q sd m n)) | Just u_ <- mutable (cont q) u = inplaceEvalHerMM (cont q) sd (-1) m n 1 =<< u_
+evalST cont (Add _ (Scale _ β u) (MulHerMM q sd m n)) | Just u_ <- mutable (cont q) u = inplaceEvalHerMM (cont q) sd (-1) m n β =<< u_
-- * No nice rules match. We have to use generic function. But still
-- let try to reuse temporaries as much as possible
evalST cont (Add q x y)
@@ -274,11 +301,11 @@ evalST cont (MulTMV q t m v) = evalTMV (cont q) 1 t m v
evalST cont (Scale _ a (MulMM q tm m tn n)) = evalMM (cont q) a tm m tn n
evalST cont (MulMM q tm m tn n) = evalMM (cont q) 1 tm m tn n
-- * Symmetric matrix x Matrix
-evalST cont (Scale _ α (MultSymMM q sd ma mb)) = evalSymMM (cont q) sd α ma mb
-evalST cont ( MultSymMM q sd ma mb) = evalSymMM (cont q) sd 1 ma mb
+evalST cont (Scale _ α (MulSymMM q sd ma mb)) = evalSymMM (cont q) sd α ma mb
+evalST cont ( MulSymMM q sd ma mb) = evalSymMM (cont q) sd 1 ma mb
-- * Hermitian matrix x Matrix
-evalST cont (Scale _ α (MultHerMM q sd ma mb)) = evalHerMM (cont q) sd α ma mb
-evalST cont ( MultHerMM q sd ma mb) = evalHerMM (cont q) sd 1 ma mb
+evalST cont (Scale _ α (MulHerMM q sd ma mb)) = evalHerMM (cont q) sd α ma mb
+evalST cont ( MulHerMM q sd ma mb) = evalHerMM (cont q) sd 1 ma mb
--
-- * Scale data type in place
evalST cont (Scale q a x) = do
@@ -443,6 +470,19 @@ evalMM cont a tm m tn n = do
multMM a tm m_ tn n_ 0 r
return r
+inplaceEvalMM :: (BLAS3 a)
+ => Cont s
+ -> a -> Trans -> Expr MatD.Matrix a
+ -> Trans -> Expr MatD.Matrix a
+ -> a -> Mutable MatD.Matrix s a
+ -> ST s (Mutable MatD.Matrix s a)
+{-# INLINE inplaceEvalMM #-}
+inplaceEvalMM cont α ta mA tb mB β mC_ = do
+ mA_ <- pull cont mA
+ mB_ <- pull cont mB
+ multMM α ta mA_ tb mB_ β mC_
+ return mC_
+
evalSymMM :: (BLAS3 a)
=> Cont s -> Side -> a
-> Expr MatS.Symmetric a
@@ -456,6 +496,19 @@ evalSymMM cont side α ma mb = do
multSymMM side α ma_ mb_ 0 mc_
return mc_
+inplaceEvalSymMM :: (BLAS3 a)
+ => Cont s -> Side -> a
+ -> Expr MatS.Symmetric a
+ -> Expr MatD.Matrix a
+ -> a -> Mutable MatD.Matrix s a
+ -> ST s (Mutable MatD.Matrix s a)
+{-# INLINE inplaceEvalSymMM #-}
+inplaceEvalSymMM cont side α ma mb β mc_ = do
+ ma_ <- pull cont ma
+ mb_ <- pull cont mb
+ multSymMM side α ma_ mb_ β mc_
+ return mc_
+
evalHerMM :: (BLAS3 a, MMatS.Conjugate a)
=> Cont s -> Side -> a
-> Expr MatS.Hermitian a
@@ -469,6 +522,19 @@ evalHerMM cont side α ma mb = do
multHerMM side α ma_ mb_ 0 mc_
return mc_
+inplaceEvalHerMM :: (BLAS3 a, MMatS.Conjugate a)
+ => Cont s -> Side -> a
+ -> Expr MatS.Hermitian a
+ -> Expr MatD.Matrix a
+ -> a -> Mutable MatD.Matrix s a
+ -> ST s (Mutable MatD.Matrix s a)
+{-# INLINE inplaceEvalHerMM #-}
+inplaceEvalHerMM cont side α ma mb β mc_ = do
+ ma_ <- pull cont ma
+ mb_ <- pull cont mb
+ multHerMM side α ma_ mb_ β mc_
+ return mc_
+
----------------------------------------------------------------

0 comments on commit 2a0d647

Please sign in to comment.