# AI Intergration Project - Monodepth2
In this project, we try to reimplement a Depth Estimation model from paper [Digging Into Self-Supervised Monocular Depth Estimation](https://arxiv.org/pdf/1806.01260.pdf) . In Supplementary Material, the authors show us many versions and varients of Monodepth2. From the result table, we decided to reimplement non-pretrained Resnet18-Encoder Monodepth2 with Auto-masking, min-reprojection and full-res multi-scale on Kitti and NYUv2 dataset (Self-supervised mono supervision). We didn't choose Resnet50 (better results) for our model because of expense of longer training and test times .



Group's members: \\
Nguyễn Đăng Hoài Nam : \\
- student's ID: 2152181 \\
- Mail: nam.nguyencshcmut@hcmut.edu.vn \\

Lê Trần Nguyên Khoa : \\
- student's ID: 2152674 \\
- Mail: khoa.lesteve@hcmut.edu.vn \\

Đinh Việt Thành : \\
- student's ID: 2152966 \\
- Mail: thanh.dinhalex19cs@hcmut.edu.vn \\

# Mount Drive
Remember to upload our folder to your Google Drive.

In [41]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Import Framework & Libs

In [42]:
from __future__ import absolute_import, division, print_function
from torchvision.transforms.functional import to_pil_image
from torch.utils.data import DataLoader
from collections import OrderedDict
from torchvision import transforms
from collections import Counter
from PIL import Image
%matplotlib inline


import torch.utils.model_zoo as model_zoo
import torchvision.models as models
import torch.utils.data as data
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torch.optim as optim
import PIL.Image as pil
import torch.nn as nn
import numpy as np
import skimage
import random
import torch
import copy
import time
import cv2
import os


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Helper Functions

In [43]:
def readlines(filename):
    """Read all the lines in a text file and return as a list"""
    with open(filename, 'r') as f:
        lines = f.read().splitlines()
    return lines

def pil_loader(path):
    """Open image"""
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            return img.convert('RGB')

def disp_to_depth(disp, min_depth, max_depth):
    """Convert network's sigmoid output into depth prediction
    The formula for this conversion is given in the 'additional considerations'
    section of the paper.
    """
    min_disp = 1 / max_depth
    max_disp = 1 / min_depth
    scaled_disp = min_disp + (max_disp - min_disp) * disp
    depth = 1 / scaled_disp
    return scaled_disp, depth

def transformation_from_parameters(axisangle, translation, invert=False):
    """Convert the network's (axisangle, translation) output into a 4x4 matrix
    """
    R = rot_from_axisangle(axisangle)
    t = translation.clone()

    if invert:
        R = R.transpose(1, 2)
        t *= -1

    T = get_translation_matrix(t)

    if invert:
        M = torch.matmul(R, T)
    else:
        M = torch.matmul(T, R)

    return M

def get_translation_matrix(translation_vector):
    """Convert a translation vector into a 4x4 transformation matrix
    """
    T = torch.zeros(translation_vector.shape[0], 4, 4).to(device=translation_vector.device)

    t = translation_vector.contiguous().view(-1, 3, 1)

    T[:, 0, 0] = 1
    T[:, 1, 1] = 1
    T[:, 2, 2] = 1
    T[:, 3, 3] = 1
    T[:, :3, 3, None] = t

    return T

def rot_from_axisangle(vec):
    """Convert an axisangle rotation into a 4x4 transformation matrix
    Input 'vec' has to be Bx1x3
    """
    angle = torch.norm(vec, 2, 2, True)
    axis = vec / (angle + 1e-7)

    ca = torch.cos(angle)
    sa = torch.sin(angle)
    C = 1 - ca

    x = axis[..., 0].unsqueeze(1)
    y = axis[..., 1].unsqueeze(1)
    z = axis[..., 2].unsqueeze(1)

    xs = x * sa
    ys = y * sa
    zs = z * sa
    xC = x * C
    yC = y * C
    zC = z * C
    xyC = x * yC
    yzC = y * zC
    zxC = z * xC

    rot = torch.zeros((vec.shape[0], 4, 4)).to(device=vec.device)

    rot[:, 0, 0] = torch.squeeze(x * xC + ca)
    rot[:, 0, 1] = torch.squeeze(xyC - zs)
    rot[:, 0, 2] = torch.squeeze(zxC + ys)
    rot[:, 1, 0] = torch.squeeze(xyC + zs)
    rot[:, 1, 1] = torch.squeeze(y * yC + ca)
    rot[:, 1, 2] = torch.squeeze(yzC - xs)
    rot[:, 2, 0] = torch.squeeze(zxC - ys)
    rot[:, 2, 1] = torch.squeeze(yzC + xs)
    rot[:, 2, 2] = torch.squeeze(z * zC + ca)
    rot[:, 3, 3] = 1

    return rot

def upsample(x):
    """Upsample input tensor by a factor of 2
    """
    return F.interpolate(x, scale_factor=2, mode="nearest")

def get_smooth_loss(disp, img):
    """Computes the smoothness loss for a disparity image
    The color image is used for edge-aware smoothness
    """
    grad_disp_x = torch.abs(disp[:, :, :, :-1] - disp[:, :, :, 1:])
    grad_disp_y = torch.abs(disp[:, :, :-1, :] - disp[:, :, 1:, :])

    grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), 1, keepdim=True)
    grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), 1, keepdim=True)

    grad_disp_x *= torch.exp(-grad_img_x)
    grad_disp_y *= torch.exp(-grad_img_y)

    return grad_disp_x.mean() + grad_disp_y.mean()

def compute_depth_errors(gt, pred):
    """Computation of error metrics between predicted and ground truth depths
    """
    thresh = torch.max((gt / pred), (pred / gt))
    a1 = (thresh < 1.25     ).float().mean()
    a2 = (thresh < 1.25 ** 2).float().mean()
    a3 = (thresh < 1.25 ** 3).float().mean()

    rmse = (gt - pred) ** 2
    rmse = torch.sqrt(rmse.mean())

    rmse_log = (torch.log(gt) - torch.log(pred)) ** 2
    rmse_log = torch.sqrt(rmse_log.mean())

    abs_rel = torch.mean(torch.abs(gt - pred) / gt)

    sq_rel = torch.mean((gt - pred) ** 2 / gt)

    return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3


def sub2ind(matrixSize, rowSub, colSub):
    """Convert row, col matrix subscripts to linear indices
    """
    m, n = matrixSize
    return rowSub * (n-1) + colSub - 1


def show_images(outputs):
    disp = outputs[("disp", 0)]
    disp_resized = torch.nn.functional.interpolate(disp,
    (512, 1024), mode="bilinear", align_corners=False)

    # Saving colormapped depth image
    disp_resized_np = disp_resized.squeeze().detach().cpu().numpy()
    vmax = np.percentile(disp_resized_np, 95)

    plt.figure(figsize=(10, 10))
    plt.imshow(disp_resized_np, cmap='magma', vmax=vmax)
    plt.title("Disparity prediction", fontsize=22)
    plt.axis('off');
    plt.show()

def show_rgb(inputs):
    inputs = inputs.cpu().squeeze().permute(1, 2, 0).numpy()
    plt.figure(figsize=(10, 10))
    plt.imshow(inputs)
    plt.title("Input", fontsize=22)
    plt.axis('off')
    plt.show()

def print_img(dir1, dir2):
    model_dir = dir1
    model = Monodepth()
    model.load_checkpoint(model_dir)

    image_path = dir2
    input_image = pil.open(image_path).convert('RGB')
    original_width, original_height = input_image.size

    input_image_resized = input_image.resize((model.feed_width, model.feed_height), pil.LANCZOS)
    input_image_pytorch = transforms.ToTensor()(input_image_resized).unsqueeze(0)

    outputs = model.forward(input_image_pytorch)
    disp = outputs[("disp", 0)]

    disp_resized = torch.nn.functional.interpolate(disp,
        (original_height, original_width), mode="bilinear", align_corners=False)

    # Saving colormapped depth image
    disp_resized_np = disp_resized.squeeze().cpu().numpy()
    vmax = np.percentile(disp_resized_np, 95)

    plt.figure(figsize=(10, 10))
    plt.subplot(211)
    plt.imshow(input_image)
    plt.title("Input", fontsize=22)
    plt.axis('off')

    plt.subplot(212)
    plt.imshow(disp_resized_np, cmap='magma', vmax=vmax)
    plt.title("Disparity prediction", fontsize=22)
    plt.axis('off');

# Datasets

Monodepth dataset

In [44]:
class MonoDataset(data.Dataset):
    def __init__(self, data_path, filenames, height, width, frame_idxs, num_scales, is_train=False, img_ext='.jpg'):
        super(MonoDataset, self).__init__()

        self.data_path = data_path
        self.filenames = filenames
        self.height = height
        self.width = width
        self.num_scales = num_scales
        self.interp = Image.Resampling.LANCZOS
        self.frame_idxs = frame_idxs
        self.is_train = is_train
        self.img_ext = img_ext
        self.loader = pil_loader
        self.to_tensor = transforms.ToTensor()

        try:
            self.brightness = (0.8, 1.2)
            self.contrast = (0.8, 1.2)
            self.saturation = (0.8, 1.2)
            self.hue = (-0.1, 0.1)
            transforms.ColorJitter.get_params(self.brightness, self.contrast, self.saturation, self.hue)
        except TypeError:
            self.brightness = 0.2
            self.contrast = 0.2
            self.saturation = 0.2
            self.hue = 0.1

        self.resize = {}
        for i in range(self.num_scales):
            s = 2 ** i
            self.resize[i] = transforms.Resize((self.height // s, self.width // s), interpolation=self.interp)

        self.load_depth = self.check_depth()

    def preprocess(self, inputs, color_aug, flag=False):
        """Resize colour images to the required scales and augment if required

        We create the color_aug object in advance and apply the same augmentation to all
        images in this item. This ensures that all images input to the pose network receive the
        same augmentation.
        """
        for k in list(inputs):
            frame = inputs[k]
            if "color" in k:
                n, im, i = k
                for i in range(self.num_scales):
                    inputs[(n, im, i)] = self.resize[i](inputs[(n, im, i - 1)])

        for k in list(inputs):
            f = inputs[k]
            if "color" in k:
                n, im, i = k
                inputs[(n, im, i)] = self.to_tensor(f)
                if flag:
                    inputs[(n + "_aug", im, i)] = color_aug(to_pil_image(f))
                else:
                    inputs[(n + "_aug", im, i)] = self.to_tensor(color_aug(f))

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, index):
        """Returns a single training item from the dataset as a dictionary.

        Values correspond to torch tensors.
        Keys in the dictionary are either strings or tuples:

            ("color", <frame_id>, <scale>)          for raw colour images,
            ("color_aug", <frame_id>, <scale>)      for augmented colour images,
            ("K", scale) or ("inv_K", scale)        for camera intrinsics,
            "depth_gt"                              for ground truth depth maps.

        <frame_id> is either:
            an integer (e.g. 0, -1, or 1) representing the temporal step relative to 'index'

        <scale> is an integer representing the scale of the image relative to the fullsize image:
            -1      images at native resolution as loaded from disk
            0       images resized to (self.width,      self.height     )
            1       images resized to (self.width // 2, self.height // 2)
            2       images resized to (self.width // 4, self.height // 4)
            3       images resized to (self.width // 8, self.height // 8)
        """
        inputs = {}

        # do_color_aug = self.is_train and random.random() > 0.5
        do_color_aug = False
        # do_flip = self.is_train and random.random() > 0.5
        do_flip = False

        line = self.filenames[index].split()
        folder = line[0]

        if len(line) == 3:
            frame_index = int(line[1])
        else:
            frame_index = 0

        if len(line) == 3:
            side = line[2]
        else:
            side = None
            print("Side=None at: ", index)

        for i in self.frame_idxs:
            inputs[("color", i, -1)] = self.get_color(folder, frame_index + i, side, do_flip)

        # adjusting intrinsics to match each scale in the pyramid
        for scale in range(self.num_scales):
            K = self.K.copy()

            K[0, :] *= self.width // (2 ** scale)
            K[1, :] *= self.height // (2 ** scale)

            inv_K = np.linalg.pinv(K)

            inputs[("K", scale)] = torch.from_numpy(K)
            inputs[("inv_K", scale)] = torch.from_numpy(inv_K)

        if do_color_aug:
            color_aug = transforms.ColorJitter.get_params(
                self.brightness, self.contrast, self.saturation, self.hue)
        else:
            color_aug = (lambda x: x)

        self.preprocess(inputs, color_aug)

        for i in self.frame_idxs:
            del inputs[("color", i, -1)]
            del inputs[("color_aug", i, -1)]

        if self.load_depth:
            depth_gt = self.get_depth(folder, frame_index, side, do_flip)
            inputs["depth_gt"] = np.expand_dims(depth_gt, 0)
            inputs["depth_gt"] = torch.from_numpy(inputs["depth_gt"].astype(np.float32))

        return inputs

    def get_color(self, folder, frame_index, side, do_flip):
        raise NotImplementedError

    def check_depth(self):
        raise NotImplementedError

    def get_depth(self, folder, frame_index, side, do_flip):
        raise NotImplementedError

NYUv2 Dataset


In [45]:
class NYUv2Dataset(MonoDataset):
    """Superclass for different types of NYUv2 dataset loaders
    """
    def __init__(self, *args, **kwargs):
        super(NYUv2Dataset, self).__init__(*args, **kwargs)

        # NOTE: Make sure your intrinsics matrix is *normalized* by the original image size.
        # To normalize you need to scale the first row by 1 / image_width and the second row
        # by 1 / image_height. Monodepth2 assumes a principal point to be exactly centered.
        # If your principal point is far from the center you might need to disable the horizontal
        # flip augmentation.
        self.K = np.array([[0.81, 0, 0.51, 0],
                           [0, 2.71, 1.32, 0],
                           [0, 0, 1, 0],
                           [0, 0, 0, 1]], dtype=np.float32)

        self.full_res_shape = (1242, 375)

    def check_depth(self):
        line = self.filenames[0].split()
        scene_name = line[0]
        frame_index = int(line[1])

        depth_filename = os.path.join(
            self.data_path,
            scene_name,
            "depth/{:010d}.png".format(int(frame_index)))

        return os.path.isfile(depth_filename)

    def get_color(self, folder, frame_index, side, do_flip):
        color = self.loader(self.get_image_path(folder, frame_index, side))

        if do_flip:
            color = color.transpose(pil.FLIP_LEFT_RIGHT)

        return color


class NYUv2RAWDataset(NYUv2Dataset):
    """KITTI dataset which loads the original depth maps for ground truth
    """
    def __init__(self, *args, **kwargs):
        super(NYUv2RAWDataset, self).__init__(*args, **kwargs)

    def get_image_path(self, folder, frame_index, side):
        if frame_index<0:
            frame_index = 0
        if side==None:
            side = "l"
        f_str = "{:010d}{}".format(frame_index, self.img_ext)
        if os.path.isfile(os.path.join(self.data_path, folder, "image", f_str)):
            image_path = os.path.join(self.data_path, folder, "image", f_str)
        else:
            frame_index-=1
            f_str = "{:010d}{}".format(frame_index, self.img_ext)
            image_path = os.path.join(self.data_path, folder, "image", f_str)
        return image_path

    def get_depth(self, folder, frame_index, side, do_flip):
        if frame_index<0:
            frame_index = 0
        depth_filename = os.path.join(self.data_path, folder,
            "depth/{:010d}.png".format(int(frame_index)))

        if os.path.isfile(depth_filename)==False:
            frame_index-=1
            depth_filename = os.path.join(self.data_path, folder,
                "depth/{:010d}.png".format(int(frame_index)))

        if side==None:
            side = "l"
        depth_gt = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED)
        depth_gt = skimage.transform.resize(
            depth_gt, self.full_res_shape[::-1], order=0, preserve_range=True, mode='constant')

        if do_flip:
            depth_gt = np.fliplr(depth_gt)

        return depth_gt

# Building Model

Building layers

In [46]:
class ConvBlock(nn.Module):
    """Layer to perform a convolution followed by ELU
    """
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()

        self.conv = Conv3x3(in_channels, out_channels)
        self.nonlin = nn.ELU(inplace=True)

    def forward(self, x):
        out = self.conv(x)
        out = self.nonlin(out)
        return out


class Conv3x3(nn.Module):
    """Layer to pad and convolve input
    """
    def __init__(self, in_channels, out_channels, use_refl=True):
        super(Conv3x3, self).__init__()

        if use_refl:
            self.pad = nn.ReflectionPad2d(1)
        else:
            self.pad = nn.ZeroPad2d(1)
        self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3)

    def forward(self, x):
        out = self.pad(x)
        out = self.conv(out)
        return out


class BackprojectDepth(nn.Module):
    """Layer to transform a depth image into a point cloud
    """
    def __init__(self, batch_size, height, width):
        super(BackprojectDepth, self).__init__()

        self.batch_size = batch_size
        self.height = height
        self.width = width

        meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy')
        self.id_coords = np.stack(meshgrid, axis=0).astype(np.float32)
        self.id_coords = nn.Parameter(torch.from_numpy(self.id_coords),
                                      requires_grad=False)

        self.ones = nn.Parameter(torch.ones(self.batch_size, 1, self.height * self.width),
                                 requires_grad=False)

        self.pix_coords = torch.unsqueeze(torch.stack(
            [self.id_coords[0].view(-1), self.id_coords[1].view(-1)], 0), 0)
        self.pix_coords = self.pix_coords.repeat(batch_size, 1, 1)
        self.pix_coords = nn.Parameter(torch.cat([self.pix_coords, self.ones], 1),
                                       requires_grad=False)

    def forward(self, depth, inv_K):
        cam_points = torch.matmul(inv_K[:, :3, :3], self.pix_coords)
        cam_points = depth.view(self.batch_size, 1, -1) * cam_points
        cam_points = torch.cat([cam_points, self.ones], 1)

        return cam_points


class Project3D(nn.Module):
    """Layer which projects 3D points into a camera with intrinsics K and at position T
    """
    def __init__(self, batch_size, height, width, eps=1e-7):
        super(Project3D, self).__init__()

        self.batch_size = batch_size
        self.height = height
        self.width = width
        self.eps = eps

    def forward(self, points, K, T):
        P = torch.matmul(K, T)[:, :3, :]

        cam_points = torch.matmul(P, points)

        pix_coords = cam_points[:, :2, :] / (cam_points[:, 2, :].unsqueeze(1) + self.eps)
        pix_coords = pix_coords.view(self.batch_size, 2, self.height, self.width)
        pix_coords = pix_coords.permute(0, 2, 3, 1)
        pix_coords[..., 0] /= self.width - 1
        pix_coords[..., 1] /= self.height - 1
        pix_coords = (pix_coords - 0.5) * 2
        return pix_coords


class SSIM(nn.Module):
    """Layer to compute the SSIM loss between a pair of images
    """
    def __init__(self):
        super(SSIM, self).__init__()
        self.mu_x_pool   = nn.AvgPool2d(3, 1)
        self.mu_y_pool   = nn.AvgPool2d(3, 1)
        self.sig_x_pool  = nn.AvgPool2d(3, 1)
        self.sig_y_pool  = nn.AvgPool2d(3, 1)
        self.sig_xy_pool = nn.AvgPool2d(3, 1)

        self.refl = nn.ReflectionPad2d(1)

        self.C1 = 0.01 ** 2
        self.C2 = 0.03 ** 2

    def forward(self, x, y):
        x = self.refl(x)
        y = self.refl(y)

        mu_x = self.mu_x_pool(x)
        mu_y = self.mu_y_pool(y)

        sigma_x  = self.sig_x_pool(x ** 2) - mu_x ** 2
        sigma_y  = self.sig_y_pool(y ** 2) - mu_y ** 2
        sigma_xy = self.sig_xy_pool(x * y) - mu_x * mu_y

        SSIM_n = (2 * mu_x * mu_y + self.C1) * (2 * sigma_xy + self.C2)
        SSIM_d = (mu_x ** 2 + mu_y ** 2 + self.C1) * (sigma_x + sigma_y + self.C2)

        return torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1)

Building network

In [47]:
class DepthDecoder(nn.Module):
    def __init__(self, num_ch_enc, scales=range(4), num_output_channels=1, use_skips=True):
        super(DepthDecoder, self).__init__()

        self.num_output_channels = num_output_channels
        self.use_skips = use_skips
        self.upsample_mode = 'nearest'
        self.scales = scales

        self.num_ch_enc = num_ch_enc
        self.num_ch_dec = np.array([16, 32, 64, 128, 256])

        # decoder
        self.convs = OrderedDict()
        for i in range(4, -1, -1):
            # upconv_0
            num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i + 1]
            num_ch_out = self.num_ch_dec[i]
            self.convs[("upconv", i, 0)] = ConvBlock(num_ch_in, num_ch_out)

            # upconv_1
            num_ch_in = self.num_ch_dec[i]
            if self.use_skips and i > 0:
                num_ch_in += self.num_ch_enc[i - 1]
            num_ch_out = self.num_ch_dec[i]
            self.convs[("upconv", i, 1)] = ConvBlock(num_ch_in, num_ch_out)

        for s in self.scales:
            self.convs[("dispconv", s)] = Conv3x3(self.num_ch_dec[s], self.num_output_channels)

        self.decoder = nn.ModuleList(list(self.convs.values()))
        self.sigmoid = nn.Sigmoid()

    def forward(self, input_features):
        self.outputs = {}

        # decoder
        x = input_features[-1]
        for i in range(4, -1, -1):
            x = self.convs[("upconv", i, 0)](x)
            x = [upsample(x)]
            if self.use_skips and i > 0:
                x += [input_features[i - 1]]
            x = torch.cat(x, 1)
            x = self.convs[("upconv", i, 1)](x)
            if i in self.scales:
                self.outputs[("disp", i)] = self.sigmoid(self.convs[("dispconv", i)](x))

        return self.outputs


class PoseDecoder(nn.Module):
    def __init__(self, num_ch_enc, num_input_features, num_frames_to_predict_for=None, stride=1):
        super(PoseDecoder, self).__init__()

        self.num_ch_enc = num_ch_enc
        self.num_input_features = num_input_features

        if num_frames_to_predict_for is None:
            num_frames_to_predict_for = num_input_features - 1
        self.num_frames_to_predict_for = num_frames_to_predict_for

        self.convs = OrderedDict()
        self.convs[("squeeze")] = nn.Conv2d(self.num_ch_enc[-1], 256, 1)
        self.convs[("pose", 0)] = nn.Conv2d(num_input_features * 256, 256, 3, stride, 1)
        self.convs[("pose", 1)] = nn.Conv2d(256, 256, 3, stride, 1)
        self.convs[("pose", 2)] = nn.Conv2d(256, 6 * num_frames_to_predict_for, 1)

        self.relu = nn.ReLU()

        self.net = nn.ModuleList(list(self.convs.values()))

    def forward(self, input_features):
        last_features = [f[-1] for f in input_features]

        cat_features = [self.relu(self.convs["squeeze"](f)) for f in last_features]
        cat_features = torch.cat(cat_features, 1)

        out = cat_features
        for i in range(3):
            out = self.convs[("pose", i)](out)
            if i != 2:
                out = self.relu(out)

        out = out.mean(3).mean(2)

        out = 0.01 * out.view(-1, self.num_frames_to_predict_for, 1, 6)

        axisangle = out[..., :3]
        translation = out[..., 3:]

        return axisangle, translation


class ResNetMultiImageInput(models.ResNet):
    """Constructs a resnet model with varying number of input images."""
    def __init__(self, block, layers, num_classes=1000, num_input_images=1):
        super(ResNetMultiImageInput, self).__init__(block, layers)
        self.inplanes = 64
        self.conv1 = nn.Conv2d(
            num_input_images * 3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)


def resnet_multiimage_input(num_layers, pretrained=False, num_input_images=1):
    """Constructs a ResNet model.
    Args:
        num_layers (int): Number of resnet layers. Must be 18 or 50
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        num_input_images (int): Number of frames stacked as input
    """
    assert num_layers in [18, 50], "Can only run with 18 or 50 layer resnet"
    blocks = {18: [2, 2, 2, 2], 50: [3, 4, 6, 3]}[num_layers]
    block_type = {18: models.resnet.BasicBlock, 50: models.resnet.Bottleneck}[num_layers]
    model = ResNetMultiImageInput(block_type, blocks, num_input_images=num_input_images)

    if pretrained:
        loaded = model_zoo.load_url(models.resnet.model_urls['resnet{}'.format(num_layers)])
        loaded['conv1.weight'] = torch.cat(
            [loaded['conv1.weight']] * num_input_images, 1) / num_input_images
        model.load_state_dict(loaded)
    return model


class ResnetEncoder(nn.Module):
    """Pytorch module for a resnet encode"""
    def __init__(self, num_layers, pretrained, num_input_images=1):
        super(ResnetEncoder, self).__init__()

        self.num_ch_enc = np.array([64, 64, 128, 256, 512])

        resnets = {18: models.resnet18,
                   50: models.resnet50}

        if num_layers not in resnets:
            raise ValueError("{} is not a valid number of resnet layers".format(num_layers))

        if num_input_images > 1:
            self.encoder = resnet_multiimage_input(num_layers, pretrained, num_input_images)
        else:
            self.encoder = resnets[num_layers](pretrained)

        if num_layers > 34:
            self.num_ch_enc[1:] *= 4

    def forward(self, input_image):
        self.features = []
        x = (input_image - 0.45) / 0.225
        x = self.encoder.conv1(x)
        x = self.encoder.bn1(x)
        self.features.append(self.encoder.relu(x))
        self.features.append(self.encoder.layer1(self.encoder.maxpool(self.features[-1])))
        self.features.append(self.encoder.layer2(self.features[-1]))
        self.features.append(self.encoder.layer3(self.features[-1]))
        self.features.append(self.encoder.layer4(self.features[-1]))

        return self.features

Monodepth forward model

In [48]:
class Monodepth:
    def __init__(self, num_layers=18, pretrained=False):
        self.encoder = ResnetEncoder(num_layers, pretrained)
        self.depth_decoder = DepthDecoder(num_ch_enc=self.encoder.num_ch_enc, scales=range(4))

    def load_checkpoint(self, model_path):
        encoder_path = os.path.join(model_path, "encoder.pth")
        depth_decoder_path = os.path.join(model_path, "depth.pth")

        self.loaded_dict_enc = torch.load(encoder_path, map_location='cpu')
        self.filtered_dict_enc = {k: v for k, v in self.loaded_dict_enc.items() if k in self.encoder.state_dict()}
        self.encoder.load_state_dict(self.filtered_dict_enc)
        self.loaded_dict = torch.load(depth_decoder_path, map_location='cpu')
        self.depth_decoder.load_state_dict(self.loaded_dict)
        self.feed_height = self.loaded_dict_enc['height']
        self.feed_width = self.loaded_dict_enc['width']

        self.encoder.eval()
        self.depth_decoder.eval();

    def forward(self, img):
        with torch.no_grad():
            features = self.encoder(img)
            outputs = self.depth_decoder(features)
        return outputs

# Trainning

Hyperparameters

In [49]:
class Hyperparameters:
    def __init__(self, height=192, width=640, save_freq=1, showimg_freq=1, batch_size=1,
                 log_path="", checkpoint_dir="", dataset="", data_path="", split_dir="",
                 learning_rate=1e-4, scheduler_step_size=15, num_epochs=1, num_workers=2,
                 module_load=["encoder", "depth", "pose_encoder", "pose"], pretrained=False,
                 num_pose_frames=2, frame_ids=[0,-1,1], scales=[0,1,2,3], img_type=".png",
                 log_frequency=250, min_depth=0.1, max_depth=100.0, disparity_smoothness=1e-3):
        #need to load: log_path, dataset, data_path, split_dir
        self.frame_ids = frame_ids if frame_ids[0]==0 else [0,-1,1]
        self.disparity_smoothness = disparity_smoothness
        self.scheduler_step_size = scheduler_step_size
        self.height = height if height%32==0 else 192
        self.width = width if width%32==0 else 640
        self.num_pose_frames = num_pose_frames
        self.checkpoint_dir = checkpoint_dir
        self.log_frequency = log_frequency
        self.learning_rate = learning_rate
        self.showimg_freq = showimg_freq
        self.num_workers = num_workers
        self.module_load = module_load
        self.pretrained = pretrained
        self.batch_size = batch_size
        self.num_epochs = num_epochs
        self.data_path =  data_path
        self.split_dir = split_dir
        self.min_depth = min_depth
        self.max_depth = max_depth
        self.save_freq = save_freq
        self.log_path = log_path
        self.img_type = img_type
        self.dataset = dataset
        self.device = DEVICE
        self.scales = scales
        self.num_layers = 18
        self.v1_multiscale = True
        self.avg_reprojection = False

Train process

In [50]:
class Trainer:
    def __init__(self, hyperparameters):
        # prepare attributes
        self.hp = hyperparameters
        self.device = torch.device(self.hp.device)
        self.models = {}
        self.writers = {}
        self.project_3d = {}
        self.backproject_depth = {}
        self.parameters_to_train = []
        self.log_path = self.hp.log_path
        self.save_freq = self.hp.save_freq
        self.num_scales = len(self.hp.scales)
        self.showimg_freq = self.hp.showimg_freq
        self.num_input_frames = len(self.hp.frame_ids)
        self.num_pose_frames = self.hp.num_pose_frames

        # prepare network
        self.models["encoder"] = ResnetEncoder(self.hp.num_layers, self.hp.pretrained)
        self.models["depth"] = DepthDecoder(self.models["encoder"].num_ch_enc, self.hp.scales)
        self.models["pose_encoder"] = ResnetEncoder(self.hp.num_layers, self.hp.pretrained, num_input_images=self.num_pose_frames)
        self.models["pose"] = PoseDecoder(self.models["pose_encoder"].num_ch_enc, num_input_features=1, num_frames_to_predict_for=2)
        self.models["encoder"].to(self.device)
        self.models["depth"].to(self.device)
        self.models["pose_encoder"].to(self.device)
        self.models["pose"].to(self.device)
        self.parameters_to_train += list(self.models["encoder"].parameters())
        self.parameters_to_train += list(self.models["depth"].parameters())
        self.parameters_to_train += list(self.models["pose_encoder"].parameters())
        self.parameters_to_train += list(self.models["pose"].parameters())

        self.ssim = SSIM()
        self.ssim.to(self.device)
        self.model_optimizer = optim.Adam(self.parameters_to_train, self.hp.learning_rate)
        self.model_lr_scheduler = optim.lr_scheduler.StepLR(self.model_optimizer, self.hp.scheduler_step_size, 0.1)
        self.depth_metric_names = ["de/abs_rel", "de/sq_rel", "de/rms", "de/log_rms", "da/a1", "da/a2", "da/a3"]

        if self.hp.checkpoint_dir!="":
            self.load_model()

        # projection
        for scale in self.hp.scales:
            h = self.hp.height // (2 ** scale)
            w = self.hp.width // (2 ** scale)
            self.backproject_depth[scale] = BackprojectDepth(self.hp.batch_size, h, w)
            self.backproject_depth[scale].to(self.device)
            self.project_3d[scale] = Project3D(self.hp.batch_size, h, w)
            self.project_3d[scale].to(self.device)

        # prepare data
        datasets_dict = {"nyuv2": NYUv2RAWDataset}
        self.dataset = datasets_dict[self.hp.dataset]
        fpath = os.path.join(self.hp.split_dir, "{}_files.txt")
        train_filenames = readlines(fpath.format("train"))
        val_filenames = readlines(fpath.format("val"))
        test_filenames = readlines(fpath.format("test"))
        img_ext = self.hp.img_type
        num_train_samples = len(train_filenames)
        self.num_total_steps = num_train_samples // self.hp.batch_size * self.hp.num_epochs

        train_dataset = self.dataset(self.hp.data_path, train_filenames, self.hp.height, self.hp.width, self.hp.frame_ids, 4, is_train=True, img_ext=img_ext)
        val_dataset = self.dataset(self.hp.data_path, val_filenames, self.hp.height, self.hp.width, self.hp.frame_ids, 4, is_train=False, img_ext=img_ext)
        test_dataset = self.dataset(self.hp.data_path, test_filenames, self.hp.height, self.hp.width, self.hp.frame_ids, 4, is_train=False, img_ext=img_ext)
        self.train_loader = DataLoader(train_dataset, self.hp.batch_size, True, num_workers=self.hp.num_workers, pin_memory=True, drop_last=True)
        self.val_loader = DataLoader(val_dataset, self.hp.batch_size, True, num_workers=self.hp.num_workers, pin_memory=True, drop_last=True)
        self.test_loader = DataLoader(test_dataset, self.hp.batch_size, True, num_workers=self.hp.num_workers, pin_memory=True, drop_last=True)
        self.val_iter = iter(self.val_loader)


        print("There are {:d} training items, {:d} validation items and {:d} test items\n".format(len(train_dataset), len(val_dataset), len(test_dataset)))

    def set_train(self):
        """Convert all models to training mode"""
        for m in self.models.values():
            m.train()

    def set_eval(self):
        """Convert all models to testing/evaluation mode"""
        for m in self.models.values():
            m.eval()

    def train(self):
        """Run the entire training pipeline"""
        self.epoch = 0
        self.step = 0
        self.start_time = time.time()
        for self.epoch in range(self.hp.num_epochs):
            self.run_epoch()
            self.save_model()
            print("save!")

    def run_epoch(self):
        """Run a single epoch of training and validation"""
        self.model_lr_scheduler.step()
        print("Training")
        self.set_train()

        for batch_idx, inputs in enumerate(self.train_loader):
            print("|--------------------|")
            print("Batch: ", batch_idx)
            for k in inputs:
                inputs[k] = inputs[k].to(self.device)
            outputs, losses = self.process_batch(inputs)
            self.model_optimizer.zero_grad()
            losses["loss"].backward()
            self.model_optimizer.step()
            if batch_idx % self.hp.log_frequency == 0:
                if "depth_gt" in inputs:
                    self.compute_depth_losses(inputs, outputs, losses)
                show_images(outputs)
                self.save_model()
                print("save!")
                self.val()

            self.step += 1
            print("Done Batch!")

    def testing(self):
        print("Testing")
        avr = {}
        num  = 0
        for i, metric in enumerate(self.depth_metric_names):
            avr[metric] = 0
        for batch_idx, inputs in enumerate(self.test_loader):
            print("|--------------------|")
            print("test: ", batch_idx)
            num+=1
            for k in inputs:
                inputs[k] = inputs[k].to(self.device)
            outputs, losses = self.process_batch(inputs)
            self.compute_depth_losses(inputs, outputs, losses)
            show_rgb(inputs["color_aug", 0, 0])
            show_images(outputs)
            for i, metric in enumerate(self.depth_metric_names):
                print(metric, ":  ", losses[metric])
                avr[metric]+=losses[metric]
                print(metric, "average: ", avr[metric]/num)
            print("Done!")
        print('Final result:')
        for i, metric in enumerate(self.depth_metric_names):
            print(metric, ":  ", avr[metric]/num)

    def process_batch(self, inputs):
        """Pass a minibatch through the network and generate images and losses"""
        features = self.models["encoder"](inputs["color_aug", 0, 0])
        outputs = self.models["depth"](features)
        outputs.update(self.predict_poses(inputs, features))
        self.generate_images_pred(inputs, outputs)
        losses = self.compute_losses(inputs, outputs)

        return outputs, losses

    def predict_poses(self, inputs, features):
        """Predict poses between input frames for monocular sequences."""
        outputs = {}
        if self.num_pose_frames == 2:
            # In this setting, we compute the pose to each source frame
            pose_feats = {f_i: inputs["color_aug", f_i, 0] for f_i in self.hp.frame_ids}

            for f_i in self.hp.frame_ids[1:]:
                # To maintain ordering we always pass frames in temporal order
                if f_i < 0:
                    pose_inputs = [pose_feats[f_i], pose_feats[0]]
                else:
                    pose_inputs = [pose_feats[0], pose_feats[f_i]]
                pose_inputs = [self.models["pose_encoder"](torch.cat(pose_inputs, 1))]
                axisangle, translation = self.models["pose"](pose_inputs)
                outputs[("axisangle", 0, f_i)] = axisangle
                outputs[("translation", 0, f_i)] = translation
                outputs[("cam_T_cam", 0, f_i)] = transformation_from_parameters(axisangle[:, 0], translation[:, 0], invert=(f_i < 0))

        else:
            # Here we input all frames to the pose net (and predict all poses) together
            pose_inputs = torch.cat([inputs[("color_aug", i, 0)] for i in self.hp.frame_ids], 1)
            pose_inputs = [self.models["pose_encoder"](pose_inputs)]
            axisangle, translation = self.models["pose"](pose_inputs)
            for i, f_i in enumerate(self.hp.frame_ids[1:]):
                outputs[("axisangle", 0, f_i)] = axisangle
                outputs[("translation", 0, f_i)] = translation
                outputs[("cam_T_cam", 0, f_i)] = transformation_from_parameters(axisangle[:, i], translation[:, i])

        return outputs

    def val(self):
        """Validate the model on a single minibatch"""
        self.set_eval()
        try:
            inputs = self.val_iter.__next__()
        except StopIteration:
            self.val_iter = iter(self.val_loader)
            inputs = self.val_iter.__next__()

        for k in inputs:
            inputs[k] = inputs[k].to(self.device)

        with torch.no_grad():
            outputs, losses = self.process_batch(inputs)
            if "depth_gt" in inputs:
                self.compute_depth_losses(inputs, outputs, losses)
            del inputs, outputs, losses
        self.set_train()

    def generate_images_pred(self, inputs, outputs):
        """Generate the warped (reprojected) color images for a minibatch.
        Generated images are saved into the `outputs` dictionary."""
        for scale in self.hp.scales:
            disp = outputs[("disp", scale)]
            source_scale = scale
            _, depth = disp_to_depth(disp, self.hp.min_depth, self.hp.max_depth)
            outputs[("depth", 0, scale)] = depth

            for i, frame_id in enumerate(self.hp.frame_ids[1:]):
                T = outputs[("cam_T_cam", 0, frame_id)]
                cam_points = self.backproject_depth[source_scale](depth, inputs[("inv_K", source_scale)])
                pix_coords = self.project_3d[source_scale](cam_points, inputs[("K", source_scale)], T)

                outputs[("sample", frame_id, scale)] = pix_coords
                outputs[("color", frame_id, scale)] = F.grid_sample(inputs[("color", frame_id, source_scale)], outputs[("sample", frame_id, scale)], padding_mode="border")
                outputs[("color_identity", frame_id, scale)] = inputs[("color", frame_id, source_scale)]

    def compute_reprojection_loss(self, pred, target):
        """Computes reprojection loss between a batch of predicted and target images
        """
        abs_diff = torch.abs(target - pred)
        l1_loss = abs_diff.mean(1, True)
        ssim_loss = self.ssim(pred, target).mean(1, True)
        reprojection_loss = 0.85 * ssim_loss + 0.15 * l1_loss

        return reprojection_loss

    def compute_losses(self, inputs, outputs):
        """Compute the reprojection and smoothness losses for a minibatch"""
        losses = {}
        total_loss = 0
        for scale in self.hp.scales:
            loss = 0
            reprojection_losses = []
            source_scale = scale
            disp = outputs[("disp", scale)]
            color = inputs[("color", 0, scale)]
            target = inputs[("color", 0, source_scale)]

            for frame_id in self.hp.frame_ids[1:]:
                pred = outputs[("color", frame_id, scale)]
                reprojection_losses.append(self.compute_reprojection_loss(pred, target))

            reprojection_losses = torch.cat(reprojection_losses, 1)


            identity_reprojection_losses = []
            for frame_id in self.hp.frame_ids[1:]:
                pred = inputs[("color", frame_id, source_scale)]
                identity_reprojection_losses.append(
                    self.compute_reprojection_loss(pred, target))

            identity_reprojection_losses = torch.cat(identity_reprojection_losses, 1)


            # save both images, and do min all at once below
            identity_reprojection_loss = identity_reprojection_losses


            reprojection_loss = reprojection_losses


            # add random numbers to break ties
            identity_reprojection_loss += torch.randn(identity_reprojection_loss.shape, device=self.device) * 0.00001
            combined = torch.cat((identity_reprojection_loss, reprojection_loss), dim=1)

            if combined.shape[1] == 1:
                to_optimise = combined
            else:
                to_optimise, idxs = torch.min(combined, dim=1)

            outputs["identity_selection/{}".format(scale)] = (
                idxs > identity_reprojection_loss.shape[1] - 1).float()

            loss += to_optimise.mean()

            mean_disp = disp.mean(2, True).mean(3, True)
            norm_disp = disp / (mean_disp + 1e-7)
            smooth_loss = get_smooth_loss(norm_disp, color)

            loss += self.hp.disparity_smoothness * smooth_loss / (2 ** scale)
            total_loss += loss
            losses["loss/{}".format(scale)] = loss

        total_loss /= self.num_scales
        losses["loss"] = total_loss
        return losses

    def compute_depth_losses(self, inputs, outputs, losses):
        """Compute depth metrics, to allow monitoring during training

        This isn't particularly accurate as it averages over the entire batch,
        so is only used to give an indication of validation performance"""
        depth_pred = outputs[("depth", 0, 0)]
        depth_pred = torch.clamp(F.interpolate(depth_pred, [375, 1242], mode="bilinear", align_corners=False), 1e-3, 80)
        depth_pred = depth_pred.detach()

        depth_gt = inputs["depth_gt"]
        mask = depth_gt > 0

        # garg/eigen crop
        crop_mask = torch.zeros_like(mask)
        crop_mask[:, :, 153:371, 44:1197] = 1
        mask = mask * crop_mask

        depth_gt = depth_gt[mask]
        depth_pred = depth_pred[mask]
        depth_pred *= torch.median(depth_gt) / torch.median(depth_pred)

        depth_pred = torch.clamp(depth_pred, min=1e-3, max=80)

        depth_errors = compute_depth_errors(depth_gt, depth_pred)

        for i, metric in enumerate(self.depth_metric_names):
            losses[metric] = np.array(depth_errors[i].cpu())

    def save_model(self):
        """Save model weights to disk
        """
        save_folder = os.path.join(self.log_path, "models", "weights_{}".format(self.epoch))
        if not os.path.exists(save_folder):
            os.makedirs(save_folder)

        for model_name, model in self.models.items():
            save_path = os.path.join(save_folder, "{}.pth".format(model_name))
            to_save = model.state_dict()
            if model_name == 'encoder':
                # save the sizes - these are needed at prediction time
                to_save['height'] = self.hp.height
                to_save['width'] = self.hp.width
            torch.save(to_save, save_path)

        save_path = os.path.join(save_folder, "{}.pth".format("adam"))
        torch.save(self.model_optimizer.state_dict(), save_path)

    def load_model(self):
        """Load model(s) from disk
        """
        print("Loading model from "+self.hp.checkpoint_dir)
        for n in self.hp.module_load:
            path = os.path.join(self.hp.checkpoint_dir, "{}.pth".format(n))
            model_dict = self.models[n].state_dict()
            pretrained_dict = torch.load(path, map_location=torch.device(self.device))
            pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
            model_dict.update(pretrained_dict)
            self.models[n].load_state_dict(model_dict)

        # loading adam state
        optimizer_load_path = os.path.join(self.hp.checkpoint_dir, "adam.pth")
        if os.path.isfile(optimizer_load_path):
            optimizer_dict = torch.load(optimizer_load_path, map_location=torch.device(self.device))
            self.model_optimizer.load_state_dict(optimizer_dict)
        print("Finished loading.")

# How to train and test?
First, you create and config Hyperparameters: log_path is the direction of saving folder; checkpoint_dir is the direction of your checkpoint folder (optional); dataset is the name of dataset; data_path is is the direction of data folder; split_dir is the direction of folder contains train, validate and test splits. If you want to train or test, create a Trainer object with params is the Hyparparameters that you have already created. Call train() function for running trainer and test() for testing.


# Running Trainer

In [51]:
# HP = Hyperparameters(batch_size=1, num_epochs=5, log_frequency=50, scheduler_step_size=15, img_type=".jpg",
#                      log_path="",
#                      checkpoint_dir="",
#                      dataset="nyuv2",
#                      data_path="/content/drive/MyDrive/nyuv2",
#                      split_dir="/content/drive/MyDrive/nyuv2/splits",)
# trainer = Trainer(HP)
# trainer.train()

# Testing


In [52]:
# HP = Hyperparameters(batch_size=1, num_epochs=5, log_frequency=50, scheduler_step_size=15, img_type=".jpg",
#                      log_path="/content/drive/MyDrive/nyu_result",
#                      checkpoint_dir="/content/drive/MyDrive/Checkpoint_Final/NYUv2",
#                      dataset="nyuv2",
#                      data_path="/content/drive/MyDrive/nyuv2",
#                      split_dir="/content/drive/MyDrive/nyuv2/splits",)
# trainer = Trainer(HP)
# trainer.testing()