In [90]:
import argparse
import random
from typing import Any, Dict

import numpy as np
import torch
import torch.utils.data
from torch import nn
from torch.cuda import amp

from typing import List
import torchsparse
from torchsparse import SparseTensor
from torchsparse import nn as spnn
from torchsparse.nn import functional as F
from torchsparse.utils.collate import sparse_collate_fn
from torchsparse.utils.quantize import sparse_quantize
from typing import List, Tuple, Union

import numpy as np
from torch import nn

from torchsparse import SparseTensor
from torchsparse import nn as spnn

In [91]:
class RandomDataset:
    def __init__(self, input_size: int, voxel_size: float) -> None:
        self.input_size = input_size
        self.voxel_size = voxel_size

    def __getitem__(self, _: int) -> Dict[str, Any]:
        inputs = np.random.uniform(-100, 100, size=(self.input_size, 4))
        labels = np.random.choice(10, size=self.input_size)

        coords, feats = inputs[:, :3], inputs
        coords -= np.min(coords, axis=0, keepdims=True)
        coords, indices = sparse_quantize(coords, self.voxel_size, return_index=True)

        coords = torch.tensor(coords, dtype=torch.int)
        feats = torch.tensor(feats[indices], dtype=torch.float)
        labels = torch.tensor(labels[indices], dtype=torch.long)

        input = SparseTensor(coords=coords, feats=feats)
        label = SparseTensor(coords=coords, feats=labels)
        return {"input": input, "label": label}

    def __len__(self):
        return 100


In [92]:
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)


<torch._C.Generator at 0x7e8dbc1402f0>

In [93]:
dataset = RandomDataset(input_size=10000, voxel_size=0.2)
dataflow = torch.utils.data.DataLoader(
        dataset,
        batch_size=2,
        collate_fn=sparse_collate_fn,
    )


In [94]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [95]:
class SparseConvBlock(nn.Sequential):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: Union[int, List[int], Tuple[int, ...]],
        stride: Union[int, List[int], Tuple[int, ...]] = 1,
        dilation: int = 1,
    ) -> None:
        super().__init__(
            spnn.Conv3d(
                in_channels, out_channels, kernel_size, stride=stride, dilation=dilation
            ),
            spnn.BatchNorm(out_channels),
            spnn.ReLU(True),
        )


In [96]:
class SparseConvTransposeBlock(nn.Sequential):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: Union[int, List[int], Tuple[int, ...]],
        stride: Union[int, List[int], Tuple[int, ...]] = 1,
        dilation: int = 1,
    ) -> None:
        super().__init__(
            spnn.Conv3d(
                in_channels,
                out_channels,
                kernel_size,
                stride=stride,
                dilation=dilation,
                transposed=True,
            ),
            spnn.BatchNorm(out_channels),
            spnn.ReLU(True),
        )


In [97]:
class SparseResBlock(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: Union[int, List[int], Tuple[int, ...]],
        stride: Union[int, List[int], Tuple[int, ...]] = 1,
        dilation: int = 1,
    ) -> None:
        super().__init__()
        self.main = nn.Sequential(
            spnn.Conv3d(
                in_channels, out_channels, kernel_size, dilation=dilation, stride=stride
            ),
            spnn.BatchNorm(out_channels),
            spnn.ReLU(True),
            spnn.Conv3d(out_channels, out_channels, kernel_size, dilation=dilation),
            spnn.BatchNorm(out_channels),
        )

        if in_channels != out_channels or np.prod(stride) != 1:
            self.shortcut = nn.Sequential(
                spnn.Conv3d(in_channels, out_channels, 1, stride=stride),
                spnn.BatchNorm(out_channels),
            )
        else:
            self.shortcut = nn.Identity()

        self.relu = spnn.ReLU(True)

    def forward(self, x: SparseTensor) -> SparseTensor:
        x = self.relu(self.main(x) + self.shortcut(x))
        return x

