# Weather and MOR Classification for Autonomous Driving - Multi-task

In [1]:
# import dependencies
import os
import cv2
import random
from time import time
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import skimage
from skimage import io, util
from skimage.filters.rank import entropy
from skimage.morphology import disk
from skimage.color import rgb2hsv, rgb2gray, rgb2yuv
from skimage.io import imread
import torch
import torchvision
from torchvision import transforms, io, models
from torchvision.transforms import functional as TF
from torch.utils.data import Dataset, DataLoader
from torch import nn
from torch.nn import functional as F
from torch.optim import Adam, SGD, AdamW
from torchinfo import summary
import torchmetrics
from torchmetrics import Metric
import albumentations as A
from albumentations.pytorch import ToTensorV2
import json
import seaborn as sns
from pyquaternion import Quaternion
import fnmatch
import argparse
import csv
import math
from torch.profiler import profile, record_function, ProfilerActivity
import ordinal_losses_theia as ordinal_losses # ordinal losses from https://github.com/rpmcruz/ordinal-losses

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


In [2]:
# # for debugging purposes
# class Args:
#     BATCH=8
#     EPOCHS=1
#     NETWORK="MobileWeatherNet_Early" # "MobileNetV3_ViT", "ResNet_ViT", "MobileNetV3_Early", and "RangeWeatherNet_Early"
#     OPT_TECHNIQUE="Multi_Adaptive" # "Weighted"
#     TRAIN_MODEL=True
#     SEED=1998
# args=Args()

def str_to_bool(value):
    if value.lower() in ('yes', 'true', '1'):
        return True
    elif value.lower() in ('no', 'false', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Invalid value for boolean argument. Accepted values are "yes"/"true"/"1" or "no"/"false"/"0".')

parser = argparse.ArgumentParser()
parser.add_argument('BATCH', type=int)
parser.add_argument('EPOCHS', type=int)
parser.add_argument('NETWORK', type=str)
parser.add_argument('OPT_TECHNIQUE', type=str)
parser.add_argument('TRAIN_MODEL', type=str_to_bool, help='Specify whether to train the model (yes/true/1) or not (no/false/0)')
parser.add_argument('SEED', type=int)
args = parser.parse_args()
print(args)

## System Variables

In [3]:
# system variables
BATCH_SIZE = args.BATCH
EPOCHS = args.EPOCHS
TRAIN_MODEL = args.TRAIN_MODEL
NETWORK = args.NETWORK
OPT_TECHNIQUE = args.OPT_TECHNIQUE
SEED = args.SEED
NUM_TOTAL_SAMPLES = 1293
IMAGE_HEIGHT = 1024
IMAGE_WIDTH = 1920
MEAN = np.array([0.5, 0.5, 0.5])
STD = np.array([0.5, 0.5, 0.5])
N_CLASSES_WEATHER = 2
N_CLASSES_VISIBILITY = 3
VISIBILITY_BINS = np.array([40.0, 200.0])
# path for saving the models
path_best_model = "./MT_MM_W_MOR_Class_Best_Network_{}_Optimization_{}_Seed_{}.pth".format(NETWORK, OPT_TECHNIQUE, SEED)
path_last_model = "./MT_MM_W_MOR_Class_Last_Network_{}_Optimization_{}_Seed_{}.pth".format(NETWORK, OPT_TECHNIQUE, SEED)

## Transformations / Data Loader

In [4]:
transform_aug = A.Compose([A.CenterCrop(int(IMAGE_HEIGHT*0.5), int(IMAGE_WIDTH*0.5), p=1.0),
                        A.HorizontalFlip(p=0.5),
                        A.Affine(scale=(1.1,1.25), keep_ratio=True, p=0.5),
                        A.Normalize(mean=MEAN, std=STD, always_apply=True),
                        ToTensorV2()],
                        additional_targets={'mask2': 'mask', 'mask3': 'mask'})

transform_base = A.Compose([A.CenterCrop(int(IMAGE_HEIGHT*0.5), int(IMAGE_WIDTH*0.5), p=1.0),
                            A.Normalize(mean=MEAN, std=STD, always_apply=True),
                            ToTensorV2()],
                            additional_targets={'mask2': 'mask', 'mask3': 'mask'})

In [5]:
class FogChamber(Dataset):
    def load_calib_data(self, path_total_dataset, name_camera_calib, tf_tree, velodyne_name='lidar_hdl64_s3_roof'):
        assert velodyne_name in ['lidar_hdl64_s3_roof', 'lidar_vlp32_roof'], 'wrong frame id in tf_tree for velodyne_name'

        with open(os.path.join(path_total_dataset, name_camera_calib), 'r') as f:
            data_camera = json.load(f)
        with open(os.path.join(path_total_dataset, tf_tree), 'r') as f:
            data_extrinsics = json.load(f)

        calib_dict = {'calib_cam_stereo_left.json': 'cam_stereo_left_optical',
                    'calib_cam_stereo_right.json': 'cam_stereo_right_optical',
                    'calib_gated_bwv.json': 'bwv_cam_optical'}

        cam_name = calib_dict[name_camera_calib]

        # scan data extrinsics for transformation from lidar to camera
        important_translations = [velodyne_name, 'radar', cam_name]
        translations = []

        for item in data_extrinsics:
            if item['child_frame_id'] in important_translations:
                translations.append(item)
                if item['child_frame_id'] == cam_name:
                    T_cam = item['transform']
                elif item['child_frame_id'] == velodyne_name:
                    T_velodyne = item['transform']
                elif item['child_frame_id'] == 'radar':
                    T_radar = item['transform']

        # use pyquaternion to setup rotation matrices properly
        R_c_quaternion = Quaternion(w=T_cam['rotation']['w'] * 360 / 2 / np.pi, x=T_cam['rotation']['x'] * 360 / 2 / np.pi,
                                    y=T_cam['rotation']['y'] * 360 / 2 / np.pi, z=T_cam['rotation']['z'] * 360 / 2 / np.pi)
        R_v_quaternion = Quaternion(w=T_velodyne['rotation']['w'] * 360 / 2 / np.pi,
                                    x=T_velodyne['rotation']['x'] * 360 / 2 / np.pi,
                                    y=T_velodyne['rotation']['y'] * 360 / 2 / np.pi,
                                    z=T_velodyne['rotation']['z'] * 360 / 2 / np.pi)

        # setup quaternion values as 3x3 orthogonal rotation matrices
        R_c_matrix = R_c_quaternion.rotation_matrix
        R_v_matrix = R_v_quaternion.rotation_matrix

        # setup translation Vectors
        Tr_cam = np.asarray([T_cam['translation']['x'], T_cam['translation']['y'], T_cam['translation']['z']])
        Tr_velodyne = np.asarray([T_velodyne['translation']['x'], T_velodyne['translation']['y'], T_velodyne['translation']['z']])

        # setup Translation Matrix camera to lidar -> ROS spans transformation from its children to its parents therefore one inversion step is needed for zero_to_camera -> <parent_child>
        zero_to_camera = np.zeros((3, 4))
        zero_to_camera[0:3, 0:3] = R_c_matrix
        zero_to_camera[0:3, 3] = Tr_cam
        zero_to_camera = np.vstack((zero_to_camera, np.array([0, 0, 0, 1])))

        zero_to_velodyne = np.zeros((3, 4))
        zero_to_velodyne[0:3, 0:3] = R_v_matrix
        zero_to_velodyne[0:3, 3] = Tr_velodyne
        zero_to_velodyne = np.vstack((zero_to_velodyne, np.array([0, 0, 0, 1])))

        # calculate total extrinsic transformation to camera
        velodyne_to_camera = np.matmul(np.linalg.inv(zero_to_camera), zero_to_velodyne)
        camera_to_velodyne = np.matmul(np.linalg.inv(zero_to_velodyne), zero_to_camera)

        # read projection matrix P and camera rectification matrix R
        P = np.reshape(data_camera['P'], [3, 4])

        # rectification matrix R has to be equal to the identity as the projection matrix P contains the R matrix w.r.t KITTI definition
        R = np.identity(4)

        # calculate total transformation matrix from velodyne to camera
        vtc = np.matmul(np.matmul(P, R), velodyne_to_camera)
        return velodyne_to_camera, camera_to_velodyne, P, R, vtc, zero_to_camera

    def py_func_project_3D_to_2D(self, points_3D, P):
        # project on image
        points_2D = np.matmul(P, np.vstack((points_3D, np.ones([1, np.shape(points_3D)[1]]))))

        # scale projected points
        points_2D[0][:] = points_2D[0][:] / points_2D[2][:]
        points_2D[1][:] = points_2D[1][:] / points_2D[2][:]

        points_2D = points_2D[0:2]
        return points_2D.transpose()

    def weather_digitize(self, weather_data):
        weather_in_bins = np.array([])
        if weather_data == 'Fog Small Droplets':
            weather_in_bins = 0
        elif weather_data == 'Rain':
            weather_in_bins = 1
        else:
            assert False, "Invalid weather class."
        return weather_in_bins

    def files_with_equal_names_in_diff_dir(self, dir1, dir2, dir3, ext1, ext2, ext3):
        files1 = os.listdir(dir1)
        files2 = os.listdir(dir2)
        files3 = os.listdir(dir3)
        matches = []

        for file1 in files1:
            if fnmatch.fnmatch(file1, '*' + ext1):
                name1 = os.path.splitext(file1)[0]
                for file2 in files2:
                    if fnmatch.fnmatch(file2, name1 + ext2):
                        for file3 in files3:
                            if fnmatch.fnmatch(file3, name1 + ext3):
                                matches.append((os.path.join(dir1, file1), os.path.join(dir2, file2), os.path.join(dir3, file3)))
        return matches
    
    def __init__(self, root, subset, rgb_fold, lidar_fold, weather_visibility_fold, dict_transform=None):
        _, _, _, _, self.vtc,_ = self.load_calib_data(root, 'calib_cam_stereo_left.json', 'calib_tf_tree_full.json', 'lidar_hdl64_s3_roof')
        self.subset = subset

        assert rgb_fold in ('cam_stereo_left_lut')
        self.rgb_path = os.path.join(root, rgb_fold)
        self.rgb_extension = '.png'

        assert lidar_fold in ('lidar_hdl64_strongest', 'lidar_hdl64_last')
        self.lidar_path = os.path.join(root, lidar_fold)
        self.lidar_extension = '.bin'

        assert weather_visibility_fold in ('cerema')
        self.weather_visibility_path = os.path.join(root, weather_visibility_fold)
        self.weather_visibility_extension = '.json'

        self.file_matches = sorted(self.files_with_equal_names_in_diff_dir(self.rgb_path, self.lidar_path, self.weather_visibility_path, self.rgb_extension, self.lidar_extension, self.weather_visibility_extension))
        self.rgb_files = [i[0] for i in self.file_matches]
        self.lidar_files = [i[1] for i in self.file_matches]
        self.weather_visibility_files = [i[2] for i in self.file_matches]

        self.dict_transform = dict_transform

    def __len__(self):
        return len(self.subset)

    def __getitem__(self, i, only_weather=False, only_visibility=False):
        j = self.subset[i]

        if only_weather:
            # get weather GT
            weather_visibility_fname = self.weather_visibility_files[j]
            weather_data = json.load(open(weather_visibility_fname))['weather']
            weather_in_bins = self.weather_digitize(weather_data)
            return weather_in_bins
        
        if only_visibility:
            # get visibility GT
            weather_visibility_fname = self.weather_visibility_files[j]
            visibility_data = json.load(open(weather_visibility_fname))['metereological_visibility']
            visibility_in_bins = np.digitize(visibility_data, VISIBILITY_BINS)
            return visibility_in_bins
    
        # get LUT RGB image
        rgb_fname = self.rgb_files[j]
        rgb_image = imread(rgb_fname)

        # image entropy (normalised)
        gray = rgb2gray(rgb_image)
        image_entropy = entropy(util.img_as_ubyte(gray), disk(5)) / 6.1917 # 6.1917 is the max image entropy for disk 5

        # load and parse lidar data
        lidar_fname = self.lidar_files[j]
        lidar_data_raw = np.fromfile(lidar_fname, dtype=np.float32).reshape((-1, 5))
        # filter away all points behind image plane and below distance threshold
        r = np.sqrt(lidar_data_raw[:, 0] ** 2 + lidar_data_raw[:, 1] ** 2 + lidar_data_raw[:, 2] ** 2)
        lidar_data_raw = lidar_data_raw[np.where(r > 1.5)]
        lidar_data_raw = lidar_data_raw[np.where(lidar_data_raw[:, 0] > 2.5)]
        # range calculation
        lidar_range = np.sqrt(lidar_data_raw[:, 0] ** 2 + lidar_data_raw[:, 1] ** 2 + lidar_data_raw[:, 2] ** 2)
        # 3D to 2D valid coordinates
        points_2D = self.py_func_project_3D_to_2D(lidar_data_raw[:, 0:3].transpose(), self.vtc)
        within_image_boarder_width = np.logical_and(IMAGE_WIDTH > points_2D[:, 0], points_2D[:, 0] >= 0)                   
        within_image_boarder_height = np.logical_and(IMAGE_HEIGHT > points_2D[:, 1], points_2D[:, 1] >= 0)
        valid_points = np.logical_and(within_image_boarder_width, within_image_boarder_height)
        coordinates = np.where(valid_points)[0]
        img_coordinates = points_2D[coordinates, :].astype(dtype=np.int32)
        # lidar range image (normalised)
        image_lidar_range = np.zeros((IMAGE_WIDTH, IMAGE_HEIGHT))
        image_lidar_range[img_coordinates[:, 0], img_coordinates[:, 1]] = lidar_range[coordinates] / 109.2780 # 109.2780 is the max range
        image_lidar_range = image_lidar_range.transpose()
        # lidar intensity image (normalised)
        image_lidar_intensity = np.zeros((IMAGE_WIDTH, IMAGE_HEIGHT))
        image_lidar_intensity[img_coordinates[:, 0], img_coordinates[:, 1]] = lidar_data_raw[:,3][coordinates] / 255 # 255 is the max intensity
        image_lidar_intensity = image_lidar_intensity.transpose()

        # get weather and visibility GT
        weather_visibility_fname = self.weather_visibility_files[j]
        weather_data = json.load(open(weather_visibility_fname))['weather']
        weather_in_bins = self.weather_digitize(weather_data)
        visibility_data = json.load(open(weather_visibility_fname))['metereological_visibility']
        visibility_in_bins = np.digitize(visibility_data, VISIBILITY_BINS)

        if self.dict_transform:
            transformed = self.dict_transform(image=rgb_image, mask=image_lidar_range, mask2=image_lidar_intensity, mask3=image_entropy)
            image_lidar_range = transformed['mask']
            image_lidar_intensity = transformed['mask2']
            image_entropy = transformed['mask3']
        
        return weather_in_bins, visibility_in_bins, image_entropy, image_lidar_range, image_lidar_intensity

In [6]:
# to ensure reproducibility
random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
np.random.seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def worker_init_fn(worker_id):
    np.random.seed(SEED + worker_id)

num_tr_samples = int(0.6 * NUM_TOTAL_SAMPLES)
num_vl_samples = int(0.2 * NUM_TOTAL_SAMPLES)
num_ts_samples = int(0.2 * NUM_TOTAL_SAMPLES)
data = list(range(0, NUM_TOTAL_SAMPLES))
random.shuffle(data) # does NOT prevent leakage between sequences, but NO better solution has been found

tr_dataset = FogChamber(r'/data/auto/DENSE/FogchamberDataset/', data[:num_tr_samples], 'cam_stereo_left_lut', 'lidar_hdl64_strongest', 'cerema', transform_aug)
vl_dataset = FogChamber(r'/data/auto/DENSE/FogchamberDataset/', data[num_tr_samples:num_tr_samples + num_vl_samples], 'cam_stereo_left_lut', 'lidar_hdl64_strongest', 'cerema', transform_base)
ts_dataset = FogChamber(r'/data/auto/DENSE/FogchamberDataset/', data[num_tr_samples + num_vl_samples:], 'cam_stereo_left_lut', 'lidar_hdl64_strongest', 'cerema', transform_base)

# manage class imbalance
weather_samples = [tr_dataset.__getitem__(i, only_weather=True) for i in range(len(tr_dataset))]
_, weather_class_sample_count = np.unique(weather_samples, return_counts=True)
print(weather_class_sample_count)
weather_class_weights = 1 / torch.tensor(weather_class_sample_count)
weight_per_sample = weather_class_weights[weather_samples]
weather_sampler = torch.utils.data.sampler.WeightedRandomSampler(weight_per_sample, len(weight_per_sample))

tr = DataLoader(tr_dataset, BATCH_SIZE, sampler=weather_sampler, num_workers=16, pin_memory=True, worker_init_fn=worker_init_fn)
vl = DataLoader(vl_dataset, BATCH_SIZE, num_workers=16, shuffle=False, pin_memory=True, worker_init_fn=worker_init_fn)
ts = DataLoader(ts_dataset, BATCH_SIZE, num_workers=16, shuffle=False, pin_memory=True, worker_init_fn=worker_init_fn)

[704  71]


## Network and its parameters

In [7]:
# RangeWeatherNet Custom - Multi-Task
def conv_batch_range(in_num, out_num, kernel_size=3, padding=1, stride=1, dropout=0.05):
    return nn.Sequential(
        nn.Conv2d(in_num, out_num, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
        nn.BatchNorm2d(out_num),
        nn.LeakyReLU(0.4),
        nn.Dropout(p=dropout))

# residual block
class DarkResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        reduced_channels = int(in_channels/2)
        self.layer1 = conv_batch_range(in_channels, reduced_channels, kernel_size=1, padding=0)
        self.layer2 = conv_batch_range(reduced_channels, in_channels)

    def forward(self, x):
        residual = x
        out = self.layer1(x)
        out = self.layer2(out)
        out += residual
        return out

class Range_Weather_Net_Custom(nn.Module):
    def make_layer(self, block, in_channels, num_blocks):
        layers = []
        for i in range(0, num_blocks):
            layers.append(block(in_channels))
        return nn.Sequential(*layers)

    def __init__(self, block, input_channels, num_classes_weather, num_classes_visibility, perception_output_features=512):
        super().__init__()

        conv1 = conv_batch_range(input_channels, 16)
        conv2 = conv_batch_range(16, 32, stride=2)
        residual_block1 = self.make_layer(block, in_channels=32, num_blocks=1)
        conv3 = conv_batch_range(32, 64, stride=2)
        residual_block2 = self.make_layer(block, in_channels=64, num_blocks=2)
        conv4 = conv_batch_range(64, 128, stride=2)
        residual_block3 = self.make_layer(block, in_channels=128, num_blocks=8)
        conv5 = conv_batch_range(128, 256, stride=2)
        residual_block4 = self.make_layer(block, in_channels=256, num_blocks=8)
        conv6 = nn.Conv2d(256, perception_output_features, kernel_size=(1, 1), stride=(1, 1))

        self.backbone = nn.Sequential(conv1, conv2,
                                      residual_block1, conv3,
                                      residual_block2, conv4,
                                      residual_block3, conv5,
                                      residual_block4, conv6)

        # head for weather
        self.head_weather = nn.Sequential(nn.Conv2d(perception_output_features, 128, kernel_size=(1, 1), stride=(1, 1)),
                                        nn.ReLU(),
                                        nn.Dropout(p=0.2),
                                        nn.Conv2d(128, num_classes_weather, kernel_size=(1, 1), stride=(1, 1)))
        
        # head for visibility
        self.head_visibility = nn.Sequential(nn.Conv2d(perception_output_features, 128, kernel_size=(1, 1), stride=(1, 1)),
                                            nn.ReLU(),
                                            nn.Dropout(p=0.25),
                                            nn.Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1)),
                                            nn.ReLU(),
                                            nn.Dropout(p=0.25),
                                            nn.Conv2d(128, num_classes_visibility, kernel_size=(1, 1), stride=(1, 1)))

    def forward(self, x):
        out = self.backbone(x)
        # apply global average pooling to collapse spatial dimensions
        out = nn.AdaptiveAvgPool2d(1)(out)
        # heads for weather and visibility
        out_weather = self.head_weather(out)
        out_visibility = self.head_visibility(out)
        # squeeze the output to remove singleton dimensions
        out_weather = out_weather.squeeze(2).squeeze(2)
        out_visibility = out_visibility.squeeze(2).squeeze(2)
        return out_weather, out_visibility

In [8]:
# MobileWeatherNet Custom - Multi-Task
class Mobile_Weather_Net_Custom(nn.Module):
    def __init__(self, input_channels, num_classes_weather, num_classes_visibility, perception_output_features=512):
        super().__init__()

        block_1 = nn.Sequential(self.conv_batch(in_num=input_channels, out_num=32, stride=2),
                                self.conv_batch(in_num=32, out_num=32))
        
        block_2 = nn.Sequential(self.conv_batch(in_num=32, out_num=64, stride=2),
                                self.conv_batch(in_num=64, out_num=64))
        
        block_3 = nn.Sequential(self.conv_batch(in_num=64, out_num=128, stride=2),
                                self.conv_batch(in_num=128, out_num=128),
                                self.conv_batch(in_num=128, out_num=128))
        
        conv = nn.Conv2d(128, perception_output_features, kernel_size=(1, 1), stride=(1, 1))
        self.backbone = nn.Sequential(block_1, block_2, block_3, conv)

        # head for weather
        self.head_weather = nn.Sequential(nn.Conv2d(perception_output_features, 128, kernel_size=(1, 1), stride=(1, 1)),
                                        nn.ReLU(),
                                        nn.Dropout(p=0.2),
                                        nn.Conv2d(128, num_classes_weather, kernel_size=(1, 1), stride=(1, 1)))
        
        # head for visibility
        self.head_visibility = nn.Sequential(nn.Conv2d(perception_output_features, 128, kernel_size=(1, 1), stride=(1, 1)),
                                            nn.ReLU(),
                                            nn.Dropout(p=0.25),
                                            nn.Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1)),
                                            nn.ReLU(),
                                            nn.Dropout(p=0.25),
                                            nn.Conv2d(128, num_classes_visibility, kernel_size=(1, 1), stride=(1, 1)))

    def forward(self, x):
        out = self.backbone(x)
        # apply global average pooling to collapse spatial dimensions
        out = nn.AdaptiveAvgPool2d(1)(out)
        # heads for weather and visibility
        out_weather = self.head_weather(out)
        out_visibility = self.head_visibility(out)
        # squeeze the output to remove singleton dimensions
        out_weather = out_weather.squeeze(2).squeeze(2)
        out_visibility = out_visibility.squeeze(2).squeeze(2)
        return out_weather, out_visibility
    
    def conv_batch(self, in_num, out_num, kernel_size=3, stride=1, padding=1, dropout=0.05):
        return nn.Sequential(nn.Conv2d(in_num, out_num, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
                            nn.BatchNorm2d(out_num),
                            nn.PReLU(),
                            nn.Dropout(p=dropout))

In [9]:
# MobileNetV3 Custom - Multi-Task
class Mobile_Net_v3_Custom(nn.Module):
    def __init__(self, input_channels, num_classes_weather, num_classes_visibility, perception_output_features=512):
        super().__init__()
        self.backbone = torchvision.models.mobilenet_v3_small(weights=None)

        self.backbone.features[0][0] = nn.Conv2d(input_channels, 16, kernel_size=3, stride=2, padding=1, bias=False)
        self.backbone.avgpool = nn.Sequential(nn.Conv2d(self.backbone.features[12][0].out_channels, perception_output_features, (1, 1)),
                                              nn.AdaptiveAvgPool2d(output_size=1))
        self.backbone.classifier = nn.Identity()
        
        self.head_weather = nn.Sequential(nn.Conv2d(perception_output_features, 128, kernel_size=(1, 1), stride=(1, 1)),
                                        nn.ReLU(),
                                        nn.Dropout(p=0.2),
                                        nn.Conv2d(128, num_classes_weather, kernel_size=(1, 1), stride=(1, 1)))
        
        self.head_visibility = nn.Sequential(nn.Conv2d(perception_output_features, 128, kernel_size=(1, 1), stride=(1, 1)),
                                            nn.ReLU(),
                                            nn.Dropout(p=0.25),
                                            nn.Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1)),
                                            nn.ReLU(),
                                            nn.Dropout(p=0.25),
                                            nn.Conv2d(128, num_classes_visibility, kernel_size=(1, 1), stride=(1, 1)))

    def forward(self, x):
        out = self.backbone(x)
        # reshape the backbone output to (batch_size, n_features, 1, 1)
        out = out.view(out.size(0), -1, 1, 1)
        # forward pass through the heads
        out_weather = self.head_weather(out).view(out.size(0), -1)
        out_visibility = self.head_visibility(out).view(out.size(0), -1)
        return out_weather, out_visibility

In [10]:
# multi-head masked self-attention layer with a projection at the end
# code based on the article "Multi-Modal Fusion Transformer for End-to-End Autonomous Driving", Aditya Prakash, Kashyap Chitta, Andreas Geiger, 2021
class MSA(nn.Module):
    def __init__(self, n_embd, n_head, attn_pdrop, resid_pdrop):
        super().__init__()
        assert n_embd % n_head == 0
        # key, query, value projections for all heads
        self.key = nn.Linear(n_embd, n_embd)
        self.query = nn.Linear(n_embd, n_embd)
        self.value = nn.Linear(n_embd, n_embd)
        # regularization
        self.attn_drop = nn.Dropout(attn_pdrop)
        self.resid_drop = nn.Dropout(resid_pdrop)
        # output projection
        self.proj = nn.Linear(n_embd, n_embd)
        self.n_head = n_head

    def forward(self, x):
        B, T, C = x.size()

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        # self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = F.softmax(att, dim=-1)
        att = self.attn_drop(att)
        y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

        # output projection
        y = self.resid_drop(self.proj(y))
        return y

class TransformerBlock(nn.Module): # transformer block
    def __init__(self, n_embd, n_head, block_exp, attn_pdrop, resid_pdrop):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)
        self.attn = MSA(n_embd, n_head, attn_pdrop, resid_pdrop)
        self.mlp = nn.Sequential(nn.Linear(n_embd, block_exp * n_embd),
                                nn.ReLU(True), # changed from GELU
                                nn.Linear(block_exp * n_embd, n_embd),
                                nn.Dropout(resid_pdrop))

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x
    
class GPT(nn.Module):
    def __init__(self, n_embd, n_head, block_exp, n_layer, camera_vert_anchors, camera_horz_anchors, lidar_vert_anchors, lidar_horz_anchors, seq_len, embd_pdrop, attn_pdrop, resid_pdrop):
        super().__init__()
        self.n_embd = n_embd
        self.seq_len = seq_len # only support seq len 1
        self.camera_vert_anchors = camera_vert_anchors
        self.camera_horz_anchors = camera_horz_anchors
        self.lidar_vert_anchors = lidar_vert_anchors
        self.lidar_horz_anchors = lidar_horz_anchors
        # positional embedding parameter (learnable), camera + lidar
        self.pos_emb = nn.Parameter(torch.zeros(1, self.seq_len * camera_vert_anchors * camera_horz_anchors + self.seq_len * lidar_vert_anchors * lidar_horz_anchors, n_embd))
        self.drop = nn.Dropout(embd_pdrop)
        # transformer
        self.blocks = nn.Sequential(*[TransformerBlock(n_embd, n_head, block_exp, attn_pdrop, resid_pdrop) for layer in range(n_layer)])
        # decoder head
        self.ln_f = nn.LayerNorm(n_embd)
        self.block_size = self.seq_len

    def forward(self, camera_tensor, lidar_tensor):
        bz = lidar_tensor.shape[0]
        lidar_h, lidar_w = lidar_tensor.shape[2:4]
        camera_h, camera_w = camera_tensor.shape[2:4]

        assert self.seq_len == 1
        camera_tensor = camera_tensor.view(bz, self.seq_len, -1, camera_h, camera_w).permute(0,1,3,4,2).contiguous().view(bz, -1, self.n_embd)
        lidar_tensor = lidar_tensor.view(bz, self.seq_len, -1, lidar_h, lidar_w).permute(0,1,3,4,2).contiguous().view(bz, -1, self.n_embd)
        token_embeddings = torch.cat((camera_tensor, lidar_tensor), dim=1)

        x = self.drop(self.pos_emb + token_embeddings)
        x = self.blocks(x) # (B, an * T, C)
        x = self.ln_f(x) # (B, an * T, C)
        x = x.view(bz, self.seq_len * self.camera_vert_anchors * self.camera_horz_anchors + self.seq_len * self.lidar_vert_anchors * self.lidar_horz_anchors, self.n_embd)

        camera_tensor_out = x[:, :self.seq_len*self.camera_vert_anchors*self.camera_horz_anchors, :].contiguous().view(bz * self.seq_len, -1, camera_h, camera_w)
        lidar_tensor_out = x[:, self.seq_len*self.camera_vert_anchors*self.camera_horz_anchors:, :].contiguous().view(bz * self.seq_len, -1, lidar_h, lidar_w)

        return camera_tensor_out, lidar_tensor_out
    
class MobileNetV3_Encoder(nn.Module):
    def __init__(self, camera_in_shape, lidar_in_shape):
        super().__init__()
        self.backbone_camera = torchvision.models.mobilenet_v3_small(weights=None)
        self.backbone_camera.features[0][0] = nn.Conv2d(camera_in_shape, 16, kernel_size=3, stride=2, padding=1, bias=False)
        self.camera_n_features = self.backbone_camera.classifier[0].in_features
        self.backbone_camera.classifier = nn.Identity()

        self.backbone_lidar = torchvision.models.mobilenet_v3_small(weights=None)
        self.backbone_lidar.features[0][0] = nn.Conv2d(lidar_in_shape, 16, kernel_size=3, stride=2, padding=1, bias=False)
        self.lidar_n_features = self.backbone_lidar.classifier[0].in_features
        self.backbone_lidar.classifier = nn.Identity()
    
class ResNet_Encoder(nn.Module):
    def __init__(self, camera_in_shape, lidar_in_shape):
        super().__init__()
        self.backbone_camera = torchvision.models.resnet34(weights=None)
        self.backbone_camera.conv1 = nn.Conv2d(camera_in_shape, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.camera_n_features = self.backbone_camera.fc.in_features
        self.backbone_camera.fc = nn.Identity()

        self.backbone_lidar = torchvision.models.resnet18(weights=None)
        self.backbone_lidar.conv1 = nn.Conv2d(lidar_in_shape, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.lidar_n_features = self.backbone_lidar.fc.in_features
        self.backbone_lidar.fc = nn.Identity()

class ViT_FusionNet(nn.Module):
    def __init__(self, num_classes_weather, num_classes_visibility, architecture="ResNet_ViT", camera_in_channels=1, lidar_in_channels=2, n_head=4, block_exp=4, n_layer=8, seq_len=1, embd_pdrop=0.1, attn_pdrop=0.1, resid_pdrop=0.1, perception_output_features=512):
        super().__init__()
        
        self.architecture = architecture

        camera_vert_anchors = 16
        camera_horz_anchors = 30
        lidar_vert_anchors = 16
        lidar_horz_anchors = 30

        self.avgpool_camera = nn.AdaptiveAvgPool2d((camera_vert_anchors, camera_horz_anchors))
        self.avgpool_lidar = nn.AdaptiveAvgPool2d((lidar_vert_anchors, lidar_horz_anchors))

        if self.architecture == "ResNet_ViT":
            self.encoders = ResNet_Encoder(camera_in_channels, lidar_in_channels)

            # n_embd (int) - number of expected features in the encoder/decoder inputs, n_head (int) - number of heads in the multiheadattention models
            self.transformer1 = GPT(n_embd=self.encoders.backbone_camera.layer1[1].conv2.out_channels, # note: in_channels can also be accessed
                                n_head=n_head,
                                block_exp=block_exp,
                                n_layer=n_layer,
                                camera_vert_anchors=camera_vert_anchors,
                                camera_horz_anchors=camera_horz_anchors,
                                lidar_vert_anchors=lidar_vert_anchors,
                                lidar_horz_anchors=lidar_horz_anchors,
                                seq_len=seq_len,
                                embd_pdrop=embd_pdrop,
                                attn_pdrop=attn_pdrop,
                                resid_pdrop=resid_pdrop)

            self.transformer2 = GPT(n_embd=self.encoders.backbone_camera.layer2[1].conv2.out_channels,
                                n_head=n_head,
                                block_exp=block_exp,
                                n_layer=n_layer,
                                camera_vert_anchors=camera_vert_anchors,
                                camera_horz_anchors=camera_horz_anchors,
                                lidar_vert_anchors=lidar_vert_anchors,
                                lidar_horz_anchors=lidar_horz_anchors,
                                seq_len=seq_len,
                                embd_pdrop=embd_pdrop,
                                attn_pdrop=attn_pdrop,
                                resid_pdrop=resid_pdrop)

            self.transformer3 = GPT(n_embd=self.encoders.backbone_camera.layer3[1].conv2.out_channels,
                                n_head=n_head,
                                block_exp=block_exp,
                                n_layer=n_layer,
                                camera_vert_anchors=camera_vert_anchors,
                                camera_horz_anchors=camera_horz_anchors,
                                lidar_vert_anchors=lidar_vert_anchors,
                                lidar_horz_anchors=lidar_horz_anchors,
                                seq_len=seq_len,
                                embd_pdrop=embd_pdrop,
                                attn_pdrop=attn_pdrop,
                                resid_pdrop=resid_pdrop)

            self.transformer4 = GPT(n_embd=self.encoders.backbone_camera.layer4[1].conv2.out_channels,
                                n_head=n_head,
                                block_exp=block_exp,
                                n_layer=n_layer,
                                camera_vert_anchors=camera_vert_anchors,
                                camera_horz_anchors=camera_horz_anchors,
                                lidar_vert_anchors=lidar_vert_anchors,
                                lidar_horz_anchors=lidar_horz_anchors,
                                seq_len=seq_len,
                                embd_pdrop=embd_pdrop,
                                attn_pdrop=attn_pdrop,
                                resid_pdrop=resid_pdrop)

            if self.encoders.backbone_camera.layer4[1].conv2.out_channels != perception_output_features:
                self.change_channel_conv_camera = nn.Conv2d(self.encoders.backbone_camera.layer4[1].conv2.out_channels, perception_output_features, (1, 1))
                self.change_channel_conv_lidar = nn.Conv2d(self.encoders.backbone_camera.layer4[1].conv2.out_channels, perception_output_features, (1, 1))
            else:
                self.change_channel_conv_camera = nn.Sequential()
                self.change_channel_conv_lidar = nn.Sequential()

        elif self.architecture == "MobileNetV3_ViT":
            self.encoders = MobileNetV3_Encoder(camera_in_channels, lidar_in_channels)

            self.transformer1 = GPT(n_embd=self.encoders.backbone_camera.features[6].block[3][0].out_channels,
                                n_head=n_head,
                                block_exp=block_exp,
                                n_layer=n_layer,
                                camera_vert_anchors=camera_vert_anchors,
                                camera_horz_anchors=camera_horz_anchors,
                                lidar_vert_anchors=lidar_vert_anchors,
                                lidar_horz_anchors=lidar_horz_anchors,
                                seq_len=seq_len,
                                embd_pdrop=embd_pdrop,
                                attn_pdrop=attn_pdrop,
                                resid_pdrop=resid_pdrop)

            self.transformer2 = GPT(n_embd=self.encoders.backbone_camera.features[12][0].out_channels,
                                n_head=n_head,
                                block_exp=block_exp,
                                n_layer=n_layer,
                                camera_vert_anchors=camera_vert_anchors,
                                camera_horz_anchors=camera_horz_anchors,
                                lidar_vert_anchors=lidar_vert_anchors,
                                lidar_horz_anchors=lidar_horz_anchors,
                                seq_len=seq_len,
                                embd_pdrop=embd_pdrop,
                                attn_pdrop=attn_pdrop,
                                resid_pdrop=resid_pdrop)

            if self.encoders.backbone_camera.features[12][0].out_channels != perception_output_features:
                self.change_channel_conv_camera = nn.Conv2d(self.encoders.backbone_camera.features[12][0].out_channels, perception_output_features, (1, 1))
                self.change_channel_conv_lidar = nn.Conv2d(self.encoders.backbone_camera.features[12][0].out_channels, perception_output_features, (1, 1))
            else:
                self.change_channel_conv_camera = nn.Sequential()
                self.change_channel_conv_lidar = nn.Sequential()
                
        else:
            assert False, "Invalid architecture."

        # classification heads
        self.head_weather = nn.Sequential(nn.Conv2d(perception_output_features, 128, kernel_size=(1, 1), stride=(1, 1)),
                                        nn.ReLU(),
                                        nn.Dropout(p=0.2),
                                        nn.Conv2d(128, num_classes_weather, kernel_size=(1, 1), stride=(1, 1)))
        
        self.head_visibility = nn.Sequential(nn.Conv2d(perception_output_features, 128, kernel_size=(1, 1), stride=(1, 1)),
                                            nn.ReLU(),
                                            nn.Dropout(p=0.25),
                                            nn.Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1)),
                                            nn.ReLU(),
                                            nn.Dropout(p=0.25),
                                            nn.Conv2d(128, num_classes_visibility, kernel_size=(1, 1), stride=(1, 1)))

    def forward(self, in_camera, in_lidar):
        if self.architecture == "ResNet_ViT":
            camera_features = self.encoders.backbone_camera.conv1(in_camera)
            camera_features = self.encoders.backbone_camera.bn1(camera_features)
            camera_features = self.encoders.backbone_camera.relu(camera_features)
            camera_features = self.encoders.backbone_camera.maxpool(camera_features)
            lidar_features = self.encoders.backbone_lidar.conv1(in_lidar)
            lidar_features = self.encoders.backbone_lidar.bn1(lidar_features)
            lidar_features = self.encoders.backbone_lidar.relu(lidar_features)
            lidar_features = self.encoders.backbone_lidar.maxpool(lidar_features)

            camera_features = self.encoders.backbone_camera.layer1(camera_features)
            lidar_features = self.encoders.backbone_lidar.layer1(lidar_features)
            camera_features_layer1 = self.avgpool_camera(camera_features)
            lidar_features_layer1 = self.avgpool_lidar(lidar_features)
            camera_features_layer1, lidar_features_layer1 = self.transformer1(camera_features_layer1, lidar_features_layer1)
            camera_features_layer1 = F.interpolate(camera_features_layer1, size=(camera_features.shape[2], camera_features.shape[3]), mode='bilinear', align_corners=False)
            lidar_features_layer1 = F.interpolate(lidar_features_layer1, size=(lidar_features.shape[2], lidar_features.shape[3]), mode='bilinear', align_corners=False)
            camera_features = camera_features + camera_features_layer1
            lidar_features = lidar_features + lidar_features_layer1

            camera_features = self.encoders.backbone_camera.layer2(camera_features)
            lidar_features = self.encoders.backbone_lidar.layer2(lidar_features)
            camera_features_layer2 = self.avgpool_camera(camera_features)
            lidar_features_layer2 = self.avgpool_lidar(lidar_features)
            camera_features_layer2, lidar_features_layer2 = self.transformer2(camera_features_layer2, lidar_features_layer2)
            camera_features_layer2 = F.interpolate(camera_features_layer2, size=(camera_features.shape[2], camera_features.shape[3]), mode='bilinear', align_corners=False)
            lidar_features_layer2 = F.interpolate(lidar_features_layer2, size=(lidar_features.shape[2], lidar_features.shape[3]), mode='bilinear', align_corners=False)
            camera_features = camera_features + camera_features_layer2
            lidar_features = lidar_features + lidar_features_layer2

            camera_features = self.encoders.backbone_camera.layer3(camera_features)
            lidar_features = self.encoders.backbone_lidar.layer3(lidar_features)
            camera_features_layer3 = self.avgpool_camera(camera_features)
            lidar_features_layer3 = self.avgpool_lidar(lidar_features)
            camera_features_layer3, lidar_features_layer3 = self.transformer3(camera_features_layer3, lidar_features_layer3)
            camera_features_layer3 = F.interpolate(camera_features_layer3, size=(camera_features.shape[2], camera_features.shape[3]), mode='bilinear', align_corners=False)
            lidar_features_layer3 = F.interpolate(lidar_features_layer3, size=(lidar_features.shape[2], lidar_features.shape[3]), mode='bilinear', align_corners=False)
            camera_features = camera_features + camera_features_layer3
            lidar_features = lidar_features + lidar_features_layer3

            camera_features = self.encoders.backbone_camera.layer4(camera_features)
            lidar_features = self.encoders.backbone_lidar.layer4(lidar_features)
            camera_features_layer4 = self.avgpool_camera(camera_features)
            lidar_features_layer4 = self.avgpool_lidar(lidar_features)
            camera_features_layer4, lidar_features_layer4 = self.transformer4(camera_features_layer4, lidar_features_layer4)
            camera_features_layer4 = F.interpolate(camera_features_layer4, size=(camera_features.shape[2], camera_features.shape[3]), mode='bilinear', align_corners=False)
            lidar_features_layer4 = F.interpolate(lidar_features_layer4, size=(lidar_features.shape[2], lidar_features.shape[3]), mode='bilinear', align_corners=False)
            camera_features = camera_features + camera_features_layer4
            lidar_features = lidar_features + lidar_features_layer4

            # Downsamples channels to 512 (if necessary)
            camera_features = self.change_channel_conv_camera(camera_features)
            lidar_features = self.change_channel_conv_lidar(lidar_features)

            camera_features = self.encoders.backbone_camera.avgpool(camera_features)
            camera_features = torch.flatten(camera_features, 1)
            lidar_features = self.encoders.backbone_lidar.avgpool(lidar_features)
            lidar_features = torch.flatten(lidar_features, 1)
            fused_features = camera_features + lidar_features
        
        else:
            camera_features = self.encoders.backbone_camera.features[0](in_camera)
            camera_features = self.encoders.backbone_camera.features[1](camera_features)
            camera_features = self.encoders.backbone_camera.features[2](camera_features)
            camera_features = self.encoders.backbone_camera.features[3](camera_features)
            camera_features = self.encoders.backbone_camera.features[4](camera_features)
            camera_features = self.encoders.backbone_camera.features[5](camera_features)
            lidar_features = self.encoders.backbone_lidar.features[0](in_lidar)
            lidar_features = self.encoders.backbone_lidar.features[1](lidar_features)
            lidar_features = self.encoders.backbone_lidar.features[2](lidar_features)
            lidar_features = self.encoders.backbone_lidar.features[3](lidar_features)
            lidar_features = self.encoders.backbone_lidar.features[4](lidar_features)
            lidar_features = self.encoders.backbone_lidar.features[5](lidar_features)

            camera_features = self.encoders.backbone_camera.features[6](camera_features)
            lidar_features = self.encoders.backbone_lidar.features[6](lidar_features)
            camera_features_block1 = self.avgpool_camera(camera_features)
            lidar_features_block1 = self.avgpool_lidar(lidar_features)
            camera_features_block1, lidar_features_block1 = self.transformer1(camera_features_block1, lidar_features_block1)
            camera_features_block1 = F.interpolate(camera_features_block1, size=(camera_features.shape[2], camera_features.shape[3]), mode='bilinear', align_corners=False)
            lidar_features_block1 = F.interpolate(lidar_features_block1, size=(lidar_features.shape[2], lidar_features.shape[3]), mode='bilinear', align_corners=False)
            camera_features = camera_features + camera_features_block1
            lidar_features = lidar_features + lidar_features_block1

            camera_features = self.encoders.backbone_camera.features[7](camera_features)
            camera_features = self.encoders.backbone_camera.features[8](camera_features)
            camera_features = self.encoders.backbone_camera.features[9](camera_features)
            camera_features = self.encoders.backbone_camera.features[10](camera_features)
            camera_features = self.encoders.backbone_camera.features[11](camera_features)
            lidar_features = self.encoders.backbone_lidar.features[7](lidar_features)
            lidar_features = self.encoders.backbone_lidar.features[8](lidar_features)
            lidar_features = self.encoders.backbone_lidar.features[9](lidar_features)
            lidar_features = self.encoders.backbone_lidar.features[10](lidar_features)
            lidar_features = self.encoders.backbone_lidar.features[11](lidar_features)

            camera_features = self.encoders.backbone_camera.features[12](camera_features)
            lidar_features = self.encoders.backbone_lidar.features[12](lidar_features)
            camera_features_block2 = self.avgpool_camera(camera_features)
            lidar_features_block2 = self.avgpool_lidar(lidar_features)
            camera_features_block2, lidar_features_block2 = self.transformer2(camera_features_block2, lidar_features_block2)
            camera_features_block2 = F.interpolate(camera_features_block2, size=(camera_features.shape[2], camera_features.shape[3]), mode='bilinear', align_corners=False)
            lidar_features_block2 = F.interpolate(lidar_features_block2, size=(lidar_features.shape[2], lidar_features.shape[3]), mode='bilinear', align_corners=False)
            camera_features = camera_features + camera_features_block2
            lidar_features = lidar_features + lidar_features_block2

            # Downsamples channels to 512 (if necessary)
            camera_features = self.change_channel_conv_camera(camera_features)
            lidar_features = self.change_channel_conv_lidar(lidar_features)

            camera_features = self.encoders.backbone_camera.avgpool(camera_features)
            camera_features = torch.flatten(camera_features, 1)
            lidar_features = self.encoders.backbone_lidar.avgpool(lidar_features)
            lidar_features = torch.flatten(lidar_features, 1)
            fused_features = camera_features + lidar_features

        # reshape the output to (batch_size, n_features, 1, 1)
        fused_features = fused_features.view(fused_features.size(0), -1, 1, 1)
        # forward pass through the heads
        out_weather = self.head_weather(fused_features).view(fused_features.size(0), -1)
        out_visibility = self.head_visibility(fused_features).view(fused_features.size(0), -1)
        
        return out_weather, out_visibility

In [11]:
class FocalLoss(nn.Module):
    # useful when there is a large class imbalance. alpha (Tensor, optional): weights for each class. Defaults to None. gamma (float, optional): a constant. Defaults to 0.
    # reduction (str, optional): 'mean', 'sum', or 'none', ignore_index (int, optional): class label to ignore. Defaults to -100
    def __init__(self, alpha=None, gamma=0.0, reduction='mean', ignore_index=-100):
        super().__init__()

        assert reduction in ('mean', 'sum', 'none')
        self.reduction = reduction
        self.alpha = alpha
        self.gamma = gamma
        self.ignore_index = ignore_index
        self.eps = 0.001 # avoid grad explode

        self.nll_loss = nn.NLLLoss(weight=self.alpha, reduction='none', ignore_index=self.ignore_index)

    def forward(self, pred_prob, target):
        if pred_prob.ndim > 2:
            # (N, C, d1, d2, ..., dk) -> (N * d1 * ... * dk, C)
            c = pred_prob.shape[1]
            pred_prob = pred_prob.permute(0, *range(2, pred_prob.ndim), 1).reshape(-1, c)
            target = target.view(-1)

        unignored_mask = target != self.ignore_index
        target = target[unignored_mask]
        if len(target) == 0:
            return torch.tensor(0.)
        pred_prob = pred_prob[unignored_mask]
        
        # compute weighted cross entropy term: -alpha * log(pt)
        # (alpha is already part of self.nll_loss)
        log_p = torch.log(pred_prob + self.eps)
        ce = self.nll_loss(log_p, target)

        # get true class column from each row
        all_rows = torch.arange(len(pred_prob))
        log_pt = log_p[all_rows, target]

        # compute focal term: (1 - pt)^gamma
        pt = log_pt.exp()
        focal_term = (1 - pt)**self.gamma

        # the full loss: -alpha * ((1 - pt)^gamma) * log(pt)
        loss = focal_term * ce

        if self.reduction == 'mean':
            loss = loss.mean()
        elif self.reduction == 'sum':
            loss = loss.sum()
        return loss

In [12]:
# loss function
loss_function_weather = FocalLoss(alpha=None, gamma=2.0, reduction='none', ignore_index=-100).to(device)
loss_function_visibility = getattr(ordinal_losses, 'OrdinalEncoding')(K=N_CLASSES_VISIBILITY).to(device)
n_outputs_visibility = loss_function_visibility.how_many_outputs()

# network
if NETWORK == "MobileNetV3_ViT":
    net = ViT_FusionNet(num_classes_weather=N_CLASSES_WEATHER, num_classes_visibility=n_outputs_visibility, architecture="MobileNetV3_ViT", camera_in_channels=1, lidar_in_channels=2, n_head=2, block_exp=2, n_layer=4).to(device)
    
    if OPT_TECHNIQUE == "Multi_Adaptive":
        opt_weather = AdamW(list(net.encoders.parameters()) +
                            list(net.transformer1.parameters()) +
                            list(net.transformer2.parameters()) +
                            list(net.change_channel_conv_camera.parameters()) +
                            list(net.change_channel_conv_lidar.parameters()) +
                            list(net.head_weather.parameters()), lr=1e-5, weight_decay=1e-4)
        opt_visibility = AdamW(list(net.encoders.parameters()) +
                            list(net.transformer1.parameters()) +
                            list(net.transformer2.parameters()) +
                            list(net.change_channel_conv_camera.parameters()) +
                            list(net.change_channel_conv_lidar.parameters()) +
                            list(net.head_visibility.parameters()), lr=1e-5, weight_decay=1e-4)
        
    elif OPT_TECHNIQUE == "Weighted":
        opt = AdamW(net.parameters(), lr=1e-5, weight_decay=1e-4)

    else:
        assert False, "Invalid optimisation technique."

elif NETWORK == "ResNet_ViT":
    net = ViT_FusionNet(num_classes_weather=N_CLASSES_WEATHER, num_classes_visibility=n_outputs_visibility, architecture="ResNet_ViT", camera_in_channels=1, lidar_in_channels=2).to(device)
    
    if OPT_TECHNIQUE == "Multi_Adaptive":
        opt_weather = AdamW(list(net.encoders.parameters()) +
                            list(net.transformer1.parameters()) +
                            list(net.transformer2.parameters()) +
                            list(net.transformer3.parameters()) +
                            list(net.transformer4.parameters()) +
                            list(net.change_channel_conv_camera.parameters()) +
                            list(net.change_channel_conv_lidar.parameters()) +
                            list(net.head_weather.parameters()), lr=1e-5, weight_decay=1e-4)
        opt_visibility = AdamW(list(net.encoders.parameters()) +
                            list(net.transformer1.parameters()) +
                            list(net.transformer2.parameters()) +
                            list(net.transformer3.parameters()) +
                            list(net.transformer4.parameters()) +
                            list(net.change_channel_conv_camera.parameters()) +
                            list(net.change_channel_conv_lidar.parameters()) +
                            list(net.head_visibility.parameters()), lr=1e-5, weight_decay=1e-4)
        
    elif OPT_TECHNIQUE == "Weighted":
        opt = AdamW(net.parameters(), lr=1e-5, weight_decay=1e-4)

    else:
        assert False, "Invalid optimisation technique."

