Skip to content

Commit 27bef52

Browse files
authored
Enforce Complex on 'imag' and 'real' functions. (#16)
* Enfore Complex on 'imag' and 'real' functions. * Constrain real/imag on RealFrac.
1 parent 0470d34 commit 27bef52

File tree

2 files changed

+29
-11
lines changed

2 files changed

+29
-11
lines changed

src/ArrayFire/Arith.hs

+13-11
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
{-# LANGUAGE TypeApplications #-}
2+
{-# LANGUAGE FlexibleContexts #-}
23
{-# LANGUAGE ScopedTypeVariables #-}
34
{-# LANGUAGE ViewPatterns #-}
45
--------------------------------------------------------------------------------
@@ -24,10 +25,11 @@
2425
--------------------------------------------------------------------------------
2526
module ArrayFire.Arith where
2627

27-
import Prelude (Bool(..), ($), (.), flip, fromEnum, fromIntegral, Real)
28+
import Prelude (Bool(..), ($), (.), flip, fromEnum, fromIntegral, Real, RealFrac)
2829

2930
import Data.Coerce
3031
import Data.Proxy
32+
import Data.Complex
3133

3234
import ArrayFire.FFI
3335
import ArrayFire.Internal.Arith
@@ -1195,31 +1197,31 @@ cplx = flip op1 af_cplx
11951197

11961198
-- | Execute real
11971199
--
1198-
-- >>> A.real (A.vector @Double 10 [1..])
1200+
-- >>> A.real (A.scalar @(Complex Double) (10 :+ 11)) :: Array Double
11991201
-- ArrayFire Array
12001202
-- [10 1 1 1]
1201-
-- 1.0000 2.0000 3.0000 4.0000 5.0000 6.0000 7.0000 8.0000 9.0000 10.0000
1203+
-- 10.0000
12021204
real
1203-
:: AFType a
1204-
=> Array a
1205+
:: (AFType a, AFType (Complex b), RealFrac a, RealFrac b)
1206+
=> Array (Complex b)
12051207
-- ^ Input array
12061208
-> Array a
12071209
-- ^ Result of calling 'real'
1208-
real = flip op1 af_real
1210+
real = flip op1d af_real
12091211

12101212
-- | Execute imag
12111213
--
1212-
-- >>> A.imag (A.vector @Double 10 [1..])
1214+
-- >>> A.imag (A.scalar @(Complex Double) (10 :+ 11)) :: Array Double
12131215
-- ArrayFire Array
12141216
-- [10 1 1 1]
1215-
-- 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
1217+
-- 11.0000
12161218
imag
1217-
:: AFType a
1218-
=> Array a
1219+
:: (AFType a, AFType (Complex b), RealFrac a, RealFrac b)
1220+
=> Array (Complex b)
12191221
-- ^ Input array
12201222
-> Array a
12211223
-- ^ Result of calling 'imag'
1222-
imag = flip op1 af_imag
1224+
imag = flip op1d af_imag
12231225

12241226
-- | Execute conjg
12251227
--

src/ArrayFire/FFI.hs

+16
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,22 @@ opw1 (Window fptr) op
220220
throwAFError =<< op p ptr
221221
peek p
222222

223+
op1d
224+
:: Array a
225+
-> (Ptr AFArray -> AFArray -> IO AFErr)
226+
-> Array b
227+
{-# NOINLINE op1d #-}
228+
op1d (Array fptr1) op =
229+
unsafePerformIO $ do
230+
withForeignPtr fptr1 $ \ptr1 -> do
231+
ptr <-
232+
alloca $ \ptrInput -> do
233+
throwAFError =<< op ptrInput ptr1
234+
peek ptrInput
235+
fptr <- newForeignPtr af_release_array_finalizer ptr
236+
pure (Array fptr)
237+
238+
223239
op1
224240
:: Array a
225241
-> (Ptr AFArray -> AFArray -> IO AFErr)

0 commit comments

Comments
 (0)