In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

from PIL import Image
from glob import glob

import os
import random
import numpy as np
import matplotlib.pyplot as plt
import time

from res.plot_lib import plot_data, plot_data_np, plot_model, set_default
set_default()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

bs_train, bs_test = 8, 1
epochs = 10
lr = 0.0001

In [None]:
camera_files = []
lidar_files = []
depth_files = []

camera_dir = glob('dataset/camera/*')
lidar_dir = glob('dataset/lidar/*')
depth_dir = glob('dataset/depth/*')

for c in camera_dir:
    camera_files.append(c)
for l in lidar_dir:
    lidar_files.append(l)
for d in depth_dir:
    depth_files.append(d)

In [None]:
def pil_loader(path):
    with open(path, 'rb') as f:
        img = Image.open(f)
        if np.array(img).ndim != 3:
            return img.convert('L')
        else:
            return img.convert('RGB')

In [None]:
rows, cols = 5, 3
fig = plt.figure(figsize = (10, 10))
for i in range(0, int(rows * cols / 3)):
        r = random.randrange(0, 1000)
        fig.add_subplot(rows, cols, i * 3 + 1)
        camera_path = camera_files[r]
        camera = pil_loader(camera_path)
        plt.imshow(camera)
        plt.axis('off')
        fig.add_subplot(rows, cols, i * 3 + 2)
        lidar_path = lidar_files[r]
        lidar = pil_loader(lidar_path)
        plt.imshow(lidar)
        plt.axis('off')
        fig.add_subplot(rows, cols, i * 3 + 3)
        depth_path = depth_files[r]
        depth = pil_loader(depth_path)
        plt.imshow(depth)
        plt.axis('off')
plt.show()

In [None]:
ds_split = len(camera_files)//2

camera_train, lidar_train, depth_train = camera_files[ds_split:], lidar_files[ds_split:], depth_files[ds_split:]
camera_test, lidar_test, depth_test = camera_files[:ds_split], lidar_files[:ds_split], depth_files[:ds_split]

print(len(camera_train), len(camera_test))

In [None]:
def zoom(image, zoom_factor):
    width, height = image.size
    new_width = int(width / zoom_factor)
    new_height = int(height / zoom_factor)


    left = (width - new_width) // 2
    top = (height - new_height) // 2
    right = left + new_width
    bottom = top + new_height

    image = image.crop((left, top, right, bottom))
    image = image.resize((width, height), Image.Resampling.LANCZOS)
        
    return image

In [None]:
class DLFDataset(Dataset): # Data-Level Fusion Dataset
    def __init__(self, camera_paths, lidar_paths, depth_paths, image_size : tuple, train = True):
        self._camera_paths = camera_paths
        self._lidar_paths = lidar_paths
        self._depth_paths = depth_paths
        self._image_size = image_size

    def transform(self, camera, lidar, depth):
        # 3 CHANNELS CONVERSION
        rgb = transforms.Lambda(lambda img: img.convert("RGB"))
        camera = rgb(camera)
        lidar = rgb(lidar)

        # 1 CHANNEL CONVERSION
        gray = transforms.Lambda(lambda img: img.convert("L"))
        depth = gray(depth)

        # RESIZE
        resize = transforms.Resize(self._image_size)
        camera = resize(camera)
        lidar = resize(lidar)
        depth = resize(depth)

        # ZOOM
        #zoom_factor = calibration()
        camera = zoom(camera, 1.8)
        lidar = zoom(lidar, 0.7)

        # HORIZONTAL FLIP
        if random.random() > 0.5:
            camera = TF.hflip(camera)
            lidar = TF.hflip(lidar)
            depth = TF.hflip(depth)

        # CONVERT TO TENSOR
        camera_tensor = TF.to_tensor(camera)
        lidar_tensor = TF.to_tensor(lidar)
        groundtruth = TF.to_tensor(depth)
        
        # FUSION
        alpha = 0.6
        image = camera_tensor * alpha + lidar_tensor * (1 - alpha)

        return image, groundtruth

    def __getitem__(self, index):
        camera = Image.open(self._camera_paths[index])
        lidar = Image.open(self._lidar_paths[index])
        depth = Image.open(self._depth_paths[index])
        x, y = self.transform(camera, lidar, depth)
        return x, y

    def __len__(self):
        return len(self._camera_paths)

In [None]:
def dc_loss(pred, target): # Dice loss function (NOT USED)
    smooth = 1.

    predf = pred.view(-1)
    targetf = target.view(-1)
    intersection = (predf * targetf).sum()
    
    return 1 - ((2. * intersection + smooth) /
              (predf.sum() + targetf.sum() + smooth))

In [None]:
def conv_layer(input_channels, output_channels):
    conv = nn.Sequential(
        nn.Conv2d(input_channels, output_channels, kernel_size=3, padding=1),
        nn.ReLU(),
        nn.Conv2d(output_channels, output_channels, kernel_size=3, padding=1),
        nn.BatchNorm2d(output_channels),
        nn.ReLU()
    )
    return conv

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        
        self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
        # input dim 256x256
        self.down_1 = conv_layer(3, 64) #128x128
        self.down_2 = conv_layer(64, 128) #64x64
        self.down_3 = conv_layer(128, 256) #32x32
        self.down_4 = conv_layer(256, 512) #16x16
        self.down_5 = conv_layer(512, 1024) #8x8
        
        self.up_1 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=2, stride=2)
        self.up_conv_1 = conv_layer(1024, 512)
        self.up_2 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=2, stride=2)
        self.up_conv_2 = conv_layer(512, 256)
        self.up_3 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2)
        self.up_conv_3 = conv_layer(256, 128)
        self.up_4 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2)
        self.up_conv_4 = conv_layer(128, 64)
        
        self.output = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1, padding=0)
        self.output_activation = nn.Sigmoid()
                
    def forward(self, img):
        x1 = self.down_1(img)
        x2 = self.max_pool(x1)
        x3 = self.down_2(x2)
        x4 = self.max_pool(x3)
        x5 = self.down_3(x4)
        x6 = self.max_pool(x5)
        x7 = self.down_4(x6)
        x8 = self.max_pool(x7)
        x9 = self.down_5(x8)
        
        x = self.up_1(x9)
        x = self.up_conv_1(torch.cat([x, x7], 1))
        x = self.up_2(x)
        x = self.up_conv_2(torch.cat([x, x5], 1)) 
        x = self.up_3(x)
        x = self.up_conv_3(torch.cat([x, x3], 1))
        x = self.up_4(x)
        x = self.up_conv_4(torch.cat([x, x1], 1))
        
        x = self.output(x)
        x = self.output_activation(x)
        
        return x
        

In [None]:
#Datasets
train_dataset = DLFDataset(camera_train, lidar_train, depth_train, (256, 256))
test_dataset = DLFDataset(camera_test, lidar_test, depth_test, (256, 256))

#Dataloaders
train_dataloader = DataLoader(train_dataset, batch_size = bs_train, shuffle = True)
test_dataloader = DataLoader(test_dataset, batch_size = bs_test, shuffle = True)

In [None]:
(di, dm) = next(iter(test_dataloader))
di = di.to(device)
dm = dm.to(device)
dm = dm[0]

plt.figure(figsize=(18,18))
plt.subplot(1,2,1)
plt.imshow(np.squeeze(di.cpu().numpy()).transpose(1,2,0))
plt.title('Original Image')
plt.subplot(1,2,2)
plt.imshow((dm.cpu().numpy()).transpose(1,2,0).squeeze(axis=2))
plt.title('Original Mask')
plt.show()

In [None]:
#Initialize the model and optimizer
model = UNet().to(device)
optimizer = optim.Adam(model.parameters(), lr = lr, betas = (0.9,0.999))

In [None]:
#define path where to save the model
PATH = './models/unet_test.pth'

In [None]:
def displayPrediction(model, img_pair):
    (img, mask) = img_pair
    img = img.to(device)
    mask = mask.to(device)
    mask = mask[0]
    pred = model(img)

    plt.figure(figsize=(12,12))
    plt.subplot(1,3,1)
    plt.imshow(np.squeeze(img.cpu().numpy()).transpose(1,2,0))
    plt.title('Original Image')
    plt.subplot(1,3,2)
    plt.imshow((mask.cpu().numpy()).transpose(1,2,0).squeeze(axis=2))
    plt.title('Original Mask')
    plt.subplot(1,3,3)
    plt.imshow(np.squeeze(pred.cpu()))
    plt.title('Prediction')
    plt.show()

In [None]:
from tqdm import tqdm

def train(model, epochs):
    
    avg_train_losses = []
    avg_test_losses = []
    
    for epoch in range(epochs):
        
        train_losses = []
        test_losses = []

        model.train()

        loop = tqdm(enumerate(train_dataloader), total = len(train_dataloader), leave = False)
        for batch, (images, targets) in loop:
            images = images.to(device)
            targets = targets.to(device)
            
            model.zero_grad()
            pred = model(images)
            loss = nn.MSELoss()
            l = loss(pred, targets)
            l.backward()
            optimizer.step()
            
            train_losses.append(l.item())
            
                        
            with torch.no_grad():
                if batch % 5 == 0:
                    torch.save(model.state_dict(), PATH)
                    model.eval()
                    displayPrediction(model, next(iter(test_dataloader)))
                    model.train()
                    print(l.item())
                
        model.eval()
        
        with torch.no_grad():
            for test_batch, (test_images, test_targets) in enumerate(test_dataloader):
                test_images = test_images.to(device)
                test_targets = test_targets.to(device)
                test_pred = model(test_images.detach())

                test_loss = nn.MSELoss()

                test_losses.append(test_loss)

            epoch_avg_train_loss = np.mean(train_losses)
            epoch_avg_test_loss = np.mean(test_losses)
            avg_train_losses.append(epoch_avg_train_loss)
            avg_test_losses.append(epoch_avg_test_loss)

            print_msg = (f'train_loss: {epoch_avg_train_loss:.5f} ' + f'valid_loss: {epoch_avg_test_loss:.5f}')
            print(print_msg)

    return  model, avg_train_losses, avg_test_losses

In [None]:
#Train the model 
best_model, avg_train_losses, avg_val_losses = train(model, epochs)

In [None]:
PATH = './old_models/data_level_fusion.pth'
model.load_state_dict(torch.load(PATH))

In [None]:
model.eval()

(img, mask) = next(iter(test_dataloader))
img = img.to(device)
mask = mask.to(device)
mask = mask[0]
pred = model(img)
plt.figure(figsize=(12,12))
plt.subplot(1,2,1)
plt.imshow((mask.cpu().numpy()).transpose(1,2,0).squeeze(axis=2))
plt.axis('off')
plt.subplot(1,2,2)
plt.imshow(np.squeeze(pred.cpu().detach().numpy()) )
plt.axis('off')
plt.show()