# Training Loss
    We are considering the three types of losses during the training
## Appearance Matching Loss:

    If you give the left image as a inut to the network the network will generate 2 disparity maps,
    one for the right and one for the left side. here we consider the disparity map created for the left side.
    So if we have the right image and disparity map for the left image we can recreate the left image.
    appearance matching loss considers the reconstructed left image and original left image and compair them.
    we use L1 loss + simillarity loss combine to defne the appearance matching loss.
    
<img src="appearance.png" title="Title text" />
    
## Disparity Smoothness Loss:

    So for the generated disparity maps there should not be abrupt changes in the disparity values. 
    And to ensure that we penalize the disparity map where we see the higher gradients.
    
<img src="disparity.png" title="Title text" />
    
## Left-Right Disparity Consistency Loss:

    left image + disparity map for left = right image
    right image + disparity map for right = left image
    author assumes as the changes in the left and right images are minor both disparity maps should be almost equal
    
<img src="left-right.png" title="Title text" />


In [1]:
from __future__ import absolute_import, division, print_function
import torch
import torch.nn as nn
from torch.nn.functional import pad

def SSIM_loss(x, y):
    """Calculate Structural Similarity Score
    Arguments:
        x {tenosor} -- Image
        y {tensor} -- Image
    Returns:
        floatTensor -- SSIM loss : (1- SSIM)/2 between 0 to 1
    """
    C1 = 0.01 ** 2
    C2 = 0.03 ** 2

    mu_x = nn.functional.avg_pool2d(x, 3, 1, padding = 0)
    mu_y = nn.functional.avg_pool2d(y, 3, 1, padding = 0)

    sigma_x  = nn.functional.avg_pool2d(x ** 2, 3, 1, padding = 0) - mu_x ** 2
    sigma_y  = nn.functional.avg_pool2d(y ** 2, 3, 1, padding = 0) - mu_y ** 2

    sigma_xy = nn.functional.avg_pool2d(x * y , 3, 1, padding = 0) - mu_x * mu_y

    SSIM_n = (2 * mu_x * mu_y + C1) * (2 * sigma_xy + C2)
    SSIM_d = (mu_x ** 2 + mu_y ** 2 + C1) * (sigma_x + sigma_y + C2)

    SSIM = SSIM_n / SSIM_d

    return torch.clamp((1 - SSIM) / 2, 0, 1)

def gradient_x(image):
    """Find gradient along x
    Arguments:
        image {tensor} -- image
    Returns:
        tensor -- dx(image)
    """
    return image[:,:,:,:-1]-image[:,:,:,1:]

def gradient_y(image):
    """Find gradient along y
    Arguments:
        image {tensor} -- image
    Returns:
        tensor -- dy(image)
    """
    return image[:,:,:-1,:]-image[:,:,1:,:]

def gradient_t(prev_image, image):
    """Find gradient along time: time derivative using two images
    Arguments:
        prev_image {tensor} -- image at timestamp 1
        image {tensor} -- image at timestamp 2
    Returns:
        tensor -- time derivative of image
    """
    return prev_image-image

def disparity_smoothness(image, disparity):
    """Calculate "Edge aware" Disparity Smoothness loss 
    Arguments:
        image {tensor} -- image
        disparity {tensor} -- disparity
    Returns:
        FloatTensor -- loss 
    """
    grad_img_x = [gradient_x(i) for i in image]
    grad_img_y = [gradient_y(i) for i in image]

    grad_disp_x = [gradient_x(i) for i in disparity]
    grad_disp_y = [gradient_y(i) for i in disparity]

    weights_x = [torch.exp(-torch.mean(torch.abs(g), 1, keepdim=True)) for g in grad_img_x]
    weights_y = [torch.exp(-torch.mean(torch.abs(g), 1, keepdim=True)) for g in grad_img_y]

    smoothness_x = [grad_disp_x[i] * weights_x[i] for i in range(4)]
    smoothness_y = [grad_disp_y[i] * weights_y[i] for i in range(4)]
    
    smoothness_x = [torch.nn.functional.pad(k,(0,1,0,0,0,0,0,0),mode='constant') for k in smoothness_x]
    smoothness_y = [torch.nn.functional.pad(k,(0,0,0,1,0,0,0,0),mode='constant') for k in smoothness_y]

    disp_smoothness = smoothness_x + smoothness_y

    disp_loss = [torch.mean(torch.abs(disp_smoothness[i])) / 2 ** i for i in range(4)]
    return disp_loss

