Skip to content

Commit

Permalink
[Dynamo] Small enchancements for graph dump ir and task arguments (hi…
Browse files Browse the repository at this point in the history
…det-org#172)

* exp, float

* wip

* chunk, groupnorm, softmax, baddbmm, emmpty

* add interpolate, lint and format

* revert changes of import hidet at top level to minimize changes for PR

* typo

* trigger actions

* trigger actions

* dummy commit

* dummy commit

* add some optimizations to skip certain operations based on alpha beta

* add group norm test

* format

* introduce a fix to torch.compile not dumping graph IR

* Revert "introduce a fix to torch.compile not dumping graph IR"

This reverts commit a1e8df0.

* add interlolate test and group norm test

* accidental push

* remove a random newline added

* minor hot fix

* add a function level static var for multi-graph save ir

* small line change

* revert previous changes

* move save dir to new utility function

* format lint and remove optional attr

* trigger actions

* trigger actions

* Update operator.py

---------

Co-authored-by: Xin Li <xin@centml.ai>
Co-authored-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
  • Loading branch information
3 people authored and AndreSlavescu committed Apr 25, 2023
1 parent 8cf28db commit 107582a
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 4 deletions.
5 changes: 3 additions & 2 deletions python/hidet/graph/frontend/torch/dynamo_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from hidet.ir.type import data_type
from hidet.graph.ir.flow_graph import FlowGraph
from hidet.graph.transforms import PassContext, optimize
from .utils import serialize_output, deserialize_output
from .utils import serialize_output, deserialize_output, resolve_save_dir_multigraph
from .dynamo_config import dynamo_config


Expand All @@ -43,7 +43,8 @@ def generate_executor(flow_graph: FlowGraph) -> Callable:
ctx.set_reduce_precision('float16')
ctx.set_use_attention(use_attention)
if save_dir:
ctx.save_graph_instrument(save_dir)
graph_dir = resolve_save_dir_multigraph(save_dir)
ctx.save_graph_instrument(graph_dir)
if tensor_core:
ctx.set_mma('mma' if tensor_core else 'simt')
ctx.set_parallel_k(disabled=(parallel_k == 'disabled'), search=(parallel_k == 'search'))
Expand Down
9 changes: 9 additions & 0 deletions python/hidet/graph/frontend/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple, Any, List, Union, Dict, Optional
from pathlib import Path
from hidet.graph.tensor import Tensor
from hidet.ir.type import DataType
from hidet.ir import dtypes
Expand Down Expand Up @@ -245,3 +246,11 @@ def relative_absolute_error(actual, expected) -> float:
actual: torch.Tensor = actual.detach()
expected: torch.Tensor = expected.detach()
return float(torch.max(torch.abs(actual - expected) / (torch.abs(expected) + 1.0)))


def resolve_save_dir_multigraph(save_dir: str) -> str:
func = resolve_save_dir_multigraph
if not hasattr(func, 'counter'):
func.counter = {}
func.counter[save_dir] = func.counter.get(save_dir, 0) + 1
return str(Path(save_dir) / "graph_{}".format(func.counter[save_dir]))
4 changes: 2 additions & 2 deletions python/hidet/graph/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, inputs: List[Tensor], attributes: Dict[str, Any], task: Optio

self.name: str = get_operator_name(self)
self.inputs: List[Tensor] = inputs
self.attrs: Dict[str, Any] = attributes if attributes is not None else {}
self.attrs: Dict[str, Any] = attributes
self.task: Optional[Task] = task.specialize_for(self.inputs)
self.outputs: List[Tensor] = []

Expand Down Expand Up @@ -126,7 +126,7 @@ def imperative_run(self, inputs: List[Tensor]) -> List[Tensor]:
for a, b in zip(self.task.outputs, outputs):
arg_remap[a] = b

args = [remap[param] for param in self.task.params]
args = [arg_remap[param] for param in self.task.params]
self.task_func(*args)

status = get_last_error()
Expand Down

0 comments on commit 107582a

Please sign in to comment.