5
5
import arrayfire_wrapper .dtypes as dtype
6
6
import arrayfire_wrapper .lib as wrapper
7
7
8
-
9
8
dtype_map = {
10
- 'int16' : dtype .s16 ,
11
- 'int32' : dtype .s32 ,
12
- 'int64' : dtype .s64 ,
13
- 'uint8' : dtype .u8 ,
14
- 'uint16' : dtype .u16 ,
15
- 'uint32' : dtype .u32 ,
16
- 'uint64' : dtype .u64 ,
17
- 'float16' : dtype .f16 ,
18
- 'float32' : dtype .f32 ,
19
9
"int16" : dtype .s16 ,
20
10
"int32" : dtype .s32 ,
21
11
"int64" : dtype .s64 ,
28
18
# 'float64': dtype.f64,
29
19
# 'complex64': dtype.c64,
30
20
# 'complex32': dtype.c32,
31
- 'bool' : dtype .b8 ,
32
- 's16' : dtype .s16 ,
33
- 's32' : dtype .s32 ,
34
- 's64' : dtype .s64 ,
35
- 'u8' : dtype .u8 ,
36
- 'u16' : dtype .u16 ,
37
- 'u32' : dtype .u32 ,
38
- 'u64' : dtype .u64 ,
39
- 'f16' : dtype .f16 ,
40
- 'f32' : dtype .f32 ,
41
21
"bool" : dtype .b8 ,
42
22
"s16" : dtype .s16 ,
43
23
"s32" : dtype .s32 ,
51
31
# 'f64': dtype.f64,
52
32
# 'c32': dtype.c32,
53
33
# 'c64': dtype.c64,
54
- 'b8' : dtype .b8 ,
55
34
"b8" : dtype .b8 ,
56
35
}
57
36
60
39
"shape" ,
61
40
[
62
41
(),
63
- (random .randint (1 , 10 ), ),
42
+ (random .randint (1 , 10 ),),
64
43
(random .randint (1 , 10 ),),
65
44
(random .randint (1 , 10 ),),
66
45
(random .randint (1 , 10 ), random .randint (1 , 10 )),
@@ -87,11 +66,13 @@ def test_asinh_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
87
66
"""Test inverse hyperbolic sine operation for unsupported data types."""
88
67
with pytest .raises (RuntimeError ):
89
68
wrapper .asinh (wrapper .randu ((10 , 10 ), invdtypes ))
69
+
70
+
90
71
@pytest .mark .parametrize (
91
72
"shape" ,
92
73
[
93
74
(),
94
- (random .randint (1 , 10 ), ),
75
+ (random .randint (1 , 10 ),),
95
76
(random .randint (1 , 10 ),),
96
77
(random .randint (1 , 10 ),),
97
78
(random .randint (1 , 10 ), random .randint (1 , 10 )),
@@ -118,11 +99,13 @@ def test_acosh_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
118
99
"""Test inverse hyperbolic cosine operation for unsupported data types."""
119
100
with pytest .raises (RuntimeError ):
120
101
wrapper .acosh (wrapper .randu ((10 , 10 ), invdtypes ))
102
+
103
+
121
104
@pytest .mark .parametrize (
122
105
"shape" ,
123
106
[
124
107
(),
125
- (random .randint (1 , 10 ), ),
108
+ (random .randint (1 , 10 ),),
126
109
(random .randint (1 , 10 ),),
127
110
(random .randint (1 , 10 ),),
128
111
(random .randint (1 , 10 ), random .randint (1 , 10 )),
@@ -150,11 +133,12 @@ def test_atanh_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
150
133
with pytest .raises (RuntimeError ):
151
134
wrapper .atanh (wrapper .randu ((10 , 10 ), invdtypes ))
152
135
136
+
153
137
@pytest .mark .parametrize (
154
138
"shape" ,
155
139
[
156
140
(),
157
- (random .randint (1 , 10 ), ),
141
+ (random .randint (1 , 10 ),),
158
142
(random .randint (1 , 10 ),),
159
143
(random .randint (1 , 10 ),),
160
144
(random .randint (1 , 10 ), random .randint (1 , 10 )),
@@ -182,11 +166,12 @@ def test_cosh_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
182
166
with pytest .raises (RuntimeError ):
183
167
wrapper .cosh (wrapper .randu ((10 , 10 ), invdtypes ))
184
168
169
+
185
170
@pytest .mark .parametrize (
186
171
"shape" ,
187
172
[
188
173
(),
189
- (random .randint (1 , 10 ), ),
174
+ (random .randint (1 , 10 ),),
190
175
(random .randint (1 , 10 ),),
191
176
(random .randint (1 , 10 ),),
192
177
(random .randint (1 , 10 ), random .randint (1 , 10 )),
@@ -214,11 +199,12 @@ def test_sinh_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
214
199
with pytest .raises (RuntimeError ):
215
200
wrapper .sinh (wrapper .randu ((10 , 10 ), invdtypes ))
216
201
202
+
217
203
@pytest .mark .parametrize (
218
204
"shape" ,
219
205
[
220
206
(),
221
- (random .randint (1 , 10 ), ),
207
+ (random .randint (1 , 10 ),),
222
208
(random .randint (1 , 10 ),),
223
209
(random .randint (1 , 10 ),),
224
210
(random .randint (1 , 10 ), random .randint (1 , 10 )),
0 commit comments