## *Iris Segmentation Backbone (VGG16-UNET)*

In [1]:
import warnings
import torch
from torch import nn
import torch.nn.functional as F
import torchvision.models as models
from torchsummary import summary

# filter future warning of tensorboard summary
warnings.filterwarnings('ignore',category=FutureWarning)

### *1. Layer Specification (Helper function from pytorch official implementation)*

In [2]:
cfgs = {
    'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}
def make_layers(cfg, batch_norm=True):
    layers = []
    in_channels = 3
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return layers

cfgs_up = {
    'A' : ['U',512,256,'U',256,128,'U',128,64,'U',64,32,'O',32,3],
    'B' : ['U',512,256,'U',256,128,'U',128,64,'U',64,32,'O',32,2]
}

def make_layers_up(cfg, batch_norm=True):
    layers = []
    in_channels = 1024
    for v in cfg:
        if v == "U":
            layers += [nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)]
            if in_channels != 1024:
                in_channels = in_channels*3
        elif v == "O":
            if in_channels != 1024:
                in_channels = in_channels*3
                continue
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                # last layer
                if v==3 or v==2:
                    layers += [conv2d]
                else:
                    layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return layers

### *2. Define VGG-UNET Model*

#### *2.1 ASPP Attention Sub Module Definition*

In [3]:
class ASPP(nn.Module):
    def __init__(self, in_channel=512, depth=256):
        super(ASPP,self).__init__()
        self.map = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)

        self.gvp = nn.Sequential(
            nn.AdaptiveAvgPool2d((1,1)),
            nn.Conv2d(in_channel, depth, kernel_size=1, stride=1),
            nn.BatchNorm2d(depth), nn.ReLU()
        )
        
        self.atrous_block1 = nn.Sequential(
            nn.Conv2d(in_channel, depth, kernel_size=1, stride=1),
            nn.BatchNorm2d(depth), nn.ReLU(inplace=True))
        
        self.atrous_block6 = nn.Sequential(
            nn.Conv2d(in_channel, depth, 3, 1, padding=6, dilation=6),
            nn.BatchNorm2d(depth), nn.ReLU(inplace=True))
        
        self.atrous_block12 = nn.Sequential(
            nn.Conv2d(in_channel, depth, 3, 1, padding=12, dilation=12),
            nn.BatchNorm2d(depth), nn.ReLU(inplace=True))

        self.atrous_block18 = nn.Sequential(
            nn.Conv2d(in_channel, depth, 3, 1, padding=18, dilation=18),
            nn.BatchNorm2d(depth), nn.ReLU(inplace=True))

        self.conv_gate = nn.Sequential(
            nn.Conv2d(depth * 5, in_channel, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(in_channel), nn.ReLU(inplace=True), nn.Sigmoid())

    def forward(self, x):
        size = x.shape[2:]
        # Initial pooling 
        x = self.map(x)
        # 5 components
        gvp = self.gvp(x)
        gvp = F.upsample(gvp, size=size, mode='bilinear')
 
        atrous_block1 = self.atrous_block1(x)
        atrous_block6 = self.atrous_block6(x)
        atrous_block12 = self.atrous_block12(x)
        atrous_block18 = self.atrous_block18(x)
        # Gated x
        gate = self.conv_gate(
            torch.cat([
                atrous_block1, atrous_block6, atrous_block12, 
                atrous_block18, gvp], dim=1))
        gated_x = torch.mul(gate,x)
        # Concat
        net = torch.cat([x, gated_x], dim=1)
        return net

#### *2.2 Main UNet Module*

In [4]:
# Simple Case only use the encoder part
class TheModel(nn.Module):

    def __init__(self, features, features2, init_weights=True):
        super(TheModel, self).__init__()
        # Double conv(conv,norm,relu)
        self.conv_1 = nn.Sequential(*features[:6])
        self.down_1 = features[6]
        self.conv_2 = nn.Sequential(*features[7:13])
        self.down_2 = features[13]
        self.conv_3 = nn.Sequential(*features[14:23])
        self.down_3 = features[23]
        self.conv_4 = nn.Sequential(*features[24:33])
        self.down_4 = features[33]
        self.conv_5 = nn.Sequential(*features[34:43])
        # ASPP attention module
        self.attention = ASPP(512,256)
        # Upsample
        self.up1 = nn.Sequential(*features2[:7])
        self.up2 = nn.Sequential(*features2[7:14])
        self.up3 = nn.Sequential(*features2[14:21])
        self.up4 = nn.Sequential(*features2[21:28])
        # Out conv
        self.out = nn.Sequential(*features2[28:])
        if init_weights:
            self._initialize_weights()

    def _cat(self, x1, x2):
        # input is NCHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1,[diffX // 2, diffX - diffX // 2,
                 diffY // 2, diffY - diffY // 2])
        return torch.cat([x1,x2],dim=1)

    def forward(self, x):
        x1 = self.conv_1(x) # [8, 64, 321, 321]
        x2 = self.conv_2(self.down_1(x1)) # [8, 128, 160, 160]
        x3 = self.conv_3(self.down_2(x2)) # [8, 256, 80, 80]
        x4 = self.conv_4(self.down_3(x3)) # [8, 512, 40, 40]
        x5 = self.conv_5(self.down_4(x4)) # [8, 512, 20, 20]
        f = self.attention(x5)
        y4 = self.up1(f) # [8, 256, 40, 40]
        y3 = self.up2(self._cat(y4,x4)) # [8, 128, 80, 80]
        y2 = self.up3(self._cat(y3,x3)) # [8, 64, 160, 160]
        y1 = self.up4(self._cat(y2,x2)) # [8, 32, 320, 320]

        out = self.out(self._cat(y1,x1)) # [8, 3, 321, 321]

        return out

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

### *3. Use pretrained weight to initialize Model*


In [5]:
def get_model(c):
    """ Initialize VGG16-UNet model and load pretrained weight
    KeyArgs:
        c(char): number of output layer. A is 3, B is 4
    """
    model = TheModel(make_layers(cfgs['D']),make_layers_up(cfgs_up[c]),init_weights=True)
    vgg16_bn = models.vgg16_bn(pretrained=True)
    
    pretrained_dict = vgg16_bn.state_dict()
    model_dict = model.state_dict()
    
    # 1.Filter unnecessary keys in pretrained model
    pretrained_dict = {k:v for k,v in pretrained_dict.items() if k.split('.')[0]=="features"}
    length = len(pretrained_dict)
    model_keys, pretrained_keys = list(model_dict.keys()),list(pretrained_dict.keys())

    update_dict = {model_keys[i]:pretrained_dict[pretrained_keys[i]] for i in range(length)}

    # 2.Update state dict
    model_dict.update(update_dict)
    model.load_state_dict(model_dict)
    return model

### *4. GPU/CPU configuration and Model Summary*

In [6]:
if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = get_model("A")
    print(list(model.children()))
    model.cuda()
    summary(model, input_size=(3,321,321),device='cuda')

[Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (5): ReLU(inplace=True)
), MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False), Sequential(
  (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (5): ReLU(inplace=True)
), MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False), Sequential(
  (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1)