Skip to content
Browse files

Factor out handling of subtraction

  • Loading branch information...
1 parent 02fa0b3 commit ee4709b5699bf007dd3f7973f101afd4a5b5efa2 @Shimuuar committed Aug 30, 2012
Showing with 2 additions and 6 deletions.
  1. +2 −6 Numeric/BLAS/Expression.hs
View
8 Numeric/BLAS/Expression.hs
@@ -193,6 +193,8 @@ evalST :: (() -> Cont s) -> Expr m a -> ST s (Mutable m s a)
{-# INLINE evalST #-}
-- Reduce double scale.
evalST cont (Scale q α (Scale _ β e)) = cont q $ Scale ()*β) e
+-- Convert subtraction to addition
+evalST cont (Sub _ x (Scale q α y)) = cont q $ Add () x (Scale () (-α) y)
--
-- Addition
-- ========
@@ -209,9 +211,7 @@ evalST cont (Add _ m (VecH q v u)) | Just m_ <- mutable (cont q) m
evalST cont (Add _ m (Scale _ α (VecH q v u))) | Just m_ <- mutable (cont q) m = inplaceEvalVVH (cont q) α v u =<< m_
-- NOTE: A ← α·x·y' + A | for subtraction coefficient α must be negated
evalST cont (Sub _ m (VecT q v u)) | Just m_ <- mutable (cont q) m = inplaceEvalVVT (cont q) (-1) v u =<< m_
-evalST cont (Sub _ m (Scale _ α (VecT q v u))) | Just m_ <- mutable (cont q) m = inplaceEvalVVT (cont q) (-α) v u =<< m_
evalST cont (Sub _ m (VecH q v u)) | Just m_ <- mutable (cont q) m = inplaceEvalVVH (cont q) (-1) v u =<< m_
-evalST cont (Sub _ m (Scale _ α (VecH q v u))) | Just m_ <- mutable (cont q) m = inplaceEvalVVH (cont q) (-α) v u =<< m_
--
-- * Matrix x Vector
evalST cont (Add _ u (MulMV q m v)) | Just u_ <- mutable (cont q) u = inplaceEvalMV (cont q) 1 m v 1 =<< u_
@@ -221,8 +221,6 @@ evalST cont (Add _ (Scale _ β u) (Scale _ α (MulMV q m v))) | Just u_ <- mutab
-- NOTE: y ← α·A·x + β·y | for subtraction coefficient α must be negated
evalST cont (Sub _ u (MulMV q m v)) | Just u_ <- mutable (cont q) u = inplaceEvalMV (cont q) (-1) m v 1 =<< u_
evalST cont (Sub _ u (Scale _ α (MulMV q m v))) | Just u_ <- mutable (cont q) u = inplaceEvalMV (cont q) (-α) m v 1 =<< u_
-evalST cont (Sub _ (Scale _ β u) (MulMV q m v)) | Just u_ <- mutable (cont q) u = inplaceEvalMV (cont q) (-1) m v β =<< u_
-evalST cont (Sub _ (Scale _ β u) (Scale _ α (MulMV q m v))) | Just u_ <- mutable (cont q) u = inplaceEvalMV (cont q) (-α) m v β =<< u_
--
-- * op(Matrix) x Vector
evalST cont (Add _ u (MulTMV q t m v)) | Just u_ <- mutable (cont q) u = inplaceEvalTMV (cont q) 1 t m v 1 =<< u_
@@ -232,8 +230,6 @@ evalST cont (Add _ (Scale _ β u) (Scale _ α (MulTMV q t m v))) | Just u_ <- mu
-- NOTE: C ← α·op(A)·op(B) + β·C | for subtraction coefficient α must be negated
evalST cont (Sub _ u (MulTMV q t m v)) | Just u_ <- mutable (cont q) u = inplaceEvalTMV (cont q) (-1) t m v 1 =<< u_
evalST cont (Sub _ u (Scale _ α (MulTMV q t m v))) | Just u_ <- mutable (cont q) u = inplaceEvalTMV (cont q) (-α) t m v 1 =<< u_
-evalST cont (Sub _ (Scale _ β u) (MulTMV q t m v)) | Just u_ <- mutable (cont q) u = inplaceEvalTMV (cont q) (-1) t m v β =<< u_
-evalST cont (Sub _ (Scale _ β u) (Scale _ α (MulTMV q t m v))) | Just u_ <- mutable (cont q) u = inplaceEvalTMV (cont q) (-α) t m v β =<< u_
-- * No nice rules match. We have to use generic function. But still
-- let try to reuse temporaries as much as possible
evalST cont (Add q x y)

0 comments on commit ee4709b

Please sign in to comment.
Something went wrong with that request. Please try again.