-
Notifications
You must be signed in to change notification settings - Fork 4
Description
Hi, I'm working on using torch_np to connect numpy to torch dynamo. One issue we are seeing is that all of the torch_np APIs are returning torch_np.ndarray. This means we have torch_np.ndarray in our captured fx graph. This is giving our downstream backends a hard time because they don't recognize torch_np.ndarray (and maybe shouldn't).
So I'm wondering if we can have a "mode" type of flip switch, to allow all the torch_np APIs to return tensor instead of torch_np.ndarray. Is this easy to do?
For an example of the captured graph:
def fn(x, y):
a = x.numpy()
b = y.numpy()
return np.add(a, 1), np.add(b, 1)
If we capture it through dynamo, it gives
[2023-05-15 11:40:25,582] torch._dynamo.output_graph.__graph: [DEBUG] TRACED GRAPH
__compiled_fn_0 <eval_with_key>.0 opcode name target args kwargs
------------- --------- ----------------------------------- --------------- --------
placeholder l_x_ L_x_ () {}
placeholder l_y_ L_y_ () {}
call_function ndarray <class 'torch_np._ndarray.ndarray'> (l_x_,) {}
call_function ndarray_1 <class 'torch_np._ndarray.ndarray'> (l_y_,) {}
call_function add <function add at 0x12de77a60> (ndarray, 1) {}
call_function add_1 <function add at 0x12de77a60> (ndarray_1, 1) {}
output output output ((add, add_1),) {}
Notice both a
and b
are being converted to torch_np.ndarray
and their result after add
is still an torch_np.ndarray
. As a result, the outputs of the graph are also torch_np.ndarray
. The backends to torch dynamo can't work with these graphs.