Skip to content

[Bug] BufferError: Can't export tensors with layout other than torch.strided when model contains sparse tensors #18474

@LiSsHhUuAaIi

Description

@LiSsHhUuAaIi

Description

When converting a PyTorch model containing sparse tensors (torch.sparse_coo_tensor) to TVM Relax module via torch.export, a BufferError occurs during the DLPack conversion. TVM's DLPack implementation only supports strided tensor layouts, but PyTorch sparse tensors use different memory layouts.

Expected behavior

The PyTorch model with sparse tensors should be successfully converted to TVM Relax module, or TVM should provide a clear error message and documentation about sparse tensor limitations.

Actual behavior

A BufferError occurs during from_exported_program conversion with the message Can't export tensors with layout other than torch.strided, indicating that TVM cannot handle non-strided tensor layouts like those used by sparse tensors.

Environment

  • OS: Ubuntu 20.04.6 LTS
  • TVM version: 0.23.dev0
  • Python version: 3.11.14

Steps to reproduce

import torch
import torch.nn as nn
import tvm
from tvm import relax

class TestModel(nn.Module):

    def __init__(self):
        super().__init__()
        i = torch.tensor([[0, 2, 1], [1, 3, 2]])
        v = torch.tensor([3, 4, 5], dtype=torch.float32)
        out = torch.sparse_coo_tensor(i, v, [2, 4])
        self.conv = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.out = out

    def forward(self, x):
        return x.relu() + self.out.sum()

model = TestModel()
model.eval()

x = torch.randn(3, 3)

# PyTorch execution works
with torch.no_grad():
    output = model(x)

# PyTorch export works  
exported_program = torch.export.export(model, (x,))

# TVM conversion fails
from tvm.relax.frontend.torch import from_exported_program
mod = from_exported_program(exported_program)  # BufferError here

Error Log

Traceback (most recent call last):
  File "test.py", line 33, in <module>
    mod = from_exported_program(exported_program)  # BufferError here
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  ...
  File "python/tvm_ffi/cython/tensor.pxi", line 189, in core.from_dlpack
  File "python/tvm_ffi/cython/tensor.pxi", line 114, in core._from_dlpack_universal
  File "/home/miniconda3/envs/tvm-build/lib/python3.11/site-packages/torch/_tensor.py", line 1718, in __dlpack__
    raise BufferError(
BufferError: Can't export tensors with layout other than torch.strided

Triage

  • needs-triage
  • bug
  • frontend: pytorch

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions