@@ -336,7 +336,7 @@ def test_scalar_rounding(shape, tile, is_constant, dtype, op, tmp_path):
336336@requires_tileiras (BytecodeVersion .V_13_2 )
337337@pytest .mark .use_mlir
338338@pytest .mark .parametrize ("dtype" , float_dtypes , ids = dtype_id )
339- @pytest .mark .parametrize ("rounding_mode" , [RMd .FULL , RMd .APPROX ])
339+ @pytest .mark .parametrize ("rounding_mode" , [RMd .FULL , RMd .APPROX , None ])
340340def test_array_tanh_rounding_mode (shape , tile , dtype , rounding_mode , tmp_path ):
341341 should_raise_dtype = rounding_mode in [RMd .FULL , RMd .APPROX ] and dtype != torch .float32
342342 x = make_tensor (shape , dtype = dtype , device = 'cuda' )
@@ -353,7 +353,7 @@ def test_array_tanh_rounding_mode(shape, tile, dtype, rounding_mode, tmp_path):
353353 launch_unary (kernel , x , y , tile )
354354 else :
355355 bytecode = get_bytecode (kernel , (x , y , tile ))
356- if rounding_mode is RMd .FULL :
356+ if rounding_mode in ( RMd .FULL , None ) :
357357 # FULL is the default, not included in mlir text
358358 check_directive = "// CHECK: %[[RES:.*]] = tanh %[[A:.*]]{{[[:space:]]*}}:"
359359 else :
@@ -363,3 +363,35 @@ def test_array_tanh_rounding_mode(shape, tile, dtype, rounding_mode, tmp_path):
363363 filecheck (bytecode , check_directive )
364364 launch_unary (kernel , x , y , tile )
365365 assert_close (y , y_ref )
366+
367+
368+ @requires_tileiras (BytecodeVersion .V_13_3 )
369+ @pytest .mark .use_mlir
370+ @pytest .mark .parametrize ("dtype" , float_dtypes , ids = dtype_id )
371+ @pytest .mark .parametrize ("rounding_mode" , [RMd .FULL , RMd .APPROX , None ])
372+ def test_array_exp_rounding_mode (shape , tile , dtype , rounding_mode , tmp_path ):
373+ should_raise_dtype = rounding_mode in [RMd .FULL , RMd .APPROX ] and dtype != torch .float32
374+ x = make_tensor (shape , dtype = dtype , device = 'cuda' )
375+ y_ref = torch .exp (x )
376+ y = torch .zeros_like (y_ref , device = "cuda" )
377+ kernel = array_kernel ("exp_rounding_mode" ,
378+ f"ty = ct.exp(tx, rounding_mode={ rounding_mode } )" ,
379+ tmp_path ,
380+ globals = {"RoundingMode" : RMd })
381+ if should_raise_dtype :
382+ with pytest .raises (TileTypeError ,
383+ match = fr"Rounding mode { rounding_mode .value } can only be used for "
384+ "float32 type" ):
385+ launch_unary (kernel , x , y , tile )
386+ else :
387+ bytecode = get_bytecode (kernel , (x , y , tile ))
388+ if rounding_mode in (RMd .FULL , None ):
389+ # FULL is the default, not included in mlir text
390+ check_directive = "// CHECK: %[[RES:.*]] = exp %[[A:.*]]{{[[:space:]]*}}:"
391+ else :
392+ check_directive = (
393+ f"// CHECK: %[[RES:.*]] = exp %[[A:.*]] rounding<{ rounding_mode .value } >"
394+ )
395+ filecheck (bytecode , check_directive )
396+ launch_unary (kernel , x , y , tile )
397+ assert_close (y , y_ref )
0 commit comments