In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
from pytorchcv.model_provider import get_model as ptcv_get_model

In [2]:
def conv3x3(in_, out):
    return nn.Conv2d(in_, out, 3, padding=1)


class ConvRelu(nn.Module):
    def __init__(self, in_, out):
        super().__init__()
        self.conv = conv3x3(in_, out)
        self.activation = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.activation(x)
        return x


class NoOperation(nn.Module):
    def forward(self, x):
        return x


class DecoderBlock(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels):
        super().__init__()

        self.block = nn.Sequential(
            ConvRelu(in_channels, middle_channels),
            nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.block(x)


class DecoderBlockV2(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels, is_deconv=True,
                output_padding=0):
        super(DecoderBlockV2, self).__init__()
        self.in_channels = in_channels

        if is_deconv:
            """
                Paramaters for Deconvolution were chosen to avoid artifacts, following
                link https://distill.pub/2016/deconv-checkerboard/
            """

            self.block = nn.Sequential(
                ConvRelu(in_channels, middle_channels),
                nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2,
                                   padding=1, output_padding=output_padding),
                nn.ReLU(inplace=True)
            )
        else:
            self.block = nn.Sequential(
                nn.Upsample(scale_factor=2, mode='bilinear'),
                ConvRelu(in_channels, middle_channels),
                ConvRelu(middle_channels, out_channels),
            )

    def forward(self, x):
        return self.block(x)

class Interpolate(nn.Module):
    def __init__(self, mode='nearest', scale_factor=2,
                 align_corners=False, output_padding=0):
        super(Interpolate, self).__init__()
        self.interp = nn.functional.interpolate
        self.mode = mode
        self.scale_factor = scale_factor
        self.align_corners = align_corners
        self.pad = output_padding
        
    def forward(self, x):
        if self.mode in ['linear','bilinear','trilinear']:
            x = self.interp(x, mode=self.mode,
                            scale_factor=self.scale_factor,
                            align_corners=self.align_corners)
        else:
            x = self.interp(x, mode=self.mode,
                            scale_factor=self.scale_factor)
            
        if self.pad > 0:
            x = nn.ZeroPad2d((0, self.pad, 0, self.pad))(x)
        return x

class DecoderBlockV3(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels,
                 is_deconv=True, output_padding=0):
        super(DecoderBlockV3, self).__init__()
        self.in_channels = in_channels

        if is_deconv:
            """
                Paramaters for Deconvolution were chosen to avoid artifacts, following
                link https://distill.pub/2016/deconv-checkerboard/
            """

            self.block = nn.Sequential(
                nn.ConvTranspose2d(in_channels, middle_channels, kernel_size=4, stride=2,
                                   padding=1, output_padding=output_padding),
                ConvRelu(middle_channels, out_channels),
            )
        else:
            self.block = nn.Sequential(
                Interpolate(mode='nearest', scale_factor=2,
                           output_padding=output_padding),
                # nn.Upsample(scale_factor=2, mode='bilinear'),
                ConvRelu(in_channels, middle_channels),
                ConvRelu(middle_channels, out_channels),
            )

    def forward(self, x):
        return self.block(x)
    

class AdaptiveConcatPool2d(nn.Module):
    def __init__(self, sz=None):
        super().__init__()
        sz = sz or (1,1)
        self.ap = nn.AdaptiveAvgPool2d(sz)
        self.mp = nn.AdaptiveMaxPool2d(sz)
    def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1)

class Resnet(nn.Module):

    def __init__(self, num_classes, num_filters=32, 
                 pretrained=True, is_deconv=False):
        super().__init__()
        self.num_classes = num_classes
        
        self.in0 = nn.InstanceNorm2d(3)
        
#         self.conv4to3 = nn.Conv2d(4, 3, 1)
            
#         self.encoder = pretrainedmodels.__dict__['se_resnext50_32x4d'](num_classes=1000,
#                                               pretrained='imagenet') 
        
        # code removes final layer
#         layers = resnet34()
        
        layers = list(ptcv_get_model("airnext50_32x4d_r2", pretrained=True).children())[:1]
                
