Skip to content

Commit 338e061

Browse files
committed
Add RandomEngine section, update FFI, remove whitespace
Adjust in place functions.
1 parent 164aa47 commit 338e061

File tree

13 files changed

+381
-64
lines changed

13 files changed

+381
-64
lines changed

arrayfire.cabal

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ library
7373
-fPIC
7474
extra-libraries:
7575
af
76+
c-sources:
77+
cbits/wrapper.c
7678
build-depends:
7779
base < 5, vector
7880
hs-source-dirs:

cbits/wrapper.c

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#include "arrayfire.h"
2+
3+
af_err af_random_engine_set_type_(af_random_engine engine, const af_random_engine_type rtype) { return af_random_engine_set_type(&engine, rtype); }
4+
5+
af_err af_random_engine_set_seed_(af_random_engine engine, const unsigned long long seed) {
6+
return af_random_engine_set_seed(&engine, seed);
7+
}

gen/Main.hs

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,8 @@ writeToDisk fileName = do
5151
let name = makeName (reverse . drop 2 . reverse $ fileName)
5252
T.writeFile (makePath name) $
5353
file name <> T.intercalate "\n" (genBinding <$> successes)
54-
when (name == "array") (T.appendFile (makePath name) (T.pack extraRelease))
5554
printf "Wrote bindings to %s\n" (makePath name)
5655

57-
extraRelease :: String
58-
extraRelease =
59-
"\nforeign import ccall unsafe \"&af_release_array\"\n\
60-
\ af_release_array_finalizer :: FunPtr (AFArray -> IO ())"
61-
6256
-- | Filename remappings
6357
makeName :: String -> String
6458
makeName n

src/ArrayFire.hs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
module ArrayFire where
2+

src/ArrayFire/BLAS.hs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,6 @@ transpose :: Array a -> Bool -> Array a
3232
transpose arr1 b =
3333
arr1 `op1` (\x y -> af_transpose x y b)
3434

35-
transposeInPlace :: Array a -> Bool -> Array a
35+
transposeInPlace :: Array a -> Bool -> IO ()
3636
transposeInPlace arr b =
3737
arr `inPlace` (`af_transpose_inplace` b)

src/ArrayFire/Data.hs

