Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slices with zero values at the begining and end of tensor are lost during dense -> sparse -> dense conversion #316

Closed
AlexeyGB opened this issue Feb 16, 2021 · 1 comment

Comments

@AlexeyGB
Copy link
Contributor

Describe the bug
I have a 2D dense tensor with zeros in first and last columns. I convert it to sparse tensor as said in docs and then convert it back to dense tensor. The shape of the resulting dense tensor is different.
If I give the correct shape to the dense() method of sparse tensor, the resulting dense tensor's shape is now correct, but the values are shifted to the left.
I noticed the same behaviour for tensors of larger dimensions with zero values at the beginning or/and at the end of some axis.

To Reproduce

def to_sparse_coo(data):
    # An intuitive way to extract coordinates and features
    coords, feats = [], []
    for i, row in enumerate(data):
        for j, val in enumerate(row):
            if val != 0:
                coords.append([i, j])
                feats.append([val])
    return torch.IntTensor(coords), torch.FloatTensor(feats)

data_batch_0 = [
    [0, 0, 2.1, 0, 0],
    [0, 1, 1.4, 3, 0],
    [0, 0, 4.0, 0, 0]
]

coords0, feats0 = to_sparse_coo(data_batch_0)
coords0, feats0 = ME.utils.sparse_collate(coords=[coords0], feats=[feats0])

A = ME.SparseTensor(coordinates=coords0, features=feats0)
A.dense()[0].shape

Returns:

torch.Size([1, 1, 3, 3])

One more attempt:

data_batch_0 = [
    [0, 0, 2.1, 0, 0],
    [0, 1, 1.4, 3, 0],
    [0, 0, 4.0, 0, 0]
]

coords0, feats0 = to_sparse_coo(data_batch_0)
coords0, feats0 = ME.utils.sparse_collate(coords=[coords0], feats=[feats0])

A = ME.SparseTensor(coordinates=coords0, features=feats0)
data_batch_restored = A.dense(torch.Size([1, 1, 3, 5]))[0]

print(data_batch_restored.shape)
data_batch_restored

Returns:

torch.Size([1, 1, 3, 5])
tensor([[[[0.0000, 2.1000, 0.0000, 0.0000, 0.0000],
          [1.0000, 1.4000, 3.0000, 0.0000, 0.0000],
          [0.0000, 4.0000, 0.0000, 0.0000, 0.0000]]]])

Steps to reproduce the behavior.

  • a minimally reproducible code. If the code is not attached and cannot be reproduced easily, the bug report will be closed without any comments.

Expected behavior
I expect to obtain the same dense tensor after conversion.

Desktop (please complete the following information):

  • OS: Ubuntu 18.04.4 LTS
  • Python version: 3.8.1
  • CUDA version: 10.2
  • NVIDIA Driver version: 440.64
  • Minkowski Engine version: 0.5.1
  • Output of the following command. (If you installed the latest MinkowskiEngine, simply call MinkowskiEngine.print_diagnostics())
wget -q https://raw.githubusercontent.com/NVIDIA/MinkowskiEngine/master/MinkowskiEngine/diagnostics.py ; python diagnostics.py
==========System==========
Linux-5.3.0-28-generic-x86_64-with-glibc2.10
DISTRIB_ID=Ubuntu
DISTRIB_RELEASE=18.04
DISTRIB_CODENAME=bionic
DISTRIB_DESCRIPTION="Ubuntu 18.04.4 LTS"
3.8.1 (default, Jan  8 2020, 22:29:32) 
[GCC 7.3.0]
==========Pytorch==========
1.7.1
torch.cuda.is_available(): True
==========NVIDIA-SMI==========
/usr/bin/nvidia-smi
Driver Version 440.64
CUDA Version 10.2
VBIOS Version 90.02.2E.00.0C
Image Version G001.0000.02.04
==========NVCC==========
/usr/local/cuda/bin/nvcc
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2019 NVIDIA Corporation
Built on Wed_Oct_23_19:24:38_PDT_2019
Cuda compilation tools, release 10.2, V10.2.89
==========CC==========
/usr/bin/c++
c++ (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Copyright (C) 2017 Free Software Foundation, Inc.
This is free software; see the source for copying conditions.  There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.

==========MinkowskiEngine==========
0.5.1
MinkowskiEngine compiled with CUDA Support: True
NVCC version MinkowskiEngine is compiled: 10020
CUDART version MinkowskiEngine is compiled: 10020

Additional context

@chrischoy
Copy link
Contributor

This is an intended behavior in the MinkowskiEngine, but since this is not obvious and this issue is recurring, I changed the behavior from the commit aba0db2.

To give you some background, the MinkowskiEngine sparse tensors use a generalized Sparse Tensor which allows a sparse tensor to have negative coordinates.

Such negative coordinates cannot be converted to a dense tensor without translating the origin to the minimum coordinate. Thus, in the previous versions, we subtract the min_coordinate from the sparse tensor coordinates automatically. but now this behavior will not be the default option. Rather, if there is a negative coordinate, it will raise a ValueError and you must manually provide min_coordinate to resolve this issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants