77{-# LANGUAGE TypeApplications #-}
88{-# LANGUAGE ViewPatterns #-}
99{-# LANGUAGE KindSignatures #-}
10+ {-# LANGUAGE TypeFamilies #-}
11+ {-# LANGUAGE FlexibleContexts #-}
1012module ArrayFire.Data where
1113
1214import Control.Exception
1315import Control.Monad
1416
17+ import Data.Complex
18+ import Data.Proxy
19+ import Data.Word
20+
1521import Foreign.C.String
1622import Foreign.C.Types
1723import Foreign.Marshal hiding (void )
@@ -20,24 +26,17 @@ import Foreign.ForeignPtr
2026import Foreign.Ptr
2127import Foreign.Storable
2228
23- import Data.Proxy
29+ import GHC.Int
30+ import GHC.TypeLits
2431
2532import ArrayFire.Internal.Array
2633
2734import ArrayFire.Exception
35+ import ArrayFire.FFI
2836import ArrayFire.Types
2937import ArrayFire.Internal.Defines
3038import ArrayFire.Internal.Data
3139
32- -- /**
33- -- \param[out] arr is the generated array of given type
34- -- \param[in] val is the value of each element in the generated array
35- -- \param[in] ndims is size of dimension array \p dims
36- -- \param[in] dims is the array containing sizes of the dimension
37- -- \param[in] type is the type of array to generate
38- -- \ingroup data_func_constant
39- -- /
40-
4140constant
4241 :: forall dims
4342 . (Dims dims )
@@ -46,6 +45,7 @@ constant val = do
4645 ptr <- alloca $ \ ptrPtr -> mask_ $ do
4746 dimArray <- newArray dimt
4847 throwAFError =<< af_constant ptrPtr val n dimArray typ
48+ free dimArray
4949 peek ptrPtr
5050 Array <$>
5151 newForeignPtr
@@ -55,3 +55,158 @@ constant val = do
5555 n = fromIntegral (length dimt)
5656 dimt = toDims (Proxy @ dims)
5757 typ = afType (Proxy @ Double )
58+
59+ constantComplex
60+ :: forall dims
61+ . (Dims dims )
62+ => Complex Double
63+ -> IO (Array (Complex Double ))
64+ constantComplex val = do
65+ ptr <- alloca $ \ ptrPtr -> mask_ $ do
66+ dimArray <- newArray dimt
67+ throwAFError =<< af_constant_complex ptrPtr (realPart val) (imagPart val) n dimArray typ
68+ free dimArray
69+ peek ptrPtr
70+ Array <$>
71+ newForeignPtr
72+ af_release_array_finalizer
73+ ptr
74+ where
75+ n = fromIntegral (length dimt)
76+ dimt = toDims (Proxy @ dims)
77+ typ = afType (Proxy @ (Complex Double ))
78+
79+ constantLong
80+ :: forall dims
81+ . (Dims dims )
82+ => Int64
83+ -> IO (Array Int64 )
84+ constantLong val = do
85+ ptr <- alloca $ \ ptrPtr -> mask_ $ do
86+ dimArray <- newArray dimt
87+ throwAFError =<< af_constant_long ptrPtr (fromIntegral val) n dimArray
88+ free dimArray
89+ peek ptrPtr
90+ Array <$>
91+ newForeignPtr
92+ af_release_array_finalizer
93+ ptr
94+ where
95+ n = fromIntegral (length dimt)
96+ dimt = toDims (Proxy @ dims)
97+ typ = afType (Proxy @ Int64 )
98+
99+ constantULong
100+ :: forall dims
101+ . (Dims dims )
102+ => Word64
103+ -> IO (Array Word64 )
104+ constantULong val = do
105+ ptr <- alloca $ \ ptrPtr -> mask_ $ do
106+ dimArray <- newArray dimt
107+ throwAFError =<< af_constant_ulong ptrPtr (fromIntegral val) n dimArray
108+ free dimArray
109+ peek ptrPtr
110+ Array <$>
111+ newForeignPtr
112+ af_release_array_finalizer
113+ ptr
114+ where
115+ n = fromIntegral (length dimt)
116+ dimt = toDims (Proxy @ dims)
117+ typ = afType (Proxy @ (Complex Double ))
118+
119+ range
120+ :: forall dims a
121+ . (Dims dims , AFType a )
122+ => Int
123+ -> IO (Array a )
124+ range k = do
125+ ptr <- alloca $ \ ptrPtr -> mask_ $ do
126+ dimArray <- newArray dimt
127+ throwAFError =<< af_range ptrPtr n dimArray k typ
128+ free dimArray
129+
130+ peek ptrPtr
131+ Array <$>
132+ newForeignPtr
133+ af_release_array_finalizer
134+ ptr
135+ where
136+ n = fromIntegral (length dimt)
137+ dimt = toDims (Proxy @ dims)
138+ typ = afType (Proxy @ a)
139+
140+ iota
141+ :: forall dims tdims a
142+ . (Dims dims , Dims tdims , AFType a , KnownNat tdims )
143+ => IO (Array a )
144+ iota = do
145+ ptr <- alloca $ \ ptrPtr -> mask_ $ do
146+ dimArray <- newArray dimt
147+ tdimArray <- newArray tdimt
148+ throwAFError =<< af_iota ptrPtr n dimArray tn tdimArray typ
149+ free dimArray
150+ peek ptrPtr
151+ Array <$>
152+ newForeignPtr
153+ af_release_array_finalizer
154+ ptr
155+ where
156+ n = fromIntegral (length dimt)
157+ dimt = toDims (Proxy @ dims)
158+ tn = fromIntegral (length dimt)
159+ tdimt = toDims (Proxy @ tdims)
160+ typ = afType (Proxy @ a)
161+
162+ identity
163+ :: forall dims a
164+ . (Dims dims , AFType a )
165+ => IO (Array a )
166+ identity = do
167+ ptr <- alloca $ \ ptrPtr -> mask_ $ do
168+ dimArray <- newArray dimt
169+ throwAFError =<< af_identity ptrPtr n dimArray typ
170+ free dimArray
171+ peek ptrPtr
172+ Array <$>
173+ newForeignPtr
174+ af_release_array_finalizer
175+ ptr
176+ where
177+ n = fromIntegral (length dimt)
178+ dimt = toDims (Proxy @ dims)
179+ typ = afType (Proxy @ a)
180+
181+ diagCreate
182+ :: AFType (a :: * )
183+ => Array a
184+ -> Int
185+ -> Array a
186+ diagCreate x n =
187+ x `op1` (\ p a -> af_diag_create p a n)
188+
189+ diagExtract
190+ :: AFType (a :: * )
191+ => Array a
192+ -> Int
193+ -> Array a
194+ diagExtract x n =
195+ x `op1` (\ p a -> af_diag_extract p a n)
196+
197+
198+ -- af_err af_join(af_array *out, const int dim, const af_array first, const af_array second);
199+ -- af_err af_join_many(af_array *out, const int dim, const unsigned n_arrays, const af_array *inputs);
200+ -- af_err af_tile(af_array *out, const af_array in, const unsigned x, const unsigned y, const unsigned z, const unsigned w);
201+ -- af_err af_reorder(af_array *out, const af_array in, const unsigned x, const unsigned y, const unsigned z, const unsigned w);
202+ -- af_err af_shift(af_array *out, const af_array in, const int x, const int y, const int z, const int w);
203+ -- af_err af_moddims(af_array *out, const af_array in, const unsigned ndims, const dim_t * const dims);
204+ -- af_err af_flat(af_array *out, const af_array in);
205+ -- af_err af_flip(af_array *out, const af_array in, const unsigned dim);
206+ -- af_err af_lower(af_array *out, const af_array in, bool is_unit_diag);
207+ -- af_err af_upper(af_array *out, const af_array in, bool is_unit_diag);
208+ -- af_err af_select(af_array *out, const af_array cond, const af_array a, const af_array b);
209+ -- af_err af_select_scalar_r(af_array *out, const af_array cond, const af_array a, const double b);
210+ -- af_err af_select_scalar_l(af_array *out, const af_array cond, const double a, const af_array b);
211+ -- af_err af_replace(af_array a, const af_array cond, const af_array b);
212+ -- af_err af_replace_scalar(af_array a, const af_array cond, const double b);
0 commit comments