In [1]:
import torch
from torch import nn
from torch.nn import functional as F 

# print(torch.cuda.is_available())
device = torch.device("cuda:0" if torch.cuda.is_available else "cpu")
print(device)

cuda:0


In [2]:
class Basicblock(nn.Module):
    def __init__(self, input_channel, output_channel, use_1x1conv=False,
                 stride=1) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=3, 
                               padding=1, stride=stride)
        self.conv2 = nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=3, 
                               padding=1, stride=stride)
        if use_1x1conv:
            self.conv3 = nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=1, 
                               padding=0, stride=stride)
        else:
            self.conv3 = None
        self.bn1 = nn.BatchNorm2d(output_channel)
        self.bn2 = nn.BatchNorm2d(output_channel)
    
    def forward(self, X):
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        Y = Y + X
        return F.relu(Y)
    
class GC_Extractor(nn.Module):
    def __init__(self, input_channel, output_channel, num_resblock) -> None:
        super(GC_Extractor, self).__init__()
        self.conv_in = nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=5, padding=2, stride=2)
        self.bn_in = nn.BatchNorm2d(output_channel)
        self.resblock = self._make_layer(num_channel=output_channel, num_resblock=num_resblock)
        self.conv_last = nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=3, padding=1, stride=1)
    
    def _make_layer(self, num_channel, num_resblock):
        resblk = []
        for i in range(num_resblock):
            resblk.append(Basicblock(num_channel, num_channel))
        return nn.Sequential(*resblk)    
    
    def forward(self, x):
        x = self.conv_in(x)
        x = self.bn_in(x)
        x = F.relu(x)
        x = self.resblock(x)
        Y = self.conv_last(x)
        return Y

In [3]:
class downsampleblock(nn.Module):
    def __init__(self, input_channel, output_channel, stride=2) -> None:
        super().__init__()
        self.conv1 = nn.Conv3d(in_channels=input_channel, out_channels=output_channel, kernel_size=3, 
                               padding=1, stride=stride)        
        self.conv2 = nn.Conv3d(in_channels=output_channel, out_channels=output_channel, kernel_size=3, 
                               padding=1, stride=1)
        self.conv3 = nn.Conv3d(in_channels=output_channel, out_channels=output_channel, kernel_size=3, 
                               padding=1, stride=1)

        self.bn1 = nn.BatchNorm3d(output_channel)
        self.bn2 = nn.BatchNorm3d(output_channel)
        self.bn3 = nn.BatchNorm3d(output_channel)
    
    def forward(self, cost):
        Y = F.relu(self.bn1(self.conv1(cost)))
        Y = F.relu(self.bn2(self.conv2(Y)))
        Y = F.relu(self.bn3(self.conv3(Y)))
        return Y
    
