Skip to content

Commit

Permalink
fix(xla): fix indexing problem when partial trace
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 08c9e3387d0974facd7297f1e475bec22cd66069
  • Loading branch information
megvii-mge committed Dec 1, 2023
1 parent 75c4d9c commit 88b3b80
Show file tree
Hide file tree
Showing 6 changed files with 307 additions and 273 deletions.
4 changes: 3 additions & 1 deletion imperative/python/megengine/jit/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,9 @@ def tensor_wrapper_resethook(obj, other):
# process the returned value of traced function
outlist, self.outdef = tree_flatten(outputs)
for i, out in enumerate(outlist):
assert isinstance(out, RawTensor), f"get out of type {type(out)}"
assert isinstance(
out, RawTensor
), f"return value of traced function must be tensor, get {type(out)}"
outlist[i] = get_marked_output_tensor(self.output_num, out)
del out
self.out_list.append(self.output_num)
Expand Down
4 changes: 3 additions & 1 deletion imperative/python/megengine/traced_module/pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,9 @@ def qparams_unflatten(qparam_type, inp, aux_data):
dict, partial(_dict_flatten, False), partial(_dict_unflatten, dict)
)
_register_supported_type(
defaultdict, partial(_dict_flatten, False), partial(_dict_unflatten, defaultdict)
defaultdict,
partial(_dict_flatten, False),
partial(_dict_unflatten, partial(defaultdict, None)),
)
_register_supported_type(
OrderedDict, partial(_dict_flatten, True), partial(_dict_unflatten, OrderedDict)
Expand Down
Loading

0 comments on commit 88b3b80

Please sign in to comment.