Skip to content

[PyTorch] Add aten::embedding_bag#12027

Closed
shingjan wants to merge 5 commits intoapache:mainfrom
shingjan:torch_embedding_bag
Closed

[PyTorch] Add aten::embedding_bag#12027
shingjan wants to merge 5 commits intoapache:mainfrom
shingjan:torch_embedding_bag

Conversation

@shingjan
Copy link

@shingjan shingjan commented Jul 7, 2022

This PR intends to add aten::embedding_bag for the pytorch frontend. Note that the implementation of aten::numel is also changed under the condition that any input that can be evaluated to a constant value will be evaluated at compile time.

Co-authored-by: Masahiro Masuda masahi@129@gmail.com

cc: @masahi @vinx13

@shingjan
Copy link
Author

shingjan commented Jul 7, 2022

@masahi This implementation still failed the test case with 1D input as the offsets are parsed into:

free_var %input2: Tensor[(2), int64];
%input2

which cannot be folded to constant at compile time even this 1D case is actually identical to 2D case:

    # 1D case
    input = torch.tensor([2, 2, 2, 2, 4, 3, 2, 9])
    offsets = torch.tensor([0, 4])
    # which is equivalent to  
    input = torch.tensor([[2, 2, 2, 2], [4, 3, 2, 9]])

My take is that we can support 2D case here and see if there could be any workaround for the 1D cases like above. Note that sparse and per_sample_weights are also taken care of.

@shingjan shingjan force-pushed the torch_embedding_bag branch from d9690c3 to 20eddfd Compare July 19, 2022 18:55
@shingjan
Copy link
Author

shingjan commented Oct 5, 2022

closed with #12993

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant