# Reproduction of Self-Supervised Learning with Geometric Constraints in Monocular Video

by Philip de Rijk en Seger Tak

In [300]:
# mount google drive
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


## Import Modules

In [26]:
# 
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F  
import torch.optim as optim
import torchvision
import skimage.transform
import matplotlib.pyplot as plt
import numpy as np

from importlib import reload # example: foo = reload(foo) where foo = module

root_dir = '/content/drive/My Drive/CV_by_DL_project/reproduction_chen/'

import sys
sys.path.append(root_dir)

# Models
from dispnet import DispNetS
from cameranet import CameraNet
from flownet import FlowNet
from PoseExpNet import PoseExpNet 

# Utility functions
import pytorch_ssim
from transforms_new_2 import ImageTransform
from dataloader import KittiLoader
from dataloader import gt_DepthLoader
from inverse_warp import inverse_warp, pose_vec2mat
from utils import to_device
import time

## Device

In [27]:
# Use GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"

## Directories & File Locations

In [288]:
# kitti data directory
kitti_dir = os.path.join(root_dir, 'data/kitti/')

# Define location of the text file containing training examples
eigen_train_file = os.path.join(kitti_dir, 'eigen_train_files_subset.txt')
eigen_val_file = os.path.join(kitti_dir, 'eigen_val_files.txt')

# GT evaluation file locations 
gt_depth_500_dir = os.path.join(root_dir, 'gt_test500/')
image_500_dir = os.path.join(root_dir, 'imagetest500/')

gt_depth_500_val_dir = os.path.join(root_dir, 'gt_valid_500/')
image_500_val_dir = os.path.join(root_dir, 'imagevalidate500/')

# directory to save model weights
model_path_disp = os.path.join(root_dir,'data/models_seger/disp_apc_mvs_e_last.pth')
model_path_pose = os.path.join(root_dir,'data/models_seger/pose_apc_mvs_e_last.pth')
model_path_flow = os.path.join(root_dir,'data/models_seger/flow_apc_mvs_e_last.pth')

# directory to load pretrained models
model_path_disp_load = os.path.join(root_dir,'data/models_seger/disp_apc_working_last_SEGER4_last_SEGER.pth')
model_path_pose_load = os.path.join(root_dir,'data/models_seger/pose_apc_working_last_SEGER4_last_SEGER.pth')
model_path_flow_load = os.path.join(root_dir,'data/models_seger/flow_apc_working_last_SEGER4_last_SEGER.pth')

# output directory for depth predictions
output_directory_train = os.path.join(root_dir,'data/output/train_full_seger')
output_directory_eval = os.path.join(root_dir,'data/output/val_full_seger')

### Training Dataloader

In [301]:
# Data augmentation 
datatransform_train = ImageTransform(mode='train')
datatransform_val = ImageTransform(mode='val')

# Run the dataloader to obtain the dataset 
train_set = KittiLoader(split_file=eigen_train_file, base_dir=kitti_dir, mode='train', transform=datatransform_train)

# define the dataloader variables, such as batch size etc.
train_loader = DataLoader(train_set, batch_size = 4, shuffle = True, num_workers = 2, pin_memory=True)

# Run the dataloader to obtain the dataset 
val_set = KittiLoader(split_file=eigen_val_file, base_dir=kitti_dir, mode='train', transform=datatransform_val)

# define the dataloader variables, such as batch size etc.
val_loader = DataLoader(val_set, batch_size = 1, shuffle = False, num_workers = 2, pin_memory=True)

data_loader = {'train': train_loader, 'val': val_loader} 

print('the training set contains', len(train_set), 'images')
print('the val set contains', len(val_set), 'images')

SKIPPED:  /content/drive/My Drive/CV_by_DL_project/reproduction_chen/data/kitti/2011_09_28/2011_09_28_drive_0001_sync/image_02/data/0000000104.png  - does not exist
SKIPPED:  /content/drive/My Drive/CV_by_DL_project/reproduction_chen/data/kitti/2011_09_26/2011_09_26_drive_0057_sync/image_02/data/0000000359.png  - does not exist
SKIPPED:  /content/drive/My Drive/CV_by_DL_project/reproduction_chen/data/kitti/2011_09_26/2011_09_26_drive_0087_sync/image_02/data/0000000727.png  - does not exist
SKIPPED:  /content/drive/My Drive/CV_by_DL_project/reproduction_chen/data/kitti/2011_09_26/2011_09_26_drive_0001_sync/image_02/data/0000000106.png  - does not exist
SKIPPED:  /content/drive/My Drive/CV_by_DL_project/reproduction_chen/data/kitti/2011_09_30/2011_09_30_drive_0020_sync/image_02/data/-000000001.png  - does not exist
SKIPPED:  /content/drive/My Drive/CV_by_DL_project/reproduction_chen/data/kitti/2011_09_26/2011_09_26_drive_0079_sync/image_02/data/-000000001.png  - does not exist
SKIPPED:  

### Test Dataloader

In [263]:
# Data transforms 
datatransform_test = ImageTransform(mode='test')

# Run the dataloader to obtain the dataset 
test_set = gt_DepthLoader(gt_depth_500_dir, image_500_dir, datatransform_test)

# define the dataloader variables, such as batch size etc.
test_loader = DataLoader(test_set, batch_size = 1, shuffle = False, num_workers = 2, pin_memory=True)

print('the test set contains', len(test_set), 'images')

the test set contains 500 images


In [264]:
# Data transforms 
datatransform_val500 = ImageTransform(mode='test')

# Run the dataloader to obtain the dataset 
val500_set = gt_DepthLoader(gt_depth_500_val_dir, image_500_val_dir, datatransform_val500)

# define the dataloader variables, such as batch size etc.
val500_loader = DataLoader(val500_set, batch_size = 1, shuffle = False, num_workers = 2, pin_memory=True)

print('the val set contains', len(val500_set), 'images')

the val set contains 500 images


## Utility functions

In [265]:
############################ SSIM import used for APC ################################
from torch.autograd import Variable
from math import exp

def gaussian(window_size, sigma):
    gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
    return gauss/gauss.sum()

def create_window(window_size, channel):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
    return window

def _ssim(img1, img2, window, window_size, channel, size_average = True):
    mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
    mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1*mu2

    sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
    sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
    sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2

    C1 = 0.01**2
    C2 = 0.03**2

    ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))

    if size_average:
        return ssim_map.mean()
    else:
        return ssim_map.mean(1)#.mean(1).mean(1)

class SSIM(torch.nn.Module):
    def __init__(self, window_size = 11, size_average = False):
        super(SSIM, self).__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.channel = 1
        self.window = create_window(window_size, self.channel)

    def forward(self, img1, img2):
        (_, channel, _, _) = img1.size()

        if channel == self.channel and self.window.data.type() == img1.data.type():
            window = self.window
        else:
            window = create_window(self.window_size, channel)
            
            if img1.is_cuda:
                window = window.cuda(img1.get_device())
            window = window.type_as(img1)
            
            self.window = window
            self.channel = channel


        return _ssim(img1, img2, window, self.window_size, channel, self.size_average)

def ssim(img1, img2, window_size = 11, size_average = False):
    (_, channel, _, _) = img1.size()
    window = create_window(window_size, channel)
    
    if img1.is_cuda:
        window = window.cuda(img1.get_device())
    window = window.type_as(img1)
    
    return _ssim(img1, img2, window, window_size, channel, size_average)

In [292]:
def make_grid(depth_or_flow_pred_1d):
    # alternative to the set_id_grid from inverse_warp.py of Clement Pinard
    """ obtain pixel coordinates of an image.
    args: 
        depth_or_flow_pred_1d: squeezed predicted depth or flow map. In case of flow map, only select 1 channel -- [B, C, H, W]
    return:
        pixel_coords: pixel coordinates, where each coordinate is denoted [u,v,1] -- [1, 3, H, W]
    """
    b, _, h, w = depth_or_flow_pred_1d.size()
    i_range = torch.arange(0, h).view(1, h, 1).expand(1,h,w).type_as(depth_or_flow_pred_1d)  # [1, H, W]
    j_range = torch.arange(0, w).view(1, 1, w).expand(1,h,w).type_as(depth_or_flow_pred_1d)  # [1, H, W]
    ones = torch.ones(1,h,w).type_as(depth_or_flow_pred_1d)

    pixel_coords = torch.stack((j_range, i_range, ones), dim=1)  # [1, 3, H, W]

    return pixel_coords

def pixel2cam_v2(depth, intrinsics_inv):
    # alternative to the pixel2cam from inverse_warp.py of Clement Pinard. 
    """Transform coordinates in the pixel frame to the camera frame.
    Args:
        depth: depth maps -- [B, 1, H, W]
        intrinsics_inv: intrinsics_inv matrix for each element of batch -- [B, 3, 3]
    Returns:
        array of (u,v,1) cam coordinates -- [B, 3, H, W]
    """
    b, c, h, w = depth.size()
    pixel_coords = make_grid(depth)
    current_pixel_coords = pixel_coords[:,:,:h,:w].expand(b,3,h,w).reshape(b, 3, -1)  # [B, 3, H*W]
    cam_coords = (intrinsics_inv @ current_pixel_coords).reshape(b, 3, h, w)
    return cam_coords * depth

def scale_intrinsics(current_scale_img, normal_scale_img, intrinsics):
    """ scale the intrinsics matrix with image scale
    """
    b, _, h, w = current_scale_img.size()
    downscale = normal_scale_img.size(2)/h
    
    downscale_matrix = torch.tensor([[1, 1, 1/downscale],
                    [1, 1, 1/downscale],
                    [1, 1, 1]]).unsqueeze(0).unsqueeze(0).to(device)

    intrinsics_scaled = intrinsics * downscale_matrix #[B, C=2, 3, 3]

    return intrinsics_scaled

def vec2skew_symmetric(vec):
    """ Convert vector to skew symmetric matrix
    Args: 
        vec = vector of length 3 - [B, 3]
    return:
        skew_symmetric_mat = skew symmetric matrix representation of vector [B, 3, 3]
    """
    assert(vec.size()[1] == 3)
    assert(vec.ndim == 2)

    # create zero tensor with right shape to input into skew symmetric matrix
    zeros = vec[:,2]*0
    b, _ = vec.size()   # extract batch size

    # define skew symmetric matrix
    skew_symmetric_mat = torch.stack([zeros, -vec[:, 2], vec[:, 1],
                    vec[:, 2], zeros,   -vec[:, 0],
                    -vec[:, 1], vec[:, 0], zeros], dim=1).reshape(b, 3, 3)
    return skew_symmetric_mat

def optical_flow_displacement(src2tgt_flow, src_img):
    """ Warp a source image pixels to the target image plane based on flow field
    Args: 
        src2tgt_flow = source to target image flow -- [B, 2, H, W]
        src_img = source image -- [B, 3, H, W]
    Return:
        pixel_coords_tgt = pixels in target image corresponding to pixels in source image [B, 2, H, W]
        valid_points = valid points of the warped image [B, H, W]
    """
    b, _, h, w = src2tgt_flow.size()
    
    # obtain non-homogeneous pixel coordinates of the source image
    pixel_coords_src = make_grid(src2tgt_flow)[:, :2] # [1, 2, H, W]

    # Obtain pixels in the target image corresponding to pixels in source image via optical flow.
    pixel_coords_tgt = pixel_coords_src.repeat(b,1,1,1) + src2tgt_flow # [B, 2, H, W]

    # # With normalizer
    # normalizer = torch.tensor([(2./w),(2./h)]).repeat(b,h,w,1).permute(0,3,1,2).float().to(device) 
    # pixel_coords_tgt = (pixel_coords_tgt/normalizer).permute(0,2,3,1) # [B, H, W, 2]

    # without normalizer
    pixel_coords_tgt = pixel_coords_tgt.permute(0,2,3,1) # [B, H, W, 2]

    # return the valid points of the sampler. Check whether this is required! ############################################
    valid_points = pixel_coords_tgt.abs().max(dim=-1)[0] <= 1 #[b, h, w]

    return pixel_coords_tgt.permute(0,3,1,2), valid_points

def flow_warp(src2tgt_flow, tgt_img):
    """ warp a target image to the source image plane based on flow field
    Args: 
        src2tgt_flow = source to target image flow -- [B, 2, H, W]
        tgt_img = target image from which pixels will be sampled-- [B, 3, H, W]
    Return:
        warped_img =  [B, 3, H, W]
      
    """
    # obtain the valid target image pixels corresponding to the source image pixels via flow
    pixel_coords_tgt, valid_points = optical_flow_displacement(src2tgt_flow, tgt_img)  # [B, 2, H, W], [B,H,W]

    #obtain valid pixels of the target images from which we will sample.
    # tgt_img = tgt_img.permute(0,3,1,2) * valid_points.unsqueeze(3) # [B,H,W,2]

    ## perform grid sampling on the source image using transformed source to target pixel coordinates. Requires [b,c,h,w], [b,h,w,coordinates]
    warped_img = F.grid_sample(tgt_img, (pixel_coords_tgt.permute(0,2,3,1)), mode='bilinear', padding_mode='zeros', align_corners=True) 

    return warped_img

def focal2intrinsics(focal_lengths, img):
    """ Create intrinsics matrix K
    Args:
        focal_lengths = focal lengths along the two optical axes fax and fay -- [B, 2] 
        img = image of resolution H × W -- [B, 3, H, W]
    Returns:
        Camera Intrinsic Matrix K -- [B, 3, 3]
    """
    b, _, h, w = img.size()
    
    # Retrieve individual focal lengths
    fax = focal_lengths[:,0]
    fay = focal_lengths[:,1]
  
    # create tensors from ints with same dimenstions as batch size
    zeros = fax*0
    ones = zeros+1
    h = ones*h
    w = ones*w

    # create the camera intrinsics matrix
    K = torch.stack([fax,   zeros, w/2,
                     zeros, fay,   h/2,
                     zeros, zeros, ones], dim=1).reshape(b, 3, 3)

    return K

def save_images(output_directory_train, depths_tgt, tgt_img, i):
    """ save the output and target images during training, validation and testing
    """
    depth_to_img = skimage.transform.resize(depths_tgt[0][0,:,:,:].squeeze().cpu().detach(), [375, 1242], mode='constant')
    tgt_to_img = skimage.transform.resize(tgt_img[0,:,:,:].squeeze().permute(1, 2, 0).cpu().detach(), [375, 1242], mode='constant')
    plt.imsave(os.path.join(output_directory_train, str(i)+ '_apc-mvs_e_traindepth_SEGER.png'), depth_to_img, cmap='plasma') 
    plt.imsave(os.path.join(output_directory_train, str(i)+ '_apc-mvs_e_trainimage_SEGER.png'), tgt_to_img)

def save_images_norm(output_directory_train, depths_tgt, tgt_img, i):
    """ save the normalised output and target images during training, validation and testing
    """ 
    # Get disparity output
    disp_tgt = 1/depths_tgt[0]

    # Make correction on the normalized disparity and depth output 
    disp_tgt *= 256
    depths_tgt[0]*=256

    # transform and save image
    depth_to_img = depths_tgt[0][0].squeeze().cpu().detach().numpy().astype(np.uint16)
    depth_to_img = skimage.transform.resize(depth_to_img, [352, 1216], mode='constant') 
    disp_to_img = skimage.transform.resize(disp_tgt[0,:,:,:].squeeze().cpu().detach(), [375, 1242], mode='constant')
    plt.imsave(os.path.join(output_directory_train, str(i)+ '_apc-mvs_e_traindepth_SEGER.png'), depth_to_img)
    tgt_to_img = skimage.transform.resize(tgt_img[0].abs().squeeze().permute(1, 2, 0).cpu().detach(), [375, 1242], mode='constant') 
    plt.imsave(os.path.join(output_directory_train, str(i)+ '_apc-mvs_e_traindisp_SEGER.png'), disp_to_img, cmap='plasma') 
    plt.imsave((os.path.join(output_directory_train, str(i)+ '_apc-mvs_e_trainimage_SEGER.png')), tgt_to_img)

def save_models(disp_net, pose_net, flow_net, model_path_disp, epoch):
    """ save the full networks
    """
    torch.save(disp_net, model_path_disp[:-4] + str(epoch) + '_last_SEGER.pth')
    torch.save(pose_net, model_path_pose[:-4] + str(epoch) + '_last_SEGER.pth')
    torch.save(flow_net, model_path_pose[:-4] + str(epoch) + '_last_SEGER.pth')

def plot_train_curve(b100loss):
    """ Plot the training curves every x number of batches during training
    """
    print()
    plt.plot(b100loss)
    plt.xlabel('100 batches')
    plt.ylabel('training loss')
    plt.legend(['training curve'])
    plt.grid(True)
    plt.show() 

def plot_curve(loss_train, loss_val):
    """ Plot the training and validation curve after every epoch
    """
    print()
    plt.plot(loss_train)
    plt.plot(loss_val)
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.legend(['training loss', 'validation loss'])
    plt.grid(True)
    plt.show()

def plot_metrics(abs_diff, abs_rel, sq_rel, rmse, rmse_log):
    """ Plot the metrics for testing
    """
    print()     
    plt.plot(abs_diff)
    plt.plot(abs_rel)
    plt.plot(rmse)
    plt.plot(rmse_log)
    plt.xlabel('batch')
    plt.ylabel('error')
    plt.legend(['abs_diff', 'abs_rel', 'rmse', 'rmse_log'])
    plt.grid(True)
    plt.show()

    print()     
    plt.plot(sq_rel)
    plt.xlabel('batch')
    plt.ylabel('error')
    plt.legend(['sq_rel'])
    plt.grid(True)
    plt.show()


## Load Models

In [295]:
disp_net = DispNetS().to(device)
pose_net = CameraNet(nb_ref_imgs=2).to(device)
flow_net = FlowNet(6, 0.1).to(device)
disp_net.init_weights()
pose_net.init_weights()
flow_net.init_weight()

### pretrained dispnet

In [296]:
pretrained_disp_path = os.path.join(root_dir, 'models/pretrained_models/dispnet_model_best.pth.tar')
weights = torch.load(pretrained_disp_path)
disp_net.load_state_dict(weights['state_dict'])  

<All keys matched successfully>

## Optimizer

In [297]:
lr = 2E-4
b1 = 0.9
b2 = 0.999
optim_params = [
        {'params': disp_net.parameters(), 'lr': lr},
        {'params': pose_net.parameters(), 'lr': lr},
        {'params': flow_net.parameters(), 'lr': lr}
    ]
optimizer = optim.Adam(optim_params, betas=(b1, b2), eps=1e-08)

## Losses

### Photometric losses

In [283]:
def adaptive_photometric_loss(tgt_img, src_imgs, depths, flows, pose, intrinsics):
    """ Calculate the adaptive photometric loss
    Args:
        tgt_img: target image                                     -- [B, 3, H, W]
        src_imgs: list of the source images (previous & next)     -- [[B, 3, H, W], [B, 3, H, W]]
        depths: list of depth maps of source images on 4 scales   -- [[B, 1, H, W]....8x] 
        flows: flow maps on 4 scales                              -- [[B, 2, H, W]....8x] 
        pose: 6DoF pose parameters from target to source          -- [B, C=2, 6]
        intrinsics: camera intrinsic matrix                       -- [B, C=2, 3, 3]
    Return:
        total adaptive photometric loss
    """
    def one_scale_apc(depth, nr_src_img, i): # [B, 1, H, W]
        assert(pose.size(1) == len(src_imgs))
        reconstruction_loss = 0
        loss_flow = 0

        # retrieve depthmap size
        b, _, h, w = depth.size() 

        # Scale the source and target image to the size of the respective depth map
        tgt_img_scaled = F.interpolate(tgt_img, (h, w), mode='area') # [B, 3, H, W]
        src_imgs_scaled = [F.interpolate(src_img, (h, w), mode='area') for src_img in src_imgs] # [[B, 3, H, W], [B, 3, H, W]]
            
        # Scale intrinsics matrix according to scale
        intrinsics_scaled = scale_intrinsics(depth, tgt_img, intrinsics) #[B, C=2, 3, 3]
        
        # obtain the current intrinsics matrix, pose matrix and flow map.
        current_intrinsics = intrinsics_scaled[:,nr_src_img] # [B, 3, 3]
        current_pose = pose[:,nr_src_img]
        current_flow = flows[i]
            
        # Inverse warp a source image to the target image plane. 
        projected_image, valid_points = inverse_warp(tgt_img, depth[:,0], current_pose, current_intrinsics) # [B, 3, H, W], # [B, H, W]

        # warp a source image to the target image plane using optical flow
        warped_image_flow = flow_warp(current_flow, tgt_img_scaled) #[B,3,H,W]
            
        # Initialization of SSIM loss
        ssim_loss = SSIM()

        # Calculating the L1 loss of both the rigid and optical flow related displacement
        diff_rigid = src_imgs_scaled[nr_src_img] - projected_image
        diff_rigid_abs = diff_rigid.abs().mean(1)
        diff_flow = src_imgs_scaled[nr_src_img] - warped_image_flow
        diff_flow_abs = diff_flow.abs().mean(1)

        # calculate SSIM loss for the rigid displacement of pixels
        ssim_rigid = ssim_loss(src_imgs_scaled[nr_src_img], projected_image) 
        reconstruction_loss = (0.85 * ((1-ssim_rigid)/2) + (1-0.85) * diff_rigid_abs)

        # Calculate SSIM loss for secondary displacement of pixels related to flow
        ssim_flow = ssim_loss(src_imgs_scaled[nr_src_img], warped_image_flow) 
        loss_flow = (0.85 * ((1-ssim_flow)/2) + (1-0.85) * diff_flow_abs)

        # calculate the minimum loss value for pixel (height, width)
        loss_apc = torch.min(reconstruction_loss, loss_flow)
            
        # obtain total loss value for one scale. Mean or sum?!
        loss_apc_batch = loss_apc.mean()
        return loss_apc_batch

    total_apc_loss = 0
    # Loop over the depths to obtain rigid photometric loss
    for i, depth in enumerate(depths):
        if i < 4:
            nr_src_img = 0
        else:
            nr_src_img = 1
        loss_apc = one_scale_apc(depth, nr_src_img, i) # [H, W]
        total_apc_loss += loss_apc                    
            
    # remove!
    warped_results = 0
    diff_results = 0   

    return total_apc_loss, warped_results, diff_results

def standard_photometric_loss(tgt_img, src_imgs, depths, pose, intrinsics):
    """ Calculate the photometric loss related to rigid displacement (not adaptive)
    Args:
        tgt_img: target image                                     -- [B, 3, H, W]
        src_imgs: list of the source images (previous & next)     -- [[B, 3, H, W], [B, 3, H, W]]
        depths: list of depth maps of target images on 4 scales   -- [[B, 1, H, W], [B, 1, H, W], [B, 1, H, W], [B, 1, H, W]]
        pose: 6DoF pose parameters from target to source          -- [B, C=2, 6]
        intrinsics: camera intrinsic matrix                       -- [B, C=2, 3, 3]
    Return:
        total photometric loss related to rigid displacement
    """
    def one_scale(depth): # [B, 1, H, W]
        assert(pose.size(1) == len(src_imgs))
       
        reconstruction_loss = 0

        # retrieve depth size
        b, _, h, w = depth.size()
        
        # Scale the source and target image to the size of the respective depth map
        tgt_img_scaled = F.interpolate(tgt_img, (h, w), mode='area') # [B, 3, H, W]
        src_imgs_scaled = [F.interpolate(src_img, (h, w), mode='area') for src_img in src_imgs] # [[B, 3, H, W], [B, 3, H, W]]
        
        # Scale intrinsics matrix according to scale
        intrinsics_scaled = scale_intrinsics(depth, tgt_img, intrinsics) #[B, C=2, 3, 3]
        
        # looping over the previous and next source image
        for i, src_img in enumerate(src_imgs_scaled):
            current_intrinsics = intrinsics_scaled[:,i] # [B, 3, 3]
            current_pose = pose[:,i]
            
            # warp a source image to the target image plane 
            projected_image, valid_points = inverse_warp(tgt_img, depth[:,0], current_pose, current_intrinsics) # [B, 3, H, W], # [B, H, W]

            # Initialization of SSIM loss
            ssim_loss = pytorch_ssim.SSIM(window_size = 11)

            # calculating the valid points on the projected and original image
            src_img_scaled_valid = src_img * valid_points.unsqueeze(1).float() 
            projected_image_valid = projected_image * valid_points.unsqueeze(1).float()

            # Calculating SSIM
            ssim = ssim_loss(src_img_scaled_valid, projected_image_valid) 

            # Calculating the l1 loss
            diff = (src_img_scaled_valid - projected_image) * valid_points.unsqueeze(1).float() 
            diff_abs = diff.abs().mean()
            
            # reconstruction_loss = diff_abs
            reconstruction_loss += 0.85 * ((1-ssim)/2) + (1-0.85) * diff_abs      
                        
        return reconstruction_loss

    # ignore
    warped_results, diff_results = 0, 0
    total_loss = 0

    # looping over the depths at each scale
    for i, depth in enumerate(depths):
        loss = one_scale(depth)                      
        total_loss += loss

    return total_loss, warped_results, diff_results


### Epipolar loss

In [272]:

########################## Epipolar loss ######################################

def epipolar_loss(intrinsics, pose, src_imgs, flows):
    """ Calculate the epipoplar constraint loss 
    Args:
        intrinsics: camera intrinsic matrix                       -- [B, C=2, 3, 3]
        pose: 6DoF pose parameters from target to source          -- [B, 6]
        src_imgs: list of the source images (previous & next)     -- [[B, 3, H, W], [B, 3, H, W]]
        flows: flow maps -- [[B, 2, H, W]....8x] 
    Returns:
        Epipolar_loss: mean epipolar loss averaged over all pixels and batch - [1,1]
    """

    def one_scale_flow(flow_local, nr_src_img): # [B, 1, H, W]
        """ Calculate the mean of the pixelwise epipolar loss for one scale and summed over the entire batch
        Args:
            flow_local: flowmap at 1 scale - [B, 2, H, W]
            nr_src_img: index indicating the source image to be evaluated - (prev=0, next=1)
        Return:
            loss_one_scale: total epipolar loss for one scale and summed over the entire batch. Mean of epipolar loss over all pixels is taken. [1,1]
        """
        # retrieve depth size and downscale factor
        batch, _, h, w = flow_local.size() 
        downscale = src_imgs[0].size(2)/h

        # Scale the source image to the size of the respective flow map
        src_imgs_scaled = [F.interpolate(src_img, (h, w), mode='area') for src_img in src_imgs] # [[B, 3, H, W], [B, 3, H, W]]

        # define downscaling matrix
        downscale_matrix = torch.tensor([[1, 1, 1/downscale],
                        [1, 1, 1/downscale],
                        [1, 1, 1]]).unsqueeze(0).unsqueeze(0).to(device)

        # Scale intrinsics matrix according to scale
        intrinsics_scaled = intrinsics * downscale_matrix #[B, C=2, 3, 3]

        # Determine the current intrinsics and pose
        current_intrinsics = intrinsics_scaled[:, nr_src_img] # [B, 3, 3]
        current_pose = pose[:, nr_src_img]

        # Calculate transformation matrix from current pose
        transform = pose_vec2mat(current_pose) # [B, 3, 4]
        
        # extract rotation matrix and translation vector from transformation matrix of pose c
        rotation_mat, translation_vec = transform[:, :, :3], transform[:, :, -1:] # [B, 3, 3], [B, 3, 1]

        # convert translation vector of pose c to skew symmetric matrix
        translation_mat = vec2skew_symmetric(translation_vec.squeeze(2))

        # Obtain transposed and inverse intrinsics matrix
        tr_intrinsics, inv_intrinsics = current_intrinsics.permute(0, 2, 1), torch.inverse(current_intrinsics) # [B, 3, 3]  
        
        # Obtain homogeneous coordinates in pixel coordinate system and prepare for multiplications
        pixel_coords_src = make_grid(flow_local).reshape(1, 3, -1).permute(2, 0, 1) # [H*W, 1, 3]
        
        # Calculate the constant component of the loss function 
        const_component = tr_intrinsics @ rotation_mat @ translation_mat @ inv_intrinsics # [B, 3, 3]

        # obtain the target image pixels that are related to source image pixels via optical flow and make homogeneous
        pixel_coords_tgt, valid_points = optical_flow_displacement(flow_local, src_imgs_scaled[nr_src_img]) # [B, 2, H, W], [B, H, W]
        pixel_coords_tgt = pixel_coords_tgt.reshape(batch, 2, -1).permute(2,1,0) # [H*W, 2, B]

        ones = torch.ones(h*w,1,batch).type_as(pixel_coords_tgt) # [H*W, 1, B]
        pixel_coords_tgt = torch.cat((pixel_coords_tgt, ones), dim=1)  # [H*W, 3, B]

        loss_batch = 0.0
        # Calculate the epipolar loss for each pixel on one scale for the entire batch. 
        #Final result per iteration over 'one_scale_flow' = [[H*W/downscale, 1, 1] ... x B * nr_src_imgs]
        for b in range(batch):
            loss_img = pixel_coords_src @ const_component[b].unsqueeze(0) @ pixel_coords_tgt[:, :, b].unsqueeze(2) # [H*W/downscale , 1, 1] 
            loss_batch += loss_img # total loss for each pixel on 1 scale

        # total epipolar loss for one scale and summed over the entire batch. Mean of epipolar loss over all pixels is taken
        loss_one_scale = loss_batch.mean()

        return loss_one_scale

    epipolar_loss = 0
    for idx, flow in enumerate(flows):
        if idx < 4:
            nr_src_img = 0
        else:
            nr_src_img = 1
        loss = one_scale_flow(flow, nr_src_img)

        #  sum loss on each scale for both src images to get total epipolar loss per batch
        epipolar_loss += loss

    return epipolar_loss       



