In [None]:
#| hide
%load_ext autoreload
%autoreload 2

# models

> Module that implements different models processing point cloud data.

# UNDER CONSTRUCTION...

In [None]:
#| default_exp models

In [None]:
#| export
import torch
from torch.nn import Module, Sequential, Conv2d, BatchNorm2d, ReLU, ModuleList, MaxPool2d, ConvTranspose2d

In [None]:
#| export
class Block(Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.net = Sequential(
            Conv2d(in_channels, out_channels, 3, 1, 1, bias=False, padding_mode='circular'), 
            BatchNorm2d(out_channels), 
            ReLU(),
            Conv2d(out_channels, out_channels, 3, 1, 1, bias=False, padding_mode='circular'), 
            BatchNorm2d(out_channels), 
            ReLU(),
        )
    
    def forward(self, x):
        return self.net(x)

In [None]:
#| eval: false
from colorcloud.datatools import SemanticKITTIDataset, SphericalProjectionTransform, ToTensorTransform
from torchvision.transforms import v2

In [None]:
#| eval: false
data_path = '../../Cloud2DImageConverter/point_clouds/semantic_kitti/'
ds = SemanticKITTIDataset(data_path)

tfms = v2.Compose([
    SphericalProjectionTransform(fov_up_deg=4., fov_down_deg=-26., W=1024, H=64),
    ToTensorTransform(),
])
ds.set_transform(tfms)
img, label, mask = ds[0]

b = Block(5, 64)
activations = b(img.reshape(-1, *img.shape))
activations.shape

torch.Size([1, 64, 64, 1024])

In [None]:
#| export
class Encoder(Module):
    def __init__(self, channels=(5, 64, 128, 256, 512, 1024)):
        super().__init__()
        self.blocks = ModuleList(
            [Block(channels[i], channels[i+1]) for i in range(len(channels)-1)]
        )
        self.pool = MaxPool2d(2)
    
    def forward(self, x):
        enc_features = []
        for block in self.blocks:
            x = block(x)
            enc_features.append(x)
            x = self.pool(x)
        return enc_features

In [None]:
#| eval: false
enc = Encoder()
activations = enc(img.reshape(-1, *img.shape))
[a.shape for a in activations]

[torch.Size([1, 64, 64, 1024]),
 torch.Size([1, 128, 32, 512]),
 torch.Size([1, 256, 16, 256]),
 torch.Size([1, 512, 8, 128]),
 torch.Size([1, 1024, 4, 64])]

In [None]:
#| export
class Decoder(Module):
    def __init__(self, channels=(1024, 512, 256, 128, 64)):
        super().__init__()
        self.channels = channels
        self.upconvs = ModuleList(
            [ConvTranspose2d(channels[i], channels[i+1], 2, 2) for i in range(len(channels)-1)]
        )
        self.blocks = ModuleList(
            [Block(channels[i], channels[i+1]) for i in range(len(channels)-1)]
        )
    
    def forward(self, enc_features):
        x = enc_features[-1]
        for i in range(len(self.channels)-1):
            x = self.upconvs[i](x)
            x = torch.cat([x, enc_features[-(i+2)]], dim=1)
            x = self.blocks[i](x)
        return x

In [None]:
#| eval: false
dec = Decoder()
activations = dec(activations)
activations.shape

torch.Size([1, 64, 64, 1024])

In [None]:
#| export
class UNet(Module):
    def __init__(self, in_channels=5, hidden_channels=(64, 128, 256, 512, 1024), n_classes=20):
        super().__init__()
        self.backbone = Sequential(
            Encoder((in_channels, *hidden_channels)),
            Decoder(hidden_channels[::-1])
        )
        self.head = Conv2d(hidden_channels[0], n_classes, 1)
    
    def forward(self, x):
        features = self.backbone(x)
        prediction = self.head(features)
        
        return prediction

In [None]:
#| eval: false
model = UNet()
logits = model(img.reshape(-1, *img.shape))
logits.shape

torch.Size([1, 20, 64, 1024])

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()