Permalink
Browse files

Refactor type classes for addition

and multiplication by scalar
  • Loading branch information...
1 parent 7eeae1d commit 328012952de5f0f4ce798e1e5ed732d4d5182d71 @Shimuuar committed Aug 31, 2012
Showing with 70 additions and 76 deletions.
  1. +8 −16 Numeric/BLAS.hs
  2. +62 −60 Numeric/BLAS/Expression.hs
View
@@ -13,8 +13,7 @@
-- BLAS operations for immutable vectors and matrices.
module Numeric.BLAS (
-- * Type class based API
- Add(..)
- , Scale(..)
+ LinSpace(..)
, Mul(..)
, trans
, conj
@@ -60,25 +59,18 @@ import Numeric.BLAS.Mutable (MVectorBLAS)
----------------------------------------------------------------
-- | Addition for vectors and matrices.
-class Add a where
- (.+.) :: a -> a -> a
- (.-.) :: a -> a -> a
+class LinSpace m a where
+ (.+.) :: m a -> m a -> m a
+ (.-.) :: m a -> m a -> m a
+ ( *.) :: a -> m a -> m a
-instance (AddM (Mutable m) a, Freeze m a, Num a, Scalable (Mutable m) a) => Add (m a) where
+instance (LinSpaceM (Mutable m) a, Freeze m a, Num a) => LinSpace m a where
x .+. y = eval $ Add () (Lit x) (Lit y)
{-# INLINE (.+.) #-}
x .-. y = eval $ Sub () (Lit x) (Lit y)
{-# INLINE (.-.) #-}
-
-
--- | Multiplication by scalar.
-class Scale v a where
- (*.) :: a -> v a -> v a
-
-instance (Num a, Scalable (Mutable m) a, Freeze m a) => Scale m a where
- α *. v = eval $ Scale () α (Lit v)
- {-# INLINE (*.) #-}
-
+ α *. x = eval $ Scale () α (Lit x)
+ {-# INLINE (*.) #-}
-- | Very overloaded operator for matrix and vector multiplication.
class Mul v u where
View
@@ -23,8 +23,7 @@ module Numeric.BLAS.Expression (
-- * Supporting type classes
, Clonable(..)
, Freeze(..)
- , Scalable(..)
- , AddM(..)
+ , LinSpaceM(..)
) where
import Control.Monad
@@ -51,6 +50,9 @@ import Data.Matrix.Symmetric.Mutable (MSymmetricRaw(..),Conjugate,IsSymmetric,Is
import Numeric.BLAS.Mutable
-- import Debug.Trace
+import Data.Vector.Storable.Internal (updPtr)
+import Foreign.Marshal.Array (advancePtr)
+
@@ -115,23 +117,22 @@ class Clonable (Mutable m) a => Freeze m a where
unsafeThaw :: m a -> ST s (Mutable m s a)
--- | Scale every element of data structure in place
-class Scalable m a where
- scale :: a -> m s a -> ST s ()
-
-
-- | Elementwise addition and subtraction for mutable data. First
-- arguments of 'addM' and 'subM' are modified in place.
---
--- > x ← x + y
--- > x ← x - y
-class AddM m a where
+class Num a => LinSpaceM m a where
+ -- | > y ← x + y
addM :: m s a -- ^ /x/
-> m s a -- ^ /y/
-> ST s ()
- subM :: m s a -- ^ /x/
- -> m s a -- ^ /y/
- -> ST s ()
+ addM = addScaleM 1
+ {-# INLINE addM #-}
+ -- | > x ← α · x
+ scaleM :: a -> m s a -> ST s ()
+ -- | > y <- α·x + y
+ addScaleM :: a -- ^ /α/
+ -> m s a -- ^ /x/
+ -> m s a -- ^ /y/
+ -> ST s ()
@@ -144,13 +145,13 @@ data Expr m a where
-- Literal value. Could not be altered.
Lit :: Freeze m a => m a -> Expr m a
-- Addition
- Add :: (Freeze m a, AddM (Mutable m) a)
+ Add :: (Freeze m a, LinSpaceM (Mutable m) a)
=> () -> Expr m a -> Expr m a -> Expr m a
-- Subtraction
- Sub :: (Freeze m a, AddM (Mutable m) a, Num a, Scalable (Mutable m) a)
+ Sub :: (Freeze m a, LinSpaceM (Mutable m) a)
=> () -> Expr m a -> Expr m a -> Expr m a
-- Scalar-X multiplication
- Scale :: (Freeze m a, Num a, Scalable (Mutable m) a)
+ Scale :: (Freeze m a, LinSpaceM (Mutable m) a)
=> () -> a -> Expr m a -> Expr m a
-- vector x transposed vector => matrix
VecT :: (Freeze v a, MVectorBLAS (Mutable v), BLAS2 a)
@@ -163,17 +164,17 @@ data Expr m a where
, MVectorBLAS (Mutable v), G.Vector v a
, BLAS2 a
, Freeze mat a, Freeze v a
- , Scalable (Mutable mat) a
- , Scalable (Mutable v ) a
+ , LinSpaceM (Mutable mat) a
+ , LinSpaceM (Mutable v ) a
)
=> () -> Expr mat a -> Expr v a -> Expr v a
-- Transformed matrix-vector multiplication
MulTMV :: ( MultTMV (Mutable mat) a
, MVectorBLAS (Mutable v), G.Vector v a
, BLAS2 a
, Freeze mat a, Freeze v a
- , Scalable (Mutable mat) a
- , Scalable (Mutable v ) a
+ , LinSpaceM (Mutable mat) a
+ , LinSpaceM (Mutable v ) a
)
=> () -> Trans -> Expr mat a -> Expr v a -> Expr v a
-- Matrix-matrix multiplication for dense matrices
@@ -294,15 +295,6 @@ evalST cont (Add q x y)
y_ <- pull (cont q) y
addM x_ y_
return x_
-evalST cont (Sub q x y)
- | Just mx <- mutable (cont q) x = do y_ <- pull (cont q) y
- x_ <- mx
- subM x_ y_
- return x_
- | otherwise = do x_ <- cont q x
- y_ <- pull (cont q) y
- addM x_ y_
- return y_
-- Multiplication
-- ==============
--
@@ -332,7 +324,7 @@ evalST cont ( MulHerMM q sd ma mb) = evalHerMM (cont q) sd 1 ma mb
-- * Scale data type in place
evalST cont (Scale q a x) = do
m <- cont q x
- scale a m
+ scaleM a m
return m
-- * Copy literal so it could be mutated in subsequent operations. It
-- doesn't incur any unnecessary cost because in places where data
@@ -611,42 +603,52 @@ instance (Storable a, Conjugate a) => Freeze (SymmetricRaw IsHermitian) a where
unsafeThaw = Mat.unsafeThaw
-instance BLAS1 a => Scalable MV.MVector a where
- scale = scaleVector
-instance BLAS1 a => Scalable MS.MVector a where
- scale = scaleVector
-instance BLAS1 a => Scalable MMatrix a where
- scale α m = do
- forM_ [0 .. MMat.cols m - 1] $ \i -> do
+instance BLAS1 a => LinSpaceM MV.MVector a where
+ scaleM = scaleVector
+ {-# INLINE scaleM #-}
+ addScaleM = addVecScaled
+ {-# INLINE addScaleM #-}
+instance BLAS1 a => LinSpaceM MS.MVector a where
+ scaleM = scaleVector
+ {-# INLINE scaleM #-}
+ addScaleM = addVecScaled
+ {-# INLINE addScaleM #-}
+instance BLAS1 a => LinSpaceM MMatrix a where
+ scaleM α m = do
+ forM_ [0 .. MMat.cols m - 1] $ \i ->
scaleVector α $ MMatD.unsafeGetCol m i
-instance (BLAS1 a, MMat.IsMMatrix (MSymmetricRaw tag) a) => Scalable (MSymmetricRaw tag) a where
- scale α m = do
- forM_ [0 .. n-1] $ \i ->
- forM_ [i .. n-1] $ \j -> do
- MMat.write m (i,j) . (*α) =<< MMat.read m (i,j)
- where
- n = MMat.cols m
-
-instance BLAS1 a => AddM MV.MVector a where
- addM x y = addVecScaled 1 y x
- subM x y = addVecScaled (-1) y x
-instance BLAS1 a => AddM MS.MVector a where
- addM x y = addVecScaled 1 y x
- subM x y = addVecScaled (-1) y x
-instance BLAS1 a => AddM MMatrix a where
- addM x y = do
- forM_ [0 .. MMat.cols x - 1] $ \i -> do
- addVecScaled 1 (MMatD.unsafeGetCol y i) (MMatD.unsafeGetCol x i)
- subM x y = do
- forM_ [0 .. MMat.cols x - 1] $ \i -> do
- addVecScaled (-1) (MMatD.unsafeGetCol y i) (MMatD.unsafeGetCol x i)
+ {-# INLINE scaleM #-}
+ addScaleM α x y
+ | MMat.shape x /= MMat.shape y = error "QWE"
+ | otherwise =
+ forM_ [0 .. MMat.cols x - 1] $ \i ->
+ addVecScaled α (MMatD.unsafeGetCol x i) (MMatD.unsafeGetCol y i)
+ {-# INLINE addM #-}
+instance (BLAS1 a, MMat.IsMMatrix (MSymmetricRaw tag) a) => LinSpaceM (MSymmetricRaw tag) a where
+ scaleM α m = do
+ forM_ [0 .. MMat.cols m - 1] $ \i ->
+ scaleVector α $ symColumn m i
+ {-# INLINE scaleM #-}
+ addScaleM α x y
+ | MMat.shape x /= MMat.shape y = error "QWE"
+ | otherwise =
+ forM_ [0 .. MMat.cols x - 1] $ \i ->
+ addVecScaled α (symColumn x i) (symColumn y i)
+
+
+
+symColumn :: (Storable a) => MSymmetricRaw tag s a -> Int -> MV.MVector s a
+{-# INLINE symColumn #-}
+symColumn (MSymmetricRaw _ lda fp) i =
+ MV.unsafeFromForeignPtr (i+1) 1 (updPtr (`advancePtr` (lda*i)) fp)
----------------------------------------------------------------
--
----------------------------------------------------------------
+{-
dumpExpressionTree :: Expr m a -> String
dumpExpressionTree (Lit _) = "_"
dumpExpressionTree (Add _ x y) = "(" ++ dumpExpressionTree x ++ ") + (" ++ dumpExpressionTree y ++ ")"
@@ -656,4 +658,4 @@ dumpExpressionTree (VecH _ v u) = "==="
dumpExpressionTree (MulMV _ x y) = "M(" ++ dumpExpressionTree x ++ ") * V(" ++ dumpExpressionTree y ++ ")"
dumpExpressionTree (MulTMV _ _ x y ) = "TM(" ++ dumpExpressionTree x ++ ") * V(" ++ dumpExpressionTree y ++ ")"
dumpExpressionTree (MulMM _ _ x _ y) = "M(" ++ dumpExpressionTree x ++ ") * M(" ++ dumpExpressionTree y ++ ")"
-
+-}

0 comments on commit 3280129

Please sign in to comment.