Skip to content

Commit

Permalink
add dimension check
Browse files Browse the repository at this point in the history
  • Loading branch information
Priya2698 committed Apr 23, 2024
1 parent f7a6121 commit fb5071f
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2204,10 +2204,13 @@ def _matmul_check(
return False

enable_matmul: None | bool = get_compile_option("nv_enable_matmul", "Enable nvFuser matmul.")
if enable_matmul is None:
enable_matmul = False

return enable_matmul and is_supported_tensor(a) and is_supported_tensor(b)
if not enable_matmul:
return False
if not are_supported_tensors(a, b):
return False
if not (a.ndim == b.ndim and a.ndim == 2):
return False
return True


def matmul(
Expand Down

0 comments on commit fb5071f

Please sign in to comment.