class Hourglass(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv_in1 = nn.Conv3d(in_channels=64, out_channels=32, kernel_size=3, padding=1, stride=1)
        self.conv_in2 = nn.Conv3d(in_channels=32, out_channels=32, kernel_size=3, padding=1, stride=1)
        self.bn_in1 = nn.BatchNorm3d(32)
        self.bn_in2 = nn.BatchNorm3d(32)

        # downsample layer
        self.downsample1 = nn.Sequential(downsampleblock(32,64,2))
        self.downsample2 = nn.Sequential(downsampleblock(64,64,2))
        self.downsample3 = nn.Sequential(downsampleblock(64,64,2))
        self.downsample4 = nn.Sequential(downsampleblock(64,128,2))

        # upsample layer
        self.upsample1 = nn.ConvTranspose3d(128, 64, kernel_size=3, padding=1, output_padding=1, stride=2)
        self.debn1 = nn.BatchNorm3d(64)
        self.upsample2 = nn.ConvTranspose3d(64, 64, kernel_size=3, padding=1, output_padding=1, stride=2)
        self.debn2 = nn.BatchNorm3d(64)
        self.upsample3 = nn.ConvTranspose3d(64, 64, kernel_size=3, padding=1, output_padding=1, stride=2)
        self.debn3 = nn.BatchNorm3d(64)
        self.upsample4 = nn.ConvTranspose3d(64, 32, kernel_size=3, padding=1, output_padding=1, stride=2)
        self.debn4 = nn.BatchNorm3d(32)
        self.upsample5 = nn.ConvTranspose3d(32, 1, kernel_size=3, padding=1, output_padding=1, stride=2)

    def forward(self, cost):
        cost_in1 = F.relu(self.bn_in1(self.conv_in1(cost)))
        cost_in1 = F.relu(self.bn_in2(self.conv_in2(cost_in1)))

        #downsample
        cost_down1 = self.downsample1(cost_in1)
        cost_down2 = self.downsample2(cost_down1)   
        cost_down3 = self.downsample3(cost_down2)
        cost_down4 = self.downsample4(cost_down3)

        #upsample
        cost_up1 = self.debn1(self.upsample1(cost_down4))
        cost_up1 = F.relu( cost_up1 + cost_down3 )
        cost_up2 = self.debn2(self.upsample2(cost_up1))
        cost_up2 = F.relu( cost_up2 + cost_down2 )
        cost_up3 = self.debn3(self.upsample3(cost_up2))
        cost_up3 = F.relu( cost_up3 + cost_down1 )
        cost_up4 = self.debn4(self.upsample4(cost_up3))
        cost_up4 = F.relu( cost_up4 + cost_in1 )
        cost_out = self.upsample5(cost_up4)
        return cost_out


In [4]:
# a = torch.tensor(
#     [[[[1,2,3],[4,5,6],[7,8,9]], 
#     [[9,8,7],[6,5,4],[3,2,1]]],
#     [[[1,2,3],[4,5,6],[7,8,9]], 
#     [[9,8,7],[6,5,4],[3,2,1]]]]
# )
# print(a.shape)
# print(a)
# b = a.unsqueeze(2)
# b = b.repeat(1,1,5,1,1)
# print(b.shape)
# print(b[1,:,1,1,0])

In [5]:
# def cost_volume(feaL, feaR, min_disp, max_disp):
    # B, C, H, W = feaL.shape
    # max_disp = int(max_disp/2)
    # min_disp = int(min_disp/2)
    # cost = torch.zeros(B, C*2, max_disp-min_disp, H, W)
    # cost[:, 0:C, :, :, :] = feaL.unsqueeze(2).repeat(1,1,max_disp-min_disp,1,1)
    # for i in range(min_disp, max_disp):
    #     if i < 0:
    #         cost[:, C:, i, :, :W+i] = feaR[:, :, :, -i:]
    #     if i >= 0:
    #         cost[:, C:, i, :, i:] = feaR[:, :, :, :W-i]
    # return cost

In [6]:
class GCNet(nn.Module):
    def __init__(self, device):
        super().__init__()
        self.device = device
        self.fea1 = nn.Sequential( GC_Extractor(3, 32, 8) )
        self.fea2 = nn.Sequential( GC_Extractor(3, 32, 8) )
        self.hourglass = nn.Sequential( Hourglass() )


    def cost_volume(self, feaL, feaR, min_disp, max_disp):
        B, C, H, W = feaL.shape

        # feature map has been downsample, so disparity should be devided by 2
        max_disp = int(max_disp/2)
        min_disp = int(min_disp/2)
        cost = torch.zeros(B, C*2, max_disp-min_disp, H, W).to(self.device)
        cost[:, 0:C, :, :, :] = feaL.unsqueeze(2).repeat(1,1,max_disp-min_disp,1,1)
        for i in range(min_disp, max_disp):
            if i < 0:
                cost[:, C:, i, :, :W+i] = feaR[:, :, :, -i:]
            if i >= 0:
                cost[:, C:, i, :, i:] = feaR[:, :, :, :W-i]
        return cost
    
    def softargmax(self, cost):
        cost_softmax = F.softmax(cost, dim = 2)
        vec = torch.arange(-64, 64).to(device)
        vec = vec.unsqueeze(0).unsqueeze(1).unsqueeze(3).unsqueeze(4)
        vec = vec.expand_as(cost_softmax).type_as(cost_softmax)
        disp = torch.sum(vec*cost_softmax, dim=2)
        return disp
    
    def forward(self, imgL, imgR, min_disp, max_disp):
        #extract feature map
        featureL = self.fea1(imgL)
        featureR = self.fea2(imgR)
        print(featureL.shape, featureR.shape)

        # construct cost volume
        cost_vol = self.cost_volume(featureL, featureR, min_disp, max_disp) # B * 2C * maxdisp-mindisp * H * W
        print(cost_vol.shape)

        # cost filtering
        cost_vol = self.hourglass(cost_vol)

        # disparity regression
        disp = self.softargmax(cost_vol)
        return disp
        
        

In [7]:
net = GCNet(device)
net = net.to(device)
x_l = torch.rand(2,3,512,512).to(device)
x_r = torch.rand(2,3,512,512).to(device)

disp = net(x_l, x_r, -64, 64)
print(disp.shape)



torch.Size([2, 32, 256, 256]) torch.Size([2, 32, 256, 256])
torch.Size([2, 64, 64, 256, 256])
torch.Size([2, 1, 512, 512])
