In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable

In [2]:
def conv3x3(in_, out):
    return nn.Conv2d(in_, out, 3, padding=1)


class ConvRelu(nn.Module):
    def __init__(self, in_, out):
        super().__init__()
        self.conv = conv3x3(in_, out)
        self.activation = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.activation(x)
        return x


class NoOperation(nn.Module):
    def forward(self, x):
        return x


class DecoderBlock(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels):
        super().__init__()

        self.block = nn.Sequential(
            ConvRelu(in_channels, middle_channels),
            nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.block(x)


class DecoderBlockV2(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels, is_deconv=True,
                output_padding=0):
        super(DecoderBlockV2, self).__init__()
        self.in_channels = in_channels

        if is_deconv:
            """
                Paramaters for Deconvolution were chosen to avoid artifacts, following
                link https://distill.pub/2016/deconv-checkerboard/
            """

            self.block = nn.Sequential(
                ConvRelu(in_channels, middle_channels),
                nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2,
                                   padding=1, output_padding=output_padding),
                nn.ReLU(inplace=True)
            )
        else:
            self.block = nn.Sequential(
                nn.Upsample(scale_factor=2, mode='bilinear'),
                ConvRelu(in_channels, middle_channels),
                ConvRelu(middle_channels, out_channels),
            )

    def forward(self, x):
        return self.block(x)

class Interpolate(nn.Module):
    def __init__(self, mode='nearest', scale_factor=2,
                 align_corners=False, output_padding=0):
        super(Interpolate, self).__init__()
        self.interp = nn.functional.interpolate
        self.mode = mode
        self.scale_factor = scale_factor
        self.align_corners = align_corners
        self.pad = output_padding
        
    def forward(self, x):
        if self.mode in ['linear','bilinear','trilinear']:
            x = self.interp(x, mode=self.mode,
                            scale_factor=self.scale_factor,
                            align_corners=self.align_corners)
        else:
            x = self.interp(x, mode=self.mode,
                            scale_factor=self.scale_factor)
            
        if self.pad > 0:
            x = nn.ZeroPad2d((0, self.pad, 0, self.pad))(x)
        return x

class DecoderBlockV3(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels,
                 is_deconv=True, output_padding=0):
        super(DecoderBlockV3, self).__init__()
        self.in_channels = in_channels

        if is_deconv:
            """
                Paramaters for Deconvolution were chosen to avoid artifacts, following
                link https://distill.pub/2016/deconv-checkerboard/
            """

            self.block = nn.Sequential(
                nn.ConvTranspose2d(in_channels, middle_channels, kernel_size=4, stride=2,
                                   padding=1, output_padding=output_padding),
                ConvRelu(middle_channels, out_channels),
            )
        else:
            self.block = nn.Sequential(
                Interpolate(mode='nearest', scale_factor=2,
                           output_padding=output_padding),
                # nn.Upsample(scale_factor=2, mode='bilinear'),
                ConvRelu(in_channels, middle_channels),
                ConvRelu(middle_channels, out_channels),
            )

    def forward(self, x):
        return self.block(x)



class SE_Resnext(nn.Module):

    def __init__(self, num_classes, num_filters=32, 
                 pretrained=True, is_deconv=False):
        super().__init__()
        self.num_classes = num_classes
        
        self.conv4to3 = nn.Conv2d(4, 3, 1)
            
        self.encoder = pretrainedmodels.__dict__['se_resnext50_32x4d'](num_classes=1000,
                                              pretrained='imagenet')                
#         self.pool = nn.MaxPool2d(2, 2)
#         self.convp = nn.Conv2d(1056, 512, 3)
        self.csize = 2048 * 1 * 1
        self.fc1 = nn.Linear(self.csize, num_classes)
#         self.fc2 = nn.Linear(108, 54)
#         self.fc3 = nn.Linear(54, num_classes)

    def forward(self, x):
        
        # set to True for debugging
        print_sizes = False
        if print_sizes: 
            print('')
            print('x',x.shape)
        
        # print layer dictionary
        # print(self.encoder.features)
        
        x = self.conv4to3(x)
        if print_sizes: print('4to3',x.shape)
        
        m = self.encoder._modules
        layer_names = list(m.keys())
        mx = {}
        for i,f in enumerate(m):
            x = m[f](x)
            mx[layer_names[i]] = x
            if print_sizes:
                if isinstance(x,tuple):
                    print(i,layer_names[i],x[0].size(),x[1].size())
                else:
                    print(i,layer_names[i],x.size())
            if layer_names[i]=='avg_pool': break
        
#         x = self.pool(F.relu(mx['cell_15']))
#         # x = self.pool(F.relu(self.convp(x)))
#         x = F.relu(self.convp(x))
#         if print_sizes: print('convp',x.shape)
        x = mx['avg_pool'].view(-1, self.csize)
        if print_sizes: print('view',x.size())
        x = self.fc1(x)
#         x = F.relu(self.fc1(x))
        if print_sizes: print('fc1',x.size())
#         x = F.relu(self.fc2(x))
#         if print_sizes: print('fc2',x.size())
#         x = self.fc3(x)
#         if print_sizes: print('fc3',x.size())
        return x
        
