1111from util import (
1212 assert_close , assert_equal , require_hopper_or_newer , torch_to_tf32 , is_ampere_or_ada
1313)
14- from conftest import dtype_id
14+ from conftest import dtype_id , get_tileiras_version
15+ from cuda .tile ._bytecode .version import BytecodeVersion
1516from cuda .tile ._exception import TileTypeError , TileUnsupportedFeatureError
1617
1718
2122def mma_kernel (A , B , C ,
2223 tm : ct .Constant [int ],
2324 tn : ct .Constant [int ],
24- tk : ct .Constant [int ]):
25+ tk : ct .Constant [int ],
26+ use_fast_acc : ct .Constant [bool ]):
2527 tx = ct .load (A , index = (0 , 0 ), shape = (tm , tk ))
2628 ty = ct .load (B , index = (0 , 0 ), shape = (tk , tn ))
2729 acc = ct .load (C , index = (0 , 0 ), shape = (tm , tn ))
28- acc = ct .mma (tx , ty , acc )
30+ acc = ct .mma (tx , ty , acc , use_fast_acc = use_fast_acc )
2931 ct .store (C , index = (0 , 0 ), tile = acc )
3032
3133
@@ -110,32 +112,56 @@ def test_mma_regular_float(tile_size, case):
110112 C = torch .ones ((m , n ), dtype = case .acc_dtype , device = "cuda" )
111113 ref = torch .mm (A , B , out_dtype = C .dtype ) + C
112114 ct .launch (torch .cuda .current_stream (), (1 ,), mma_kernel ,
113- (A , B , C , m , n , k ))
115+ (A , B , C , m , n , k , False ))
114116 atol , rtol = get_tolerance (A .dtype )
115117 assert_close (C , ref , atol = atol , rtol = rtol )
116118
117119
120+ @ct .kernel
121+ def mma_fast_acc_kernel (A , B , C ,
122+ tm : ct .Constant [int ],
123+ tn : ct .Constant [int ],
124+ tk : ct .Constant [int ]):
125+ tx = ct .load (A , index = (0 , 0 ), shape = (tm , tk ))
126+ ty = ct .load (B , index = (0 , 0 ), shape = (tk , tn ))
127+ acc = ct .load (C , index = (0 , 0 ), shape = (tm , tn ))
128+ acc = ct .mma (tx , ty , acc , use_fast_acc = True )
129+ ct .store (C , index = (0 , 0 ), tile = acc )
130+
131+
118132@require_hopper_or_newer ()
119133@pytest .mark .parametrize ("tile_size" , [(16 , 16 , 16 )])
120134@pytest .mark .parametrize ("case" , fp8_cases , ids = str )
121- def test_mma_fp8 (tile_size , case ):
135+ @pytest .mark .parametrize ("use_fast_acc" , [True , False ])
136+ def test_mma_fp8 (tile_size , case , use_fast_acc ):
137+ if use_fast_acc and get_tileiras_version () < BytecodeVersion .V_13_3 :
138+ pytest .skip ("use_fast_acc requires tileiras 13.3" )
122139 m , n , k = tile_size
123140 A = torch .randn ((m , k ), dtype = torch .float32 , device = "cuda" ).to (case .dtype )
124141 B = torch .randn ((n , k ), dtype = torch .float32 , device = "cuda" ).to (case .dtype )
125142 C = torch .ones ((m , n ), dtype = case .acc_dtype , device = "cuda" )
126143 scale = torch .tensor ([1.0 ], dtype = torch .float32 , device = "cuda" )
127144 try :
128- ref = torch ._scaled_mm (A , B .T , scale , scale , out_dtype = C .dtype ) + C
145+ ref = torch ._scaled_mm (A , B .T , scale , scale , out_dtype = C .dtype ,
146+ use_fast_accum = use_fast_acc ) + C
129147 except (RuntimeError , ValueError ) as e :
130148 assert 'Multiplication of two Float8_e5m2 matrices is not supported' in str (e )
131149 ref = None
132150 ct .launch (torch .cuda .current_stream (), (1 ,), mma_kernel ,
133- (A , B .T , C , m , n , k ))
151+ (A , B .T , C , m , n , k , use_fast_acc ))
134152 if ref is not None :
135153 atol , rtol = get_tolerance (A .dtype )
136154 assert_close (C , ref , atol = atol , rtol = rtol )
137155
138156
157+ def test_mma_fast_acc_non_fp8_error ():
158+ A = torch .randn ((2 , 4 ), dtype = torch .float16 , device = "cuda" )
159+ B = torch .randn ((4 , 2 ), dtype = torch .float16 , device = "cuda" )
160+ C = torch .zeros ((2 , 2 ), dtype = torch .float16 , device = "cuda" )
161+ with pytest .raises (TileTypeError , match = "use_fast_acc is only supported for fp8" ):
162+ ct .launch (torch .cuda .current_stream (), (1 ,), mma_fast_acc_kernel , (A , B , C , 2 , 2 , 4 ))
163+
164+
139165@pytest .mark .parametrize ("tile_size" , [(8 , 2 , 4 )])
140166def test_mma_tf32 (tile_size ):
141167 m , n , k = tile_size
@@ -163,7 +189,7 @@ def test_mma_int(tile_size, case):
163189 C = torch .ones ((m , n ), dtype = case .acc_dtype , device = "cuda" )
164190 ref = C + (A .to (torch .float32 ) @ B .to (torch .float32 )).to (C .dtype )
165191 ct .launch (torch .cuda .current_stream (), (1 ,), mma_kernel ,
166- (A , B , C , m , n , k ))
192+ (A , B , C , m , n , k , False ))
167193 assert_equal (C , ref )
168194
169195
@@ -175,7 +201,7 @@ def test_mma_mixed_int_uint(tile_size):
175201 C = torch .ones ((m , n ), dtype = torch .int32 , device = "cuda" )
176202 ref = C + (A .to (torch .float32 ) @ B .to (torch .float32 )).to (C .dtype )
177203 ct .launch (torch .cuda .current_stream (), (1 ,), mma_kernel ,
178- (A , B , C , m , n , k ))
204+ (A , B , C , m , n , k , False ))
179205 assert_equal (C , ref )
180206
181207
@@ -229,7 +255,7 @@ def test_mma_dtype_error(case):
229255 with pytest .raises (TileTypeError , match = case .message ):
230256 ct .launch (torch .cuda .current_stream (),
231257 (1 ,), mma_kernel ,
232- (A , B , C , 2 , 2 , 2 ))
258+ (A , B , C , 2 , 2 , 2 , False ))
233259
234260# ================ ct.matmul =================
235261
@@ -405,4 +431,4 @@ def test_ampere_fp8_error(dtype):
405431 with pytest .raises (TileUnsupportedFeatureError ,
406432 match = "is not supported on sm_80" ):
407433 ct .launch (torch .cuda .current_stream (), (1 ,), mma_kernel ,
408- (A , B , C , 16 , 16 , 16 ))
434+ (A , B , C , 16 , 16 , 16 , False ))
0 commit comments