### Multi-View 3D Consistency loss

In [273]:
def mvs_loss(intrinsics, pose, src_imgs, tgt_img, depths_src, depths_tgt):
    """ Calculate 
    Args:
        pose: 6DoF pose parameters from target to source            -- [B, C=2, 6]
        intrinsics: camera intrinsic matrix                         -- [B, C=2, 3, 3]
        src_imgs: list of the source images (previous & next)       -- [[B, 3, H, W], ... x2]
        tgt_img: target image                                       -- [B, 3, H, W]
        depths_src: list of depth maps of source images on 4 scales -- [[B, 1, H, W], .... x8] 
        depths_tgt: list of depth maps of target image on 4 scales  -- [[B, 1, H, W], .... x4]
    Return:
        mvs_loss: multi-view 3D structure consistency loss
    """

    def one_scale_mvs(local_depth_tgt, local_depth_src, nr_src_img):
        """ Calculate the L1 loss between target image and 1 source image on 1 scale
            local_depth_tgt: depth map of target image on one scale - [B, 1, H, W]
            local_depth_src: depth map of source image on one scale - [B, 1, H, W]
            nr_src_img: either 0 for previous src or 1 for next image

        """
        # retrieve depth size and downscale factor
        b, _, h, w = local_depth_tgt.size() 
        downscale = src_imgs[0].size(2)/h

        # Scale the source image to the size of the respective flow map
        src_imgs_scaled = [F.interpolate(src_img, (h, w), mode='area') for src_img in src_imgs] # [[B, 3, H, W], [B, 3, H, W]]

        # define downscaling matrix
        downscale_matrix = torch.tensor([[1, 1, 1/downscale],
                        [1, 1, 1/downscale],
                        [1, 1, 1]]).unsqueeze(0).unsqueeze(0).to(device)

        # Scale intrinsics matrix according to scale
        intrinsics_scaled = intrinsics * downscale_matrix #[B, C=2, 3, 3]

        # Determine the current intrinsics and pose. Invert the intrinsics matrix
        current_pose = pose[:, nr_src_img]
        current_intrinsics = intrinsics_scaled[:, nr_src_img] # [B, 3, 3]
        inv_intrinsics = torch.inverse(current_intrinsics) # [B, 3, 3] 

        # Calculate transformation matrix from current pose and make [B, 4, 4]
        transform = pose_vec2mat(current_pose) # [B, 3, 4]
        extra_row = (torch.tensor([0, 0, 0, 1]).type_as(transform)).unsqueeze(0).unsqueeze(0).repeat(b,1,1)
        transform_4d = torch.cat((transform, extra_row), dim=1) # [B, 4, 4]
        
        # backprojection to 3D scene position from pixel to target and source camera frame respectively: 
        tgt_coords_tgt_frame = pixel2cam_v2(local_depth_tgt, inv_intrinsics) # [B, 3, H, W] depth * K^-1 * p'
        src_coords_src_frame = pixel2cam_v2(local_depth_src, inv_intrinsics) # [B, 3, H, W] depth* K^-1 * p
        
         # Obtain homogeneous coordinates in pixel coordinate system and prepare for multiplications with M
        ones = torch.ones(b, 1, h, w).type_as(src_coords_src_frame) # [B, 1, H, W]

        tgt_coords_tgt_frame_4d = torch.cat((tgt_coords_tgt_frame, ones), dim=1) # [B, 4, H, W]
        src_coords_src_frame_4d = torch.cat((src_coords_src_frame, ones), dim=1) # [B, 4, H, W]
        src_coords_src_frame_reshape = src_coords_src_frame_4d.reshape(b, 4, -1).permute(0, 2, 1).unsqueeze(3) # [B, H*W, 4, 1]

        # multiplication with transformation matrix to backproject source pixel into target camera coordinate system
        src_coords_tgt_frame = transform_4d.unsqueeze(1) @ src_coords_src_frame_reshape# [B, H*W, 4]
        src_coords_tgt_frame = src_coords_tgt_frame.squeeze(3).permute(0, 2, 1).reshape(b, 4, h, w)

        # L1 loss
        l1loss = nn.L1Loss()
        loss = l1loss(tgt_coords_tgt_frame_4d, src_coords_tgt_frame)

        return loss

    batch_mvs_loss = 0
    # loop over the number of source images
    for nr_src_img in range(len(src_imgs)):
        # loop over the depth scales of the target image
        for scale, depth_tgt in enumerate(depths_tgt):
            if nr_src_img == 0:
                loss = one_scale_mvs(depth_tgt, depths_src[scale], nr_src_img)
            else:
                loss = one_scale_mvs(depth_tgt, depths_src[scale+(len(depths_tgt))], nr_src_img)
            batch_mvs_loss += loss
    
    return batch_mvs_loss




### Training and validation

In [None]:
def validate_with_gt_during_training(test_loader, disp_net, abs_diff, abs_rel, sq_rel, rmse, rmse_log):
    # Set networks into evaluation mode
    disp_net.eval(), pose_net.eval(), flow_net.eval()

    # evaluate the network
    errors = validate_with_gt(test_loader, disp_net)

    # reset models to training
    disp_net.train(), pose_net.train(), flow_net.train()

    abs_diff.append(errors[0]), abs_rel.append(errors[1]), sq_rel.append(errors[2])
    rmse.append(errors[3]), rmse_log.append(errors[4])
    return abs_diff, abs_rel, sq_rel, rmse, rmse_log

In [298]:
############################## Training One Epoch #################################
def standard_train(disp_net, pose_net, photometric_loss, epipolar_loss, mvs_loss, optimizer, n_epochs, train_loader, val_loader, test_loader):
    abs_diff, abs_rel, sq_rel, rmse, rmse_log = [], [], [], [], []   
    loss_train, loss_val, b100loss_list = [], [], []        
    # loop over the total number of epochs
    for epoch in range(n_epochs):
        running_loss = 0
        running_loss_val = 0
        b100loss = 0

        for phase in ['train', 'val']:

            # set the network architectures to training mode
            if phase == 'train':
                disp_net.train()
                pose_net.train()
                flow_net.train()

            else:
                disp_net.eval()
                pose_net.eval()
                flow_net.eval()

            # loop through the batches
            for i, data in enumerate(data_loader[phase]):
                
                # extract source and target image for 1 forward pass. Send to GPU
                tgt_img = data['target_image'].to(device)
                src_img_prev = data['source_image_prev'].to(device) 
                src_img_next = data['source_image_next'].to(device) 
                # Concatenate source images, and concatenate source and target image
                src_images = [src_img_prev , src_img_next]
                concat_tgt_src = [torch.cat((tgt_img, src_img), 1) for src_img in src_images] # [[B, 6, H, W], .. (2x)] 
                     
                with torch.set_grad_enabled(phase == 'train'): # tracking history; only during training

                    ############################# DEPTH #############################

                    # predict disparity of target image and translate to depth at four scales
                    disparities_tgt = disp_net(tgt_img)
                      
                    # predict disparities of source images and translate to depth at four scales for each src image
                    disparities_src_prev = disp_net(src_img_prev)
                    disparities_src_next = disp_net(src_img_next)
                    disparities_src = [*disparities_src_prev, *disparities_src_next]

                    # convert disparity to depth
                    if phase == 'train':
                        depths_tgt = [1/disp for disp in disparities_tgt] # [[B, 1, H, W], ... (4x) ]   
                        depths_src = [1/disp for disp in disparities_src] # [[B, 1, H, W], ... (8x) ]
                    else:                    
                        depths_tgt = [(1/disp).unsqueeze(1) for disp in disparities_tgt] # [[B, 1, H, W], ... (4x) ]
                        depths_src = [(1/disp).unsqueeze(1) for disp in disparities_src] # [[B, 1, H, W], ... (8x) ]

                    #############################  POSE  ##############################
                    # predict pose focal lengths 
                    pose, focal_lengths = pose_net(tgt_img, src_images)  # [B, C=2, 6] and [B, C=2, 2]
                            
                    # calculate the camera intrinsics matrix
                    intrinsics = torch.stack([focal2intrinsics(focal_lengths[:,c,:], tgt_img) for c in range(len(src_images))], dim=1) #[B, C=2, 3, 3]                

                    ############################## FLOW ################################

                    # predict flow map between the two source images and target image. Put in a list
                    flow = [*flow_net(concat_tgt_src[0]), *flow_net(concat_tgt_src[1])] # [[B, 2, H, W], ...x8]

                    ########################## LOSS AND BACKPROP ###########################
                    loss_pc, warped, diff = photometric_loss(tgt_img, src_images, depths_src, flow, pose, intrinsics) ### ADD 'flow' BETWEEN DEPTHS AND POSE FOR APC 
                    loss_e = epipolar_loss(intrinsics, pose, src_images, flow)                               ### comment this loss away
                    loss_mvs = mvs_loss(intrinsics, pose, src_images, tgt_img, depths_src, depths_tgt)
                    loss = loss_pc + 0.1 * loss_mvs + 0.001 * loss_e

                    # In training mode; zero the gradients, backprop and optimization of parameters
                    if phase == 'train':
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()

                # update intermediate batch 100 loss and running loss
                b100loss += loss
                running_loss += loss

                # In training mode; print intermediate results every 100 batches
                if phase == 'train' and i % 100 == 0 and i != 0:
                    avg_100_batch_loss = float(b100loss / 100)
                    b100loss_list.append(avg_100_batch_loss)
                    print('batch: ', i), print('100 batch avg loss = ', avg_100_batch_loss)   
                    print('loss e =', loss_e), print('loss mvs =', loss_mvs), print('loss ap =', loss_pc)

                    # Save the results and plot the training curve 
                    save_images_norm(output_directory_train, depths_tgt, tgt_img, i)
                    plot_train_curve(b100loss_list)

                    b100loss = 0

                    ############## VALIDATE EVERY 100 BATCHES #####################
                    abs_diff, abs_rel, sq_rel, rmse, rmse_log = validate_with_gt_during_training(test_loader, disp_net, abs_diff, abs_rel, sq_rel, rmse, rmse_log)  
                    plot_metrics(abs_diff, abs_rel, sq_rel, rmse, rmse_log)

                    save_models(disp_net, pose_net, flow_net, model_path_disp, epoch)
                    print('Model_saved')

        ######################## CALCULATE LOSSES ######################
        # Calculate training van validation loss and plot the resulting curves
        running_loss_avg = running_loss/(len(train_loader.dataset)/4)
        running_loss_val_avg = running_loss_val/len(val_loader.dataset)
        loss_train.append(running_loss_avg)
        loss_val.append(running_loss_val_avg)
        plot_curve(loss_train, loss_val)

        print('epoch:', epoch)
        print('Training Loss: {:.4f}'.format(running_loss/(len(data_loader['train'].dataset)/4)))
        print('Validation Loss: {:.4f}'.format(running_loss_val/(len(data_loader['val'].dataset)/1)))
        

In [276]:
def compute_errors(gt, pred, crop=True):
    abs_diff, abs_rel, sq_rel, a1, a2, a3 = 0,0,0,0,0,0
    batch_size = gt.size(0)

    '''
    crop used by Garg ECCV16 to reprocude Eigen NIPS14 results
    construct a mask of False values, with the same size as target
    and then set to True values inside the crop
    '''
    if crop:
        crop_mask = gt[0] != gt[0]
        y1,y2 = int(0.40810811 * gt.size(1)), int(0.99189189 * gt.size(1))
        x1,x2 = int(0.03594771 * gt.size(2)), int(0.96405229 * gt.size(2))
        crop_mask[y1:y2,x1:x2] = 1

    for current_gt, current_pred in zip(gt, pred):
        valid = (current_gt > 0) & (current_gt < 80)
        if crop:
            valid = valid & crop_mask

        valid_gt = current_gt[valid]
        valid_pred = current_pred[valid].clamp(1e-3, 80)

        valid_pred = valid_pred * torch.median(valid_gt)/torch.median(valid_pred)

        thresh = torch.max((valid_gt / valid_pred), (valid_pred / valid_gt))
        a1 += (thresh < 1.25).float().mean()
        a2 += (thresh < 1.25 ** 2).float().mean()
        a3 += (thresh < 1.25 ** 3).float().mean()

        rmse = (valid_gt - valid_pred) ** 2
        rmse = torch.sqrt(rmse.mean())

        rmse_log = (torch.log(valid_gt) - torch.log(valid_pred)) ** 2
        rmse_log = torch.sqrt(rmse_log.mean())

        abs_diff += torch.mean(torch.abs(valid_gt - valid_pred))
        abs_rel += torch.mean(torch.abs(valid_gt - valid_pred) / valid_gt)

        sq_rel += torch.mean(((valid_gt - valid_pred)**2) / valid_gt)

    return [metric.item() / batch_size for metric in [abs_diff, abs_rel, sq_rel, rmse, rmse_log, a3]]

In [277]:
def validate_with_gt(val_loader, disp_net):
        # intialize validation metrics             
        #abs_diff_sum, abs_rel_sum, sq_rel_sum, a1_sum, a2_sum, a3_sum = 0,0,0,0,0,0
        #n = len(val_loader)

        disp_net.eval()
        pose_net.eval()
        flow_net.eval()
        errors = []

        with torch.no_grad():

            for (i, data) in enumerate(val_loader):

                tgt_image = data['image'].to(device)
                gt_depth = data['gt_depth'].to(device)

                # 1 forward pass in eval mode
                output_disp = disp_net(tgt_image)

                # inverse disparity to obtain depth
                output_depth = 1/output_disp[:,0]

                if i % 25 == 0:
                    depth_to_img = skimage.transform.resize(output_depth.squeeze().cpu().detach(), [375, 1242], mode='constant')
                    gt_to_img = skimage.transform.resize(gt_depth.squeeze().cpu().detach(), [375, 1242], mode='constant')
                    plt.imsave(os.path.join(output_directory_eval, str(i)+ '_L1_pred.png'), depth_to_img, cmap='plasma')
                    plt.imsave(os.path.join(output_directory_eval, str(i)+'_L1_gt.png'), gt_to_img, cmap='plasma')              
                
                errors.append(compute_errors(gt_depth, output_depth, crop=False))

        mean_errors = np.array(errors).mean(0)

        print("\n  " + ("{:>8} | " * 6).format("abs_diff", "abs_rel", "sq_rel", "rmse", "rmse_log", "a3"))
        print(("&{: 8.3f}  " * 6).format(*mean_errors.tolist()) + "\\\\")
        return mean_errors

## Start training

In [None]:
standard_train(disp_net, pose_net, adaptive_photometric_loss, epipolar_loss, mvs_loss, optimizer, 5, train_loader, val_loader, test_loader)