In [102]:
class SparseResUNet(nn.Module):
    def __init__(
        self,
        stem_channels: int,
        encoder_channels: List[int],
        decoder_channels: List[int],
        *,
        in_channels: int = 4,
        width_multiplier: float = 1.0,
    ) -> None:
        super().__init__()
        self.stem_channels = stem_channels
        self.encoder_channels = encoder_channels
        self.decoder_channels = decoder_channels
        self.in_channels = in_channels
        self.width_multiplier = width_multiplier

        num_channels = [stem_channels] + encoder_channels + decoder_channels
        num_channels = [int(width_multiplier * nc) for nc in num_channels]

        self.stem = nn.Sequential(
            spnn.Conv3d(in_channels, num_channels[0], 3),
            spnn.BatchNorm(num_channels[0]),
            spnn.ReLU(True),
            spnn.Conv3d(num_channels[0], num_channels[0], 3),
            spnn.BatchNorm(num_channels[0]),
            spnn.ReLU(True),
        )

        # TODO(Zhijian): the current implementation of encoder and decoder
        # is hard-coded for 4 encoder stages and 4 decoder stages. We should
        # work on a more generic implementation in the future.

        self.encoders = nn.ModuleList()
        for k in range(4):
            self.encoders.append(
                nn.Sequential(
                    SparseConvBlock(
                        num_channels[k],
                        num_channels[k],
                        2,
                        stride=2,
                    ),
                    SparseResBlock(num_channels[k], num_channels[k + 1], 3),
                    SparseResBlock(num_channels[k + 1], num_channels[k + 1], 3),
                )
            )

        self.decoders = nn.ModuleList()
        for k in range(4):
            self.decoders.append(
                nn.ModuleDict(
                    {
                        "upsample": SparseConvTransposeBlock(
                            num_channels[k + 4],
                            num_channels[k + 5],
                            2,
                            stride=2,
                        ),
                        "fuse": nn.Sequential(
                            SparseResBlock(
                                num_channels[k + 5] + num_channels[3 - k],
                                num_channels[k + 5],
                                3,
                            ),
                            SparseResBlock(
                                num_channels[k + 5],
                                num_channels[k + 5],
                                3,
                            ),
                        ),
                    }
                )
            )

    def _unet_forward(
        self,
        x: SparseTensor,
        encoders: nn.ModuleList,
        decoders: nn.ModuleList,
    ) -> List[SparseTensor]:
        if not encoders and not decoders:
            return [x]

        # downsample
        xd = encoders[0](x)

        # inner recursion
        outputs = self._unet_forward(xd, encoders[1:], decoders[:-1])
        yd = outputs[-1]

        # upsample and fuse
        u = decoders[-1]["upsample"](yd)
        y = decoders[-1]["fuse"](torchsparse.cat([u, x]))

        return [x] + outputs + [y]

    def forward(self, x: SparseTensor) -> List[SparseTensor]:
        return self._unet_forward(self.stem(x), self.encoders, self.decoders)


class SparseResUNet42(SparseResUNet):
    def __init__(self, **kwargs) -> None:
        super().__init__(
            stem_channels=32,
            encoder_channels=[32, 64, 128, 256],
            decoder_channels=[256, 128, 96, 96],
            **kwargs,
        )

In [103]:
# Create the model
model = SparseResUNet42()

In [106]:
model.to(device)

SparseResUNet42(
  (stem): Sequential(
    (0): Conv3d(4, 32, kernel_size=(3, 3, 3), bias=False)
    (1): BatchNorm(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv3d(32, 32, kernel_size=(3, 3, 3), bias=False)
    (4): BatchNorm(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
  )
  (encoders): ModuleList(
    (0): Sequential(
      (0): SparseConvBlock(
        (0): Conv3d(32, 32, kernel_size=(2, 2, 2), stride=(2, 2, 2), bias=False)
        (1): BatchNorm(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
      )
      (1): SparseResBlock(
        (main): Sequential(
          (0): Conv3d(32, 32, kernel_size=(3, 3, 3), bias=False)
          (1): BatchNorm(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv3d(32, 32, kernel_size=(3, 3, 3), bias=False)
          (

In [107]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)


In [110]:
for k, feed_dict in enumerate(dataflow):
        inputs = feed_dict["input"].to(device=device)
        labels = feed_dict["label"].to(device=device)
        print("inputs coords = ", inputs.coords.shape, "inputs feats =", inputs.feats.shape)
        print("labels coords = ", labels.coords.shape,  "labels feats=", labels.feats.shape)

       
        outputs = model(inputs)
        print("outputs =", outputs[0].coords.shape, outputs[0].featss)
            

       


inputs coords =  torch.Size([20000, 4]) inputs feats = torch.Size([20000, 4])
labels coords =  torch.Size([20000, 4]) labels feats= torch.Size([20000])
outputs = 9
inputs coords =  torch.Size([19999, 4]) inputs feats = torch.Size([19999, 4])
labels coords =  torch.Size([19999, 4]) labels feats= torch.Size([19999])
outputs = 9
inputs coords =  torch.Size([20000, 4]) inputs feats = torch.Size([20000, 4])
labels coords =  torch.Size([20000, 4]) labels feats= torch.Size([20000])
outputs = 9
inputs coords =  torch.Size([19999, 4]) inputs feats = torch.Size([19999, 4])
labels coords =  torch.Size([19999, 4]) labels feats= torch.Size([19999])
outputs = 9
inputs coords =  torch.Size([20000, 4]) inputs feats = torch.Size([20000, 4])
labels coords =  torch.Size([20000, 4]) labels feats= torch.Size([20000])
outputs = 9
inputs coords =  torch.Size([20000, 4]) inputs feats = torch.Size([20000, 4])
labels coords =  torch.Size([20000, 4]) labels feats= torch.Size([20000])
outputs = 9
inputs coords = 