Skip to content

Commit

Permalink
Pairs as representation types now work with the type trafo
Browse files Browse the repository at this point in the history
Ignore-this: 8e7ccfde83e3a1239a2ba21072298736

darcs-hash:20090729130619-6295e-e2f974440a4007363e8bf0c60a2ae3c709cef58a.gz
  • Loading branch information
mchakravarty committed Jul 29, 2009
1 parent 277a650 commit 0de33fd
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 156 deletions.
28 changes: 22 additions & 6 deletions Data/Array/Accelerate/AST.hs
Expand Up @@ -61,7 +61,7 @@ module Data.Array.Accelerate.AST (
-- friends
import Data.Array.Accelerate.Type
import Data.Array.Accelerate.Array.Representation
import Data.Array.Accelerate.Array.Sugar (ElemRepr)
import Data.Array.Accelerate.Array.Sugar (ElemRepr, ElemRepr')


-- |Abstract syntax of array computations
Expand Down Expand Up @@ -234,26 +234,42 @@ data OpenExp env t where
Const :: TupleType t -> t -> OpenExp env t

-- |Tuples
Pair :: OpenExp env s -> OpenExp env t -> OpenExp env (s, t)
Fst :: OpenExp env (s, t) -> OpenExp env s
Snd :: OpenExp env (s, t) -> OpenExp env t
Pair :: s {- dummy to fix the type variable -}
-> t {- dummy to fix the type variable -}
-> OpenExp env (ElemRepr s)
-> OpenExp env (ElemRepr t)
-> OpenExp env (ElemRepr (s, t))
Fst :: s {- dummy to fix the type variable -}
-> t {- dummy to fix the type variable -}
-> OpenExp env (ElemRepr (s, t))
-> OpenExp env (ElemRepr s)
Snd :: s {- dummy to fix the type variable -}
-> t {- dummy to fix the type variable -}
-> OpenExp env (ElemRepr (s, t))
-> OpenExp env (ElemRepr t)

-- |Conditional expression (non-strict in 2nd and 3rd argument)
Cond :: OpenExp env ((), Bool) -> OpenExp env t -> OpenExp env t
Cond :: OpenExp env (ElemRepr Bool)
-> OpenExp env t
-> OpenExp env t
-> OpenExp env t

-- |Primitive constants
PrimConst :: PrimConst t -> OpenExp env (ElemRepr t)

-- |Primitive scalar operations
PrimApp :: PrimFun (a -> r) -> OpenExp env a -> OpenExp env r
PrimApp :: PrimFun (a -> r)
-> OpenExp env (ElemRepr a)
-> OpenExp env (ElemRepr r)

-- |Project a single scalar from an array
IndexScalar :: Arr dim t -> OpenExp env dim -> OpenExp env t

-- |Array shape
Shape :: Arr dim e -> OpenExp env dim

-- Cvt :: ElemRepr

-- |Primitive GPU constants
--
data PrimConst ty where
Expand Down
7 changes: 0 additions & 7 deletions Data/Array/Accelerate/Language.hs
Expand Up @@ -99,12 +99,6 @@ arr ! ix = wrapComp $
zip :: forall dim a b. (Ix dim, Elem a, Elem b)
=> Arr dim a -> Arr dim b -> AP (Arr dim (a, b))
zip = zipWith (\x y -> x `Pair` y)
{-
zip arr1 arr2
= wrapComp $
mkZip (undefined::dim) (undefined::a) (undefined::b)
(convertArr arr1) (convertArr arr2)
-}

map :: (Ix dim, Elem a, Elem b)
=> (Exp a -> Exp b) -> Arr dim a -> AP (Arr dim b)
Expand All @@ -113,7 +107,6 @@ map f arr = wrapComp $ Map (convertFun1 f) (convertArr arr)
zipWith :: (Ix dim, Elem a, Elem b, Elem c)
=> (Exp a -> Exp b -> Exp c) -> Arr dim a -> Arr dim b -> AP (Arr dim c)
zipWith f arr1 arr2
-- = zip arr1 arr2 >>= map (\xy -> f (Fst xy) (Snd xy))
= wrapComp $ ZipWith (convertFun2 f) (convertArr arr1) (convertArr arr2)

filter :: Elem a => (Exp a -> Exp Bool) -> Arr DIM1 a -> AP (Arr DIM1 a)
Expand Down
15 changes: 8 additions & 7 deletions Data/Array/Accelerate/Pretty.hs
@@ -1,4 +1,5 @@
{-# LANGUAGE GADTs, FlexibleInstances, PatternGuards, TypeOperators #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- |Embedded array processing language: pretty printing
--
Expand Down Expand Up @@ -121,18 +122,18 @@ prettyFun fun =
-- * Apply the wrapping combinator (1st argument) to any compound expressions.
--
prettyExp :: (Doc -> Doc) -> OpenExp env t -> Doc
--prettyExp wrap (Arg ty) = wrap $ text "arg ::" <+> prettyAnyType ty
prettyExp wrap (Var _ idx) = text $ "a" ++ show (count idx)
where
count :: Idx env t -> Int
count ZeroIdx = 0
count (SuccIdx idx) = 1 + count idx
prettyExp _ (Const ty v) = text $ runTupleShow ty v
prettyExp _ (Pair e1 e2) = prettyTuple (Pair e1 e2)
prettyExp wrap (Fst e) = wrap $ text "fst" <+> prettyExp parens e
prettyExp wrap (Snd e) = wrap $ text "snd" <+> prettyExp parens e
prettyExp _ e@(Pair _ _ _ _) = prettyTuple e
prettyExp wrap (Fst _ _ e) = wrap $ text "fst" <+> prettyExp parens e
prettyExp wrap (Snd _ _ e) = wrap $ text "snd" <+> prettyExp parens e
prettyExp wrap (Cond c t e)
= wrap $ sep [prettyExp parens c <+> char '?', prettyExp noParens (Pair t e)]
= wrap $ sep [prettyExp parens c <+> char '?',
parens (prettyExp noParens t <> comma <+> prettyExp noParens e)]
prettyExp _ (PrimConst a) = prettyConst a
prettyExp wrap (PrimApp p a) = wrap $ prettyPrim p <+> prettyExp parens a
prettyExp wrap (IndexScalar a i)
Expand All @@ -147,8 +148,8 @@ prettyTuple e = parens $ sep (map (<> comma) (init es) ++ [last es])
es = collect e
--
collect :: OpenExp env t -> [Doc]
collect (Pair e1 e2) = collect e1 ++ collect e2
collect e = [prettyExp noParens e]
collect (Pair _ _ e1 e2) = collect e1 ++ collect e2
collect e = [prettyExp noParens e]

-- |Pretty print a primitive constant
--
Expand Down
151 changes: 15 additions & 136 deletions Data/Array/Accelerate/Smart.hs
Expand Up @@ -58,75 +58,6 @@ import qualified Data.Array.Accelerate.AST as AST
import qualified Data.Array.Accelerate.Array.Representation as AST


-- |Conversion of surface to internal types
-- ----------------------------------------

-- Conversion of type representations
--
{-
convertIntegralType :: IntegralType a -> IntegralType (ElemRepr a)
convertIntegralType ty@(TypeInt _) = ty
convertIntegralType ty@(TypeInt8 _) = ty
convertIntegralType ty@(TypeInt16 _) = ty
convertIntegralType ty@(TypeInt32 _) = ty
convertIntegralType ty@(TypeInt64 _) = ty
convertIntegralType ty@(TypeWord _) = ty
convertIntegralType ty@(TypeWord8 _) = ty
convertIntegralType ty@(TypeWord16 _) = ty
convertIntegralType ty@(TypeWord32 _) = ty
convertIntegralType ty@(TypeWord64 _) = ty
convertIntegralType ty@(TypeCShort _) = ty
convertIntegralType ty@(TypeCUShort _) = ty
convertIntegralType ty@(TypeCInt _) = ty
convertIntegralType ty@(TypeCUInt _) = ty
convertIntegralType ty@(TypeCLong _) = ty
convertIntegralType ty@(TypeCULong _) = ty
convertIntegralType ty@(TypeCLLong _) = ty
convertIntegralType ty@(TypeCULLong _) = ty
convertFloatingType :: FloatingType a -> FloatingType (ElemRepr a)
convertFloatingType ty@(TypeFloat _) = ty
convertFloatingType ty@(TypeDouble _) = ty
convertFloatingType ty@(TypeCFloat _) = ty
convertFloatingType ty@(TypeCDouble _) = ty
convertNonNumType :: NonNumType a -> NonNumType (ElemRepr a)
convertNonNumType ty@(TypeBool _) = ty
convertNonNumType ty@(TypeChar _) = ty
convertNonNumType ty@(TypeCChar _) = ty
convertNonNumType ty@(TypeCSChar _) = ty
convertNonNumType ty@(TypeCUChar _) = ty
convertNumType :: NumType a -> NumType (ElemRepr a)
convertNumType (IntegralNumType ty) = IntegralNumType $ convertIntegralType ty
convertNumType (FloatingNumType ty) = FloatingNumType $ convertFloatingType ty
convertBoundedType :: BoundedType a -> BoundedType (ElemRepr a)
convertBoundedType (IntegralBoundedType ty)
= IntegralBoundedType (convertIntegralType ty)
convertBoundedType (NonNumBoundedType ty)
= NonNumBoundedType (convertNonNumType ty)
convertScalarType :: ScalarType a -> ScalarType (ElemRepr a)
convertScalarType (NumScalarType ty) = NumScalarType $ convertNumType ty
convertScalarType (NonNumScalarType ty)
= NonNumScalarType $ convertNonNumType ty
-}
{-
-- |Conversion of slice indices
--
convertSlice :: forall sl. SliceIx sl
=> sl -> SliceIndex (ToShapeRepr (Slice sl))
(ToShapeRepr (CoSlice sl))
(ToShapeRepr (SliceDim sl))
convertSlice = cvt . toShapeRepr
where
cvt :: ToShapeRepr sl -> SliceIndex (ToShapeRepr (Slice sl))
(ToShapeRepr (CoSlice sl))
(ToShapeRepr (SliceDim sl))
cvt () = SliceNil
-}

-- |HOAS AST
-- ---------

Expand Down Expand Up @@ -196,68 +127,27 @@ convertOpenExp :: forall t env.
convertOpenExp lyt = cvt
where
cvt :: forall t'. Exp t' -> AST.OpenExp env (ElemRepr t')
cvt (Tag i) = AST.Var (elemType (undefined::t')) (prjIdx i lyt)
cvt (Const v) = AST.Const (elemType (undefined::t')) (fromElem v)
-- FIXME:
-- cvt (Pair e1 e2) = AST.Pair (cvt e1) (cvt e2)
-- cvt (Fst e) = AST.Fst (cvt e)
-- cvt (Snd e) = AST.Snd (cvt e)
cvt (Cond e1 e2 e3) = AST.Cond (cvt e1) (cvt e2) (cvt e3)
-- cvt (PrimConst c) = AST.PrimConst (convertPrimConst c)
cvt (PrimConst c) = AST.PrimConst c
-- cvt (PrimApp p e) = AST.PrimApp (convertPrimFun p) (cvt e)
cvt (IndexScalar a e) = AST.IndexScalar (convertArr a) (cvt e)
cvt (Shape a) = AST.Shape (convertArr a)
cvt (Tag i) = AST.Var (elemType (undefined::t')) (prjIdx i lyt)
cvt (Const v) = AST.Const (elemType (undefined::t')) (fromElem v)
cvt (Pair (e1::Exp t1)
(e2::Exp t2)) = AST.Pair (undefined::t1)
(undefined::t2)
(cvt e1) (cvt e2)
cvt (Fst (e::Exp (t', t2)))
= AST.Fst (undefined::t') (undefined::t2) (cvt e)
cvt (Snd (e::Exp (t1, t')))
= AST.Snd (undefined::t1) (undefined::t') (cvt e)
cvt (Cond e1 e2 e3) = AST.Cond (cvt e1) (cvt e2) (cvt e3)
cvt (PrimConst c) = AST.PrimConst c
cvt (PrimApp p e) = AST.PrimApp p (cvt e)
cvt (IndexScalar a e) = AST.IndexScalar (convertArr a) (cvt e)
cvt (Shape a) = AST.Shape (convertArr a)

-- |Convert a closed expression
--
convertExp :: Exp t -> AST.Exp (ElemRepr t)
convertExp = convertOpenExp EmptyLayout

{-
-- |Convert a primitive constant
--
convertPrimConst :: PrimConst a -> PrimConst (ElemRepr a)
convertPrimConst (PrimMinBound ty) = PrimMinBound $ convertBoundedType ty
convertPrimConst (PrimMaxBound ty) = PrimMinBound $ convertBoundedType ty
convertPrimConst (PrimPi ty) = PrimPi $ convertFloatingType ty
-}

{-
-- |Convert a primitive operation
--
convertPrimFun :: PrimFun (a -> b) -> PrimFun (ElemRepr a -> ElemRepr b)
convertPrimFun (PrimAdd ty) = PrimAdd (convertNumType ty)
convertPrimFun (PrimSub ty) = PrimSub (convertNumType ty)
convertPrimFun (PrimMul ty) = PrimMul (convertNumType ty)
convertPrimFun (PrimNeg ty) = PrimNeg (convertNumType ty)
convertPrimFun (PrimAbs ty) = PrimAbs (convertNumType ty)
convertPrimFun (PrimSig ty) = PrimSig (convertNumType ty)
convertPrimFun (PrimQuot ty) = PrimQuot (convertIntegralType ty)
convertPrimFun (PrimRem ty) = PrimRem (convertIntegralType ty)
convertPrimFun (PrimIDiv ty) = PrimIDiv (convertIntegralType ty)
convertPrimFun (PrimMod ty) = PrimMod (convertIntegralType ty)
convertPrimFun (PrimBAnd ty) = PrimBAnd (convertIntegralType ty)
convertPrimFun (PrimBOr ty) = PrimBOr (convertIntegralType ty)
convertPrimFun (PrimBXor ty) = PrimBXor (convertIntegralType ty)
convertPrimFun (PrimBNot ty) = PrimBNot (convertIntegralType ty)
convertPrimFun (PrimFDiv ty) = PrimFDiv (convertFloatingType ty)
convertPrimFun (PrimRecip ty) = PrimRecip (convertFloatingType ty)
convertPrimFun (PrimLt ty) = PrimLt (convertScalarType ty)
convertPrimFun (PrimGt ty) = PrimGt (convertScalarType ty)
convertPrimFun (PrimLtEq ty) = PrimLtEq (convertScalarType ty)
convertPrimFun (PrimGtEq ty) = PrimGtEq (convertScalarType ty)
convertPrimFun (PrimEq ty) = PrimEq (convertScalarType ty)
convertPrimFun (PrimNEq ty) = PrimNEq (convertScalarType ty)
convertPrimFun (PrimMax ty) = PrimMax (convertScalarType ty)
convertPrimFun (PrimMin ty) = PrimMin (convertScalarType ty)
convertPrimFun PrimLAnd = PrimLAnd
convertPrimFun PrimLOr = PrimLOr
convertPrimFun PrimLNot = PrimLNot
convertPrimFun PrimOrd = PrimOrd
convertPrimFun PrimChr = PrimChr
convertPrimFun PrimRoundFloatInt = PrimRoundFloatInt
-}
-- |Convert surface array representation to the internal one
--
convertArray :: forall dim e.
Expand Down Expand Up @@ -394,17 +284,6 @@ mkReplicate :: forall slix e. (SliceIx slix, Elem e)
mkReplicate slix _ e arr
= Replicate (convertSliceIndex slix (sliceIndex (undefined::slix))) e arr

{-
mkZip :: (Ix dim, Elem a, Elem b)
=> dim {- dummy to fix the type variable -}
-> a {- dummy to fix the type variable -}
-> b {- dummy to fix the type variable -}
-> AST.Arr (ElemRepr dim) (ElemRepr a)
-> AST.Arr (ElemRepr dim) (ElemRepr' b)
-> Comp (AST.Arr (ElemRepr dim) (ElemRepr (a, b)))
mkZip _ _ _ arr1 arr2 = Zip arr1 arr2
-}


-- |Smart constructors to construct HOAS AST expressions
-- -----------------------------------------------------
Expand Down

0 comments on commit 0de33fd

Please sign in to comment.