Permalink
Browse files

Add subtraction

  • Loading branch information...
1 parent ec2664e commit 1d887c7f6f0a468a1e4eea15e7a994c1911bd1df @Shimuuar committed Aug 27, 2012
Showing with 66 additions and 26 deletions.
  1. +4 −0 Numeric/BLAS.hs
  2. +62 −26 Numeric/BLAS/Expression.hs
View
@@ -61,9 +61,13 @@ import Numeric.BLAS.Mutable (MVectorBLAS)
-- | Addition for vectors and matrices.
class Add a where
(.+.) :: a -> a -> a
+ (.-.) :: a -> a -> a
instance (AddM (Mutable m) a, Freeze m a) => Add (m a) where
x .+. y = eval $ Add () (Lit x) (Lit y)
+ {-# INLINE (.+.) #-}
+ x .-. y = eval $ Sub () (Lit x) (Lit y)
+ {-# INLINE (.-.) #-}
-- | Scalar multiplication
class Scale v a where
@@ -109,13 +109,17 @@ class Scalable m a where
scale :: a -> m s a -> ST s ()
--- | Addition for mutable data. Second argument of 'addM' is modified
--- in place.
+-- | Elementwise addition and subtraction for mutable data. First
+-- argument of 'addM' is modified in place.
--
--- > y ← x + y
+-- > x ← x + y
+-- > x ← x - y
class AddM m a where
- addM :: m s a -- /x/
- -> m s a -- /y/
+ addM :: m s a -- ^ /x/
+ -> m s a -- ^ /y/
+ -> ST s ()
+ subM :: m s a -- ^ /x/
+ -> m s a -- ^ /y/
-> ST s ()
@@ -131,6 +135,9 @@ data Expr m a where
-- | Addition
Add :: (Freeze m a, AddM (Mutable m) a)
=> () -> Expr m a -> Expr m a -> Expr m a
+ -- | Subtraction
+ Sub :: (Freeze m a, AddM (Mutable m) a)
+ => () -> Expr m a -> Expr m a -> Expr m a
-- | Scalar-X multiplication
Scale :: (Freeze m a, BLAS1 a, Scalable (Mutable m) a)
=> () -> a -> Expr m a -> Expr m a
@@ -181,32 +188,59 @@ evalST cont (Scale q α (Scale _ β e)) = cont q $ Scale () (α*β) e
-- appear on the left
--
-- * Vector x trans(Vector)
-evalST cont (Add _ m (VecT q v u)) | Just m_ <- mutable (cont q) m = inplaceEvalVVT (cont q) 1 v u =<< m_
-evalST cont (Add _ m (Scale _ α (VecT q v u))) | Just m_ <- mutable (cont q) m = inplaceEvalVVT (cont q) α v u =<< m_
-evalST cont (Add _ m (VecH q v u)) | Just m_ <- mutable (cont q) m = inplaceEvalVVH (cont q) 1 v u =<< m_
-evalST cont (Add _ m (Scale _ α (VecH q v u))) | Just m_ <- mutable (cont q) m = inplaceEvalVVH (cont q) α v u =<< m_
+evalST cont (Add _ m (VecT q v u)) | Just m_ <- mutable (cont q) m = inplaceEvalVVT (cont q) 1 v u =<< m_
+evalST cont (Add _ m (Scale _ α (VecT q v u))) | Just m_ <- mutable (cont q) m = inplaceEvalVVT (cont q) α v u =<< m_
+evalST cont (Add _ m (VecH q v u)) | Just m_ <- mutable (cont q) m = inplaceEvalVVH (cont q) 1 v u =<< m_
+evalST cont (Add _ m (Scale _ α (VecH q v u))) | Just m_ <- mutable (cont q) m = inplaceEvalVVH (cont q) α v u =<< m_
+-- NOTE: A ← α·x·y' + A | for subtraction coefficient α must be negated
+evalST cont (Sub _ m (VecT q v u)) | Just m_ <- mutable (cont q) m = inplaceEvalVVT (cont q) (-1) v u =<< m_
+evalST cont (Sub _ m (Scale _ α (VecT q v u))) | Just m_ <- mutable (cont q) m = inplaceEvalVVT (cont q) (-α) v u =<< m_
+evalST cont (Sub _ m (VecH q v u)) | Just m_ <- mutable (cont q) m = inplaceEvalVVH (cont q) (-1) v u =<< m_
+evalST cont (Sub _ m (Scale _ α (VecH q v u))) | Just m_ <- mutable (cont q) m = inplaceEvalVVH (cont q) (-α) v u =<< m_
+--
-- * Matrix x Vector
-evalST cont (Add _ u (MulMV q m v)) | Just u_ <- mutable (cont q) u = inplaceEvalMV (cont q) 1 m v 1 =<< u_
-evalST cont (Add _ u (Scale _ α (MulMV q m v))) | Just u_ <- mutable (cont q) u = inplaceEvalMV (cont q) α m v 1 =<< u_
-evalST cont (Add _ (Scale _ β u) (MulMV q m v)) | Just u_ <- mutable (cont q) u = inplaceEvalMV (cont q) 1 m v β =<< u_
-evalST cont (Add _ (Scale _ β u) (Scale _ α (MulMV q m v))) | Just u_ <- mutable (cont q) u = inplaceEvalMV (cont q) α m v β =<< u_
+evalST cont (Add _ u (MulMV q m v)) | Just u_ <- mutable (cont q) u = inplaceEvalMV (cont q) 1 m v 1 =<< u_
+evalST cont (Add _ u (Scale _ α (MulMV q m v))) | Just u_ <- mutable (cont q) u = inplaceEvalMV (cont q) α m v 1 =<< u_
+evalST cont (Add _ (Scale _ β u) (MulMV q m v)) | Just u_ <- mutable (cont q) u = inplaceEvalMV (cont q) 1 m v β =<< u_
+evalST cont (Add _ (Scale _ β u) (Scale _ α (MulMV q m v))) | Just u_ <- mutable (cont q) u = inplaceEvalMV (cont q) α m v β =<< u_
+-- NOTE: y ← α·A·x + β·y | for subtraction coefficient α must be negated
+evalST cont (Sub _ u (MulMV q m v)) | Just u_ <- mutable (cont q) u = inplaceEvalMV (cont q) (-1) m v 1 =<< u_
+evalST cont (Sub _ u (Scale _ α (MulMV q m v))) | Just u_ <- mutable (cont q) u = inplaceEvalMV (cont q) (-α) m v 1 =<< u_
+evalST cont (Sub _ (Scale _ β u) (MulMV q m v)) | Just u_ <- mutable (cont q) u = inplaceEvalMV (cont q) (-1) m v β =<< u_
+evalST cont (Sub _ (Scale _ β u) (Scale _ α (MulMV q m v))) | Just u_ <- mutable (cont q) u = inplaceEvalMV (cont q) (-α) m v β =<< u_
+--
-- * op(Matrix) x Vector
-evalST cont (Add _ u (MulTMV q t m v)) | Just u_ <- mutable (cont q) u = inplaceEvalTMV (cont q) 1 t m v 1 =<< u_
-evalST cont (Add _ u (Scale _ α (MulTMV q t m v))) | Just u_ <- mutable (cont q) u = inplaceEvalTMV (cont q) α t m v 1 =<< u_
-evalST cont (Add _ (Scale _ β u) (MulTMV q t m v)) | Just u_ <- mutable (cont q) u = inplaceEvalTMV (cont q) 1 t m v β =<< u_
-evalST cont (Add _ (Scale _ β u) (Scale _ α (MulTMV q t m v))) | Just u_ <- mutable (cont q) u = inplaceEvalTMV (cont q) α t m v β =<< u_
--- * No nice rules match. We have to use generic function
+evalST cont (Add _ u (MulTMV q t m v)) | Just u_ <- mutable (cont q) u = inplaceEvalTMV (cont q) 1 t m v 1 =<< u_
+evalST cont (Add _ u (Scale _ α (MulTMV q t m v))) | Just u_ <- mutable (cont q) u = inplaceEvalTMV (cont q) α t m v 1 =<< u_
+evalST cont (Add _ (Scale _ β u) (MulTMV q t m v)) | Just u_ <- mutable (cont q) u = inplaceEvalTMV (cont q) 1 t m v β =<< u_
+evalST cont (Add _ (Scale _ β u) (Scale _ α (MulTMV q t m v))) | Just u_ <- mutable (cont q) u = inplaceEvalTMV (cont q) α t m v β =<< u_
+-- NOTE: C ← α·op(A)·op(B) + β·C | for subtraction coefficient α must be negated
+evalST cont (Sub _ u (MulTMV q t m v)) | Just u_ <- mutable (cont q) u = inplaceEvalTMV (cont q) (-1) t m v 1 =<< u_
+evalST cont (Sub _ u (Scale _ α (MulTMV q t m v))) | Just u_ <- mutable (cont q) u = inplaceEvalTMV (cont q) (-α) t m v 1 =<< u_
+evalST cont (Sub _ (Scale _ β u) (MulTMV q t m v)) | Just u_ <- mutable (cont q) u = inplaceEvalTMV (cont q) (-1) t m v β =<< u_
+evalST cont (Sub _ (Scale _ β u) (Scale _ α (MulTMV q t m v))) | Just u_ <- mutable (cont q) u = inplaceEvalTMV (cont q) (-α) t m v β =<< u_
+-- * No nice rules match. We have to use generic function. But still
+-- let try to reuse temporaries as much as possible
evalST cont (Add q x y)
- | Just mx <- mutable (cont q) x = do y_ <- pull (cont q) y
- x_ <- mx
- addM y_ x_
+ | Just mx <- mutable (cont q) x = do x_ <- mx
+ y_ <- pull (cont q) y
+ addM x_ y_
return x_
| Just my <- mutable (cont q) y = do x_ <- pull (cont q) x
y_ <- my
- addM x_ y_
+ addM y_ x_
return y_
- | otherwise = do x_ <- pull (cont q) x
- y_ <- cont q y
+ | otherwise = do x_ <- cont q x
+ y_ <- pull (cont q) y
+ addM x_ y_
+ return x_
+evalST cont (Sub q x y)
+ | Just mx <- mutable (cont q) x = do y_ <- pull (cont q) y
+ x_ <- mx
+ subM x_ y_
+ return x_
+ | otherwise = do x_ <- cont q x
+ y_ <- pull (cont q) y
addM x_ y_
return y_
-- Multiplication
@@ -431,9 +465,11 @@ instance BLAS1 a => Scalable MS.MVector a where
scale = scaleVector
instance BLAS1 a => AddM MV.MVector a where
- addM x y = addVecScaled 1 x y
+ addM x y = addVecScaled 1 y x
+ subM x y = addVecScaled (-1) y x
instance BLAS1 a => AddM MS.MVector a where
- addM x y = addVecScaled 1 x y
+ addM x y = addVecScaled 1 y x
+ subM x y = addVecScaled (-1) y x

0 comments on commit 1d887c7

Please sign in to comment.