Skip to content

Unpack tuple/list input#4376

Closed
csukuangfj wants to merge 4 commits intoTencent:masterfrom
csukuangfj:fuse-input-unpack-1128
Closed

Unpack tuple/list input#4376
csukuangfj wants to merge 4 commits intoTencent:masterfrom
csukuangfj:fuse-input-unpack-1128

Conversation

@csukuangfj
Copy link
Contributor

For some networks, e.g., RNN/LSTM, we need to provide some previous states as input.

Example code

#!/usr/bin/env python3

import torch
import torch.nn as nn
from typing import List


class Foo(nn.Module):
    def forward(self, x, states: List[torch.Tensor]):
        return x, states[0] + states[1]


def main():
    f = Foo()

    x = torch.rand(2, 3, 4)
    y = [x, x]
    m = torch.jit.trace(f, (x, y))
    print(m.graph)
    m.save("m.pt")


if __name__ == "__main__":
    torch.manual_seed(20221128)
    main()

@nihui nihui self-requested a review November 29, 2022 08:08
@nihui
Copy link
Member

nihui commented Jan 4, 2023

we need level0 pass otherwise shape inference will fail

@csukuangfj
Copy link
Contributor Author

we need level0 pass otherwise shape inference will fail

Do you mean we should move it to pass_level0 ?

@nihui
Copy link
Member

nihui commented Jan 4, 2023

we need level0 pass otherwise shape inference will fail

Do you mean we should move it to pass_level0 ?

yeah, implement flatten inputs on torch graph, and apply the pass before shape inference

reference code
https://github.com/pytorch/pytorch/blob/3120054c151d37fdf43963dbb60ab420908f48cf/torch/csrc/jit/passes/lower_tuples.cpp#L185

@nihui nihui mentioned this pull request Jan 29, 2023
14 tasks
@nihui
Copy link
Member

nihui commented Jan 31, 2023

move to #4498

@nihui nihui closed this Jan 31, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants