Skip to content

Commit

Permalink
Add subtraction
Browse files Browse the repository at this point in the history
  • Loading branch information
Shimuuar committed Aug 27, 2012
1 parent ec2664e commit 1d887c7
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 26 deletions.
4 changes: 4 additions & 0 deletions Numeric/BLAS.hs
Expand Up @@ -61,9 +61,13 @@ import Numeric.BLAS.Mutable (MVectorBLAS)
-- | Addition for vectors and matrices.
class Add a where
(.+.) :: a -> a -> a
(.-.) :: a -> a -> a

instance (AddM (Mutable m) a, Freeze m a) => Add (m a) where
x .+. y = eval $ Add () (Lit x) (Lit y)
{-# INLINE (.+.) #-}
x .-. y = eval $ Sub () (Lit x) (Lit y)
{-# INLINE (.-.) #-}

-- | Scalar multiplication
class Scale v a where
Expand Down
88 changes: 62 additions & 26 deletions Numeric/BLAS/Expression.hs
Expand Up @@ -109,13 +109,17 @@ class Scalable m a where
scale :: a -> m s a -> ST s ()


-- | Addition for mutable data. Second argument of 'addM' is modified
-- in place.
-- | Elementwise addition and subtraction for mutable data. First
-- argument of 'addM' is modified in place.
--
-- > y ← x + y
-- > x ← x + y
-- > x ← x - y
class AddM m a where
addM :: m s a -- /x/
-> m s a -- /y/
addM :: m s a -- ^ /x/
-> m s a -- ^ /y/
-> ST s ()
subM :: m s a -- ^ /x/
-> m s a -- ^ /y/
-> ST s ()


Expand All @@ -131,6 +135,9 @@ data Expr m a where
-- | Addition
Add :: (Freeze m a, AddM (Mutable m) a)
=> () -> Expr m a -> Expr m a -> Expr m a
-- | Subtraction
Sub :: (Freeze m a, AddM (Mutable m) a)
=> () -> Expr m a -> Expr m a -> Expr m a
-- | Scalar-X multiplication
Scale :: (Freeze m a, BLAS1 a, Scalable (Mutable m) a)
=> () -> a -> Expr m a -> Expr m a
Expand Down Expand Up @@ -181,32 +188,59 @@ evalST cont (Scale q α (Scale _ β e)) = cont q $ Scale () (α*β) e
-- appear on the left
--
-- * Vector x trans(Vector)
evalST cont (Add _ m (VecT q v u)) | Just m_ <- mutable (cont q) m = inplaceEvalVVT (cont q) 1 v u =<< m_
evalST cont (Add _ m (Scale _ α (VecT q v u))) | Just m_ <- mutable (cont q) m = inplaceEvalVVT (cont q) α v u =<< m_
evalST cont (Add _ m (VecH q v u)) | Just m_ <- mutable (cont q) m = inplaceEvalVVH (cont q) 1 v u =<< m_
evalST cont (Add _ m (Scale _ α (VecH q v u))) | Just m_ <- mutable (cont q) m = inplaceEvalVVH (cont q) α v u =<< m_
evalST cont (Add _ m (VecT q v u)) | Just m_ <- mutable (cont q) m = inplaceEvalVVT (cont q) 1 v u =<< m_
evalST cont (Add _ m (Scale _ α (VecT q v u))) | Just m_ <- mutable (cont q) m = inplaceEvalVVT (cont q) α v u =<< m_
evalST cont (Add _ m (VecH q v u)) | Just m_ <- mutable (cont q) m = inplaceEvalVVH (cont q) 1 v u =<< m_
evalST cont (Add _ m (Scale _ α (VecH q v u))) | Just m_ <- mutable (cont q) m = inplaceEvalVVH (cont q) α v u =<< m_
-- NOTE: A ← α·x·y' + A | for subtraction coefficient α must be negated
evalST cont (Sub _ m (VecT q v u)) | Just m_ <- mutable (cont q) m = inplaceEvalVVT (cont q) (-1) v u =<< m_
evalST cont (Sub _ m (Scale _ α (VecT q v u))) | Just m_ <- mutable (cont q) m = inplaceEvalVVT (cont q) (-α) v u =<< m_
evalST cont (Sub _ m (VecH q v u)) | Just m_ <- mutable (cont q) m = inplaceEvalVVH (cont q) (-1) v u =<< m_
evalST cont (Sub _ m (Scale _ α (VecH q v u))) | Just m_ <- mutable (cont q) m = inplaceEvalVVH (cont q) (-α) v u =<< m_
--
-- * Matrix x Vector
evalST cont (Add _ u (MulMV q m v)) | Just u_ <- mutable (cont q) u = inplaceEvalMV (cont q) 1 m v 1 =<< u_
evalST cont (Add _ u (Scale _ α (MulMV q m v))) | Just u_ <- mutable (cont q) u = inplaceEvalMV (cont q) α m v 1 =<< u_
evalST cont (Add _ (Scale _ β u) (MulMV q m v)) | Just u_ <- mutable (cont q) u = inplaceEvalMV (cont q) 1 m v β =<< u_
evalST cont (Add _ (Scale _ β u) (Scale _ α (MulMV q m v))) | Just u_ <- mutable (cont q) u = inplaceEvalMV (cont q) α m v β =<< u_
evalST cont (Add _ u (MulMV q m v)) | Just u_ <- mutable (cont q) u = inplaceEvalMV (cont q) 1 m v 1 =<< u_
evalST cont (Add _ u (Scale _ α (MulMV q m v))) | Just u_ <- mutable (cont q) u = inplaceEvalMV (cont q) α m v 1 =<< u_
evalST cont (Add _ (Scale _ β u) (MulMV q m v)) | Just u_ <- mutable (cont q) u = inplaceEvalMV (cont q) 1 m v β =<< u_
evalST cont (Add _ (Scale _ β u) (Scale _ α (MulMV q m v))) | Just u_ <- mutable (cont q) u = inplaceEvalMV (cont q) α m v β =<< u_
-- NOTE: y ← α·A·x + β·y | for subtraction coefficient α must be negated
evalST cont (Sub _ u (MulMV q m v)) | Just u_ <- mutable (cont q) u = inplaceEvalMV (cont q) (-1) m v 1 =<< u_
evalST cont (Sub _ u (Scale _ α (MulMV q m v))) | Just u_ <- mutable (cont q) u = inplaceEvalMV (cont q) (-α) m v 1 =<< u_
evalST cont (Sub _ (Scale _ β u) (MulMV q m v)) | Just u_ <- mutable (cont q) u = inplaceEvalMV (cont q) (-1) m v β =<< u_
evalST cont (Sub _ (Scale _ β u) (Scale _ α (MulMV q m v))) | Just u_ <- mutable (cont q) u = inplaceEvalMV (cont q) (-α) m v β =<< u_
--
-- * op(Matrix) x Vector
evalST cont (Add _ u (MulTMV q t m v)) | Just u_ <- mutable (cont q) u = inplaceEvalTMV (cont q) 1 t m v 1 =<< u_
evalST cont (Add _ u (Scale _ α (MulTMV q t m v))) | Just u_ <- mutable (cont q) u = inplaceEvalTMV (cont q) α t m v 1 =<< u_
evalST cont (Add _ (Scale _ β u) (MulTMV q t m v)) | Just u_ <- mutable (cont q) u = inplaceEvalTMV (cont q) 1 t m v β =<< u_
evalST cont (Add _ (Scale _ β u) (Scale _ α (MulTMV q t m v))) | Just u_ <- mutable (cont q) u = inplaceEvalTMV (cont q) α t m v β =<< u_
-- * No nice rules match. We have to use generic function
evalST cont (Add _ u (MulTMV q t m v)) | Just u_ <- mutable (cont q) u = inplaceEvalTMV (cont q) 1 t m v 1 =<< u_
evalST cont (Add _ u (Scale _ α (MulTMV q t m v))) | Just u_ <- mutable (cont q) u = inplaceEvalTMV (cont q) α t m v 1 =<< u_
evalST cont (Add _ (Scale _ β u) (MulTMV q t m v)) | Just u_ <- mutable (cont q) u = inplaceEvalTMV (cont q) 1 t m v β =<< u_
evalST cont (Add _ (Scale _ β u) (Scale _ α (MulTMV q t m v))) | Just u_ <- mutable (cont q) u = inplaceEvalTMV (cont q) α t m v β =<< u_
-- NOTE: C ← α·op(A)·op(B) + β·C | for subtraction coefficient α must be negated
evalST cont (Sub _ u (MulTMV q t m v)) | Just u_ <- mutable (cont q) u = inplaceEvalTMV (cont q) (-1) t m v 1 =<< u_
evalST cont (Sub _ u (Scale _ α (MulTMV q t m v))) | Just u_ <- mutable (cont q) u = inplaceEvalTMV (cont q) (-α) t m v 1 =<< u_
evalST cont (Sub _ (Scale _ β u) (MulTMV q t m v)) | Just u_ <- mutable (cont q) u = inplaceEvalTMV (cont q) (-1) t m v β =<< u_
evalST cont (Sub _ (Scale _ β u) (Scale _ α (MulTMV q t m v))) | Just u_ <- mutable (cont q) u = inplaceEvalTMV (cont q) (-α) t m v β =<< u_
-- * No nice rules match. We have to use generic function. But still
-- let try to reuse temporaries as much as possible
evalST cont (Add q x y)
| Just mx <- mutable (cont q) x = do y_ <- pull (cont q) y
x_ <- mx
addM y_ x_
| Just mx <- mutable (cont q) x = do x_ <- mx
y_ <- pull (cont q) y
addM x_ y_
return x_
| Just my <- mutable (cont q) y = do x_ <- pull (cont q) x
y_ <- my
addM x_ y_
addM y_ x_
return y_
| otherwise = do x_ <- pull (cont q) x
y_ <- cont q y
| otherwise = do x_ <- cont q x
y_ <- pull (cont q) y
addM x_ y_
return x_
evalST cont (Sub q x y)
| Just mx <- mutable (cont q) x = do y_ <- pull (cont q) y
x_ <- mx
subM x_ y_
return x_
| otherwise = do x_ <- cont q x
y_ <- pull (cont q) y
addM x_ y_
return y_
-- Multiplication
Expand Down Expand Up @@ -431,9 +465,11 @@ instance BLAS1 a => Scalable MS.MVector a where
scale = scaleVector

instance BLAS1 a => AddM MV.MVector a where
addM x y = addVecScaled 1 x y
addM x y = addVecScaled 1 y x
subM x y = addVecScaled (-1) y x
instance BLAS1 a => AddM MS.MVector a where
addM x y = addVecScaled 1 x y
addM x y = addVecScaled 1 y x
subM x y = addVecScaled (-1) y x



Expand Down

0 comments on commit 1d887c7

Please sign in to comment.