55import math
66from typing import Callable
77from cuda .tile ._ir .ir import Block
8- from cuda .tile ._ir .ops import TypedConst
8+ from cuda .tile ._ir .ops import TileAtomicRMW , TileAtomicRedView , TypedConst , AtomicRMWMode
99from cuda .tile ._ir .type import TileTy , PointerTy , Type
10- from cuda .tile ._datatype import DType , float4_e2m1fn , float8_e4m3fn , float8_e5m2 , float8_e8m0fnu
10+ from cuda .tile ._datatype import (
11+ DType , float4_e2m1fn , float8_e4m3fn , float8_e5m2 , float8_e8m0fnu , bfloat16
12+ )
1113from cuda .tile ._bytecode .version import BytecodeVersion
1214from cuda .tile ._exception import TileUnsupportedFeatureError , TileValueError
1315
@@ -65,8 +67,33 @@ def _check_const_value(op: TypedConst):
6567 raise TileValueError (msg , loc = op .loc )
6668
6769
68- def _check_dtype (dtype : DType , sm_arch : str , version : BytecodeVersion , loc ):
69- sm_number = int (sm_arch .removeprefix ("sm_" ))
70+ def _check_atomic_rmw_dtype (op : TileAtomicRedView | TileAtomicRMW ,
71+ sm_arch : str ,
72+ sm_number : int ,
73+ version : BytecodeVersion ):
74+ dtypes = (_extract_dtypes (op .view .try_get_type ())
75+ if isinstance (op , TileAtomicRedView ) else
76+ _extract_dtypes (op .result_vars [0 ].try_get_type ()))
77+ if not (op .mode == AtomicRMWMode .ADD_FLOAT and bfloat16 in dtypes ):
78+ return
79+
80+ if sm_number < 90 :
81+ raise TileUnsupportedFeatureError (
82+ f"{ bfloat16 } is not supported by atomic add on { sm_arch } " ,
83+ loc = op .loc
84+ )
85+
86+ min_version = BytecodeVersion .V_13_3
87+ if version < min_version :
88+ raise TileUnsupportedFeatureError (
89+ f"{ bfloat16 } on atomic add requires tileiras"
90+ f" { min_version .as_string ()} or later."
91+ f" Current version is { version .as_string ()} ." ,
92+ loc = op .loc
93+ )
94+
95+
96+ def _check_dtype (dtype : DType , sm_arch : str , sm_number : int , version : BytecodeVersion , loc ):
7097 min_sm = _DTYPE_MIN_SM .get (dtype )
7198 if min_sm is not None and sm_number < min_sm :
7299 raise TileUnsupportedFeatureError (
@@ -78,17 +105,21 @@ def _check_dtype(dtype: DType, sm_arch: str, version: BytecodeVersion, loc):
78105 if min_version is not None and version < min_version :
79106 raise TileUnsupportedFeatureError (
80107 f"{ dtype } requires tileiras"
81- f" { min_version .major () } . { min_version . minor ()} or later."
82- f" Current version is { version .major () } . { version . minor ()} ." ,
108+ f" { min_version .as_string ()} or later."
109+ f" Current version is { version .as_string ()} ." ,
83110 loc = loc ,
84111 )
85112
86113
87114def check_dtype_support (root_block : Block , sm_arch : str , version : BytecodeVersion ) -> None :
115+ sm_number = int (sm_arch .removeprefix ("sm_" ))
88116 for op in root_block .traverse ():
89117 if isinstance (op , TypedConst ):
90118 _check_const_value (op )
91119
120+ if isinstance (op , (TileAtomicRedView , TileAtomicRMW )):
121+ _check_atomic_rmw_dtype (op , sm_arch , sm_number , version )
122+
92123 all_dtypes = set ().union (* (_extract_dtypes (v .try_get_type ()) for v in op .all_inputs ()))
93124 for dtype in all_dtypes :
94- _check_dtype (dtype , sm_arch , version , op .loc )
125+ _check_dtype (dtype , sm_arch , sm_number , version , op .loc )
0 commit comments