22
33import pytest
44
5- import arrayfire_wrapper .dtypes as dtypes
65import arrayfire_wrapper .lib as wrapper
6+ from arrayfire_wrapper .dtypes import (
7+ Dtype ,
8+ c32 ,
9+ c64 ,
10+ c_api_value_to_dtype ,
11+ f16 ,
12+ f32 ,
13+ f64 ,
14+ s16 ,
15+ s32 ,
16+ s64 ,
17+ u8 ,
18+ u16 ,
19+ u32 ,
20+ u64 ,
21+ )
722
823invalid_shape = (
924 random .randint (1 , 10 ),
1429)
1530
1631
32+ all_types = [s16 , s32 , s64 , u8 , u16 , u32 , u64 , f16 , f32 , f64 , c32 , c64 ]
33+
34+
1735@pytest .mark .parametrize (
1836 "shape" ,
1937 [
2745def test_constant_shape (shape : tuple ) -> None :
2846 """Test if constant creates an array with the correct shape."""
2947 number = 5.0
30- dtype = dtypes . s16
48+ dtype = s16
3149
3250 result = wrapper .constant (number , shape , dtype )
3351
@@ -46,9 +64,8 @@ def test_constant_shape(shape: tuple) -> None:
4664)
4765def test_constant_complex_shape (shape : tuple ) -> None :
4866 """Test if constant_complex creates an array with the correct shape."""
49- dtype = dtypes . c32
67+ dtype = c32
5068
51- dtype = dtypes .c32
5269 rand_array = wrapper .randu ((1 , 1 ), dtype )
5370 number = wrapper .get_scalar (rand_array , dtype )
5471
@@ -71,7 +88,7 @@ def test_constant_complex_shape(shape: tuple) -> None:
7188)
7289def test_constant_long_shape (shape : tuple ) -> None :
7390 """Test if constant_long creates an array with the correct shape."""
74- dtype = dtypes . s64
91+ dtype = s64
7592 rand_array = wrapper .randu ((1 , 1 ), dtype )
7693 number = wrapper .get_scalar (rand_array , dtype )
7794
@@ -93,7 +110,7 @@ def test_constant_long_shape(shape: tuple) -> None:
93110)
94111def test_constant_ulong_shape (shape : tuple ) -> None :
95112 """Test if constant_ulong creates an array with the correct shape."""
96- dtype = dtypes . u64
113+ dtype = u64
97114 rand_array = wrapper .randu ((1 , 1 ), dtype )
98115 number = wrapper .get_scalar (rand_array , dtype )
99116
@@ -109,15 +126,15 @@ def test_constant_shape_invalid() -> None:
109126 """Test if constant handles a shape with greater than 4 dimensions"""
110127 with pytest .raises (TypeError ):
111128 number = 5.0
112- dtype = dtypes . s16
129+ dtype = s16
113130
114131 wrapper .constant (number , invalid_shape , dtype )
115132
116133
117134def test_constant_complex_shape_invalid () -> None :
118135 """Test if constant_complex handles a shape with greater than 4 dimensions"""
119136 with pytest .raises (TypeError ):
120- dtype = dtypes . c32
137+ dtype = c32
121138 rand_array = wrapper .randu ((1 , 1 ), dtype )
122139 number = wrapper .get_scalar (rand_array , dtype )
123140
@@ -128,7 +145,7 @@ def test_constant_complex_shape_invalid() -> None:
128145def test_constant_long_shape_invalid () -> None :
129146 """Test if constant_long handles a shape with greater than 4 dimensions"""
130147 with pytest .raises (TypeError ):
131- dtype = dtypes . s64
148+ dtype = s64
132149 rand_array = wrapper .randu ((1 , 1 ), dtype )
133150 number = wrapper .get_scalar (rand_array , dtype )
134151
@@ -139,7 +156,7 @@ def test_constant_long_shape_invalid() -> None:
139156def test_constant_ulong_shape_invalid () -> None :
140157 """Test if constant_ulong handles a shape with greater than 4 dimensions"""
141158 with pytest .raises (TypeError ):
142- dtype = dtypes . u64
159+ dtype = u64
143160 rand_array = wrapper .randu ((1 , 1 ), dtype )
144161 number = wrapper .get_scalar (rand_array , dtype )
145162
@@ -148,50 +165,47 @@ def test_constant_ulong_shape_invalid() -> None:
148165
149166
150167@pytest .mark .parametrize (
151- "dtype_index " ,
152- [ i for i in range ( 13 )] ,
168+ "dtype " ,
169+ all_types ,
153170)
154- def test_constant_dtype (dtype_index : int ) -> None :
171+ def test_constant_dtype (dtype : Dtype ) -> None :
155172 """Test if constant creates an array with the correct dtype."""
156- if dtype_index in [ 1 , 3 ] or ( dtype_index == 2 and not wrapper . get_dbl_support () ):
173+ if is_cmplx_type ( dtype ) or not is_system_supported ( dtype ):
157174 pytest .skip ()
158175
159- dtype = dtypes .c_api_value_to_dtype (dtype_index )
160-
161176 rand_array = wrapper .randu ((1 , 1 ), dtype )
162177 value = wrapper .get_scalar (rand_array , dtype )
163178 shape = (2 , 2 )
164179 if isinstance (value , (int , float )):
165180 result = wrapper .constant (value , shape , dtype )
166- assert dtypes . c_api_value_to_dtype (wrapper .get_type (result )) == dtype
181+ assert c_api_value_to_dtype (wrapper .get_type (result )) == dtype
167182 else :
168183 pytest .skip ()
169184
170185
171186@pytest .mark .parametrize (
172- "dtype_index " ,
173- [ i for i in range ( 13 )] ,
187+ "dtype " ,
188+ all_types ,
174189)
175- def test_constant_complex_dtype (dtype_index : int ) -> None :
190+ def test_constant_complex_dtype (dtype : Dtype ) -> None :
176191 """Test if constant_complex creates an array with the correct dtype."""
177- if dtype_index not in [ 1 , 3 ] or ( dtype_index == 3 and not wrapper . get_dbl_support () ):
192+ if not is_cmplx_type ( dtype ) or not is_system_supported ( dtype ):
178193 pytest .skip ()
179194
180- dtype = dtypes .c_api_value_to_dtype (dtype_index )
181195 rand_array = wrapper .randu ((1 , 1 ), dtype )
182196 value = wrapper .get_scalar (rand_array , dtype )
183197 shape = (2 , 2 )
184198
185199 if isinstance (value , (int , float , complex )):
186200 result = wrapper .constant_complex (value , shape , dtype )
187- assert dtypes . c_api_value_to_dtype (wrapper .get_type (result )) == dtype
201+ assert c_api_value_to_dtype (wrapper .get_type (result )) == dtype
188202 else :
189203 pytest .skip ()
190204
191205
192206def test_constant_long_dtype () -> None :
193207 """Test if constant_long creates an array with the correct dtype."""
194- dtype = dtypes . s64
208+ dtype = s64
195209
196210 rand_array = wrapper .randu ((1 , 1 ), dtype )
197211 value = wrapper .get_scalar (rand_array , dtype )
@@ -200,14 +214,14 @@ def test_constant_long_dtype() -> None:
200214 if isinstance (value , (int , float )):
201215 result = wrapper .constant_long (value , shape , dtype )
202216
203- assert dtypes . c_api_value_to_dtype (wrapper .get_type (result )) == dtype
217+ assert c_api_value_to_dtype (wrapper .get_type (result )) == dtype
204218 else :
205219 pytest .skip ()
206220
207221
208222def test_constant_ulong_dtype () -> None :
209223 """Test if constant_ulong creates an array with the correct dtype."""
210- dtype = dtypes . u64
224+ dtype = u64
211225
212226 rand_array = wrapper .randu ((1 , 1 ), dtype )
213227 value = wrapper .get_scalar (rand_array , dtype )
@@ -216,6 +230,17 @@ def test_constant_ulong_dtype() -> None:
216230 if isinstance (value , (int , float )):
217231 result = wrapper .constant_ulong (value , shape , dtype )
218232
219- assert dtypes . c_api_value_to_dtype (wrapper .get_type (result )) == dtype
233+ assert c_api_value_to_dtype (wrapper .get_type (result )) == dtype
220234 else :
221235 pytest .skip ()
236+
237+
238+ def is_cmplx_type (dtype : Dtype ) -> bool :
239+ return dtype == c32 or dtype == c64
240+
241+
242+ def is_system_supported (dtype : Dtype ) -> bool :
243+ if dtype in [f64 , c64 ] and not wrapper .get_dbl_support ():
244+ return False
245+
246+ return True
0 commit comments