Skip to content

fx.export_and_import hangs #4157

Open
Open
@justin-ngo-arm

Description

@justin-ngo-arm

I have a simple program:

class Conv2D(torch.nn.Module):

    def __init__(
        self,
        kernel_size=3,
        in_channels=8,
        out_channels=16,
        stride=1,
        padding=0,
        dilation=1,
        bias=True,
    ):
        super().__init__()
        self.conv = torch.nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            bias=bias,
        )

    def forward(self, x):
        return self.conv(x)

if __name__ == "__main__":
    model = Conv2D(
        kernel_size=(3, 3),
        in_channels=3,
        out_channels=8,
        stride=(1, 2),
        padding=(1, 1),
        dilation=(1, 1),
        bias=False,
    )
    model.eval()  # Set to evaluation mode
    example_input = torch.randn(2, 3, 5, 32, requires_grad=True, device="cpu")
    prog = torch.export.export(model, (example_input,))
    torch_module = fx.export_and_import(
        prog,
        func_name="temp",
        enable_graph_printing=False,
        import_symbolic_shape_expressions=True,
    )
    print(torch_module)

When I run it, it can generate the Torch-MLIR module like I want. However, when the program finished, it didn't exit cleanly but rather just hung there. I had to Ctrl+C to exit. I found the same thing happens to one of the examples - projects/pt1/examples/fximporter_resnet18.py (I've not checked other examples).
I've tried running my program with a debugger, and it looks like at the end, some Python internal cleaning processes got stuck in a loop or something like that. I'm not entirely sure what causes that.

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions