|
| 1 | +{-# LANGUAGE RankNTypes #-} |
| 2 | +{-# LANGUAGE ScopedTypeVariables #-} |
1 | 3 | {-# LANGUAGE TypeApplications #-}
|
| 4 | + |
2 | 5 | module ArrayFire.ArithSpec where
|
3 | 6 |
|
4 |
| -import ArrayFire hiding (acos) |
5 |
| -import Prelude hiding (abs, sqrt, div, and, or, not, isNaN) |
6 |
| -import Test.Hspec |
| 7 | +import ArrayFire (AFType, Array, cast, clamp, getType, isInf, isZero, matrix, maxOf, minOf, mkArray, scalar, vector) |
| 8 | +import qualified ArrayFire |
| 9 | +import Control.Exception (throwIO) |
| 10 | +import Control.Monad (unless, when) |
7 | 11 | import Foreign.C
|
| 12 | +import GHC.Exts (IsList (..)) |
| 13 | +import GHC.Stack |
| 14 | +import Test.HUnit.Lang (FailureReason (..), HUnitFailure (..)) |
| 15 | +import Test.Hspec |
| 16 | +import Test.Hspec.QuickCheck |
| 17 | +import Prelude hiding (div) |
| 18 | + |
| 19 | +compareWith :: (HasCallStack, Show a) => (a -> a -> Bool) -> a -> a -> Expectation |
| 20 | +compareWith comparator result expected = |
| 21 | + unless (comparator result expected) $ do |
| 22 | + throwIO (HUnitFailure location $ ExpectedButGot Nothing expectedMsg actualMsg) |
| 23 | + where |
| 24 | + expectedMsg = show expected |
| 25 | + actualMsg = show result |
| 26 | + location = case reverse (toList callStack) of |
| 27 | + (_, loc) : _ -> Just loc |
| 28 | + [] -> Nothing |
| 29 | + |
| 30 | +class (Num a) => HasEpsilon a where |
| 31 | + eps :: a |
| 32 | + |
| 33 | +instance HasEpsilon Float where |
| 34 | + eps = 1.1920929e-7 |
| 35 | + |
| 36 | +instance HasEpsilon Double where |
| 37 | + eps = 2.220446049250313e-16 |
| 38 | + |
| 39 | +approxWith :: (Ord a, Num a) => a -> a -> a -> a -> Bool |
| 40 | +approxWith rtol atol a b = abs (a - b) <= Prelude.max atol (rtol * Prelude.max (abs a) (abs b)) |
| 41 | + |
| 42 | +approx :: (Ord a, HasEpsilon a) => a -> a -> Bool |
| 43 | +approx a b = approxWith (2 * eps * Prelude.max (abs a) (abs b)) (4 * eps) a b |
| 44 | + |
| 45 | +shouldBeApprox :: (Ord a, HasEpsilon a, Show a) => a -> a -> Expectation |
| 46 | +shouldBeApprox = compareWith approx |
| 47 | + |
| 48 | +evalf :: (AFType a) => Array a -> a |
| 49 | +evalf = ArrayFire.getScalar |
| 50 | + |
| 51 | +shouldMatchBuiltin :: |
| 52 | + (AFType a, Ord a, RealFloat a, HasEpsilon a, Show a) => |
| 53 | + (Array a -> Array a) -> |
| 54 | + (a -> a) -> |
| 55 | + a -> |
| 56 | + Expectation |
| 57 | +shouldMatchBuiltin f f' x |
| 58 | + | isInfinite y && isInfinite y' = pure () |
| 59 | + | Prelude.isNaN y && Prelude.isNaN y' = pure () |
| 60 | + | otherwise = y `shouldBeApprox` y' |
| 61 | + where |
| 62 | + y = evalf (f (scalar x)) |
| 63 | + y' = f' x |
| 64 | + |
| 65 | +shouldMatchBuiltin2 :: |
| 66 | + (AFType a, Ord a, RealFloat a, HasEpsilon a, Show a) => |
| 67 | + (Array a -> Array a -> Array a) -> |
| 68 | + (a -> a -> a) -> |
| 69 | + a -> |
| 70 | + a -> |
| 71 | + Expectation |
| 72 | +shouldMatchBuiltin2 f f' a = shouldMatchBuiltin (f (scalar a)) (f' a) |
8 | 73 |
|
9 | 74 | spec :: Spec
|
10 | 75 | spec =
|
11 | 76 | describe "Arith tests" $ do
|
12 | 77 | it "Should negate scalar value" $ do
|
13 | 78 | negate (scalar @Int 1) `shouldBe` (-1)
|
14 | 79 | it "Should negate a vector" $ do
|
15 |
| - negate (vector @Int 3 [2,2,2]) `shouldBe` vector @Int 3 [-2,-2,-2] |
| 80 | + negate (vector @Int 3 [2, 2, 2]) `shouldBe` vector @Int 3 [-2, -2, -2] |
16 | 81 | it "Should add two scalar arrays" $ do
|
17 | 82 | scalar @Int 1 + 2 `shouldBe` 3
|
18 | 83 | it "Should add two scalar bool arrays" $ do
|
19 | 84 | scalar @CBool 1 + 0 `shouldBe` 1
|
20 | 85 | it "Should subtract two scalar arrays" $ do
|
21 | 86 | scalar @Int 4 - 2 `shouldBe` 2
|
22 | 87 | it "Should multiply two scalar arrays" $ do
|
23 |
| - scalar @Double 4 `mul` 2 `shouldBe` 8 |
| 88 | + scalar @Double 4 `ArrayFire.mul` 2 `shouldBe` 8 |
24 | 89 | it "Should divide two scalar arrays" $ do
|
25 |
| - div @Double 8 2 `shouldBe` 4 |
| 90 | + ArrayFire.div @Double 8 2 `shouldBe` 4 |
26 | 91 | it "Should add two matrices" $ do
|
27 |
| - matrix @Int (2,2) [[1,1],[1,1]] + matrix @Int (2,2) [[1,1],[1,1]] |
28 |
| - `shouldBe` |
29 |
| - matrix @Int (2,2) [[2,2],[2,2]] |
30 |
| - -- Exact comparisons of Double don't make sense here, so we just check that the result is |
31 |
| - -- accurate up to some epsilon. |
32 |
| - it "Should take cubed root" $ do |
33 |
| - allTrueAll ((abs (3 - cbrt @Double 27)) `lt` 1.0e-14) `shouldBe` (1, 0) |
34 |
| - it "Should take square root" $ do |
35 |
| - allTrueAll ((abs (2 - sqrt @Double 4)) `lt` 1.0e-14) `shouldBe` (1, 0) |
| 92 | + matrix @Int (2, 2) [[1, 1], [1, 1]] + matrix @Int (2, 2) [[1, 1], [1, 1]] |
| 93 | + `shouldBe` matrix @Int (2, 2) [[2, 2], [2, 2]] |
| 94 | + prop "Should take cubed root" $ \(x :: Double) -> |
| 95 | + evalf (ArrayFire.cbrt (scalar (x * x * x))) `shouldBeApprox` x |
36 | 96 |
|
37 | 97 | it "Should lte Array" $ do
|
38 |
| - 2 `le` (3 :: Array Double) `shouldBe` 1 |
| 98 | + 2 `ArrayFire.le` (3 :: Array Double) `shouldBe` 1 |
39 | 99 | it "Should gte Array" $ do
|
40 |
| - 2 `ge` (3 :: Array Double) `shouldBe` 0 |
| 100 | + 2 `ArrayFire.ge` (3 :: Array Double) `shouldBe` 0 |
41 | 101 | it "Should gt Array" $ do
|
42 |
| - 2 `gt` (3 :: Array Double) `shouldBe` 0 |
| 102 | + 2 `ArrayFire.gt` (3 :: Array Double) `shouldBe` 0 |
43 | 103 | it "Should lt Array" $ do
|
44 |
| - 2 `le` (3 :: Array Double) `shouldBe` 1 |
| 104 | + 2 `ArrayFire.le` (3 :: Array Double) `shouldBe` 1 |
45 | 105 | it "Should eq Array" $ do
|
46 | 106 | 3 == (3 :: Array Double) `shouldBe` True
|
47 | 107 | it "Should and Array" $ do
|
48 |
| - (mkArray @CBool [1] [0] `and` mkArray [1] [1]) |
49 |
| - `shouldBe` mkArray [1] [0] |
| 108 | + (mkArray @CBool [1] [0] `ArrayFire.and` mkArray [1] [1]) |
| 109 | + `shouldBe` mkArray [1] [0] |
50 | 110 | it "Should and Array" $ do
|
51 |
| - (mkArray @CBool [2] [0,0] `and` mkArray [2] [1,0]) |
52 |
| - `shouldBe` mkArray [2] [0, 0] |
| 111 | + (mkArray @CBool [2] [0, 0] `ArrayFire.and` mkArray [2] [1, 0]) |
| 112 | + `shouldBe` mkArray [2] [0, 0] |
53 | 113 | it "Should or Array" $ do
|
54 |
| - (mkArray @CBool [2] [0,0] `or` mkArray [2] [1,0]) |
55 |
| - `shouldBe` mkArray [2] [1, 0] |
| 114 | + (mkArray @CBool [2] [0, 0] `ArrayFire.or` mkArray [2] [1, 0]) |
| 115 | + `shouldBe` mkArray [2] [1, 0] |
56 | 116 | it "Should not Array" $ do
|
57 |
| - not (mkArray @CBool [2] [1,0]) `shouldBe` mkArray [2] [0,1] |
| 117 | + ArrayFire.not (mkArray @CBool [2] [1, 0]) `shouldBe` mkArray [2] [0, 1] |
58 | 118 | it "Should bitwise and array" $ do
|
59 |
| - bitAnd (scalar @Int 1) (scalar @Int 0) |
60 |
| - `shouldBe` |
61 |
| - 0 |
| 119 | + ArrayFire.bitAnd (scalar @Int 1) (scalar @Int 0) |
| 120 | + `shouldBe` 0 |
62 | 121 | it "Should bitwise or array" $ do
|
63 |
| - bitOr (scalar @Int 1) (scalar @Int 0) |
64 |
| - `shouldBe` |
65 |
| - 1 |
| 122 | + ArrayFire.bitOr (scalar @Int 1) (scalar @Int 0) |
| 123 | + `shouldBe` 1 |
66 | 124 | it "Should bitwise xor array" $ do
|
67 |
| - bitXor (scalar @Int 1) (scalar @Int 1) |
68 |
| - `shouldBe` |
69 |
| - 0 |
| 125 | + ArrayFire.bitXor (scalar @Int 1) (scalar @Int 1) |
| 126 | + `shouldBe` 0 |
70 | 127 | it "Should bitwise shift left an array" $ do
|
71 |
| - bitShiftL (scalar @Int 1) (scalar @Int 3) |
72 |
| - `shouldBe` |
73 |
| - 8 |
| 128 | + ArrayFire.bitShiftL (scalar @Int 1) (scalar @Int 3) |
| 129 | + `shouldBe` 8 |
74 | 130 | it "Should cast an array" $ do
|
75 | 131 | getType (cast (scalar @Int 1) :: Array Double)
|
76 |
| - `shouldBe` |
77 |
| - F64 |
| 132 | + `shouldBe` ArrayFire.F64 |
78 | 133 | it "Should find the minimum of two arrays" $ do
|
79 | 134 | minOf (scalar @Int 1) (scalar @Int 0)
|
80 |
| - `shouldBe` |
81 |
| - 0 |
| 135 | + `shouldBe` 0 |
82 | 136 | it "Should find the max of two arrays" $ do
|
83 | 137 | maxOf (scalar @Int 1) (scalar @Int 0)
|
84 |
| - `shouldBe` |
85 |
| - 1 |
| 138 | + `shouldBe` 1 |
86 | 139 | it "Should take the clamp of 3 arrays" $ do
|
87 | 140 | clamp (scalar @Int 2) (scalar @Int 1) (scalar @Int 3)
|
88 |
| - `shouldBe` |
89 |
| - 2 |
| 141 | + `shouldBe` 2 |
90 | 142 | it "Should check if an array has positive or negative infinities" $ do
|
91 | 143 | isInf (scalar @Double (1 / 0)) `shouldBe` scalar @Double 1
|
92 | 144 | isInf (scalar @Double 10) `shouldBe` scalar @Double 0
|
93 | 145 | it "Should check if an array has any NaN values" $ do
|
94 |
| - isNaN (scalar @Double (acos 2)) `shouldBe` scalar @Double 1 |
95 |
| - isNaN (scalar @Double 10) `shouldBe` scalar @Double 0 |
| 146 | + ArrayFire.isNaN (scalar @Double (acos 2)) `shouldBe` scalar @Double 1 |
| 147 | + ArrayFire.isNaN (scalar @Double 10) `shouldBe` scalar @Double 0 |
96 | 148 | it "Should check if an array has any Zero values" $ do
|
97 | 149 | isZero (scalar @Double (acos 2)) `shouldBe` scalar @Double 0
|
98 | 150 | isZero (scalar @Double 0) `shouldBe` scalar @Double 1
|
99 | 151 | isZero (scalar @Double 1) `shouldBe` scalar @Double 0
|
| 152 | + |
| 153 | + prop "Floating @Float (exp)" $ \(x :: Float) -> exp `shouldMatchBuiltin` exp $ x |
| 154 | + prop "Floating @Float (log)" $ \(x :: Float) -> log `shouldMatchBuiltin` log $ x |
| 155 | + prop "Floating @Float (sqrt)" $ \(x :: Float) -> sqrt `shouldMatchBuiltin` sqrt $ x |
| 156 | + prop "Floating @Float (**)" $ \(x :: Float) (y :: Float) -> ((**) `shouldMatchBuiltin2` (**)) x y |
| 157 | + prop "Floating @Float (sin)" $ \(x :: Float) -> sin `shouldMatchBuiltin` sin $ x |
| 158 | + prop "Floating @Float (cos)" $ \(x :: Float) -> cos `shouldMatchBuiltin` cos $ x |
| 159 | + prop "Floating @Float (tan)" $ \(x :: Float) -> tan `shouldMatchBuiltin` tan $ x |
| 160 | + prop "Floating @Float (asin)" $ \(x :: Float) -> asin `shouldMatchBuiltin` asin $ x |
| 161 | + prop "Floating @Float (acos)" $ \(x :: Float) -> acos `shouldMatchBuiltin` acos $ x |
| 162 | + prop "Floating @Float (atan)" $ \(x :: Float) -> atan `shouldMatchBuiltin` atan $ x |
| 163 | + prop "Floating @Float (sinh)" $ \(x :: Float) -> sinh `shouldMatchBuiltin` sinh $ x |
| 164 | + prop "Floating @Float (cosh)" $ \(x :: Float) -> cosh `shouldMatchBuiltin` cosh $ x |
| 165 | + prop "Floating @Float (tanh)" $ \(x :: Float) -> tanh `shouldMatchBuiltin` tanh $ x |
| 166 | + prop "Floating @Float (asinh)" $ \(x :: Float) -> asinh `shouldMatchBuiltin` asinh $ x |
| 167 | + prop "Floating @Float (acosh)" $ \(x :: Float) -> acosh `shouldMatchBuiltin` acosh $ x |
| 168 | + prop "Floating @Float (atanh)" $ \(x :: Float) -> atanh `shouldMatchBuiltin` atanh $ x |
0 commit comments