elif NETWORK == "MobileNetV3_Early":
    net = Mobile_Net_v3_Custom(3, N_CLASSES_WEATHER, n_outputs_visibility).to(device)

    if OPT_TECHNIQUE == "Multi_Adaptive":
        opt_weather = AdamW(list(net.backbone.parameters()) +
                            list(net.head_weather.parameters()), lr=1e-5, weight_decay=1e-4)
        opt_visibility = AdamW(list(net.backbone.parameters()) +
                            list(net.head_visibility.parameters()), lr=1e-5, weight_decay=1e-4)
    
    elif OPT_TECHNIQUE == "Weighted":
        opt = AdamW(net.parameters(), lr=1e-5, weight_decay=1e-4)

    else:
        assert False, "Invalid optimisation technique."

elif NETWORK == "MobileWeatherNet_Early":
    net = Mobile_Weather_Net_Custom(3, N_CLASSES_WEATHER, n_outputs_visibility).to(device)

    if OPT_TECHNIQUE == "Multi_Adaptive":
        opt_weather = AdamW(list(net.backbone.parameters()) +
                            list(net.head_weather.parameters()), lr=1e-5, weight_decay=1e-4)
        opt_visibility = AdamW(list(net.backbone.parameters()) +
                            list(net.head_visibility.parameters()), lr=1e-5, weight_decay=1e-4)
        
    elif OPT_TECHNIQUE == "Weighted":
        opt = AdamW(net.parameters(), lr=1e-5, weight_decay=1e-4)

    else:
        assert False, "Invalid optimisation technique."

elif NETWORK == "RangeWeatherNet_Early":
    net = Range_Weather_Net_Custom(DarkResidualBlock, 3, N_CLASSES_WEATHER, n_outputs_visibility).to(device)

    if OPT_TECHNIQUE == "Multi_Adaptive":
        opt_weather = AdamW(list(net.backbone.parameters()) +
                            list(net.head_weather.parameters()), lr=1e-5, weight_decay=1e-4)
        opt_visibility = AdamW(list(net.backbone.parameters()) +
                            list(net.head_visibility.parameters()), lr=1e-5, weight_decay=1e-4)
        
    elif OPT_TECHNIQUE == "Weighted":
        opt = AdamW(net.parameters(), lr=1e-5, weight_decay=1e-4)

    else:
        assert False, "Invalid optimisation technique."
    
else:
    assert False, "Invalid option. Valid options: MobileNetV3_ViT, ResNet_ViT, MobileNetV3_Early, RangeWeatherNet_Early, and MobileWeatherNet_Early."

print(summary(net))

Layer (type:depth-idx)                   Param #
Mobile_Weather_Net_Custom                --
├─Sequential: 1-1                        --
│    └─Sequential: 2-1                   --
│    │    └─Sequential: 3-1              929
│    │    └─Sequential: 3-2              9,281
│    └─Sequential: 2-2                   --
│    │    └─Sequential: 3-3              18,561
│    │    └─Sequential: 3-4              36,993
│    └─Sequential: 2-3                   --
│    │    └─Sequential: 3-5              73,985
│    │    └─Sequential: 3-6              147,713
│    │    └─Sequential: 3-7              147,713
│    └─Conv2d: 2-4                       66,048
├─Sequential: 1-2                        --
│    └─Conv2d: 2-5                       65,664
│    └─ReLU: 2-6                         --
│    └─Dropout: 2-7                      --
│    └─Conv2d: 2-8                       258
├─Sequential: 1-3                        --
│    └─Conv2d: 2-9                       65,664
│    └─ReLU: 2-10               

## Train and Validation

In [None]:
if TRAIN_MODEL:
    # training and validation cycles
    print("[INFO] Network training and validation...")
    PATIENCE = int(0.8*EPOCHS)
    vl_loss_min = 1e6
    wait = 0
    loss_avg_tr = []
    loss_avg_vl = []

    # loop over EPOCHS
    for epoch in range(EPOCHS):
        print(f'* Epoch {epoch+1}/{EPOCHS}')

        loss_total_tr = 0
        loss_total_vl = 0

        tic = time()
        net.train()

        for weather_gt_tr, visibility_gt_tr, camera_image_tr, lidar_image_range_tr, lidar_image_intensity_tr in tr: 
            weather_gt_tr = weather_gt_tr.to(device)
            visibility_gt_tr = visibility_gt_tr.to(device)
            camera_image_tr = camera_image_tr.to(device)
            lidar_image_range_tr = lidar_image_range_tr.to(device)
            lidar_image_intensity_tr = lidar_image_intensity_tr.to(device)

            if NETWORK == "MobileNetV3_ViT" or NETWORK == "ResNet_ViT":
                camera_image_tr = camera_image_tr[:,np.newaxis,:,:].float()
                lidar_image_tr = torch.cat((lidar_image_range_tr[:,np.newaxis,:,:], lidar_image_intensity_tr[:,np.newaxis,:,:]), dim=1).float()
                logits_weather_tr, logits_visibility_tr = net(camera_image_tr, lidar_image_tr)
            else:
                data_tr = torch.cat((camera_image_tr[:,np.newaxis,:,:], lidar_image_range_tr[:,np.newaxis,:,:], lidar_image_intensity_tr[:,np.newaxis,:,:]), dim=1).float()
                logits_weather_tr, logits_visibility_tr = net(data_tr)
            
            proba_weather_tr = torch.nn.functional.softmax(logits_weather_tr, dim=1)

            # forward
            loss_weather_tr = loss_function_weather(proba_weather_tr, weather_gt_tr).mean()
            loss_visibility_tr = loss_function_visibility(logits_visibility_tr, visibility_gt_tr).mean()
            loss_total_tr += (loss_weather_tr.item() + loss_visibility_tr.item())

            # backward
            if OPT_TECHNIQUE == "Multi_Adaptive":
                opt_weather.zero_grad()
                loss_weather_tr.backward(retain_graph=True)
                grads_weather_tr = [(param, param.grad.clone()) for param in net.parameters() if param.grad is not None]
                
                opt_visibility.zero_grad()
                loss_visibility_tr.backward()
                grads_visibility_tr = [(param, param.grad.clone()) for param in net.parameters() if param.grad is not None]

                for param, grad in grads_weather_tr:
                    param.grad = grad
                opt_weather.step()

                for param, grad in grads_visibility_tr:
                    param.grad = grad
                opt_visibility.step()

            else:
                loss_tr = (1.0 * loss_weather_tr) + (1.0 * loss_visibility_tr)
                # zero the gradients
                opt.zero_grad()
                # compute gradients
                loss_tr.backward()
                # adjust learning weights
                opt.step()
            
        toc = time()
        print(f'  Elapsed training time: {toc-tic}s')

        tic = time()
        net.eval()

        with torch.no_grad():
            for weather_gt_vl, visibility_gt_vl, camera_image_vl, lidar_image_range_vl, lidar_image_intensity_vl in vl:
                weather_gt_vl = weather_gt_vl.to(device)
                visibility_gt_vl = visibility_gt_vl.to(device)
                camera_image_vl = camera_image_vl.to(device)
                lidar_image_range_vl = lidar_image_range_vl.to(device)
                lidar_image_intensity_vl = lidar_image_intensity_vl.to(device)

                if NETWORK == "MobileNetV3_ViT" or NETWORK == "ResNet_ViT":
                    camera_image_vl = camera_image_vl[:,np.newaxis,:,:].float()
                    lidar_image_vl = torch.cat((lidar_image_range_vl[:,np.newaxis,:,:], lidar_image_intensity_vl[:,np.newaxis,:,:]), dim=1).float()
                    logits_weather_vl, logits_visibility_vl = net(camera_image_vl, lidar_image_vl)
                else:
                    data_vl = torch.cat((camera_image_vl[:,np.newaxis,:,:], lidar_image_range_vl[:,np.newaxis,:,:], lidar_image_intensity_vl[:,np.newaxis,:,:]), dim=1).float()
                    logits_weather_vl, logits_visibility_vl = net(data_vl)
                
                proba_weather_vl = torch.nn.functional.softmax(logits_weather_vl, dim=1)

                # forward
                loss_weather_vl = loss_function_weather(proba_weather_vl, weather_gt_vl).mean()
                loss_visibility_vl = loss_function_visibility(logits_visibility_vl, visibility_gt_vl).mean()
                loss_total_vl += (loss_weather_vl.item() + loss_visibility_vl.item())

        toc = time()

        loss_avg_tr.append(loss_total_tr / len(tr))
        loss_avg_vl.append(loss_total_vl / len(vl))

        print(f'  Elapsed validation time: {toc-tic}s')
        print(f'  Tr Loss: {loss_avg_tr[epoch]}, Vl Loss: {loss_avg_vl[epoch]}')

        # save model if validation loss has decreased
        if loss_avg_vl[epoch] <= vl_loss_min:
            print(f'  The best model was saved!')
            torch.save(net, path_best_model)
            vl_loss_min = loss_avg_vl[epoch]
            wait = 0
        # early stopping
        else:
            wait += 1
            if wait >= PATIENCE:
                print(f"Terminated training for early stopping at epoch {epoch+1}")
                break

    print(f'The last model was saved!')
    torch.save(net, path_last_model)

    # plot loss
    epochs_plot = range(1, (len(loss_avg_vl)+1))
    plt.plot(epochs_plot, loss_avg_tr)
    plt.plot(epochs_plot, loss_avg_vl)
    plt.xlabel("Epoch #")
    plt.ylabel("Loss")
    plt.xticks(epochs_plot)
    plt.legend(('Training loss', 'Validation loss'), loc='upper right')
    plt.savefig('MT_MM_W_MOR_Class_Train_Val_Loss_Network_{}_Optimization_{}_Seed_{}.pdf'.format(NETWORK, OPT_TECHNIQUE, SEED))
    plt.close()

