Skip to content

[Feature Request] Support for Spatial Transformer Network operations in PyTorch frontend #18475

@LiSsHhUuAaIi

Description

@LiSsHhUuAaIi

Description

When converting a PyTorch model containing Spatial Transformer Network (STN) operations to TVM Relax module via torch.export, an AssertionError occurs. TVM currently does not support the affine_grid_generator.default and grid_sampler.default operations that are essential for STNs.

Expected behavior

The PyTorch model with STN operations should be successfully converted to TVM Relax module, enabling deployment of spatial transformation models on various hardware targets.

Actual behavior

An AssertionError occurs during from_exported_program conversion with the message Unsupported function types ['affine_grid_generator.default', 'grid_sampler.default'], indicating that TVM's PyTorch frontend lacks support for these spatial transformation operations.

AssertionError: Unsupported function types ['affine_grid_generator.default', 'grid_sampler.default']

Environment

  • OS: Ubuntu 20.04.6 LTS
  • TVM version: 0.23.dev0
  • Python version: 3.11.14

Steps to reproduce

import torch
import torch.nn as nn
import torch.nn.functional as F
import tvm
from tvm import relax

class MinimalSTNModel(nn.Module):
    def __init__(self):
        super(MinimalSTNModel, self).__init__()
        self.localizer = nn.Sequential(
            nn.Conv2d(3, 8, kernel_size=3),
            nn.ReLU(True),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(8, 6)
        )

    def forward(self, x):
        theta = self.localizer(x)
        theta = theta.view(-1, 2, 3)
        
        # Unsupported operations
        grid = F.affine_grid(theta, x.size())  # affine_grid_generator.default
        x = F.grid_sample(x, grid)             # grid_sampler.default
        
        return x

model = MinimalSTNModel()
model.eval()

x = torch.randn(1, 3, 32, 32)

# PyTorch execution works
with torch.no_grad():
    output = model(x)

# PyTorch export works  
exported_program = torch.export.export(model, (x,))

# TVM conversion fails
from tvm.relax.frontend.torch import from_exported_program
mod = from_exported_program(exported_program)  # AssertionError here

Triage

  • needs-triage

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions