In [2]:
import os
import sys
from pathlib import Path

import pandas as pd
import numpy as np

import torch as T
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt
import seaborn as sns

%matplotlib inline
%config InlineBackend.figure_format = 'retina'
sns.set()

In [74]:
def make_sparse(x):
    nz = np.where(np.logical_not(np.isclose(x.numpy(), 0)))
    sparse_x = T.sparse_coo_tensor(nz, x[nz], size=x.size(), dtype=T.float32)
    
    return sparse_x


class SparseTensorDataset(Dataset):
    def __init__(self, X, Y):
        assert len(X) == len(Y)
        self.X = X
        self.Y = Y
        
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        x = self.X[idx]
        y = self.Y[idx]
        
        sparse_x = make_sparse(x)
        
        return (sparse_x, y)

In [45]:
class Net(nn.Module):
    def __init__(self, input_size, output_size):
        super(Net, self).__init__()
        
        self.seq = nn.Sequential(
            nn.Linear(input_size, input_size),
            nn.ReLU(inplace=True),
            nn.Linear(input_size, output_size),
        )
    
    def forward(self, sx):
        return T.sigmoid(self.seq(sx))

### 1. Using `SparseTensor` out-of-the-box

This will raise:
```
NotImplementedError: Cannot access storage of SparseTensorImpl
```

In [4]:
N = 10
M = 32

X = T.rand(N, M)
Y = T.randint(0, 2, size=(N,))

ds = SparseTensorDataset(X, Y)
dl = DataLoader(ds, batch_size=4, shuffle=False, pin_memory=True, num_workers=2)

for batch_idx, (sx, y) in enumerate(dl):
    print(sx, y)
    break

NotImplementedError: Caught NotImplementedError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/alex/miniconda3/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/alex/miniconda3/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
    return self.collate_fn(data)
  File "/home/alex/miniconda3/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 84, in default_collate
    return [default_collate(samples) for samples in transposed]
  File "/home/alex/miniconda3/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 84, in <listcomp>
    return [default_collate(samples) for samples in transposed]
  File "/home/alex/miniconda3/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 54, in default_collate
    storage = elem.storage()._new_shared(numel)
NotImplementedError: Cannot access storage of SparseTensorImpl


### 2. Using `SparseTensor` without multiprocessing and memory pinning

In [60]:
N = 1024 * 20
M = 8

X = T.rand(N, M) * T.randint(0, 2, size=(N, M))
Y = (X.sum(dim=-1) < 2).to(T.float32)

ds = SparseTensorDataset(X, Y)
dl = DataLoader(ds, batch_size=256, shuffle=False, pin_memory=False, num_workers=0)

net = Net(input_size=M, output_size=1)
loss_crit = nn.BCELoss()
optimizer = T.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=0.001)
scheduler = T.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)
avg_loss = 0

net.cuda()

for epoch_idx in range(1, 60 + 1):
    
    for batch_idx, (sx, y) in enumerate(dl, start=1):
        sx, y = sx.cuda(), y.cuda()
        
        # print(f"{sx=}")
        # print('-' * 128)
        # print(f"{y=}")

        # print('-' * 128)

        y_pred = net(sx)
        # print(f'{y_pred=}')

        # print('-' * 128)

        optimizer.zero_grad()
        loss = loss_crit(y_pred.to(T.float32), y.unsqueeze(1).to(T.float32))
        loss.backward()
        optimizer.step()
        
        avg_loss += loss.item() / len(dl)
    # ---

    print(f'{epoch_idx=:2d} | lr={scheduler.get_last_lr()} | avg_loss={loss.item():.5f}', end='\r')

    scheduler.step()
    avg_loss = 0

epoch_idx=60 | lr=[0.0005] | avg_loss=0.04111

In [106]:
x = T.rand(1, M) * T.randint(0, 2, size=(1, M))
true = (x.sum(dim=-1) < 2).item()

net(x.cuda()), net(make_sparse(x).cuda()), true

(tensor([[0.9930]], device='cuda:0', grad_fn=<SigmoidBackward>),
 tensor([[0.9930]], device='cuda:0', grad_fn=<SigmoidBackward>),
 True)