## Test

In [None]:
# load model and analyze model memory footprint
a = torch.cuda.memory_allocated(device)
saved_model = torch.load(path_best_model, map_location=torch.device(device))
b = torch.cuda.memory_allocated(device)

print("Is the model on cuda: ", next(saved_model.parameters()).is_cuda)
model_memory = (b - a)/(1024**2)
print("Total memory of the model:", model_memory, "MB")

In [None]:
# analyze the mean inference time
camera_image_dummy = torch.randn(1, 1, 512, 960).to(device)
lidar_image_dummy = torch.randn(1, 2, 512, 960).to(device)
data_image_dummy = torch.cat((camera_image_dummy, lidar_image_dummy), dim=1)
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
repetitions = 10000
timings = np.zeros((repetitions, 1))

with torch.no_grad():
    # gpu-warm-up
    for _ in range(100):
        if NETWORK == "MobileNetV3_ViT" or NETWORK == "ResNet_ViT":
            _ = saved_model(camera_image_dummy, lidar_image_dummy)
        else:
            _ = saved_model(data_image_dummy)
    # measure performance
    for rep in range(repetitions):
        if NETWORK == "MobileNetV3_ViT" or NETWORK == "ResNet_ViT":
            starter.record()
            _ = saved_model(camera_image_dummy, lidar_image_dummy)
            ender.record()
        else:
            starter.record()
            _ = saved_model(data_image_dummy)
            ender.record()
        # wait for GPU sync
        torch.cuda.synchronize()
        curr_time = starter.elapsed_time(ender)
        timings[rep] = curr_time

mean_syn = np.sum(timings) / repetitions
std_syn = np.std(timings)
print(mean_syn, "ms")

In [None]:
print("[INFO] Testing the network...")
saved_model.eval() # set model to evaluation mode

metrics_weather = [torchmetrics.classification.MulticlassAccuracy(num_classes=N_CLASSES_WEATHER, average='weighted').to(device), torchmetrics.CohenKappa(task='multiclass', num_classes=N_CLASSES_WEATHER).to(device), torchmetrics.F1Score(task='multiclass', num_classes=N_CLASSES_WEATHER, average='weighted').to(device)]
cm_weather = torchmetrics.classification.MulticlassConfusionMatrix(num_classes=N_CLASSES_WEATHER, normalize='none').to(device)

metrics_visibility = [torchmetrics.classification.MulticlassAccuracy(num_classes=N_CLASSES_VISIBILITY, average='weighted').to(device), torchmetrics.CohenKappa(task='multiclass', num_classes=N_CLASSES_VISIBILITY).to(device), torchmetrics.F1Score(task='multiclass', num_classes=N_CLASSES_VISIBILITY, average='weighted').to(device)]
cm_visibility = torchmetrics.classification.MulticlassConfusionMatrix(num_classes=N_CLASSES_VISIBILITY, normalize='none').to(device)

tic = time()
with torch.no_grad(): # turn off gradient tracking
    for weather_gt_ts, visibility_gt_ts, camera_image_ts, lidar_image_range_ts, lidar_image_intensity_ts in ts:
        weather_gt_ts = weather_gt_ts.to(device)
        visibility_gt_ts = visibility_gt_ts.to(device)
        camera_image_ts = camera_image_ts.to(device)
        lidar_image_range_ts = lidar_image_range_ts.to(device)
        lidar_image_intensity_ts = lidar_image_intensity_ts.to(device)

        if NETWORK == "MobileNetV3_ViT" or NETWORK == "ResNet_ViT":
            camera_image_ts = camera_image_ts[:,np.newaxis,:,:].float()
            lidar_image_ts = torch.cat((lidar_image_range_ts[:,np.newaxis,:,:], lidar_image_intensity_ts[:,np.newaxis,:,:]), dim=1).float()
            logits_weather_ts, logits_visibility_ts = saved_model(camera_image_ts, lidar_image_ts)    
        else:
            data_ts = torch.cat((camera_image_ts[:,np.newaxis,:,:], lidar_image_range_ts[:,np.newaxis,:,:], lidar_image_intensity_ts[:,np.newaxis,:,:]), dim=1).float()
            logits_weather_ts, logits_visibility_ts = saved_model(data_ts)
   
        proba_weather_ts = torch.nn.functional.softmax(logits_weather_ts, dim=1)
        class_weather_ts = torch.argmax(proba_weather_ts, dim=1)

        class_visibility_ts = loss_function_visibility.to_classes(logits_visibility_ts)

        for metric in metrics_weather:
            metric.update(class_weather_ts, weather_gt_ts)
        cm_weather.update(class_weather_ts, weather_gt_ts)

        for metric in metrics_visibility:
            metric.update(class_visibility_ts, visibility_gt_ts)
        cm_visibility.update(class_visibility_ts, visibility_gt_ts)

toc = time()
print(f'Elapsed test time: {toc-tic}s')

fig, axs = plt.subplots(2)
fig.tight_layout(pad=2.0, w_pad=1.0, h_pad=5.0)
sns.heatmap(cm_weather.compute().cpu(), ax=axs[0], annot=True, fmt='g', annot_kws={"size": 12})
sns.heatmap(cm_visibility.compute().cpu(), ax=axs[1], annot=True, fmt='g', annot_kws={"size": 12})
axs[0].set_title('Weather Classification', fontsize=18)
axs[0].set_xlabel('Predicted Label', fontsize=16)
axs[0].set_ylabel('True Label', fontsize=16)
axs[0].xaxis.set_ticklabels(['Fog', 'Rain'], fontsize=12)
axs[0].yaxis.set_ticklabels(['Fog', 'Rain'], fontsize=12, rotation=90)

axs[1].set_title('MOR Classification', fontsize=18)
axs[1].set_xlabel('Predicted Label', fontsize=16)
axs[1].set_ylabel('True Label', fontsize=16)
axs[1].xaxis.set_ticklabels(['0-40', '40-200', '>200'], fontsize=12)
axs[1].yaxis.set_ticklabels(['0-40', '40-200', '>200'], fontsize=12, rotation=90)
plt.savefig('MT_MM_W_MOR_Class_Confusion_Matrix_Network_{}_Optimization_{}_Seed_{}.pdf'.format(NETWORK, OPT_TECHNIQUE, SEED))
plt.close()

# save metrics
with open('MT_MM_W_MOR_Class_Metrics_Network_{}_Optimization_{}_Seed_{}.csv'.format(NETWORK, OPT_TECHNIQUE, SEED),'w') as f:
    writer = csv.writer(f, dialect='excel')
    writer.writerow(["Metric", "Value"])
    for metric in metrics_weather:
        writer.writerow(["Weather" + str(metric.__class__.__name__), metric.compute().item()])

    for metric in metrics_visibility:
        writer.writerow(["Visibility" + str(metric.__class__.__name__), metric.compute().item()])

    writer.writerow(["ModelMemory", model_memory])
    writer.writerow(["MeanInferenceTime", mean_syn])