Skip to content

Commit

Permalink
Reorder functions and add missing pragmas
Browse files Browse the repository at this point in the history
  • Loading branch information
Shimuuar committed Aug 31, 2012
1 parent 6463de2 commit b19ab16
Showing 1 changed file with 34 additions and 35 deletions.
69 changes: 34 additions & 35 deletions Numeric/BLAS/Expression.hs
Expand Up @@ -47,6 +47,8 @@ import qualified Data.Matrix.Symmetric as MatS
import qualified Data.Matrix.Symmetric.Mutable as MMatS

import Numeric.BLAS.Mutable
import Debug.Trace




Expand Down Expand Up @@ -182,6 +184,27 @@ data Expr m a where



-- | Evaluate expression. If expression is known statically which is
-- the case if it was built using combinators from 'Numeric.BLAS' it
-- will evaluated at compile time.
eval :: Freeze m a => Expr m a -> m a
{-# INLINE[1] eval #-}
eval x = runST $ do
-- trace (dumpExpressionTree x) $ return ()
unsafeFreeze =<< evalST' () x

-- Rewrite rules:
--
-- Eliminate constructors and evals
{-# RULES "BLAS:Lit/eval" forall e. Lit (eval e) = e #-}
-- Forcefully inline evalST'
{-# RULES "BLAS:evalST" evalST' () = evalST evalST' #-}






-- Continuation type
type Cont s = forall v a. Expr v a -> ST s (Mutable v s a)

Expand Down Expand Up @@ -341,23 +364,6 @@ evalST' :: () -> Expr m a -> ST s (Mutable m s a)
evalST' _ = evalST evalST'


-- | Evaluate expression. If expression is known statically which is
-- the case if it was built using combinators from 'Numeric.BLAS' it
-- will evaluated at compile time.
eval :: Freeze m a => Expr m a -> m a
{-# INLINE[1] eval #-}
eval x = runST $ do
-- trace (dumpExpressionTree x) $ return ()
unsafeFreeze =<< evalST' () x


-- Rewrite rules:
--
-- Eliminate constructors and evals
{-# RULES "BLAS:Lit/eval" forall e. Lit (eval e) = e #-}
-- Forcefully inline evalST'
{-# RULES "BLAS:evalST" evalST' () = evalST evalST' #-}



----------------------------------------------------------------
Expand All @@ -380,6 +386,7 @@ evalVVT cont a v u = do
inplaceEvalVVT
:: ( BLAS2 a, MVectorBLAS (Mutable v) )
=> Cont s -> a -> Expr v a -> Expr v a -> MMatD.MMatrix s a -> ST s (MMatD.MMatrix s a)
{-# INLINE inplaceEvalVVT #-}
inplaceEvalVVT cont a v u m_ = do
v_ <- pull cont v
u_ <- pull cont u
Expand All @@ -403,6 +410,7 @@ evalVVH cont a v u = do
inplaceEvalVVH
:: ( BLAS2 a, MVectorBLAS (Mutable v) )
=> Cont s -> a -> Expr v a -> Expr v a -> MMatD.MMatrix s a -> ST s (MMatD.MMatrix s a)
{-# INLINE inplaceEvalVVH #-}
inplaceEvalVVH cont a v u m_ = do
v_ <- pull cont v
u_ <- pull cont u
Expand Down Expand Up @@ -609,22 +617,13 @@ instance BLAS1 a => AddM MMatD.MMatrix a where
--
----------------------------------------------------------------

{-
dumpVec :: (MVectorBLAS v, Show a, MS.Storable a) => v s a -> IO ()
dumpVec v = do
print $ V.unsafeFromForeignPtr (blasLength v) (blasStride v) (blasFPtr v)
boogie :: v s a -> v RealWorld a
boogie = unsafeCoerce
dumpExpressionTree :: Expr m a -> String
dumpExpressionTree (Lit _) = "_"
dumpExpressionTree (Add x y) = "(" ++ dumpExpressionTree x ++ ") + (" ++ dumpExpressionTree y ++ ")"
dumpExpressionTree (Scale _ y) = "S * ?(" ++ dumpExpressionTree y ++ ")"
dumpExpressionTree (VecT v u) = "==="
dumpExpressionTree (VecH v u) = "==="
dumpExpressionTree (MulMV x y) = "M(" ++ dumpExpressionTree x ++ ") * V(" ++ dumpExpressionTree y ++ ")"
dumpExpressionTree (MulTMV _ x y) = "TM(" ++ dumpExpressionTree x ++ ") * V(" ++ dumpExpressionTree y ++ ")"
dumpExpressionTree (MulMM _ x _ y) = "M(" ++ dumpExpressionTree x ++ ") * M(" ++ dumpExpressionTree y ++ ")"
-}
dumpExpressionTree (Lit _) = "_"
dumpExpressionTree (Add _ x y) = "(" ++ dumpExpressionTree x ++ ") + (" ++ dumpExpressionTree y ++ ")"
dumpExpressionTree (Scale _ _ y) = "S * ?(" ++ dumpExpressionTree y ++ ")"
dumpExpressionTree (VecT _ v u) = "==="
dumpExpressionTree (VecH _ v u) = "==="
dumpExpressionTree (MulMV _ x y) = "M(" ++ dumpExpressionTree x ++ ") * V(" ++ dumpExpressionTree y ++ ")"
dumpExpressionTree (MulTMV _ _ x y ) = "TM(" ++ dumpExpressionTree x ++ ") * V(" ++ dumpExpressionTree y ++ ")"
dumpExpressionTree (MulMM _ _ x _ y) = "M(" ++ dumpExpressionTree x ++ ") * M(" ++ dumpExpressionTree y ++ ")"

0 comments on commit b19ab16

Please sign in to comment.