Lines changed: 165 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,17 @@
77
{-# LANGUAGE TypeApplications #-}
88
{-# LANGUAGE ViewPatterns #-}
99
{-# LANGUAGE KindSignatures #-}
10+
{-# LANGUAGE TypeFamilies #-}
11+
{-# LANGUAGE FlexibleContexts #-}
1012
module ArrayFire.Data where
1113

1214
import Control.Exception
1315
import Control.Monad
1416

17+
import Data.Complex
18+
import Data.Proxy
19+
import Data.Word
20+
1521
import Foreign.C.String
1622
import Foreign.C.Types
1723
import Foreign.Marshal hiding (void)
@@ -20,24 +26,17 @@ import Foreign.ForeignPtr
2026
import Foreign.Ptr
2127
import Foreign.Storable
2228

23-
import Data.Proxy
29+
import GHC.Int
30+
import GHC.TypeLits
2431

2532
import ArrayFire.Internal.Array
2633

2734
import ArrayFire.Exception
35+
import ArrayFire.FFI
2836
import ArrayFire.Types
2937
import ArrayFire.Internal.Defines
3038
import ArrayFire.Internal.Data
3139

32-
-- /**
33-
-- \param[out] arr is the generated array of given type
34-
-- \param[in] val is the value of each element in the generated array
35-
-- \param[in] ndims is size of dimension array \p dims
36-
-- \param[in] dims is the array containing sizes of the dimension
37-
-- \param[in] type is the type of array to generate
38-
-- \ingroup data_func_constant
39-
-- /
40-
4140
constant
4241
:: forall dims
4342
. (Dims dims)
@@ -46,6 +45,7 @@ constant val = do
4645
ptr <- alloca $ \ptrPtr -> mask_ $ do
4746
dimArray <- newArray dimt
4847
throwAFError =<< af_constant ptrPtr val n dimArray typ
48+
free dimArray
4949
peek ptrPtr
5050
Array <$>
5151
newForeignPtr
@@ -55,3 +55,158 @@ constant val = do
5555
n = fromIntegral (length dimt)
5656
dimt = toDims (Proxy @ dims)
5757
typ = afType (Proxy @ Double)
58+
59+
constantComplex
60+
:: forall dims
61+
. (Dims dims)
62+
=> Complex Double
63+
-> IO (Array (Complex Double))
64+
constantComplex val = do
65+
ptr <- alloca $ \ptrPtr -> mask_ $ do
66+
dimArray <- newArray dimt
67+
throwAFError =<< af_constant_complex ptrPtr (realPart val) (imagPart val) n dimArray typ
68+
free dimArray
69+
peek ptrPtr
70+
Array <$>
71+
newForeignPtr
72+
af_release_array_finalizer
73+
ptr
74+
where
75+
n = fromIntegral (length dimt)
76+
dimt = toDims (Proxy @ dims)
77+
typ = afType (Proxy @ (Complex Double))
78+
79+
constantLong
80+
:: forall dims
81+
. (Dims dims)
82+
=> Int64
83+
-> IO (Array Int64)
84+
constantLong val = do
85+
ptr <- alloca $ \ptrPtr -> mask_ $ do
86+
dimArray <- newArray dimt
87+
throwAFError =<< af_constant_long ptrPtr (fromIntegral val) n dimArray
88+
free dimArray
89+
peek ptrPtr
90+
Array <$>
91+
newForeignPtr
92+
af_release_array_finalizer
93+
ptr
94+
where
95+
n = fromIntegral (length dimt)
96+
dimt = toDims (Proxy @ dims)
97+
typ = afType (Proxy @ Int64)
98+
99+
constantULong
100+
:: forall dims
101+
. (Dims dims)
102+
=> Word64
103+
-> IO (Array Word64)
104+
constantULong val = do
105+
ptr <- alloca $ \ptrPtr -> mask_ $ do
106+
dimArray <- newArray dimt
107+
throwAFError =<< af_constant_ulong ptrPtr (fromIntegral val) n dimArray
108+
free dimArray
109+
peek ptrPtr
110+
Array <$>
111+
newForeignPtr
112+
af_release_array_finalizer
113+
ptr
114+
where
115+
n = fromIntegral (length dimt)
116+
dimt = toDims (Proxy @ dims)
117+
typ = afType (Proxy @ (Complex Double))
118+
119+
range
120+
:: forall dims a
121+
. (Dims dims, AFType a)
122+
=> Int
123+
-> IO (Array a)
124+
range k = do
125+
ptr <- alloca $ \ptrPtr -> mask_ $ do
126+
dimArray <- newArray dimt
127+
throwAFError =<< af_range ptrPtr n dimArray k typ
128+
free dimArray
129+
130+
peek ptrPtr
131+
Array <$>
132+
newForeignPtr
133+
af_release_array_finalizer
134+
ptr
135+
where
136+
n = fromIntegral (length dimt)
137+
dimt = toDims (Proxy @ dims)
138+
typ = afType (Proxy @ a)
139+
140+
iota
141+
:: forall dims tdims a
142+
. (Dims dims, Dims tdims, AFType a, KnownNat tdims)
143+
=> IO (Array a)
144+
iota = do
145+
ptr <- alloca $ \ptrPtr -> mask_ $ do
146+
dimArray <- newArray dimt
147+
tdimArray <- newArray tdimt
148+
throwAFError =<< af_iota ptrPtr n dimArray tn tdimArray typ
149+
free dimArray
150+
peek ptrPtr
151+
Array <$>
152+
newForeignPtr
153+
af_release_array_finalizer
154+
ptr
155+
where
156+
n = fromIntegral (length dimt)
157+
dimt = toDims (Proxy @ dims)
158+
tn = fromIntegral (length dimt)
159+
tdimt = toDims (Proxy @ tdims)
160+
typ = afType (Proxy @ a)
161+
162+
identity
163+
:: forall dims a
164+
. (Dims dims, AFType a)
165+
=> IO (Array a)
166+
identity = do
167+
ptr <- alloca $ \ptrPtr -> mask_ $ do
168+
dimArray <- newArray dimt
169+
throwAFError =<< af_identity ptrPtr n dimArray typ
170+
free dimArray
171+
peek ptrPtr
172+
Array <$>
173+
newForeignPtr
174+
af_release_array_finalizer
175+
ptr
176+
where
177+
n = fromIntegral (length dimt)
178+
dimt = toDims (Proxy @ dims)
179+
typ = afType (Proxy @ a)
180+
181+
diagCreate
182+
:: AFType (a :: *)
183+
=> Array a
184+
-> Int
185+
-> Array a
186+
diagCreate x n =
187+
x `op1` (\p a -> af_diag_create p a n)
188+
189+
diagExtract
190+
:: AFType (a :: *)
191+
=> Array a
192+
-> Int
193+
-> Array a
194+
diagExtract x n =
195+
x `op1` (\p a -> af_diag_extract p a n)
196+
197+
198+
-- af_err af_join(af_array *out, const int dim, const af_array first, const af_array second);
199+
-- af_err af_join_many(af_array *out, const int dim, const unsigned n_arrays, const af_array *inputs);
200+
-- af_err af_tile(af_array *out, const af_array in, const unsigned x, const unsigned y, const unsigned z, const unsigned w);
201+
-- af_err af_reorder(af_array *out, const af_array in, const unsigned x, const unsigned y, const unsigned z, const unsigned w);
202+
-- af_err af_shift(af_array *out, const af_array in, const int x, const int y, const int z, const int w);
203+
-- af_err af_moddims(af_array *out, const af_array in, const unsigned ndims, const dim_t * const dims);
204+
-- af_err af_flat(af_array *out, const af_array in);
205+
-- af_err af_flip(af_array *out, const af_array in, const unsigned dim);
206+
-- af_err af_lower(af_array *out, const af_array in, bool is_unit_diag);
207+
-- af_err af_upper(af_array *out, const af_array in, bool is_unit_diag);
208+
-- af_err af_select(af_array *out, const af_array cond, const af_array a, const af_array b);
209+
-- af_err af_select_scalar_r(af_array *out, const af_array cond, const af_array a, const double b);
210+
-- af_err af_select_scalar_l(af_array *out, const af_array cond, const double a, const af_array b);
211+
-- af_err af_replace(af_array a, const af_array cond, const af_array b);
212+
-- af_err af_replace_scalar(af_array a, const af_array cond, const double b);

src/ArrayFire/Exception.hs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import Control.Monad
77
import Foreign.C.String
88
import Foreign.Marshal
99
import Foreign.Storable
10+
import Foreign.Ptr
1011

1112
import ArrayFire.Internal.Exception
1213
import ArrayFire.Internal.Defines
@@ -78,8 +79,8 @@ throwAFError exitCode =
7879
afExceptionMsg <- errorToString exitCode
7980
throwIO AFException {..}
8081

81-
-- foreign import ccall unsafe "af_get_last_error"
82-
-- af_get_last_error :: Ptr (Ptr CChar) -> Ptr DimT -> IO ()
82+
foreign import ccall unsafe "&af_release_random_engine"
83+
af_release_random_engine_finalizer :: FunPtr (AFRandomEngine -> IO ())
8384

84-
-- foreign import ccall unsafe "af_err_to_string"
85-
-- af_err_to_string :: AFErr -> IO (Ptr CChar)
85+
foreign import ccall unsafe "&af_release_array"
86+
af_release_array_finalizer :: FunPtr (AFArray -> IO ())

src/ArrayFire/FFI.hs

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ op1f
143143
-> (Ptr AFFeatures -> AFFeatures -> IO AFErr)
144144
-> Features
145145
op1f (Features x) op =
146-
unsafePerformIO $ do
146+
unsafePerformIO . mask_ $ do
147147
withForeignPtr x $ \ptr1 -> do
148148
ptr <-
149149
alloca $ \ptrInput -> do
@@ -152,6 +152,19 @@ op1f (Features x) op =
152152
fptr <- newForeignPtr af_release_features ptr
153153
pure (Features fptr)
154154

155+
op1re
156+
:: RandomEngine
157+
-> (Ptr AFRandomEngine -> AFRandomEngine -> IO AFErr)
158+
-> IO RandomEngine
159+
op1re (RandomEngine x) op = mask_ $
160+
withForeignPtr x $ \ptr1 -> do
161+
ptr <-
162+
alloca $ \ptrInput -> do
163+
throwAFError =<< op ptrInput ptr1
164+
peek ptrInput
165+
fptr <- newForeignPtr af_release_random_engine_finalizer ptr
166+
pure (RandomEngine fptr)
167+
155168
op1b
156169
:: Storable b
157170
=> Array a
@@ -171,13 +184,15 @@ op1b (Array fptr1) op =
171184
afCall
172185
:: IO AFErr
173186
-> IO ()
174-
afCall = (throwAFError =<<)
187+
afCall = mask_ . (throwAFError =<<)
188+
189+
inPlace :: Array a -> (AFArray -> IO AFErr) -> IO ()
190+
inPlace (Array fptr) op =
191+
mask_ . withForeignPtr fptr $ (throwAFError <=< op)
175192

176-
inPlace :: Array a -> (AFArray -> IO AFErr) -> Array a
177-
inPlace r@(Array fptr) op =
178-
(unsafePerformIO $
179-
withForeignPtr fptr $ \ptr ->
180-
throwAFError =<< op ptr) `seq` r
193+
inPlaceEng :: RandomEngine -> (AFRandomEngine -> IO AFErr) -> IO ()
194+
inPlaceEng (RandomEngine fptr) op =
195+
mask_ . withForeignPtr fptr $ (throwAFError <=< op)
181196

182197
afCall1
183198
:: Storable a
@@ -213,6 +228,18 @@ infoFromFeatures (Features fptr1) op =
213228
throwAFError =<< op ptrInput ptr1
214229
peek ptrInput
215230

231+
infoFromRandomEngine
232+
:: Storable a
233+
=> RandomEngine
234+
-> (Ptr a -> AFRandomEngine -> IO AFErr)
235+
-> IO a
236+
infoFromRandomEngine (RandomEngine fptr1) op =
237+
mask_ $ do
238+
withForeignPtr fptr1 $ \ptr1 -> do
239+
alloca $ \ptrInput -> do
240+
throwAFError =<< op ptrInput ptr1
241+
peek ptrInput
242+
216243
infoFromArray
217244
:: Storable a
218245
=> Array b

src/ArrayFire/Features.hs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
{-# LANGUAGE ViewPatterns #-}
22
module ArrayFire.Features where
33

4-
54
import Control.Exception hiding (TypeError)
65
import Data.Typeable
76
import Control.Monad
@@ -17,10 +16,10 @@ import ArrayFire.FFI
1716
import ArrayFire.Exception
1817
import ArrayFire.Internal.Defines
1918

20-
createFeatures
19+
createFeatures
2120
:: Int
2221
-> Features
23-
createFeatures (fromIntegral -> n) =
22+
createFeatures (fromIntegral -> n) =
2423
unsafePerformIO $ do
2524
ptr <-
2625
alloca $ \ptrInput -> do
@@ -58,7 +57,7 @@ getFeaturesOrientation
5857
:: Features
5958
-> Array a
6059
getFeaturesOrientation = (`featuresToArray` af_get_features_orientation)
61-
60+
6261
getFeaturesSize
6362
:: Features
6463
-> Array a

src/ArrayFire/Internal/Array.hsc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,5 +71,3 @@ foreign import ccall unsafe "af_is_sparse"
7171
af_is_sparse :: Ptr Bool -> AFArray -> IO AFErr
7272
foreign import ccall unsafe "af_get_scalar"
7373
af_get_scalar :: Ptr () -> AFArray -> IO AFErr
74-
foreign import ccall unsafe "&af_release_array"
75-
af_release_array_finalizer :: FunPtr (AFArray -> IO ())

0 commit comments

Comments
 (0)