diff --git a/Numeric/BLAS/Expression.hs b/Numeric/BLAS/Expression.hs index d61d13d..b333a9d 100644 --- a/Numeric/BLAS/Expression.hs +++ b/Numeric/BLAS/Expression.hs @@ -47,6 +47,8 @@ import qualified Data.Matrix.Symmetric as MatS import qualified Data.Matrix.Symmetric.Mutable as MMatS import Numeric.BLAS.Mutable +import Debug.Trace + @@ -182,6 +184,27 @@ data Expr m a where +-- | Evaluate expression. If expression is known statically which is +-- the case if it was built using combinators from 'Numeric.BLAS' it +-- will evaluated at compile time. +eval :: Freeze m a => Expr m a -> m a +{-# INLINE[1] eval #-} +eval x = runST $ do + -- trace (dumpExpressionTree x) $ return () + unsafeFreeze =<< evalST' () x + +-- Rewrite rules: +-- +-- Eliminate constructors and evals +{-# RULES "BLAS:Lit/eval" forall e. Lit (eval e) = e #-} +-- Forcefully inline evalST' +{-# RULES "BLAS:evalST" evalST' () = evalST evalST' #-} + + + + + + -- Continuation type type Cont s = forall v a. Expr v a -> ST s (Mutable v s a) @@ -341,23 +364,6 @@ evalST' :: () -> Expr m a -> ST s (Mutable m s a) evalST' _ = evalST evalST' --- | Evaluate expression. If expression is known statically which is --- the case if it was built using combinators from 'Numeric.BLAS' it --- will evaluated at compile time. -eval :: Freeze m a => Expr m a -> m a -{-# INLINE[1] eval #-} -eval x = runST $ do - -- trace (dumpExpressionTree x) $ return () - unsafeFreeze =<< evalST' () x - - --- Rewrite rules: --- --- Eliminate constructors and evals -{-# RULES "BLAS:Lit/eval" forall e. Lit (eval e) = e #-} --- Forcefully inline evalST' -{-# RULES "BLAS:evalST" evalST' () = evalST evalST' #-} - ---------------------------------------------------------------- @@ -380,6 +386,7 @@ evalVVT cont a v u = do inplaceEvalVVT :: ( BLAS2 a, MVectorBLAS (Mutable v) ) => Cont s -> a -> Expr v a -> Expr v a -> MMatD.MMatrix s a -> ST s (MMatD.MMatrix s a) +{-# INLINE inplaceEvalVVT #-} inplaceEvalVVT cont a v u m_ = do v_ <- pull cont v u_ <- pull cont u @@ -403,6 +410,7 @@ evalVVH cont a v u = do inplaceEvalVVH :: ( BLAS2 a, MVectorBLAS (Mutable v) ) => Cont s -> a -> Expr v a -> Expr v a -> MMatD.MMatrix s a -> ST s (MMatD.MMatrix s a) +{-# INLINE inplaceEvalVVH #-} inplaceEvalVVH cont a v u m_ = do v_ <- pull cont v u_ <- pull cont u @@ -609,22 +617,13 @@ instance BLAS1 a => AddM MMatD.MMatrix a where -- ---------------------------------------------------------------- -{- -dumpVec :: (MVectorBLAS v, Show a, MS.Storable a) => v s a -> IO () -dumpVec v = do - print $ V.unsafeFromForeignPtr (blasLength v) (blasStride v) (blasFPtr v) - - -boogie :: v s a -> v RealWorld a -boogie = unsafeCoerce - dumpExpressionTree :: Expr m a -> String -dumpExpressionTree (Lit _) = "_" -dumpExpressionTree (Add x y) = "(" ++ dumpExpressionTree x ++ ") + (" ++ dumpExpressionTree y ++ ")" -dumpExpressionTree (Scale _ y) = "S * ?(" ++ dumpExpressionTree y ++ ")" -dumpExpressionTree (VecT v u) = "===" -dumpExpressionTree (VecH v u) = "===" -dumpExpressionTree (MulMV x y) = "M(" ++ dumpExpressionTree x ++ ") * V(" ++ dumpExpressionTree y ++ ")" -dumpExpressionTree (MulTMV _ x y) = "TM(" ++ dumpExpressionTree x ++ ") * V(" ++ dumpExpressionTree y ++ ")" -dumpExpressionTree (MulMM _ x _ y) = "M(" ++ dumpExpressionTree x ++ ") * M(" ++ dumpExpressionTree y ++ ")" --} +dumpExpressionTree (Lit _) = "_" +dumpExpressionTree (Add _ x y) = "(" ++ dumpExpressionTree x ++ ") + (" ++ dumpExpressionTree y ++ ")" +dumpExpressionTree (Scale _ _ y) = "S * ?(" ++ dumpExpressionTree y ++ ")" +dumpExpressionTree (VecT _ v u) = "===" +dumpExpressionTree (VecH _ v u) = "===" +dumpExpressionTree (MulMV _ x y) = "M(" ++ dumpExpressionTree x ++ ") * V(" ++ dumpExpressionTree y ++ ")" +dumpExpressionTree (MulTMV _ _ x y ) = "TM(" ++ dumpExpressionTree x ++ ") * V(" ++ dumpExpressionTree y ++ ")" +dumpExpressionTree (MulMM _ _ x _ y) = "M(" ++ dumpExpressionTree x ++ ") * M(" ++ dumpExpressionTree y ++ ")" +