def LR_disparity_consistency(input_images, x_offset, wrap_mode='border', tensor_type = 'torch.FloatTensor'):
    """Implementation of Bilinear Sampling
    Arguments:
        input_images {tensor} -- Image
        x_offset {tensor} -- tensor
    Keyword Arguments:
        wrap_mode {str} -- method of warp (default: {'border'})
        tensor_type {str} -- datatype (default: {'torch.FloatTensor'})
    Returns:
        tensor -- image applied bilinear sampling
    """
    is_gpu_available = torch.cuda.is_available()
    
    if is_gpu_available:
        tensor_type = 'torch.cuda.FloatTensor'
    num_batch, num_channels, height, width = input_images.size()

    # Handle both texture border types
    edge_size = 0
    if wrap_mode == 'border':
        edge_size = 1
        # Pad last and second-to-last dimensions by 1 from both sides
        input_images = pad(input_images, (1, 1, 1, 1))
    elif wrap_mode == 'edge':
        edge_size = 0
    else:
        return None

    # Put channels to slowest dimension and flatten batch with respect to others
    input_images = input_images.permute(1, 0, 2, 3).contiguous()
    im_flat = input_images.view(num_channels, -1)

    # Create meshgrid for pixel indicies (PyTorch doesn't have dedicated meshgrid function)
    x = torch.linspace(0, width - 1, width).repeat(height, 1).type(tensor_type).to(torch.device("cuda:0"))
    y = torch.linspace(0, height - 1, height).repeat(width, 1).transpose(0, 1).type(tensor_type).to(torch.device("cuda:0"))
    # Take padding into account
    x = x + edge_size
    y = y + edge_size
    # Flatten and repeat for each image in the batch
    x = x.reshape(-1).repeat(1, num_batch)
    y = y.reshape(-1).repeat(1, num_batch)

    # Now we want to sample pixels with indicies shifted by disparity in X direction
    # For that we convert disparity from % to pixels and add to X indicies
    x = x + x_offset.contiguous().view(-1) * width
    # Make sure we don't go outside of image
    x = torch.clamp(x, 0.0, width - 1 + 2 * edge_size)
    # Round disparity to sample from integer-valued pixel grid
    y0 = torch.floor(y)
    # In X direction round both down and up to apply linear interpolation
    # between them later
    x0 = torch.floor(x)
    x1 = x0 + 1
    # After rounding up we might go outside the image boundaries again
    x1 = x1.clamp(max=(width - 1 + 2 * edge_size))

    # Calculate indices to draw from flattened version of image batch
    dim2 = (width + 2 * edge_size)
    dim1 = (width + 2 * edge_size) * (height + 2 * edge_size)
    # Set offsets for each image in the batch
    base = dim1 * torch.arange(num_batch).type(tensor_type).to(torch.device("cuda:0"))
    base = base.view(-1, 1).repeat(1, height * width).view(-1)
    # One pixel shift in Y  direction equals dim2 shift in flattened array
    base_y0 = base + y0 * dim2
    # Add two versions of shifts in X direction separately
    idx_l = base_y0 + x0
    idx_r = base_y0 + x1

    # Sample pixels from images
    pix_l = im_flat.gather(1, idx_l.repeat(num_channels, 1).long())
    pix_r = im_flat.gather(1, idx_r.repeat(num_channels, 1).long())

    # Apply linear interpolation to account for fractional offsets
    weight_l = x1 - x
    weight_r = x - x0
    output = weight_l * pix_l + weight_r * pix_r

    # Reshape back into image batch and permute back to (N,C,H,W) shape
    output = output.view(num_channels, num_batch, height, width).permute(1,0,2,3)

    return output