In [99]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

from PIL import Image
from glob import glob

import os
import random
import numpy as np
import matplotlib.pyplot as plt
import time

from res.plot_lib import plot_data, plot_data_np, plot_model, set_default
set_default()

In [100]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

bs_train, bs_test = 8, 1
epochs = 10
lr = 0.0001

cpu


In [101]:
def zoom(image, zoom_factor):
    width, height = image.size
    new_width = int(width / zoom_factor)
    new_height = int(height / zoom_factor)


    left = (width - new_width) // 2
    top = (height - new_height) // 2
    right = left + new_width
    bottom = top + new_height

    image = image.crop((left, top, right, bottom))
    image = image.resize((width, height), Image.Resampling.LANCZOS)
        
    return image

In [102]:
class DLFDatasetCam(Dataset): # Data-Level Fusion Dataset
    def __init__(self, camera_paths, lidar_paths, depth_paths, image_size : tuple, train = True):
        self._camera_paths = camera_paths
        self._lidar_paths = lidar_paths
        self._depth_paths = depth_paths
        self._image_size = image_size

    def transform(self, camera, lidar, depth):
        # 3 CHANNELS CONVERSION
        rgb = transforms.Lambda(lambda img: img.convert("RGB"))
        camera = rgb(camera)
        lidar = rgb(lidar)

        # 1 CHANNEL CONVERSION
        gray = transforms.Lambda(lambda img: img.convert("L"))
        depth = gray(depth)

        # RESIZE
        resize = transforms.Resize(self._image_size)
        camera = resize(camera)
        lidar = resize(lidar)
        depth = resize(depth)

        # ZOOM
        #zoom_factor = calibration()
        camera = zoom(camera, 1.8)
        lidar = zoom(lidar, 0.7)

        # CONVERT TO TENSOR
        camera_tensor = TF.to_tensor(camera)
        lidar_tensor = TF.to_tensor(lidar)
        groundtruth = TF.to_tensor(depth)
        
        # FUSION
        alpha = 0.6
        image = camera_tensor

        return image, groundtruth

    def __getitem__(self, index):
        camera = Image.open(self._camera_paths[index])
        lidar = Image.open(self._lidar_paths[index])
        depth = Image.open(self._depth_paths[index])
        x, y = self.transform(camera, lidar, depth)
        return x, y

    def __len__(self):
        return len(self._camera_paths)

In [103]:
class DLFDatasetLid(Dataset):
    def __init__(self, camera_paths, lidar_paths, depth_paths, image_size : tuple, train = True):
        self._camera_paths = camera_paths
        self._lidar_paths = lidar_paths
        self._depth_paths = depth_paths
        self._image_size = image_size

    def transform(self, camera, lidar, depth):
        # 3 CHANNELS CONVERSION
        rgb = transforms.Lambda(lambda img: img.convert("RGB"))
        camera = rgb(camera)
        lidar = rgb(lidar)

        # 1 CHANNEL CONVERSION
        gray = transforms.Lambda(lambda img: img.convert("L"))
        depth = gray(depth)

        # RESIZE
        resize = transforms.Resize(self._image_size)
        camera = resize(camera)
        lidar = resize(lidar)
        depth = resize(depth)

        # ZOOM
        #zoom_factor = calibration()
        camera = zoom(camera, 1.8)
        lidar = zoom(lidar, 0.7)

        # CONVERT TO TENSOR
        camera_tensor = TF.to_tensor(camera)
        lidar_tensor = TF.to_tensor(lidar)
        groundtruth = TF.to_tensor(depth)
        
        # FUSION
        alpha = 0.6
        image = lidar_tensor

        return image, groundtruth

    def __getitem__(self, index):
        camera = Image.open(self._camera_paths[index])
        lidar = Image.open(self._lidar_paths[index])
        depth = Image.open(self._depth_paths[index])
        x, y = self.transform(camera, lidar, depth)
        return x, y

    def __len__(self):
        return len(self._camera_paths)

In [104]:
def conv_layer(input_channels, output_channels):
    conv = nn.Sequential(
        nn.Conv2d(input_channels, output_channels, kernel_size=3, padding=1),
        nn.ReLU(),
        nn.Conv2d(output_channels, output_channels, kernel_size=3, padding=1),
        nn.BatchNorm2d(output_channels),
        nn.ReLU()
    )
    return conv

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        
        self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
        # input dim 256x256
        self.down_1 = conv_layer(3, 64) #128x128
        self.down_2 = conv_layer(64, 128) #64x64
        self.down_3 = conv_layer(128, 256) #32x32
        self.down_4 = conv_layer(256, 512) #16x16
        self.down_5 = conv_layer(512, 1024) #8x8
        
        self.up_1 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=2, stride=2)
        self.up_conv_1 = conv_layer(1024, 512)
        self.up_2 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=2, stride=2)
        self.up_conv_2 = conv_layer(512, 256)
        self.up_3 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2)
        self.up_conv_3 = conv_layer(256, 128)
        self.up_4 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2)
        self.up_conv_4 = conv_layer(128, 64)
        
        self.output = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1, padding=0)
        self.output_activation = nn.Sigmoid()
                
    def forward(self, img):
        x1 = self.down_1(img)
        x2 = self.max_pool(x1)
        x3 = self.down_2(x2)
        x4 = self.max_pool(x3)
        x5 = self.down_3(x4)
        x6 = self.max_pool(x5)
        x7 = self.down_4(x6)
        x8 = self.max_pool(x7)
        x9 = self.down_5(x8)
        
        x = self.up_1(x9)
        x = self.up_conv_1(torch.cat([x, x7], 1))
        x = self.up_2(x)
        x = self.up_conv_2(torch.cat([x, x5], 1)) 
        x = self.up_3(x)
        x = self.up_conv_3(torch.cat([x, x3], 1))
        x = self.up_4(x)
        x = self.up_conv_4(torch.cat([x, x1], 1))
        
        x = self.output(x)
        x = self.output_activation(x)
        
        return x

In [105]:
camera_files = []
lidar_files = []
depth_files = []

camera_dir = glob('dataset/camera/*')
lidar_dir = glob('dataset/lidar/*')
depth_dir = glob('dataset/depth/*')

for c in camera_dir:
    camera_files.append(c)
for l in lidar_dir:
    lidar_files.append(l)
for d in depth_dir:
    depth_files.append(d)

ds_split = len(camera_files)//2

camera_pred_test, lidar_pred_test, depth_pred_test = camera_files[:ds_split], lidar_files[:ds_split], depth_files[:ds_split]

In [106]:
#Datasets
test_camera_dataset = DLFDatasetCam(camera_pred_test, lidar_pred_test, depth_pred_test, (256, 256))
test_lidar_dataset = DLFDatasetLid(camera_pred_test, lidar_pred_test, depth_pred_test, (256, 256))

#Dataloaders
test_camera_dataloader = DataLoader(test_camera_dataset, batch_size = bs_test, shuffle = False)
test_lidar_dataloader = DataLoader(test_lidar_dataset, batch_size = bs_test, shuffle = False)

In [107]:
model = UNet().to(device)
optimizer = optim.Adam(model.parameters(), lr = lr, betas = (0.9,0.999))

In [None]:
import tqdm as tqdm

PATH = './old_models/camera_only.pth'
model.load_state_dict(torch.load(PATH))
model.eval()

for i, (image, target) in enumerate(test_camera_dataloader):
    image = image.to(device)
    target = target.to(device)
    target = target[0]
    pred = model(image)

    target = (target.cpu().numpy()).transpose(1,2,0).squeeze(axis=2)
    pred = np.squeeze(pred.cpu().detach().numpy())

    plt.imsave(f'dataset_ensemble/depth_pred/depth_{i}.png', target, cmap='gray')
    plt.imsave(f'dataset_ensemble/camera_pred/cam_{i}.png', pred, cmap='gray')

PATH = './old_models/lidar_only.pth'
model.load_state_dict(torch.load(PATH))
model.eval()

for i, (image, target) in enumerate(test_lidar_dataloader):
    image = image.to(device)
    target = target.to(device)
    target = target[0]
    pred = model(image)

    pred = np.squeeze(pred.cpu().detach().numpy())

    plt.imsave(f'dataset_ensemble/lidar_pred/lid_{i}.png', pred, cmap='gray')


  model.load_state_dict(torch.load(PATH))
  model.load_state_dict(torch.load(PATH))
