In [13]:
import torch
from torch import nn
from torch.nn import functional as F
import math

In [3]:
class SwishImplementation(torch.autograd.Function):
    @staticmethod
    def forward(ctx, i):
        result = i * torch.sigmoid(i)
        ctx.save_for_backward(i)
        return result

    @staticmethod
    def backward(ctx, grad_output):
        i = ctx.saved_variables[0]
        sigmoid_i = torch.sigmoid(i)
        return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))


class MemoryEfficientSwish(nn.Module):
    def forward(self, x):
        return SwishImplementation.apply(x)

In [4]:
def stochastic_depth(inputs, skip_probability, training):
    if not training: return inputs
    batch_size = inputs.shape[0]
    keep_prob = 1 - skip_probability
    random_tensor = keep_prob
    random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
    binary_tensor = torch.floor(random_tensor)
    output = inputs / keep_prob * binary_tensor
    return output

In [5]:
class SqueezeExcitation(nn.Module):
    def __init__(self, channel, se_ratio, activation = None):
        super().__init__()
        self.squeezed_channel = max(1, int(channel * se_ratio))
        self.se_reduce = nn.Conv2d(channel, self.squeezed_channel, 1)
        self.se_expand = nn.Conv2d(self.squeezed_channel, channel, 1)
        self.activation = activation
    def forward(self, x):
        x_squeezed = F.adaptive_avg_pool2d(x, 1)
        x_squeezed = self.se_expand(self.activation(self.se_reduce(x_squeezed)))
        return torch.sigmoid(x_squeezed)

In [6]:
class MBConvBasicBlc(nn.Module):
    def __init__(self, in_channel, out_channel, expand_ratio, stride, kernel, 
                 se_ratio, skip_probability, bn_momentum, bn_epsilon):
        super().__init__()
        self.expand_option = (expand_ratio != 1)
        med_channel = in_channel*expand_ratio
        self.activation = MemoryEfficientSwish()
        self.skip_probability = skip_probability
        if self.expand_option:
            self.expand = nn.Conv2d(in_channel, med_channel, 1)
            self.bn_expand = nn.BatchNorm2d(num_features=med_channel, momentum=(1-bn_momentum),
                                            eps=bn_epsilon)
        self.depth_wise = nn.Conv2d(med_channel, med_channel, kernel, stride = stride, 
                                    padding=math.ceil((kernel-stride)/2), groups=med_channel)
        self.bn_depth_wise = nn.BatchNorm2d(num_features=med_channel, momentum=(1-bn_momentum),
                                            eps=bn_epsilon)
        
        if (se_ratio is not None) and (0 < se_ratio < 1):
            self.se_operation = SqueezeExcitation(med_channel, se_ratio, self.activation)
        else:
            self.se_operation = None
        self.real_out = nn.Conv2d(med_channel, out_channel, 1)
        self.bn_out = nn.BatchNorm2d(num_features=out_channel, momentum=(1-bn_momentum),
                                            eps=bn_epsilon)
        
    def forward(self, inputs):
        x = inputs
        
        if self.expand_option:
            x = self.expand(x)
            x = self.bn_expand(x)
            x = self.activation(x)
            
        x = self.depth_wise(x)
        x = self.bn_depth_wise(x)
        x = self.activation(x)
        
        if self.se_operation is not None:
            x_squeezed = self.se_operation(x)
            x = x_squeezed * x
            
        x = self.real_out(x)
        x = self.bn_out(x)
        x = self.activation(x)
        
        if x.shape == inputs.shape:
            if self.skip_probability:
                x = stochastic_depth(x, self.skip_probability, training=self.training)
            x = x + inputs
        return x

In [7]:
class MBConvBlc(nn.Module):
    def __init__(self, in_channel, out_channel, expand_ratio, stride, kernel, 
                 se_ratio, skip_probability, bn_momentum, bn_epsilon, n_repeat):
        super().__init__()
        self.blocks = nn.ModuleList([])
        for i in range(n_repeat):
            if i == 0:
                self.blocks.append(MBConvBasicBlc(in_channel, out_channel, 
                                                  expand_ratio, stride, kernel, 
                                                  se_ratio, skip_probability,
                                                  bn_momentum, bn_epsilon))
            else:
                self.blocks.append(MBConvBasicBlc(out_channel, out_channel, 
                                                  expand_ratio, 1, kernel, 
                                                  se_ratio, skip_probability,
                                                  bn_momentum, bn_epsilon))
    def forward(self, x):
        for blc in self.blocks:
            x = blc(x)
        return x

In [11]:
class EfficientNet(nn.Module):
    """
    Args:
        blocks_args (list): A list of BlockArgs to construct blocks
        global_params (namedtuple): A set of GlobalParams shared between blocks
    """
    
    def __init__(self, blocks_args=None):
        super().__init__()
#         assert isinstance(blocks_args, list), 'blocks_args should be a list'
#         assert len(blocks_args) > 0, 'block args must be greater than 0'
        self._blocks_args = blocks_args
        
 #       Conv2d = get_same_padding_conv2d(image_size=global_params.image_size)
        # Batch norm parameters
        bn_mom = 1 - 0.99
        bn_eps = 0.001

        # Stem
        in_channels = 3  # rgb
        out_channels = 32
#         out_channels = round_filters(32, self._global_params)  # number of output channels
#         self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
        self._conv_stem = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, 
                                    padding=1)
        self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
        
        # Build blocks
        self._blocks = nn.ModuleList([])
        for block_args in self._blocks_args:
            
#             # Update block input and output filters based on depth multiplier.
#             block_args = block_args._replace(
#                 input_filters=round_filters(block_args.input_filters, self._global_params),
#                 output_filters=round_filters(block_args.output_filters, self._global_params),
#                 num_repeat=round_repeats(block_args.num_repeat, self._global_params)
#             )
#             # build block with new bottleneck block function
            self._blocks.append(MBConvBlc(**block_args))
            
        # Head
#         in_channels = block_args.output_filters  # output of final block
        in_channels = block_args['out_channel']
        out_channels =1280 # round_filters(1280, self._global_params)
        
        # Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        self._conv_head = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1)
        
        self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)

        # Final linear layer
        self._avg_pooling = nn.AdaptiveAvgPool2d(1)
        self._dropout = nn.Dropout(0.2)
        self._fc = nn.Linear(out_channels, 256)
        
        self.output = nn.Linear(256, 1)
        
        self._swish = MemoryEfficientSwish()
        
#     def set_swish(self, memory_efficient=True):
#     """Sets swish function as memory efficient (for training) or standard (for export)"""
#         self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
#         for block in self._blocks:
#             block.set_swish(memory_efficient)
    def forward(self, inputs):
        bs = inputs.size(0)
        #Stem
        x = self._swish(self._bn0(self._conv_stem(inputs)))
        # Blocks
        for idx, block in enumerate(self._blocks):
#             drop_connect_rate = 0.2
#             if drop_connect_rate:
#                 drop_connect_rate *= float(idx) / len(self._blocks)
            x = block(x)

        # Head
        x = self._swish(self._bn1(self._conv_head(x)))
        # Pooling and final linear layer
        x = self._avg_pooling(x)
        x = x.view(bs, -1)
        x = self._dropout(x)
        x = self._swish(self._fc(x))
        
        x = self._dropout(x)
        x = self.output(x)
        
        return x

In [9]:
#b0 parameters
mb_params = {'kernel':3,'n_repeat':1,'in_channel':32, 'out_channel':16,'expand_ratio':1,'stride':1, 'se_ratio':1/4,
           'skip_probability':0.2,'bn_momentum':0.99, 'bn_epsilon':0.001}
block_args=[mb_params]
block_args.append({'kernel':3,'n_repeat':2,'in_channel':16, 'out_channel':24,'expand_ratio':6,'stride':2, 'se_ratio':1/4,
           'skip_probability':0.2,'bn_momentum':0.99, 'bn_epsilon':0.001})
block_args.append({'kernel':5,'n_repeat':2,'in_channel':24, 'out_channel':40,'expand_ratio':6,'stride':2, 'se_ratio':1/4,
           'skip_probability':0.2,'bn_momentum':0.99, 'bn_epsilon':0.001})
block_args.append({'kernel':3,'n_repeat':3,'in_channel':40, 'out_channel':80,'expand_ratio':6,'stride':2, 'se_ratio':1/4,
           'skip_probability':0.2,'bn_momentum':0.99, 'bn_epsilon':0.001})
block_args.append({'kernel':5,'n_repeat':3,'in_channel':80, 'out_channel':112,'expand_ratio':6,'stride':1, 'se_ratio':1/4,
           'skip_probability':0.2,'bn_momentum':0.99, 'bn_epsilon':0.001})
block_args.append({'kernel':5,'n_repeat':4,'in_channel':112, 'out_channel':192,'expand_ratio':6,'stride':2, 'se_ratio':1/4,
           'skip_probability':0.2,'bn_momentum':0.99, 'bn_epsilon':0.001})
block_args.append({'kernel':3,'n_repeat':1,'in_channel':192, 'out_channel':320,'expand_ratio':6,'stride':1, 'se_ratio':1/4,
           'skip_probability':0.2,'bn_momentum':0.99, 'bn_epsilon':0.001})

In [14]:
x = torch.randn(20, 3, 224, 224)
expand = EfficientNet(block_args)
expand_output = expand(x)
expand_output.shape

torch.Size([20, 1])