Skip to content

Commit

Permalink
Gemma+torch.compile fixes(autocast, rtruediv) (hidet-org#159)
Browse files Browse the repository at this point in the history
Gemma+torch.compile fixes:
 - process `_enter_autocast` and `_exit_autocast` as nop
 - support `truediv(float, Tensor)`
 - and support of eager mode to `tests/benchmarks`
  • Loading branch information
vadiklyutiy committed Apr 22, 2024
1 parent 1444c20 commit a9d4aa8
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 2 deletions.
2 changes: 2 additions & 0 deletions python/hidet/graph/frontend/torch/register_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1175,6 +1175,8 @@ def torch_conj(x: Tensor) -> Tensor:
@register_function(torch._C._log_api_usage_once)
@register_function(torch._assert_async)
@register_function(torch.cuda.synchronize)
@register_function(torch.amp.autocast_mode._enter_autocast)
@register_function(torch.amp.autocast_mode._exit_autocast)
def torch_noop(*args, **kwargs):
return

Expand Down
5 changes: 5 additions & 0 deletions python/hidet/graph/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,11 @@ def __truediv__(self, other) -> Tensor:

return divide(self, utils.convert_to_tensor(other, self))

def __rtruediv__(self, other) -> Tensor:
from .ops import divide, utils

return divide(utils.convert_to_tensor(other, self), self)

def __mod__(self, other) -> Tensor:
from .ops import mod, utils

Expand Down
9 changes: 7 additions & 2 deletions tests/benchmarks/bench_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
class Backend:
def __init__(self, backend, dtype) -> None:
assert (
backend == 'hidet' or backend == 'max-autotune' or backend == 'max-autotune-no-cudagraphs'
), 'backend is hidet or max-autotune or max-autotune-no-cudagraphs supported only'
backend == 'hidet'
or backend == 'max-autotune'
or backend == 'max-autotune-no-cudagraphs'
or backend == 'eager'
), 'backend is hidet or max-autotune or max-autotune-no-cudagraphs or eager supported only'
self.backend = backend
self.dtype = dtype
if self.backend == 'hidet':
Expand Down Expand Up @@ -41,6 +44,8 @@ def compile(self, model):

if self.backend == 'hidet':
model = torch.compile(model, backend=self.backend)
elif self.backend == 'eager':
pass
else:
model = torch.compile(model, mode=self.backend)
return model
Expand Down

0 comments on commit a9d4aa8

Please sign in to comment.