In [3]:
import torch
from torch import nn
from torch.nn import functional as F 

In [4]:
class Residual(nn.Module):
    def __init__(self, input_channel, output_channel, use_1x1conv=False,
                 stride=1) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=3, 
                               padding=1, stride=stride)
        self.conv2 = nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=3, 
                               padding=1, stride=stride)
        if use_1x1conv:
            self.conv3 = nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=1, 
                               padding=0, stride=stride)
        else:
            self.conv3 = None
        self.bn1 = nn.BatchNorm2d(output_channel)
        self.bn2 = nn.BatchNorm2d(output_channel)
    
    def forward(self, X):
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        Y = Y + X
        return F.relu(Y)

In [5]:
class GC_Extractor(nn.Module):
    def __init__(self, input_channel, output_channel, num_resblock) -> None:
        super(GC_Extractor, self).__init__()
        self.conv_in = nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=5, padding=2, stride=2)
        self.bn_in = nn.BatchNorm2d(output_channel)
        self.resblock = self._make_layer(num_channel=output_channel, num_resblock=num_resblock)
        self.conv_last = nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=3, padding=1, stride=1)
    
    def _make_layer(self, num_channel, num_resblock):
        resblk = []
        for i in range(num_resblock):
            resblk.append(Residual(num_channel, num_channel))
        return nn.Sequential(*resblk)    
    
    def forward(self, x):
        x = self.conv_in(x)
        x = self.bn_in(x)
        x = F.relu(x)
        x = self.resblock(x)
        Y = self.conv_last(x)
        return Y

In [24]:
a = torch.tensor(
    [[[[1,2,3],[4,5,6],[7,8,9]], 
    [[9,8,7],[6,5,4],[3,2,1]]],
    [[[1,2,3],[4,5,6],[7,8,9]], 
    [[9,8,7],[6,5,4],[3,2,1]]]]
)
print(a.shape)
print(a)
b = a.unsqueeze(2)
b = b.repeat(1,1,5,1,1)
print(b.shape)
print(b[1,:,1,1,0])

torch.Size([2, 2, 3, 3])
tensor([[[[1, 2, 3],
          [4, 5, 6],
          [7, 8, 9]],

         [[9, 8, 7],
          [6, 5, 4],
          [3, 2, 1]]],


        [[[1, 2, 3],
          [4, 5, 6],
          [7, 8, 9]],

         [[9, 8, 7],
          [6, 5, 4],
          [3, 2, 1]]]])
torch.Size([2, 2, 5, 3, 3])
tensor([4, 6])


In [30]:
def cost_volume(feaL, feaR, min_disp, max_disp):
    B, C, H, W = feaL.shape
    cost = torch.zeros(B, C*2, max_disp-min_disp, H, W)
    cost[:, 0:C, :, :, :] = feaL.unsqueeze(2).repeat(1,1,max_disp-min_disp,1,1)
    for i in range(min_disp, max_disp):
        if i < 0:
            cost[:, C:, i, :, :W+i] = feaR[:, :, :, -i:]
        if i >= 0:
            cost[:, C:, i, :, i:] = feaR[:, :, :, :W-i]
    return cost

In [44]:
# net = GC_Extractor(3, 32, 8)
# x_l = torch.rand(4,3,99,99)
# y_l = net(x_l)
# x_r = torch.rand(4,3,99,99)
# y_r = net(x_r)
# print(y_l.shape, y_r.shape)
# cost = cost_volume(y_l, y_r, -10, 10)
# print(cost.shape)
# print(cost[0, :, 1:3, 0, 5])


In [None]:
class GCNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fea = nn.Sequential(*GC_Extractor(3, 32, 8))


    def cost_volume(feaL, feaR, min_disp, max_disp):
        B, C, H, W = feaL.shape
        cost = torch.zeros(B, C*2, max_disp-min_disp, H, W)
        cost[:, 0:C, :, :, :] = feaL.unsqueeze(2).repeat(1,1,max_disp-min_disp,1,1)
        for i in range(min_disp, max_disp):
            if i < 0:
                cost[:, C:, i, :, :W+i] = feaR[:, :, :, -i:]
            if i >= 0:
                cost[:, C:, i, :, i:] = feaR[:, :, :, :W-i]
        return cost
    
    def forward(self, imgL, imgR, min_disp, max_disp):
        featureL = self.fea(imgL)
        featureR = self.fea(imgR)
        cost_vol = cost_volume(featureL, featureR, min_disp, max_disp) # B * 2C * maxdisp-mindisp * H * W
        