<a href="https://colab.research.google.com/github/JanNogga/Vision_Systems_Lab/blob/main/Project_Solution/prototyping.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install segmentation-models-pytorch >> /dev/null
!pip install cmocean >> /dev/null

In [2]:
import numpy as np
import random as rand
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.optim.lr_scheduler import ReduceLROnPlateau
import matplotlib.pyplot as plt
import os
import time
from tqdm.notebook import tqdm
import shutil
import random
from PIL import Image
import os
from torch.utils.data import Dataset
from torch.cuda.amp import autocast, GradScaler
import torch.jit as jit
import segmentation_models_pytorch as smp

import warnings
warnings.filterwarnings("ignore", category=UserWarning)

### Models

In [3]:
B, C, H, W = 1, 3, 640, 480 #368, 1216
batch_a = torch.randn(B, C, H, W).cuda()
batch_b = torch.randn(B, C, H, W).cuda()
print(batch_a.shape, batch_b.shape)
print(torch.allclose(batch_a, batch_b), torch.allclose(batch_a, batch_a), torch.allclose(batch_b, batch_b))

torch.Size([1, 3, 640, 480]) torch.Size([1, 3, 640, 480])
False True True


In [4]:
# Here we are adapting the source code of the Unet from https://github.com/qubvel/segmentation_models.pytorch
# to fit our Feature Fusion needs. Specifically, the Unet building blocks are inspired by the ones at: https://github.com/qubvel/segmentation_models.pytorch/blob/master/segmentation_models_pytorch/unet/decoder.py
# The reason we adapt this code right here instead of writing it from scratch is because the pretrained segmentation backbones available at the repository above are then completely compatible with our Feature Fusion.
# This yields a huge variety of possible encoder backbones that we could use.

from segmentation_models_pytorch.base import modules as md

class CustomDecoderBlock(nn.Module):
    def __init__(
            self,
            in_channels,
            skip_channels,
            out_channels,
            use_batchnorm=True,
            attention_type=None,
            upsample=True
    ):
        super().__init__()
        self.upsample = nn.Upsample(scale_factor=2) if upsample else nn.Identity()
        self.conv1 = md.Conv2dReLU(
            in_channels + skip_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )
        self.attention1 = md.Attention(attention_type, in_channels=in_channels + skip_channels)
        self.conv2 = md.Conv2dReLU(
            out_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )
        self.attention2 = md.Attention(attention_type, in_channels=out_channels)

    def forward(self, x, skip=None):
        x = self.upsample(x)
        if skip is not None:
            x = torch.cat([x, skip], dim=1)
            x = self.attention1(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.attention2(x)
        return x


class CustomCenterBlock(nn.Sequential):
    def __init__(self, in_channels, out_channels, use_batchnorm=True):
        conv1 = md.Conv2dReLU(
            in_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )
        conv2 = md.Conv2dReLU(
            out_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )
        super().__init__(conv1, conv2)


class CustomUnetDecoder(nn.Module):
    def __init__(
            self,
            encoder_channels,
            decoder_channels,
            use_batchnorm=True,
            attention_type=None,
            center=False,
    ):
        super().__init__()

        n_blocks = len(decoder_channels)

        encoder_channels = encoder_channels[:n_blocks+1]
        encoder_channels = encoder_channels[1:]  # remove first skip with same spatial resolution
        encoder_channels = encoder_channels[::-1]  # reverse channels to start from head of encoder
        upsample = [True for channel in encoder_channels]
        upsample[-2:] = [False, False]
        self.num_downsample = [0 for channel in encoder_channels]
        self.num_downsample[-2:] = [1, 2]
        # computing blocks input and output channels
        head_channels = encoder_channels[0]
        in_channels = [head_channels] + list(decoder_channels[:-1])
        skip_channels = list(encoder_channels[1:]) + [0]
        out_channels = decoder_channels

        if center:
            self.center = CustomCenterBlock(
                head_channels, head_channels, use_batchnorm=use_batchnorm
            )
        else:
            self.center = nn.Identity()

        # combine decoder keyword arguments
        kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type)
        blocks = [
            CustomDecoderBlock(in_ch, skip_ch, out_ch, upsample=ups, **kwargs)
            for in_ch, skip_ch, out_ch, ups in zip(in_channels, skip_channels, out_channels, upsample)
        ]
        self.blocks = nn.ModuleList(blocks)
        self.upsample = nn.Upsample(scale_factor=2)
        self.downsample = nn.MaxPool2d(kernel_size=2)

    def forward(self, *features):

        features = features[1:]    # remove first skip with same spatial resolution
        features = features[::-1]  # reverse channels to start from head of encoder

        head = features[0]
        skips = features[1:]
        feature_maps = []
        x = self.center(head)
        for i, decoder_block in enumerate(self.blocks):
            skip = skips[i] if i < len(skips) else None
            if skip is not None:
                for j in range(self.num_downsample[i]):
                    skip = self.downsample(skip)
            x = decoder_block(x, skip)
            feature_maps.append(x)
        x = feature_maps[0]
        feature_maps = feature_maps[1:]
        for i in range(len(feature_maps)-2):
                x = torch.cat([self.upsample(x), feature_maps[i]], dim=1)
        for i in range(2):
                x = torch.cat([x, feature_maps[-(i+1)]], dim=1)
        return x

class UnetFeatures(nn.Module):
    def __init__(self, encoder_depth=3, decoder_channels=(64, 32, 16, 12, 8), backbone='resnet34'):
        super().__init__()
        model = smp.Unet(encoder_name=backbone, encoder_weights="imagenet", encoder_depth=encoder_depth, decoder_channels=decoder_channels[:encoder_depth]) # "efficientnet-b3"
        self.encoder = model.encoder
        self.decoder = CustomUnetDecoder(encoder_channels=model.encoder._out_channels, decoder_channels=decoder_channels[:encoder_depth], center=True)
        self._freeze_encoder()

    def _freeze_encoder(self):
        self.encoder.requires_grad_(False)
  
    def _unfreeze_encoder(self):
        self.encoder.requires_grad_(True)

    def forward(self, x):
        enc_features = self.encoder(x)
        return self.decoder(*enc_features)

In [5]:
# ed = 5
# tmp = smp.Unet(encoder_name="resnet34", encoder_weights="imagenet", encoder_depth=ed, decoder_channels=(64, 32, 16, 12, 8)[:ed]).cuda() # "efficientnet-b3", "mobilenet_v2", "resnet34"
# tmp_feats = tmp.encoder(batch_a)
# for map in tmp_feats:
#   print(map.shape)

# print(tmp.encoder._out_channels[:2])
# decoder = CustomUnetDecoder(encoder_channels=tmp.encoder._out_channels, decoder_channels=(64, 32, 16, 12, 8)[:ed], center=True).cuda()
# #print(decoder)
# out = decoder(*tmp_feats)
# print(out.shape)
# print(sum((64, 32, 16, 12, 8)[:ed]))

