In [1]:
import cv2
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torchsummary import summary

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from torch import nn

Pool = nn.MaxPool2d

def batchnorm(x):
    return nn.BatchNorm2d(x.size()[1])(x)

class Conv(nn.Module):
    def __init__(self, inp_dim, out_dim, kernel_size=3, stride = 1, bn = False, relu = True):
        super(Conv, self).__init__()
        self.inp_dim = inp_dim
        self.conv = nn.Conv2d(inp_dim, out_dim, kernel_size, stride, padding=(kernel_size-1)//2, bias=True)
        self.relu = None
        self.bn = None
        if relu:
            self.relu = nn.ReLU()
        if bn:
            self.bn = nn.BatchNorm2d(out_dim)

    def forward(self, x):
        assert x.size()[1] == self.inp_dim, "{} {}".format(x.size()[1], self.inp_dim)
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu is not None:
            x = self.relu(x)
        return x
    
class DWConv(nn.Module):
    def __init__(self, inp_dim, out_dim, kernel_size=3, stride = 1, bn = False, relu = True):
        super(DWConv, self).__init__()
        self.inp_dim = inp_dim
        self.depthwise = nn.Conv2d(inp_dim, inp_dim, kernel_size, stride, padding=(kernel_size-1)//2, groups=inp_dim,bias=inp_dim)
        self.pointwise = nn.Conv2d(inp_dim, out_dim, kernel_size=1, groups=1)
        self.relu = None
        self.bn = None
        if relu:
            self.relu = nn.ReLU()
        if bn:
            self.bn = nn.BatchNorm2d(out_dim)

    def forward(self, x):
        assert x.size()[1] == self.inp_dim, "{} {}".format(x.size()[1], self.inp_dim)
        x = self.depthwise(x)
        x = self.pointwise(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu is not None:
            x = self.relu(x)
        return x
class depthwise_separable_conv(nn.Module):
    def init(self, nin, nout, kernel_size=3, stride = 1, bn = False, relu = True):
        super(depthwise_separable_conv, self).init()
        self.depthwise = nn.Conv2d(nin, nin, kernel_size=3, padding=1, groups=nin)
        self.pointwise = nn.Conv2d(nin, nout, kernel_size=1)
    def forward(self, x):
        out = self.depthwise(x)
        out = self.pointwise(out)
        return out


class Residual(nn.Module):
    def __init__(self, inp_dim, out_dim, Conv_method = Conv):
        super(Residual, self).__init__()
        self.relu = nn.ReLU()
        self.bn1 = nn.BatchNorm2d(inp_dim)
        middle_dim = 1 if int(out_dim/2) <= 0 else int(out_dim/2)
        self.conv1 = Conv(inp_dim, middle_dim, 1, relu=False)
        self.bn2 = nn.BatchNorm2d(middle_dim)
        self.conv2 = Conv_method(middle_dim, middle_dim, 3, relu=False)
        self.bn3 = nn.BatchNorm2d(middle_dim)
        self.conv3 = Conv(middle_dim, out_dim, 1, relu=False)
        self.skip_layer = Conv(inp_dim, out_dim, 1, relu=False)
        if inp_dim == out_dim:
            self.need_skip = False
        else:
            self.need_skip = True
        
    def forward(self, x):
        if self.need_skip:
            residual = self.skip_layer(x)
        else:
            residual = x
        out = self.bn1(x)
        out = self.relu(out)
        out = self.conv1(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn3(out)
        out = self.relu(out)
        out = self.conv3(out)
        out += residual
        return out 

def conv1x1(in_channels, out_channels, groups=1):
    return nn.Sequential(nn.Conv2d(
        in_channels,
        out_channels,
        kernel_size=1,
        groups=groups,
        stride=1),
    nn.BatchNorm2d(out_channels))

class Hourglass(nn.Module):
    def __init__(self, n, f, bn=None, increase=0):
        super(Hourglass, self).__init__()
        nf = f + increase
        self.up1 = Residual(f, f)
        # Lower branch
        self.pool1 = Pool(2, 2)
        self.low1 = Residual(f, nf)
        self.n = n
        # Recursive hourglass
        if self.n > 1:
            self.low2 = Hourglass(n-1, nf, bn=bn)
        else:
            self.low2 = Residual(nf, nf)
        self.low3 = Residual(nf, f)
        self.up2 = nn.Upsample(scale_factor=2, mode='nearest')

    def forward(self, x):
        up1  = self.up1(x)
        pool1 = self.pool1(x)
        low1 = self.low1(pool1)
        low2 = self.low2(low1)
        low3 = self.low3(low2)
        up2  = self.up2(low3)
        return up1 + up2


In [27]:
class Merge(nn.Module):
    def __init__(self, x_dim, y_dim):
        super(Merge, self).__init__()
        self.conv = Conv(x_dim, y_dim, 1, relu=False, bn=False)

    def forward(self, x):
        return self.conv(x)
class HeatmapLoss(torch.nn.Module):
    """
    loss for detection heatmap
    """
    def __init__(self):
        super(HeatmapLoss, self).__init__()

    def forward(self, pred, gt):
        l = ((pred - gt)**2)
        l = l.mean(dim=3).mean(dim=2).mean(dim=1)
        return l ## l of dim bsize
    
class ResUnetBlock(torch.nn.Module):
    def __init__(self,nstack, inp_dim, oup_dim, bn, increase, conv_api):
        super(ResUnetBlock, self).__init__()
        self.hgs = nn.ModuleList( [
            nn.Sequential(
            Hourglass(4, inp_dim, bn, increase),
        ) for i in range(nstack)] )
        self.features = nn.ModuleList( [
            nn.Sequential(
            Residual(inp_dim, inp_dim, conv_api),
            conv_api(inp_dim, inp_dim, 1, bn=True, relu=True)
        ) for i in range(nstack)] )
        self.outs = nn.ModuleList( [conv_api(inp_dim, oup_dim, 1, relu=False, bn=False) for i in range(nstack)] )
        self.nstack = nstack
    def forward(self, x):
        combined_hm_preds = []
        for i in range(self.nstack):
            hg = self.hgs[i](x)
            #print("hg:",hg.size())
            feature = self.features[i](hg)
            #print("feature:",feature.size())
            preds = self.outs[i](feature)
            #print("preds:", preds.size())
            combined_hm_preds.append(preds)
        return combined_hm_preds 
    
class RESUNet(nn.Module):
    def __init__(self, nstack, inp_dim, oup_dim, Conv_method = "Conv",image_shape = (1,256,256),bn=False, increase=0, **kwargs):
        super(RESUNet, self).__init__()
        self.Conv_method = Conv_method
        self.image_shape = image_shape
        self.nstack = nstack
        self.inp_dim = inp_dim 
        self.oup_dim = oup_dim
        self.conv_type_dict = {
            "DWConv":DWConv,
            "Conv":Conv,
        }
        print("using :",Conv_method)
        self.pre = nn.Sequential(
            self.conv_type_dict[self.Conv_method](image_shape[0], 64, 7, 2, bn=True, relu=True),
            Residual(64, 128, self.conv_type_dict[self.Conv_method]),
            Pool(2, 2),
            #Residual(64, 128, self.conv_type_dict[self.Conv_method]),
            Residual(128,128, self.conv_type_dict[self.Conv_method]),
        )
        self.break_up = Residual(128, 1, self.conv_type_dict[self.Conv_method])
        self.hgs_r_pip = ResUnetBlock(nstack, inp_dim, oup_dim, bn, increase, self.conv_type_dict[self.Conv_method])
        self.hgs_g_pip = ResUnetBlock(nstack, inp_dim, oup_dim, bn, increase, self.conv_type_dict[self.Conv_method])
        self.hgs_b_pip = ResUnetBlock(nstack, inp_dim, oup_dim, bn, increase, self.conv_type_dict[self.Conv_method])
        
        #self.merge_features = nn.ModuleList( [Merge(inp_dim, inp_dim) for i in range(nstack-1)] )
        #self.merge_preds = nn.ModuleList( [Merge(oup_dim, inp_dim) for i in range(nstack-1)] )
        
        self.merge = nn.ModuleList( [
            nn.Sequential(
            nn.Conv2d(nstack*128,9,2,2),
            nn.Conv2d(9,9,2,2)
        ) for i in range(3)] )
        
        self.head = nn.ModuleList( 
            [nn.Sequential( nn.Conv2d(9,1,1,1) ) for i in range(3)]
        )
        self.nstack = nstack
        self.heatmapLoss = HeatmapLoss()
        self.up = nn.Upsample(scale_factor=4, mode='bicubic')
        
    def forward(self, imgs):
        ## our posenet
        P,C,W,H = imgs.size()

        if( C == 1 or C == 3):
            x = imgs
        else:
            x = imgs.permute(0, 3, 1, 2) #x of size 1,3,inpdim,inpdim
            
        x_backup = x
        x_origin = x 
        x = self.pre(x)
        x = self.break_up(x)
        #print('res:',x.size())
        combined_hm_preds = []
        r = self.hgs_r_pip(x_backup)
        g = self.hgs_g_pip(x_backup)
        b = self.hgs_b_pip(x_backup)
        
        r_multi_map = torch.cat(r, 1)
        g_multi_map = torch.cat(g, 1)
        b_multi_map = torch.cat(b, 1)
        
        
        r_multi_map = self.merge[0](r_multi_map)
        g_multi_map = self.merge[1](g_multi_map)
        b_multi_map = self.merge[2](b_multi_map)
        
        r_attention_map = torch.mul(r_multi_map,x)
        g_attention_map = torch.mul(g_multi_map,x)
        b_attention_map = torch.mul(b_multi_map,x)
        #print(r_attention_map.size())
        color_r_offset = self.head[0](r_attention_map)
        color_g_offset = self.head[1](g_attention_map)
        color_b_offset = self.head[2](b_attention_map)
        color_offset = torch.cat([color_r_offset, color_g_offset, color_b_offset],1)
        x_color = self.up(color_offset)
        #print(x_color.size())
        return x_color #torch.stack(combined_hm_preds, 1)

    def calc_loss(self, combined_hm_preds, heatmaps):
        combined_loss = []
        for i in range(self.nstack):
            combined_loss.append(self.heatmapLoss(combined_hm_preds[0][:,i], heatmaps))
        combined_loss = torch.stack(combined_loss, dim=1)
        return combined_loss


In [28]:
image = torch.zeros(1,3,224,224)

In [29]:
model = RESUNet(2, 3, 128, Conv_method = "Conv",image_shape = (3,224,224),bn=False)

using : Conv


In [30]:
model(image).shape

res: torch.Size([1, 1, 56, 56])
torch.Size([1, 9, 56, 56])
torch.Size([1, 3, 224, 224])


torch.Size([1, 3, 224, 224])

In [210]:
summary(model.cuda(), input_size=(1,224,224))

res: torch.Size([2, 1, 56, 56])
hg: torch.Size([2, 1, 224, 224])
feature: torch.Size([2, 1, 224, 224])
preds: torch.Size([2, 128, 224, 224])
hg: torch.Size([2, 1, 224, 224])
feature: torch.Size([2, 1, 224, 224])
preds: torch.Size([2, 128, 224, 224])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           3,200
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
              Conv-4         [-1, 64, 112, 112]               0
            Conv2d-5        [-1, 128, 112, 112]           8,320
              Conv-6        [-1, 128, 112, 112]               0
       BatchNorm2d-7         [-1, 64, 112, 112]             128
              ReLU-8         [-1, 64, 112, 112]               0
            Conv2d-9         [-1, 64, 112, 112]           4,160
             Conv-10         [-1, 64, 112, 11