1313from cuda .tile .compilation import CallingConvention , KernelSignature
1414
1515
16- def compile_with_version (pyfunc , args , version : str ):
17- kernel = ct .kernel (pyfunc )
16+ def compile_with_version (kernel , args , version : str ):
1817 cconv = CallingConvention .cutile_python_v1 ()
1918 sig = KernelSignature .from_kernel_args (kernel , args , cconv )
2019 ct .compilation .export_kernel (kernel , [sig ], output_file = BytesIO (),
@@ -27,6 +26,7 @@ def tensor(dtype=torch.float32):
2726
2827
2928def test_atan2_requires_13_2 ():
29+ @ct .kernel
3030 def kernel (x , y , z ):
3131 tx = ct .load (x , 0 , shape = 64 )
3232 ty = ct .load (y , 0 , shape = 64 )
@@ -37,6 +37,7 @@ def kernel(x, y, z):
3737
3838
3939def test_tanh_rounding_mode_requires_13_2 ():
40+ @ct .kernel
4041 def kernel (x , y ):
4142 tx = ct .load (x , 0 , shape = 64 )
4243 ct .store (y , 0 , tile = ct .tanh (tx , rounding_mode = RoundingMode .APPROX ))
@@ -47,6 +48,7 @@ def kernel(x, y):
4748
4849
4950def test_tanh_without_rounding_mode_works_on_13_1 ():
51+ @ct .kernel
5052 def kernel (x , y ):
5153 tx = ct .load (x , 0 , shape = 64 )
5254 ct .store (y , 0 , tile = ct .tanh (tx ))
@@ -56,6 +58,7 @@ def kernel(x, y):
5658
5759
5860def test_exp_rounding_mode_requires_13_3 ():
61+ @ct .kernel
5962 def kernel (x , y ):
6063 tx = ct .load (x , 0 , shape = 64 )
6164 ct .store (y , 0 , tile = ct .exp (tx , rounding_mode = RoundingMode .APPROX ))
@@ -66,9 +69,21 @@ def kernel(x, y):
6669
6770
6871def test_exp_without_rounding_mode_works_on_13_1 ():
72+ @ct .kernel
6973 def kernel (x , y ):
7074 tx = ct .load (x , 0 , shape = 64 )
7175 ct .store (y , 0 , tile = ct .exp (tx ))
7276
7377 # Should not raise version error
7478 compile_with_version (kernel , (tensor (), tensor ()), "13.1" )
79+
80+
81+ def test_num_worker_warps_warns_below_13_3 ():
82+ @ct .kernel (num_worker_warps = 8 )
83+ def kernel (x , y ):
84+ tx = ct .load (x , 0 , shape = 64 )
85+ ct .store (y , 0 , tile = tx )
86+
87+ with pytest .warns (UserWarning ,
88+ match = r"num_worker_warps requires tileiras 13\.3" ):
89+ compile_with_version (kernel , (tensor (), tensor ()), "13.1" )
0 commit comments