Skip to content

Commit

Permalink
dynamicAcc2: handle scalar functions, iteration
Browse files Browse the repository at this point in the history
  • Loading branch information
tmcdonell committed Nov 4, 2014
1 parent 0f835e1 commit a48ceee
Showing 1 changed file with 60 additions and 14 deletions.
74 changes: 60 additions & 14 deletions backend-kit/Data/Array/Accelerate/DynamicAcc2.hs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,12 @@ data SealedExp = SealedExp
}
deriving Show

data SealedFun = SealedFun
{ funTyIn :: [S.Type]
, funTyOut :: S.Type
, funDyn :: Dynamic
} deriving Show

data SealedAcc = SealedAcc
{ arrTy :: ArrTy
, accDyn :: Dynamic
Expand All @@ -129,6 +135,12 @@ sealExp x = SealedExp ety (toDyn x)
where
ety = expType (undefined :: AST.Exp () a)

sealFun1 :: forall a b. (Elt a, Elt b) => (A.Exp a -> A.Exp b) -> SealedFun
sealFun1 f = SealedFun [ta] tb (toDyn f)
where
ta = expType (undefined :: AST.Exp () a)
tb = expType (undefined :: AST.Exp () b)

sealAcc :: (Arrays a, Typeable a) => Acc a -> SealedAcc
sealAcc x =
dbgtrace (" ** Creating arrTy: "++show ty0++" for "++show x) $
Expand All @@ -149,11 +161,11 @@ downcastE (SealedExp _ d) =
error$"Attempt to unpack SealedExp "++show d
++ ", expecting type Exp "++ show (toDyn (unused::a))

downcastF1 :: forall a b. (Typeable a, Typeable b) => SealedExp -> (A.Exp a -> A.Exp b)
downcastF1 (SealedExp _ d) =
downcastF1 :: forall a b. (Typeable a, Typeable b) => SealedFun -> (A.Exp a -> A.Exp b)
downcastF1 (SealedFun _ _ d) =
case fromDynamic d of
Just e -> e
Nothing -> error $ printf "Attempt to unpack SealedExp %s, expecting type Exp %s"
Nothing -> error $ printf "Attempt to unpack SealedFun %s, expecting type Exp %s"
(show d) (show (toDyn (unused :: a -> b)))

unused :: a
Expand Down Expand Up @@ -542,13 +554,46 @@ resealTup components =

-- | Convert open scalar functions
--
convertOpenFun1 :: S.Fun1 S.Exp -> EnvPack -> SealedExp -> SealedExp
convertOpenFun1 (S.Lam1 (var,ty) body) env x =
convertOpenExp (extendE var ty x env) body
convertFun1 :: S.Fun1 S.Exp -> S.Type -> SealedFun
convertFun1 = convertOpenFun1 emptyEnvPack

convertFun2 :: S.Fun2 S.Exp -> S.Type -> SealedFun
convertFun2 = convertOpenFun2 emptyEnvPack


-- | Convert open scalar functions
--
convertOpenFun1 :: EnvPack -> S.Fun1 S.Exp -> S.Type -> SealedFun
convertOpenFun1 env (S.Lam1 (va,ta) body) tb
| SealedEltTuple (_ :: EltTuple a) <- scalarTypeD ta
, SealedEltTuple (_ :: EltTuple b) <- scalarTypeD tb
= let
f :: A.Exp a -> A.Exp b
f x =
let x' = sealExp x
env' = extendE va ta x' env
in
downcastE $ convertOpenExp env' body
in
SealedFun [ta] tb (toDyn f)


convertOpenFun2 :: EnvPack -> S.Fun2 S.Exp -> S.Type -> SealedFun
convertOpenFun2 env (S.Lam2 (va,ta) (vb,tb) body) tc
| SealedEltTuple (_ :: EltTuple a) <- scalarTypeD ta
, SealedEltTuple (_ :: EltTuple b) <- scalarTypeD tb
, SealedEltTuple (_ :: EltTuple c) <- scalarTypeD tc
= let
f :: A.Exp a -> A.Exp b -> A.Exp c
f x y =
let x' = sealExp x
y' = sealExp y
env' = extendE vb tb y' (extendE va ta x' env) -- TLM: which order to push things in?
in
downcastE $ convertOpenExp env' body
in
SealedFun [ta,tb] tc (toDyn f)

convertOpenFun2 :: S.Fun2 S.Exp -> EnvPack -> SealedExp -> SealedExp -> SealedExp
convertOpenFun2 (S.Lam2 (var1,ty1) (var2,ty2) body) env x1 x2 =
convertOpenExp (extendE var2 ty2 x2 (extendE var1 ty1 x1 env)) body

-- | Convert a closed scalar expression
--
Expand All @@ -565,8 +610,8 @@ convertOpenExp ep@(EnvPack envE envA mp) ex
$ dbgtrace(printf " @ Converted exp result: %s " (show result))
$ result
where
cvtF1 :: S.Fun1 S.Exp -> SealedExp -> SealedExp
cvtF1 f = convertOpenFun1 f ep
cvtF1 :: S.Fun1 S.Exp -> S.Type -> SealedFun
cvtF1 = convertOpenFun1 ep

cvtE :: S.Exp -> SealedExp
cvtE e =
Expand Down Expand Up @@ -704,12 +749,13 @@ convertOpenExp ep@(EnvPack envE envA mp) ex
--
ewhile :: S.Fun1 S.Exp -> S.Fun1 S.Exp -> S.Exp -> SealedExp
ewhile p f e
| SealedEltTuple (_ :: EltTuple e) <- scalarTypeD (S.recoverExpType typeEnv e)
| te <- S.recoverExpType typeEnv e
, SealedEltTuple (_ :: EltTuple e) <- scalarTypeD te
= let p' :: Exp e -> Exp Bool
p' = downcastF1 (cvtF1 p undefined)
p' = downcastF1 (cvtF1 p S.TBool)

f' :: Exp e -> Exp e
f' = downcastF1 (cvtF1 f undefined)
f' = downcastF1 (cvtF1 f te)

e' :: Exp e
e' = downcastE (cvtE e)
Expand Down

0 comments on commit a48ceee

Please sign in to comment.