#         # replace first convolutional layer by 4->64 while keeping corresponding weights
#         # and initializing new weights with zeros
#         # https://www.kaggle.com/iafoss/pretrained-resnet34-with-rgby-0-448-public-lb/notebook
#         w = layers[0].weight
#         layers[0] = nn.Conv2d(4,64,kernel_size=(7,7),stride=(2,2),padding=(3, 3),
#                               bias=False)
#         layers[0].weight = torch.nn.Parameter(torch.cat((w,torch.zeros(64,1,7,7)),
#                                                         dim=1))
        
        layers += [AdaptiveConcatPool2d()]
        self.encoder = nn.Sequential(*layers)

        self.act1 = nn.ReLU()
        self.bn1 = nn.BatchNorm2d(4096)
        self.do1 = nn.Dropout(p=0.5)
#         self.conv1 = nn.Conv2d(4096, 256, kernel_size=(3,3), 
#                               stride=(1,1), padding=1)
        self.conv1 = nn.Conv2d(4096, 256, kernel_size=(1,1), 
                              stride=(1,1), padding=0)

        self.act2 = nn.ReLU()
        self.bn2 = nn.BatchNorm2d(256)
        self.do2 = nn.Dropout(p=0.5)
#         self.conv2 = nn.Conv2d(256, num_classes, kernel_size=(3,3), 
#                               stride=(1,1), padding=1)
        self.conv2 = nn.Conv2d(256, num_classes, kernel_size=(1,1), 
                              stride=(1,1), padding=0)

#         self.act3 = nn.ReLU()
#         self.bn3 = nn.BatchNorm2d(256)
#         self.do3 = nn.Dropout(p=0.3)
#         self.conv3 = nn.Conv2d(256, num_classes, kernel_size=(3,3), 
#                               stride=(1,1), padding=1)

#         self.encoder = nn.Sequential(*list(self.encoder.children())[:-1])

#         self.pool = nn.MaxPool2d(2, 2)
#         self.convp = nn.Conv2d(1056, 512, 3)

#         self.csize = 1024 * 1 * 1
#         self.bn1 = nn.BatchNorm1d(1024)
#         self.do1 = nn.Dropout(p=0.5)
#         self.lin1 = nn.Linear(1024, 512)
#         self.act1 = nn.ReLU()
#         self.bn2 = nn.BatchNorm1d(512)
#         self.do2 = nn.Dropout(0.5)
#         self.lin2 = nn.Linear(512, num_classes)
        
    def forward(self, x):
        
        # set to True for debugging
        print_sizes = False
        if print_sizes: 
            print('')
            print('x',x.shape)
            
        x = self.in0(x)
        
        # print layer dictionary
        # print(self.encoder.features)
        
#         x = self.conv4to3(x)
        
        m = self.encoder._modules
        layer_names = list(m.keys())
        mx = {}
        for i,f in enumerate(m):
            x = m[f](x)
            mx[layer_names[i]] = x
            if print_sizes:
                if isinstance(x,tuple):
                    print(i,layer_names[i],x[0].size(),x[1].size())
                else:
                    print(i,layer_names[i],x.size())
#             if layer_names[i]=='avg_pool': break
                
#         x = self.encoder(x)
        if print_sizes: print('encoder',x.shape)

        x = self.act1(x)
        x = self.bn1(x)
        x = self.do1(x) 
        x = self.conv1(x)
        if print_sizes: print('conv1',x.shape)

        x = self.act2(x)
        x = self.bn2(x)
        x = self.do2(x) 
        x = self.conv2(x)
        if print_sizes: print('conv2',x.shape)

#         x = self.act3(x)
#         x = self.bn3(x)
#         x = self.do3(x) 
#         x = self.conv3(x)
#         if print_sizes: print('conv3',x.shape)

#         x = self.map_logits(x)
#         if print_sizes: print('map_logits',x.shape)

#         x = x.view(-1, self.csize)
#         if print_sizes: print('view',x.size())

#         x = self.bn1(x)
#         x = self.do1(x)
#         if print_sizes: print('do1',x.size())
            
#         x = self.lin1(x)
#         if print_sizes: print('lin1',x.size())
#         x = self.act1(x)
#         x = self.bn2(x)
#         x = self.do2(x) 
#         x = self.lin2(x)
#         if print_sizes: print('lin2',x.shape)

        return x
    