In [None]:
%load_ext autoreload
%autoreload 2

# models

> Module that implements different models processing point cloud data.

# UNDER CONSTRUCTION...

In [None]:
#| default_exp models

In [None]:
#| export
from torch.nn import Module, Sequential, Conv2d, BatchNorm2d, ReLU, ModuleList, MaxPool2d

In [None]:
#| export
class Block(Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.net = Sequential(
            Conv2d(in_channels, out_channels, 3, 1, 1, bias=False, padding_mode='circular'), 
            BatchNorm2d(out_channels, momentum = 0.99), 
            ReLU(),
            Conv2d(out_channels, out_channels, 3, 1, 1, bias=False, padding_mode='circular'), 
            BatchNorm2d(out_channels, 1e-05, 0.99), 
            ReLU(),
        )
    
    def forward(self, x):
        return self.net(x)

In [None]:
#| eval: false
from colorcloud.datatools import SemanticKITTIDataset, SphericalProjectionTransform, ToTensorTransform
from torchvision.transforms import v2

In [None]:
#| eval: false
data_path = '../../Cloud2DImageConverter/point_clouds/semantic_kitti/'
ds = SemanticKITTIDataset(data_path)

tfms = v2.Compose([
    SphericalProjectionTransform(fov_up_deg=4., fov_down_deg=-26., W=1024, H=64),
    ToTensorTransform(),
])
ds.set_transform(tfms)
img, label, mask = ds[0]

b = Block(5, 64)
activations = b(img.reshape(-1, *img.shape))
activations.shape

torch.Size([1, 64, 64, 1024])

In [None]:
#| export
class Encoder(Module):
    def __init__(self, channels=(2, 64, 128, 256, 512, 1024)):
        super().__init__()
        self.encBlocks = ModuleList(
            [Block(channels[i], channels[i+1]) for i in range(len(channels)-1)]
        )
        self.pool = MaxPool2d(2)
    
    def forward(self, x):
        "Pass the inputs into the current encoder block, store the block outputs and aply the max-pooling."
        blockOutputs = []
        for block in self.encBlocks:
            x = block(x)
            blockOutputs.append(x)
            x = self.pool(x)
        return x, blockOutputs

In [None]:
#| export
class Decoder(Module):
    def __init__(self, channels=(1024, 512, 256, 128, 64)):
        super().__init__()
        self.channels = channels
        self.upconvs = ModuleList(
            [ConvTranspose2d(channels[i], channels[i+1], 2, 2) for i in range(len(channels)-1)]
        )
        self.dec_blocks = ModuleList(
            [Block(channels[i], channels[i+1]) for i in range(len(channels)-1)]
        )
    
    def forward(self, x, encFeatures):
        for i in range(len(self.channels)-1):
            x = self.upconvs[i](x)
            encFeat = self.crop(encFeatures[i], x)
            x = torch.cat([x, encFeat], dim=1)
            x = self.dec_blocks[i](x)
        return x
    
    def crop(self, encFeatures, x):
        (_, _, H, W) = x.shape
        encFeatures = CenterCrop([H, W])(encFeatures)
        return encFeatures

In [None]:
#| export
class UNet(Module):
    def __init__(self, encChannels=(2, 64, 128, 256, 512, 1024), decChannels=(1024, 512, 256, 128, 64), nbClasses=20):
        super().__init__()
        self.encoder = Encoder(encChannels)
        self.decoder = Decoder(decChannels)
        
        # Regression head
        self.head = Conv2d(decChannels[-1], nbClasses, 1)
    
    def forward(self, x):
        encFeatures = self.encoder(x)
        decFeatures = self.decoder(encFeatures[::-1][0], encFeatures[::-1][1:])
        
        # segmentation map
        projection = self.head(decFeatures)
        
        return projection

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()