# Dataset

In [1]:
import os
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

# Code from SalsaNext: https://github.com/Halmstad-University/SalsaNext
from laserscan import LaserScan, SemLaserScan
from laserscanvis import LaserScanVis

In [2]:
# Dataset directory
data_dir = "/tmp/wads_dataset/clean_dataset/"
# Lists to save .bin and .label files directories
scans = []
labels = []
# Expand the dataset directory
for folder in os.walk(data_dir):
    # Add .label files to labels list
    if "labels" in folder[0]:
        for file in sorted(folder[-1]):
            labels.append(os.path.join(folder[0], file))
    # Add .bin files to scans list
    elif "velodyne" in folder[0]:
        for file in sorted(folder[-1]):
            scans.append(os.path.join(folder[0], file))

In [3]:
print(scans[0:4], "\n")
print(labels[0:4])

['/tmp/wads_dataset/clean_dataset/24/velodyne/040908.bin', '/tmp/wads_dataset/clean_dataset/24/velodyne/040909.bin', '/tmp/wads_dataset/clean_dataset/24/velodyne/040910.bin', '/tmp/wads_dataset/clean_dataset/24/velodyne/040911.bin'] 

['/tmp/wads_dataset/clean_dataset/24/labels/040908.label', '/tmp/wads_dataset/clean_dataset/24/labels/040909.label', '/tmp/wads_dataset/clean_dataset/24/labels/040910.label', '/tmp/wads_dataset/clean_dataset/24/labels/040911.label']


In [4]:
scans_0 = np.fromfile(scans[0], dtype=np.float32).reshape(-1,4)
labels_0 = np.fromfile(labels[0], dtype=np.int16).reshape(-1,2)

In [5]:
print(scans_0.shape)
print(labels_0.shape)

(97207, 4)
(97207, 2)


In [6]:
laser_scan = LaserScan(project=True)

In [7]:
laser_scan.open_scan(scans[0])

In [8]:
laser_scan.proj_xyz.shape

(64, 1024, 3)

In [9]:
# Todo:
# Visualize projection output in 3D using LaserScanVis to make sure it's working!

# WeatherNet
### Based on LilaNet: https://github.com/TheCodez/pytorch-LiLaNet

In [10]:
class BasicConv2d(nn.Module):
    '''
    Conv2D + Batchnorm
    '''
    def __init__(self, in_channels, out_channels, **kwargs):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return F.relu(x, inplace=True)

In [11]:
class ModifiedLiLaBlock(nn.Module):
    '''
    Modified LiLa Block as mentioned in paper (https://arxiv.org/abs/1912.03874).
    '''
    def __init__(self, in_channels, n):
        super(ModifiedLiLaBlock, self).__init__()

        self.branch1 = BasicConv2d(in_channels, n, kernel_size=(7, 3), padding=(2, 0))
        self.branch2 = BasicConv2d(in_channels, n, kernel_size=3)
        self.branch3 = BasicConv2d(in_channels, n, kernel_size=3, dilation=2)
        self.branch4 = BasicConv2d(in_channels, n, kernel_size=(3, 7), padding=(0, 2))
        self.conv = BasicConv2d(n * 3, n, kernel_size=1, padding=1)

    def forward(self, x):
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        branch4 = self.branch4(x)

        output = torch.cat([branch1, branch2, branch3, branch4], 1)
        output = self.conv(output)

        return output

In [12]:
class WeatherNet(nn.Module):
    '''
    Weather Net as mentioned in paper (https://arxiv.org/abs/1912.03874).
    '''
    
    def __init__(self, num_classes=3):
        super(WeatherNet, self).__init__()

        self.lila1 = ModifiedLiLaBlock(2, 32)
        self.lila2 = ModifiedLiLaBlock(32, 64)
        self.lila3 = ModifiedLiLaBlock(64, 96)
        self.lila4 = ModifiedLiLaBlock(96, 96)
        self.dropout = nn.Dropout2d(p=0.5)
        self.lila5 = ModifiedLiLaBlock(96, 64)
        self.classifier = nn.Conv2d(64, num_classes, kernel_size=1)
        
        # weights and biases initilization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, distance, reflectivity):
        x = torch.cat([distance, reflectivity], 1)
        x = self.lila1(x)
        x = self.lila2(x)
        x = self.lila3(x)
        x = self.lila4(x)
        x = sel.dropout(x)
        x = self.lila5(x)
        x = self.classifier(x)

        return x