In [18]:
class StackedHourglass(nn.Module):
    def __init__(self, inplanes, norm_layer = None):
        super(StackedHourglass, self).__init__()
        self.C = inplanes

        self.layer_1_2 = nn.Sequential(nn.Conv3d(self.C, self.C//2, kernel_size=3, padding=1, bias=False),
                                       nn.BatchNorm3d(self.C//2),
                                       nn.ReLU(inplace=True),
                                       nn.Conv3d(self.C//2, self.C//2, kernel_size=3, padding=1, bias=False),
                                       nn.BatchNorm3d(self.C//2),
                                       nn.ReLU(inplace=True))
        self.layer_2_4 = nn.Sequential(nn.Conv3d(self.C//2, self.C//2, kernel_size=3, padding=1, bias=False),
                                       nn.BatchNorm3d(self.C//2),
                                       nn.ReLU(inplace=True),
                                       nn.Conv3d(self.C//2, self.C//2, kernel_size=3, padding=1, bias=False),
                                       nn.BatchNorm3d(self.C//2))
        self.layer_4_6 = nn.Sequential(nn.Conv3d(self.C//2, self.C, kernel_size=3, padding=1, stride=2, bias=False),
                                       nn.BatchNorm3d(self.C),
                                       nn.ReLU(inplace=True),
                                       nn.Conv3d(self.C, self.C, kernel_size=3, padding=1, bias=False),
                                       nn.BatchNorm3d(self.C),
                                       nn.ReLU(inplace=True))
        self.layer_6_8 = nn.Sequential(nn.Conv3d(self.C, self.C, kernel_size=3, padding=1, stride=2, bias=False),
                                       nn.BatchNorm3d(self.C),
                                       nn.ReLU(inplace=True),
                                       nn.Conv3d(self.C, self.C, kernel_size=3, padding=1, bias=False),
                                       nn.BatchNorm3d(self.C),
                                       nn.ReLU(inplace=True))
        self.layer_8_9 = nn.Sequential(nn.ConvTranspose3d(self.C, self.C, kernel_size=3, padding=1, output_padding=1, stride=2,bias=False),
                                       nn.BatchNorm3d(self.C))
        self.layer_9_10 = nn.Sequential(nn.ConvTranspose3d(self.C, self.C//2, kernel_size=3, padding=1, output_padding=1, stride=2,bias=False),
                                        nn.BatchNorm3d(self.C//2))
        self.layer_10_12 = nn.Sequential(nn.Conv3d(self.C//2, self.C, kernel_size=3, padding=1, stride=2, bias=False),
                                         nn.BatchNorm3d(self.C),
                                         nn.ReLU(inplace=True),
                                         nn.Conv3d(self.C, self.C, kernel_size=3, padding=1, bias=False),
                                         nn.BatchNorm3d(self.C))
        self.layer_12_14 = nn.Sequential(nn.Conv3d(self.C, self.C, kernel_size=3, padding=1, stride=2, bias=False),
                                         nn.BatchNorm3d(self.C),
                                         nn.ReLU(inplace=True),
                                         nn.Conv3d(self.C, self.C, kernel_size=3, padding=1, bias=False),
                                         nn.BatchNorm3d(self.C),
                                         nn.ReLU(inplace=True))
        self.layer_14_15 = nn.Sequential(nn.ConvTranspose3d(self.C, self.C, kernel_size=3, padding=1, output_padding=1, stride=2,bias=False),
                                         nn.BatchNorm3d(self.C))
        self.layer_15_16 = nn.Sequential(nn.ConvTranspose3d(self.C, self.C//2, kernel_size=3, padding=1, output_padding=1, stride=2,bias=False),
                                         nn.BatchNorm3d(self.C//2)) 
        self.layer_16_18 = nn.Sequential(nn.Conv3d(self.C//2, self.C, kernel_size=3, padding=1, stride=2, bias=False),
                                         nn.BatchNorm3d(self.C),
                                         nn.ReLU(inplace=True),
                                         nn.Conv3d(self.C, self.C, kernel_size=3, padding=1, bias=False),
                                         nn.BatchNorm3d(self.C))
        self.layer_18_20 = nn.Sequential(nn.Conv3d(self.C, self.C, kernel_size=3, padding=1, stride=2, bias=False),
                                         nn.BatchNorm3d(self.C),
                                         nn.ReLU(inplace=True),
                                         nn.Conv3d(self.C, self.C, kernel_size=3, padding=1, bias=False),
                                         nn.BatchNorm3d(self.C),
                                         nn.ReLU(inplace=True))
        self.layer_20_21 = nn.Sequential(nn.ConvTranspose3d(self.C, self.C, kernel_size=3, padding=1, output_padding=1, stride=2,bias=False),
                                         nn.BatchNorm3d(self.C))
        self.layer_21_22 = nn.Sequential(nn.ConvTranspose3d(self.C, self.C//2, kernel_size=3, padding=1, output_padding=1, stride=2,bias=False),
                                         nn.BatchNorm3d(self.C//2))
      
        self.dist1 = nn.Sequential(nn.Conv3d(self.C//2, self.C//2, kernel_size=3, padding=1, bias=False),
                                   nn.BatchNorm3d(self.C//2),
                                   nn.ReLU(inplace=True),
                                   nn.Conv3d(self.C//2, 1, kernel_size=3, padding=1,bias=False))
        self.dist2 = nn.Sequential(nn.Conv3d(self.C//2, self.C//2, kernel_size=3, padding=1, bias=False),
                                   nn.BatchNorm3d(self.C//2),
                                   nn.ReLU(inplace=True),
                                   nn.Conv3d(self.C//2, 1, kernel_size=3, padding=1,bias=False))
        self.dist3 = nn.Sequential(nn.Conv3d(self.C//2, self.C//2, kernel_size=3, padding=1, bias=False),
                                   nn.BatchNorm3d(self.C//2),
                                   nn.ReLU(inplace=True),
                                   nn.Conv3d(self.C//2, 1, kernel_size=3, padding=1,bias=False))
    
    def forward(self, x):
        # 2x conv3d 
        a2 = self.layer_1_2(x) #-- 64->32
        # 2x conv3d
        a4 = F.relu(self.layer_2_4(a2) + a2, inplace=True) #-- 32->32
        ## hourglass 1 
        # 1x downconv + 1x conv
        a6 = self.layer_4_6(a4) #-- 32->64
        # 1x downconv + 1x conv
        a8 = self.layer_6_8(a6) #-- 64->64
        # 2x convtranspose
        a9 = F.relu(self.layer_8_9(a8)+ a6, inplace=True) #-- 64->64
        a10 = F.relu(self.layer_9_10(a9)+ a4, inplace=True) #-- 64->32
        ## hourglass 2
        # 1x downconv + 1x conv
        a12 = F.relu(self.layer_10_12(a10) + a9, inplace=True) #-- 32->64
        # 1x downconv + 1x conv
        a14 = self.layer_12_14(a12) #-- 64->64
        # 2x convtranspose
        a15 = F.relu(self.layer_14_15(a14)+ a6, inplace=True) #-- 64->64
        a16 = F.relu(self.layer_15_16(a15)+ a4, inplace=True) #-- 64->32
        ## hourglass 3
        # 1x downconv + 1x conv
        a18 = F.relu(self.layer_16_18(a16) + a15, inplace=True) #-- 32->64
        # 1x downconv + 1x conv
        a20 = self.layer_18_20(a18) #-- 64->64
        # 2x convtranspose
        a21 = F.relu(self.layer_20_21(a20)+ a6, inplace=True) #-- 64->64
        a22 = F.relu(self.layer_21_22(a21)+ a4, inplace=True) #-- 64->32

        c1 = self.dist1(a10) #-- 32->1
        c2 = self.dist2(a16) + c1 #-- 32->1
        c3 = self.dist2(a22) + c2 #-- 32->1
        if self.training:
            return (c1, c2, c3)
        else:
            return c3

In [19]:
# squeeze layer
class Squeeze(nn.Module):
    def __init__(self, dim):
        super(Squeeze, self).__init__()
        self.dim = dim
    
    def forward(self, x):
        # type: (Tensor) -> Tensor
        return x.squeeze(self.dim)


# disparity expectation layer
class Expectation(nn.Module):
    def __init__(self, max_disp):
        super(Expectation, self).__init__()
        # this is a parameter without grad to ensure model.cuda() behaves as expected
        self.weights = nn.parameter.Parameter(torch.arange(max_disp).float(), requires_grad=False)
    
    def forward(self, x):
        # type: (Tensor) -> Tensor
        return (x.permute(0,2,3,1) @ self.weights).unsqueeze(1)


# group-wise corr as per GWC net: https://github.com/xy-guo/GwcNet/blob/master/models/submodule.py
class GWC(nn.Module):
    def __init__(self, max_disp, num_groups):
        super(GWC, self).__init__()
        self.max_disp = max_disp
        self.num_groups = num_groups
    
    def forward(self, left_feats, right_feats):
        # type: (Tensor, Tensor) -> Tensor
        B, C, H_prime, W_prime = left_feats.shape
        channels_per_group = C // self.num_groups
        cost = (left_feats * right_feats).view([B, self.num_groups, channels_per_group, H_prime, W_prime]).mean(dim=2)
        return cost

# cost as per GWC net: https://github.com/xy-guo/GwcNet/blob/master/models/submodule.py
class GWCBuilder(nn.Module):
    def __init__(self, max_disp, num_groups):
        super(GWCBuilder, self).__init__()
        self.gwc = GWC(max_disp, num_groups)
        self.max_disp = max_disp
        self.num_groups = num_groups
    
    def forward(self, left_feats, right_feats):
        # type: (Tensor, Tensor) -> Tensor
        B, C, H_prime, W_prime = left_feats.shape
        cost = torch.zeros(B, self.num_groups, self.max_disp//4, H_prime, W_prime).to(left_feats.device)
        for i in range(self.max_disp//4):
            if i > 0:
                cost[:, :, i, :, i:] = self.gwc(left_feats[:, :, :, i:], right_feats[:, :, :, :-i])
            else:
                cost[:, :, i, :, :] = self.gwc(left_feats, right_feats)
        return cost


# cost volume construction is taken from the lab slides
class CostBuilder(nn.Module):
    def __init__(self, max_disp):
        super(CostBuilder, self).__init__()
        self.max_disp = max_disp
    
    def forward(self, left_feats, right_feats):
        # type: (Tensor, Tensor) -> Tensor
        B, C, H_prime, W_prime = left_feats.shape
        cost = torch.zeros(B, 2*C, self.max_disp//4, H_prime, W_prime).to(left_feats.device)
        for i in range(self.max_disp//4):
            if i==0:
                cost[:,:C,i,:,:] = left_feats
                cost[:,C:,i,:,:] = right_feats
            else:
                cost[:,:C,i,:,i:] = left_feats[:,:,:,i:]
                cost[:,C:,i,:,i:] = right_feats[:,:,:,:-i]
        return cost

# residual blocks are built like shown in https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
def conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)


def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
    def __init__(self, inplanes, planes, stride = 1, norm_layer = None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = None if (stride == 1 and inplanes == planes) else nn.Sequential(conv1x1(inplanes, planes, stride), norm_layer(planes))

    def forward(self, x):
        # type: (Tensor) -> Tensor
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        out = self.relu(out)
        return out

def conv3d3x3(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv3d:
    """3x3 convolution with padding"""
    return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)


def conv3d1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv3d:
    """1x1 convolution"""
    return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

    
class BasicBlock3D(nn.Module):
    def __init__(self, inplanes, planes, stride = 1, norm_layer = None):
        super(BasicBlock3D, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm3d
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3d3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3d3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = None if (stride == 1 and inplanes == planes) else nn.Sequential(conv3d1x1(inplanes, planes, stride), norm_layer(planes))
    
    def forward(self, x):
        # type: (Tensor) -> Tensor
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        out = self.relu(out)
        return out


class CustomModel(nn.Module):
    def __init__(self, in_channels=3, max_disp=192, encoder_depth=3, decoder_channels=(256, 128, 64, 32, 16), backbone='resnet34'): #136
        super(CustomModel, self).__init__()
        # features:
        # in:  B  x  3 x H    x W
        # out: B x  C x H//4 x W//4
        self.features = UnetFeatures(encoder_depth=encoder_depth, decoder_channels=decoder_channels, backbone=backbone)
        C = sum(decoder_channels[:encoder_depth])
        # independent feature refinement
        # concat features
        # in:   B x C     x H//4 x W//4
        # out:  B x C_Cat x H//4 x W//4
        self.left_conv_cat = BasicBlock(inplanes=C, planes=12)
        self.right_conv_cat = BasicBlock(inplanes=C, planes=12)
        # correlation features
        # in:   B x C     x H//4 x W//4
        # out:  B x C_Cor x H//4 x W//4
        self.left_conv_cor = BasicBlock(inplanes=C, planes=160)
        self.right_conv_cor = BasicBlock(inplanes=C, planes=160)
        # concat cost volume
        # in: 2 of B x   C_Cat           x H//4 x W//4
        # out:     B x 2*C_Cat x Disp//4 x H//4 x W//4
        self.build_cost = CostBuilder(max_disp=max_disp)
        # group-wise cost volume
        # in: 2 of B x C_Cor              x H//4 x W//4
        # out:     B x N_Groups x Disp//4 x H//4 x W//4
        self.build_gwc = GWCBuilder(max_disp=max_disp, num_groups=40)
        # 3d convs
        # in:  B x N_Groups + 2*C_Cat x Disp//4 x H//4 x W//4
        # out: B x 1                  x Disp//4 x H//4 x W//4
        C3d = 64
        self.stacked_hourglass = StackedHourglass(C3d)
        # regressor
        # in:  B x 1 x Disp//4 x H//4 x W//4
        # out: B x 1           x H    x W
        self.regressor = nn.Sequential(
            nn.Upsample(scale_factor=4),
            Squeeze(dim=1),
            torch.nn.Softmax(dim=1),
            Expectation(max_disp)
            )
        
    def forward(self, img_left, img_right):
        # type: (Tensor, Tensor) -> Tensor
        _, _, H, W = img_left.shape
        feat_left, feat_right = self.features(img_left), self.features(img_right)
        feat_left_cor, feat_right_cor = self.left_conv_cor(feat_left), self.right_conv_cor(feat_right)
        feat_left_cat, feat_right_cat = self.left_conv_cat(feat_left), self.right_conv_cat(feat_right)
        cost_gwc = self.build_gwc(feat_left_cor, feat_right_cor)
        cost_cat = self.build_cost(feat_left_cat, feat_right_cat)
        cost_volume = torch.cat([cost_gwc, cost_cat], dim=1)
        if self.training:
            (c1, c2, c3) = self.stacked_hourglass(cost_volume)
            return (self.regressor(c1), self.regressor(c2), self.regressor(c3))
        else:
            c3 = self.stacked_hourglass(cost_volume)
            return self.regressor(c3)

In [25]:
import time
model = CustomModel(encoder_depth=3, backbone="mobilenet_v2").cuda()
with torch.no_grad():
  with autocast():
    features = model.features.encoder(batch_a)
    #print('feats:')
    #print(model.features.encoder._out_channels)
    # for map in features:
    #   print(map.shape)
    start = time.time()
    disparity_map = model(batch_a, batch_b)
    torch.cuda.current_stream().synchronize()
    print('Took', time.time()-start, 'seconds.')
#print()
print(type(disparity_map) is tuple)

Took 0.391345739364624 seconds.
True


### Image sizes in KITTI valid set

torch.Size([1, 374, 1238])
torch.Size([1, 376, 1241])
torch.Size([1, 375, 1242])
torch.Size([1, 375, 1242])
torch.Size([1, 376, 1241])
torch.Size([1, 375, 1242])
torch.Size([1, 375, 1242])
torch.Size([1, 374, 1238])
torch.Size([1, 375, 1242])
torch.Size([1, 375, 1242])
torch.Size([1, 375, 1242])
torch.Size([1, 370, 1224])
torch.Size([1, 375, 1242])
torch.Size([1, 375, 1242])
torch.Size([1, 375, 1242])
torch.Size([1, 375, 1242])
torch.Size([1, 376, 1241])
torch.Size([1, 375, 1242])
torch.Size([1, 375, 1242])
torch.Size([1, 375, 1242])
torch.Size([1, 376, 1241])
torch.Size([1, 375, 1242])
torch.Size([1, 375, 1242])
torch.Size([1, 375, 1242])
torch.Size([1, 375, 1242])
torch.Size([1, 375, 1242])
torch.Size([1, 375, 1242])
torch.Size([1, 375, 1242])
torch.Size([1, 375, 1242])
torch.Size([1, 375, 1242])
torch.Size([1, 375, 1242])
torch.Size([1, 375, 1242])
torch.Size([1, 375, 1242])
torch.Size([1, 375, 1242])
torch.Size([1, 376, 1241])
torch.Size([1, 374, 1238])
torch.Size([1, 375, 1242])
torch.Size([1, 375, 1242])
torch.Size([1, 375, 1242])
torch.Size([1, 376, 1241])
torch.Size([1, 375, 1242])
torch.Size([1, 376, 1241])
torch.Size([1, 374, 1238])
torch.Size([1, 376, 1241])
torch.Size([1, 375, 1242])
torch.Size([1, 375, 1242])
torch.Size([1, 375, 1242])
torch.Size([1, 376, 1241])
torch.Size([1, 375, 1242])
torch.Size([1, 375, 1242])

 352 or  384
 
1216 or 1248

### Criterions

In [None]:
# class NonZeroWrapper(nn.Module):
#     def __init__(self, criterion, eps=0):
#         super(NonZeroWrapper, self).__init__()
#         self.criterion = criterion
#         self.eps = eps
#     def forward(self, x, target):
#         mask = target > self.eps
#         loss = self.criterion(x, target)
#         return loss[mask].mean()

# class MaxDispWrapper(nn.Module):
#     def __init__(self, criterion, max_disp=192):
#         super(MaxDispWrapper, self).__init__()
#         self.criterion = criterion
#         self.max_disp = max_disp
#     def forward(self, x, target):
#         mask = target <= self.max_disp
#         loss = self.criterion(x, target)
#         return loss[mask].mean()

# class PE(nn.Module):
#     def __init__(self, threshold=3, reduction='mean'):
#         super(PE, self).__init__()
#         self.diff = nn.L1Loss(reduction='none')
#         self.threshold = threshold
#         self.reduction = reduction
#     def forward(self, x, target):
#         err = self.diff(x, target)
#         mask_err = torch.ge(err, self.threshold)
#         mask_ratio = torch.ge(err, 0.05*target)
#         mask = torch.logical_and(mask_err, mask_ratio)
#         loss = torch.where(mask, torch.ones_like(target), torch.zeros_like(target))
#         if self.reduction == 'mean':
#           return loss.mean()
#         else:
#           return loss

# criterion_smoothl1 = torch.nn.SmoothL1Loss(reduction='mean', beta=1.0)
# wrapped_criterion_smoothl1 = NonZeroWrapper(torch.nn.SmoothL1Loss(reduction='none', beta=1.0))
# criterion_3PE = PE()
# wrapped_criterion_3PE = NonZeroWrapper(PE(reduction='none'))
# max_wrapped_criterion_smoothl1 = MaxDispWrapper(torch.nn.SmoothL1Loss(reduction='none', beta=1.0))

In [None]:
# B, C, H, W = 5, 3, 368, 1216
# batch_a = 1*torch.randn(B, C, H, W).cuda()
# batch_b = 500*torch.ones(B, C, H, W).cuda()
# print(batch_a.shape, batch_b.shape)
# print(torch.allclose(batch_a, batch_b), torch.allclose(batch_a, batch_a), torch.allclose(batch_b, batch_b))
# mask = torch.randn(B, C, H, W).cuda() < 0.3
# batch_b[mask] = 0

In [None]:
# loss = wrapped_criterion_smoothl1(batch_a, batch_b)
# print(loss)
# loss_max_disp = max_wrapped_criterion_smoothl1(batch_a, batch_b)
# print(loss_max_disp)
# loss_nw = criterion_smoothl1(batch_a, batch_b)
# print(loss_nw)
# loss_3pe = criterion_3PE(batch_a, batch_b)
# print(loss_3pe)
# loss_3pe_wrapped = wrapped_criterion_3PE(batch_a, batch_b)
# print(loss_3pe_wrapped)

### Visualization

In [None]:
# Error cmap for GwC Net is available: https://github.com/xy-guo/GwcNet/blob/master/utils/visualization.py

# For disparity maps, use Turbo: https://ai.googleblog.com/2019/08/turbo-improved-rainbow-colormap-for.html

In [None]:
# # As per: https://gist.github.com/mikhailov-work/ee72ba4191942acecc03fe6da94fc73f
# import matplotlib.cm
# from matplotlib.colors import ListedColormap
# turbo_colormap_data = [[0.18995,0.07176,0.23217],[0.19483,0.08339,0.26149],[0.19956,0.09498,0.29024],[0.20415,0.10652,0.31844],[0.20860,0.11802,0.34607],[0.21291,0.12947,0.37314],[0.21708,0.14087,0.39964],[0.22111,0.15223,0.42558],[0.22500,0.16354,0.45096],[0.22875,0.17481,0.47578],[0.23236,0.18603,0.50004],[0.23582,0.19720,0.52373],[0.23915,0.20833,0.54686],[0.24234,0.21941,0.56942],[0.24539,0.23044,0.59142],[0.24830,0.24143,0.61286],[0.25107,0.25237,0.63374],[0.25369,0.26327,0.65406],[0.25618,0.27412,0.67381],[0.25853,0.28492,0.69300],[0.26074,0.29568,0.71162],[0.26280,0.30639,0.72968],[0.26473,0.31706,0.74718],[0.26652,0.32768,0.76412],[0.26816,0.33825,0.78050],[0.26967,0.34878,0.79631],[0.27103,0.35926,0.81156],[0.27226,0.36970,0.82624],[0.27334,0.38008,0.84037],[0.27429,0.39043,0.85393],[0.27509,0.40072,0.86692],[0.27576,0.41097,0.87936],[0.27628,0.42118,0.89123],[0.27667,0.43134,0.90254],[0.27691,0.44145,0.91328],[0.27701,0.45152,0.92347],[0.27698,0.46153,0.93309],[0.27680,0.47151,0.94214],[0.27648,0.48144,0.95064],[0.27603,0.49132,0.95857],[0.27543,0.50115,0.96594],[0.27469,0.51094,0.97275],[0.27381,0.52069,0.97899],[0.27273,0.53040,0.98461],[0.27106,0.54015,0.98930],[0.26878,0.54995,0.99303],[0.26592,0.55979,0.99583],[0.26252,0.56967,0.99773],[0.25862,0.57958,0.99876],[0.25425,0.58950,0.99896],[0.24946,0.59943,0.99835],[0.24427,0.60937,0.99697],[0.23874,0.61931,0.99485],[0.23288,0.62923,0.99202],[0.22676,0.63913,0.98851],[0.22039,0.64901,0.98436],[0.21382,0.65886,0.97959],[0.20708,0.66866,0.97423],[0.20021,0.67842,0.96833],[0.19326,0.68812,0.96190],[0.18625,0.69775,0.95498],[0.17923,0.70732,0.94761],[0.17223,0.71680,0.93981],[0.16529,0.72620,0.93161],[0.15844,0.73551,0.92305],[0.15173,0.74472,0.91416],[0.14519,0.75381,0.90496],[0.13886,0.76279,0.89550],[0.13278,0.77165,0.88580],[0.12698,0.78037,0.87590],[0.12151,0.78896,0.86581],[0.11639,0.79740,0.85559],[0.11167,0.80569,0.84525],[0.10738,0.81381,0.83484],[0.10357,0.82177,0.82437],[0.10026,0.82955,0.81389],[0.09750,0.83714,0.80342],[0.09532,0.84455,0.79299],[0.09377,0.85175,0.78264],[0.09287,0.85875,0.77240],[0.09267,0.86554,0.76230],[0.09320,0.87211,0.75237],[0.09451,0.87844,0.74265],[0.09662,0.88454,0.73316],[0.09958,0.89040,0.72393],[0.10342,0.89600,0.71500],[0.10815,0.90142,0.70599],[0.11374,0.90673,0.69651],[0.12014,0.91193,0.68660],[0.12733,0.91701,0.67627],[0.13526,0.92197,0.66556],[0.14391,0.92680,0.65448],[0.15323,0.93151,0.64308],[0.16319,0.93609,0.63137],[0.17377,0.94053,0.61938],[0.18491,0.94484,0.60713],[0.19659,0.94901,0.59466],[0.20877,0.95304,0.58199],[0.22142,0.95692,0.56914],[0.23449,0.96065,0.55614],[0.24797,0.96423,0.54303],[0.26180,0.96765,0.52981],[0.27597,0.97092,0.51653],[0.29042,0.97403,0.50321],[0.30513,0.97697,0.48987],[0.32006,0.97974,0.47654],[0.33517,0.98234,0.46325],[0.35043,0.98477,0.45002],[0.36581,0.98702,0.43688],[0.38127,0.98909,0.42386],[0.39678,0.99098,0.41098],[0.41229,0.99268,0.39826],[0.42778,0.99419,0.38575],[0.44321,0.99551,0.37345],[0.45854,0.99663,0.36140],[0.47375,0.99755,0.34963],[0.48879,0.99828,0.33816],[0.50362,0.99879,0.32701],[0.51822,0.99910,0.31622],[0.53255,0.99919,0.30581],[0.54658,0.99907,0.29581],[0.56026,0.99873,0.28623],[0.57357,0.99817,0.27712],[0.58646,0.99739,0.26849],[0.59891,0.99638,0.26038],[0.61088,0.99514,0.25280],[0.62233,0.99366,0.24579],[0.63323,0.99195,0.23937],[0.64362,0.98999,0.23356],[0.65394,0.98775,0.22835],[0.66428,0.98524,0.22370],[0.67462,0.98246,0.21960],[0.68494,0.97941,0.21602],[0.69525,0.97610,0.21294],[0.70553,0.97255,0.21032],[0.71577,0.96875,0.20815],[0.72596,0.96470,0.20640],[0.73610,0.96043,0.20504],[0.74617,0.95593,0.20406],[0.75617,0.95121,0.20343],[0.76608,0.94627,0.20311],[0.77591,0.94113,0.20310],[0.78563,0.93579,0.20336],[0.79524,0.93025,0.20386],[0.80473,0.92452,0.20459],[0.81410,0.91861,0.20552],[0.82333,0.91253,0.20663],[0.83241,0.90627,0.20788],[0.84133,0.89986,0.20926],[0.85010,0.89328,0.21074],[0.85868,0.88655,0.21230],[0.86709,0.87968,0.21391],[0.87530,0.87267,0.21555],[0.88331,0.86553,0.21719],[0.89112,0.85826,0.21880],[0.89870,0.85087,0.22038],[0.90605,0.84337,0.22188],[0.91317,0.83576,0.22328],[0.92004,0.82806,0.22456],[0.92666,0.82025,0.22570],[0.93301,0.81236,0.22667],[0.93909,0.80439,0.22744],[0.94489,0.79634,0.22800],[0.95039,0.78823,0.22831],[0.95560,0.78005,0.22836],[0.96049,0.77181,0.22811],[0.96507,0.76352,0.22754],[0.96931,0.75519,0.22663],[0.97323,0.74682,0.22536],[0.97679,0.73842,0.22369],[0.98000,0.73000,0.22161],[0.98289,0.72140,0.21918],[0.98549,0.71250,0.21650],[0.98781,0.70330,0.21358],[0.98986,0.69382,0.21043],[0.99163,0.68408,0.20706],[0.99314,0.67408,0.20348],[0.99438,0.66386,0.19971],[0.99535,0.65341,0.19577],[0.99607,0.64277,0.19165],[0.99654,0.63193,0.18738],[0.99675,0.62093,0.18297],[0.99672,0.60977,0.17842],[0.99644,0.59846,0.17376],[0.99593,0.58703,0.16899],[0.99517,0.57549,0.16412],[0.99419,0.56386,0.15918],[0.99297,0.55214,0.15417],[0.99153,0.54036,0.14910],[0.98987,0.52854,0.14398],[0.98799,0.51667,0.13883],[0.98590,0.50479,0.13367],[0.98360,0.49291,0.12849],[0.98108,0.48104,0.12332],[0.97837,0.46920,0.11817],[0.97545,0.45740,0.11305],[0.97234,0.44565,0.10797],[0.96904,0.43399,0.10294],[0.96555,0.42241,0.09798],[0.96187,0.41093,0.09310],[0.95801,0.39958,0.08831],[0.95398,0.38836,0.08362],[0.94977,0.37729,0.07905],[0.94538,0.36638,0.07461],[0.94084,0.35566,0.07031],[0.93612,0.34513,0.06616],[0.93125,0.33482,0.06218],[0.92623,0.32473,0.05837],[0.92105,0.31489,0.05475],[0.91572,0.30530,0.05134],[0.91024,0.29599,0.04814],[0.90463,0.28696,0.04516],[0.89888,0.27824,0.04243],[0.89298,0.26981,0.03993],[0.88691,0.26152,0.03753],[0.88066,0.25334,0.03521],[0.87422,0.24526,0.03297],[0.86760,0.23730,0.03082],[0.86079,0.22945,0.02875],[0.85380,0.22170,0.02677],[0.84662,0.21407,0.02487],[0.83926,0.20654,0.02305],[0.83172,0.19912,0.02131],[0.82399,0.19182,0.01966],[0.81608,0.18462,0.01809],[0.80799,0.17753,0.01660],[0.79971,0.17055,0.01520],[0.79125,0.16368,0.01387],[0.78260,0.15693,0.01264],[0.77377,0.15028,0.01148],[0.76476,0.14374,0.01041],[0.75556,0.13731,0.00942],[0.74617,0.13098,0.00851],[0.73661,0.12477,0.00769],[0.72686,0.11867,0.00695],[0.71692,0.11268,0.00629],[0.70680,0.10680,0.00571],[0.69650,0.10102,0.00522],[0.68602,0.09536,0.00481],[0.67535,0.08980,0.00449],[0.66449,0.08436,0.00424],[0.65345,0.07902,0.00408],[0.64223,0.07380,0.00401],[0.63082,0.06868,0.00401],[0.61923,0.06367,0.00410],[0.60746,0.05878,0.00427],[0.59550,0.05399,0.00453],[0.58336,0.04931,0.00486],[0.57103,0.04474,0.00529],[0.55852,0.04028,0.00579],[0.54583,0.03593,0.00638],[0.53295,0.03169,0.00705],[0.51989,0.02756,0.00780],[0.50664,0.02354,0.00863],[0.49321,0.01963,0.00955],[0.47960,0.01583,0.01055]]
# matplotlib.cm.register_cmap('turbo', cmap=ListedColormap(turbo_colormap_data))



# batch_disp = torch.tile(torch.arange(192).unsqueeze(1), (1,192)).float()
# batch_disp[100:160, 20:60] = 175
# batch_disp[25:85, 130:170] = 17
# print(batch_disp.shape)

# plt.imshow(batch_disp, cmap='turbo', vmin=0, vmax=192)
# plt.colorbar()
# plt.show()

In [None]:
# # https://matplotlib.org/cmocean/
# import cmocean

# batch_diff = torch.tile(torch.arange(200).unsqueeze(1), (1,200)).float()
# batch_diff[100:160, 20:60] = 200
# batch_diff[25:85, 130:170] = 0
# batch_diff -= 100
# batch_diff = torch.abs(batch_diff)
# print(batch_diff.shape)

# plt.imshow(batch_diff, cmap='Spectral_r', vmin=0, vmax=90)
# plt.title('spectral_r')
# plt.colorbar()
# plt.show()

# plt.imshow(batch_diff, cmap='cmo.thermal', vmin=0, vmax=90)
# plt.title('thermal')
# plt.colorbar()
# plt.show()

# # Clipped colormap might be interesting for errors
# newcmap = cmocean.tools.crop_by_percent(cmocean.cm.oxy_r, 20, which='min', N=None)

# plt.imshow(batch_diff, cmap=newcmap, vmin=0, vmax=90)
# plt.title('oxy')
# plt.colorbar()
# plt.show()

# # Also try: https://stackoverflow.com/questions/16400241/how-to-redefine-a-color-for-a-specific-value-in-a-matplotlib-colormap for masked output

In [None]:
# import copy
# # make a function to plot disparity maps
# # make a function to plot disparity map errors
# # both should be able to write onto an ax or produce an individual plot
# # disparity map should have an option to highlight invalid values
# # disparity map errors should have two color map modes and an option to exclude incomplete groundtruth



# def show_disparity(disparity, ax=None, vmin=0, vmax=192, remove_invalid=True, title='Disparity Map'):
#     disparity = np.squeeze(disparity)
#     turbo_colormap_data = [[0.18995,0.07176,0.23217],[0.19483,0.08339,0.26149],[0.19956,0.09498,0.29024],[0.20415,0.10652,0.31844],[0.20860,0.11802,0.34607],[0.21291,0.12947,0.37314],[0.21708,0.14087,0.39964],[0.22111,0.15223,0.42558],[0.22500,0.16354,0.45096],[0.22875,0.17481,0.47578],[0.23236,0.18603,0.50004],[0.23582,0.19720,0.52373],[0.23915,0.20833,0.54686],[0.24234,0.21941,0.56942],[0.24539,0.23044,0.59142],[0.24830,0.24143,0.61286],[0.25107,0.25237,0.63374],[0.25369,0.26327,0.65406],[0.25618,0.27412,0.67381],[0.25853,0.28492,0.69300],[0.26074,0.29568,0.71162],[0.26280,0.30639,0.72968],[0.26473,0.31706,0.74718],[0.26652,0.32768,0.76412],[0.26816,0.33825,0.78050],[0.26967,0.34878,0.79631],[0.27103,0.35926,0.81156],[0.27226,0.36970,0.82624],[0.27334,0.38008,0.84037],[0.27429,0.39043,0.85393],[0.27509,0.40072,0.86692],[0.27576,0.41097,0.87936],[0.27628,0.42118,0.89123],[0.27667,0.43134,0.90254],[0.27691,0.44145,0.91328],[0.27701,0.45152,0.92347],[0.27698,0.46153,0.93309],[0.27680,0.47151,0.94214],[0.27648,0.48144,0.95064],[0.27603,0.49132,0.95857],[0.27543,0.50115,0.96594],[0.27469,0.51094,0.97275],[0.27381,0.52069,0.97899],[0.27273,0.53040,0.98461],[0.27106,0.54015,0.98930],[0.26878,0.54995,0.99303],[0.26592,0.55979,0.99583],[0.26252,0.56967,0.99773],[0.25862,0.57958,0.99876],[0.25425,0.58950,0.99896],[0.24946,0.59943,0.99835],[0.24427,0.60937,0.99697],[0.23874,0.61931,0.99485],[0.23288,0.62923,0.99202],[0.22676,0.63913,0.98851],[0.22039,0.64901,0.98436],[0.21382,0.65886,0.97959],[0.20708,0.66866,0.97423],[0.20021,0.67842,0.96833],[0.19326,0.68812,0.96190],[0.18625,0.69775,0.95498],[0.17923,0.70732,0.94761],[0.17223,0.71680,0.93981],[0.16529,0.72620,0.93161],[0.15844,0.73551,0.92305],[0.15173,0.74472,0.91416],[0.14519,0.75381,0.90496],[0.13886,0.76279,0.89550],[0.13278,0.77165,0.88580],[0.12698,0.78037,0.87590],[0.12151,0.78896,0.86581],[0.11639,0.79740,0.85559],[0.11167,0.80569,0.84525],[0.10738,0.81381,0.83484],[0.10357,0.82177,0.82437],[0.10026,0.82955,0.81389],[0.09750,0.83714,0.80342],[0.09532,0.84455,0.79299],[0.09377,0.85175,0.78264],[0.09287,0.85875,0.77240],[0.09267,0.86554,0.76230],[0.09320,0.87211,0.75237],[0.09451,0.87844,0.74265],[0.09662,0.88454,0.73316],[0.09958,0.89040,0.72393],[0.10342,0.89600,0.71500],[0.10815,0.90142,0.70599],[0.11374,0.90673,0.69651],[0.12014,0.91193,0.68660],[0.12733,0.91701,0.67627],[0.13526,0.92197,0.66556],[0.14391,0.92680,0.65448],[0.15323,0.93151,0.64308],[0.16319,0.93609,0.63137],[0.17377,0.94053,0.61938],[0.18491,0.94484,0.60713],[0.19659,0.94901,0.59466],[0.20877,0.95304,0.58199],[0.22142,0.95692,0.56914],[0.23449,0.96065,0.55614],[0.24797,0.96423,0.54303],[0.26180,0.96765,0.52981],[0.27597,0.97092,0.51653],[0.29042,0.97403,0.50321],[0.30513,0.97697,0.48987],[0.32006,0.97974,0.47654],[0.33517,0.98234,0.46325],[0.35043,0.98477,0.45002],[0.36581,0.98702,0.43688],[0.38127,0.98909,0.42386],[0.39678,0.99098,0.41098],[0.41229,0.99268,0.39826],[0.42778,0.99419,0.38575],[0.44321,0.99551,0.37345],[0.45854,0.99663,0.36140],[0.47375,0.99755,0.34963],[0.48879,0.99828,0.33816],[0.50362,0.99879,0.32701],[0.51822,0.99910,0.31622],[0.53255,0.99919,0.30581],[0.54658,0.99907,0.29581],[0.56026,0.99873,0.28623],[0.57357,0.99817,0.27712],[0.58646,0.99739,0.26849],[0.59891,0.99638,0.26038],[0.61088,0.99514,0.25280],[0.62233,0.99366,0.24579],[0.63323,0.99195,0.23937],[0.64362,0.98999,0.23356],[0.65394,0.98775,0.22835],[0.66428,0.98524,0.22370],[0.67462,0.98246,0.21960],[0.68494,0.97941,0.21602],[0.69525,0.97610,0.21294],[0.70553,0.97255,0.21032],[0.71577,0.96875,0.20815],[0.72596,0.96470,0.20640],[0.73610,0.96043,0.20504],[0.74617,0.95593,0.20406],[0.75617,0.95121,0.20343],[0.76608,0.94627,0.20311],[0.77591,0.94113,0.20310],[0.78563,0.93579,0.20336],[0.79524,0.93025,0.20386],[0.80473,0.92452,0.20459],[0.81410,0.91861,0.20552],[0.82333,0.91253,0.20663],[0.83241,0.90627,0.20788],[0.84133,0.89986,0.20926],[0.85010,0.89328,0.21074],[0.85868,0.88655,0.21230],[0.86709,0.87968,0.21391],[0.87530,0.87267,0.21555],[0.88331,0.86553,0.21719],[0.89112,0.85826,0.21880],[0.89870,0.85087,0.22038],[0.90605,0.84337,0.22188],[0.91317,0.83576,0.22328],[0.92004,0.82806,0.22456],[0.92666,0.82025,0.22570],[0.93301,0.81236,0.22667],[0.93909,0.80439,0.22744],[0.94489,0.79634,0.22800],[0.95039,0.78823,0.22831],[0.95560,0.78005,0.22836],[0.96049,0.77181,0.22811],[0.96507,0.76352,0.22754],[0.96931,0.75519,0.22663],[0.97323,0.74682,0.22536],[0.97679,0.73842,0.22369],[0.98000,0.73000,0.22161],[0.98289,0.72140,0.21918],[0.98549,0.71250,0.21650],[0.98781,0.70330,0.21358],[0.98986,0.69382,0.21043],[0.99163,0.68408,0.20706],[0.99314,0.67408,0.20348],[0.99438,0.66386,0.19971],[0.99535,0.65341,0.19577],[0.99607,0.64277,0.19165],[0.99654,0.63193,0.18738],[0.99675,0.62093,0.18297],[0.99672,0.60977,0.17842],[0.99644,0.59846,0.17376],[0.99593,0.58703,0.16899],[0.99517,0.57549,0.16412],[0.99419,0.56386,0.15918],[0.99297,0.55214,0.15417],[0.99153,0.54036,0.14910],[0.98987,0.52854,0.14398],[0.98799,0.51667,0.13883],[0.98590,0.50479,0.13367],[0.98360,0.49291,0.12849],[0.98108,0.48104,0.12332],[0.97837,0.46920,0.11817],[0.97545,0.45740,0.11305],[0.97234,0.44565,0.10797],[0.96904,0.43399,0.10294],[0.96555,0.42241,0.09798],[0.96187,0.41093,0.09310],[0.95801,0.39958,0.08831],[0.95398,0.38836,0.08362],[0.94977,0.37729,0.07905],[0.94538,0.36638,0.07461],[0.94084,0.35566,0.07031],[0.93612,0.34513,0.06616],[0.93125,0.33482,0.06218],[0.92623,0.32473,0.05837],[0.92105,0.31489,0.05475],[0.91572,0.30530,0.05134],[0.91024,0.29599,0.04814],[0.90463,0.28696,0.04516],[0.89888,0.27824,0.04243],[0.89298,0.26981,0.03993],[0.88691,0.26152,0.03753],[0.88066,0.25334,0.03521],[0.87422,0.24526,0.03297],[0.86760,0.23730,0.03082],[0.86079,0.22945,0.02875],[0.85380,0.22170,0.02677],[0.84662,0.21407,0.02487],[0.83926,0.20654,0.02305],[0.83172,0.19912,0.02131],[0.82399,0.19182,0.01966],[0.81608,0.18462,0.01809],[0.80799,0.17753,0.01660],[0.79971,0.17055,0.01520],[0.79125,0.16368,0.01387],[0.78260,0.15693,0.01264],[0.77377,0.15028,0.01148],[0.76476,0.14374,0.01041],[0.75556,0.13731,0.00942],[0.74617,0.13098,0.00851],[0.73661,0.12477,0.00769],[0.72686,0.11867,0.00695],[0.71692,0.11268,0.00629],[0.70680,0.10680,0.00571],[0.69650,0.10102,0.00522],[0.68602,0.09536,0.00481],[0.67535,0.08980,0.00449],[0.66449,0.08436,0.00424],[0.65345,0.07902,0.00408],[0.64223,0.07380,0.00401],[0.63082,0.06868,0.00401],[0.61923,0.06367,0.00410],[0.60746,0.05878,0.00427],[0.59550,0.05399,0.00453],[0.58336,0.04931,0.00486],[0.57103,0.04474,0.00529],[0.55852,0.04028,0.00579],[0.54583,0.03593,0.00638],[0.53295,0.03169,0.00705],[0.51989,0.02756,0.00780],[0.50664,0.02354,0.00863],[0.49321,0.01963,0.00955],[0.47960,0.01583,0.01055]]
#     cmap = ListedColormap(turbo_colormap_data)
#     if remove_invalid:
#         cmap.set_under('k')
#     if ax:
#         ln = ax.imshow(disparity, cmap=cmap, vmin=vmin, vmax=vmax)
#         ln = ax.set_title(title)
#         ln = ax.axis('off')
#         return ln
#     else:
#         plt.imshow(disparity, cmap=cmap, vmin=vmin, vmax=vmax)
#         plt.title(title)
#         plt.axis('off')
#         plt.show()
#     return None

# def show_error(error, ax=None, vmin=0, vmax=100, remove_invalid=True, title='Absolute Error', highlight_large=False, diverging=False):
#     error = np.squeeze(error)
#     cmap = matplotlib.cm.get_cmap('Spectral_r') if diverging else cmocean.cm.thermal
#     if highlight_large:
#         cmap = cmocean.tools.crop_by_percent(cmocean.cm.oxy_r, 20, which='min', N=None)
#     cmap = copy.deepcopy(cmap)
#     if remove_invalid:
#         cmap.set_under('k')
#     if ax:
#         ln = ax.imshow(error, cmap=cmap) if highlight_large else ax.imshow(error, cmap=cmap, vmin=vmin, vmax=vmax)
#         ln = ax.set_title(title)
#         ln = ax.axis('off')
#         return ln
#     else:
#         plt.imshow(error, cmap=cmap) if highlight_large else plt.imshow(error, cmap=cmap, vmin=vmin, vmax=vmax)
#         plt.title(title)
#         plt.axis('off')
#         plt.show()
#     return None

# def show_image(image, ax=None, title='Frame'):
#     image = image.permute(1,2,0)
#     if ax:
#         ln = ax.imshow(image)
#         ln = ax.set_title(title)
#         ln = ax.axis('off')
#         return ln
#     else:
#         plt.imshow(image)
#         plt.title(title)
#         plt.axis('off')
#         plt.show()
#     return None

In [None]:
# batch_imgs = torch.rand((3,256,512))

# batch_disp_masked = batch_disp.clone()
# mask = torch.rand_like(batch_disp_masked) < 0.
# batch_disp_masked[mask] = -1

# fig, (ax1, ax2, ax3) = plt.subplots(1, 3)
# show_disparity(batch_disp_masked, ax1)
# show_error(batch_diff, ax2)
# show_image(batch_imgs, ax3)
# plt.show()


# show_disparity(batch_disp_masked)
# show_error(batch_diff)
# show_image(batch_imgs)

In [None]:
def conv3d3x3(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv3d:
    """3x3 convolution with padding"""
    return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)


def conv3d1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv3d:
    """1x1 convolution"""
    return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

    
class BasicBlock3D(nn.Module):
    def __init__(self, inplanes, planes, stride = 1, norm_layer = None):
        super(BasicBlock3D, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm3d
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3d3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3d3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = None if (stride == 1 and inplanes == planes) else nn.Sequential(conv3d1x1(inplanes, planes, stride), norm_layer(planes))
    
    def forward(self, x):
        # type: (Tensor) -> Tensor
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        out = self.relu(out)
        return out

# in:  B x N_Groups + 2*C_Cat x Disp//4 x H//4 x W//4
# out: B x 1                  x Disp//4 x H//4 x W//4

C3d = 64
test_conv_3d = nn.Sequential(
    BasicBlock3D(inplanes=C3d,    planes=C3d//2),
    BasicBlock3D(inplanes=C3d//2, planes=C3d//4),
    BasicBlock3D(inplanes=C3d//4, planes=C3d//8),
    BasicBlock3D(inplanes=C3d//8, planes=1)
)

In [None]:
test_conv_3d = test_conv_3d.cuda()
tmp_batch = torch.rand((2,64, 192//4, 256//4, 512//4)).cuda()
tmp_out = test_conv_3d(tmp_batch)
print(tmp_out.shape)

In [None]:
# Notes: Apply relu to sum
# batchnorm after every conv3d
