In [1]:
from utils import efficientnet
blocks_args, global_params = efficientnet(width_coefficient=1.0, depth_coefficient=1.0, dropout_rate=0.7, image_size=[263, 15])

In [3]:
print(global_params)
print()
for idx, block_args in enumerate(blocks_args):
    print(f'MBConvblock {idx}: ', block_args)
    print()

GlobalParams(width_coefficient=1.0, depth_coefficient=1.0, image_size=[263, 15], dropout_rate=0.7, num_classes=1, batch_norm_momentum=0.99, batch_norm_epsilon=0.001, drop_connect_rate=0.2, depth_divisor=8, min_depth=None, include_top=None)

MBConvblock 0:  BlockArgs(num_repeat=1, kernel_size=(5, 1), stride=(1, 1), expand_ratio=1, input_filters=8, output_filters=16, se_ratio=0.25, id_skip=True)

MBConvblock 1:  BlockArgs(num_repeat=1, kernel_size=(3, 1), stride=(1, 1), expand_ratio=1, input_filters=16, output_filters=32, se_ratio=0.25, id_skip=True)

MBConvblock 2:  BlockArgs(num_repeat=1, kernel_size=(3, 3), stride=(1, 1), expand_ratio=1, input_filters=32, output_filters=64, se_ratio=0.25, id_skip=True)



In [11]:
import math
from torch import nn
from torch.nn import functional as F

class MaxPool2dStaticSamePadding_(nn.MaxPool2d):
    """2D MaxPooling like TensorFlow's 'SAME' mode, with the given input image size.
    The padding mudule is calculated in construction function, then used in forward.
    """

    def __init__(self, kernel_size, stride, image_size=None, **kwargs):
        super().__init__(kernel_size, stride, **kwargs)
        # self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride
        self.kernel_size = (
            [self.kernel_size] * 2
            if isinstance(self.kernel_size, int)
            else self.kernel_size
        )
        self.dilation = (
            [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation
        )

        # Calculate padding based on image size and save it
        assert image_size is not None
        ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
        kh, kw = self.kernel_size
        sh, sw = (
            (self.stride, self.stride) if isinstance(self.stride, int) else self.stride
        )
        oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
        pad_h = max((oh - 1) * sh + (kh - 1) * self.dilation[0] + 1 - ih, 0)
        pad_w = max((ow - 1) * sw + (kw - 1) * self.dilation[1] + 1 - iw, 0)
        if pad_h > 0 or pad_w > 0:
            self.static_padding = nn.ZeroPad2d(
                (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
            )
        else:
            self.static_padding = nn.Identity()

    def forward(self, x):
        x = self.static_padding(x)
        x = F.max_pool2d(
            x,
            self.kernel_size,
            self.stride,
            self.padding,
            self.dilation,
            self.ceil_mode,
            self.return_indices,
        )
        return x

In [12]:
from utils import Conv2dStaticSamePadding, MaxPool2dStaticSamePadding
from functools import partial

bn_mom = 1 - global_params.batch_norm_momentum
bn_eps = global_params.batch_norm_epsilon
image_size = global_params.image_size
Conv2d = partial(Conv2dStaticSamePadding, image_size=image_size)
MaxPool2d = partial(MaxPool2dStaticSamePadding_, image_size=image_size)

In [13]:
from torch import nn
from utils import round_filters, calculate_output_image_size, MemoryEfficientSwish

# Stem
image_size = [263, 15]
in_channels = 1
out_channels = round_filters(8, global_params)
_conv_stem = Conv2d(in_channels, out_channels, kernel_size=(5, 1), stride=(1, 1), bias=False, image_size=image_size)
_bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
_max_pool = MaxPool2d(kernel_size=(2, 1), stride=(2, 1))
_swish = MemoryEfficientSwish()
image_size = calculate_output_image_size(image_size, stride=(1, 1))
image_size

[263, 15]

In [37]:
import torch

x = torch.randn(1, 1, 263, 15)
x = _conv_stem(x)
x = _bn0(x)
x = _max_pool(x)
x

tensor([[[[ 1.3046,  0.6385,  0.2950,  ...,  0.5212,  1.6209,  0.9231],
          [ 0.6213, -0.6765, -0.6158,  ...,  0.4885,  1.2506,  0.4115],
          [ 0.3602, -0.2707, -0.1748,  ...,  1.1710,  1.2817, -0.3123],
          ...,
          [ 1.1332, -0.8674, -1.1314,  ...,  0.2574,  0.4384,  0.6540],
          [-0.0487, -0.3110,  1.4396,  ..., -0.6007,  0.2222,  1.9080],
          [ 0.0000,  0.0000,  0.4522,  ...,  0.0000,  0.0000,  1.1264]],

         [[ 0.3695,  0.4974,  0.4579,  ..., -0.0723,  0.7462,  0.6954],
          [-0.0263, -0.5752, -0.7477,  ...,  0.9208, -0.1122, -0.3897],
          [ 0.6201, -0.0773, -0.1894,  ...,  1.1950,  0.0330,  0.3992],
          ...,
          [ 0.0201, -0.3949,  0.2088,  ...,  0.5793,  0.1122,  1.0368],
          [-0.9608, -0.5949,  0.9894,  ..., -0.6246, -0.0272,  1.4105],
          [ 0.0000,  0.0827,  0.3053,  ...,  0.1683,  0.0000,  0.0496]],

         [[ 0.0831,  0.3793,  0.2023,  ...,  0.0065,  0.1381,  0.6438],
          [ 0.8491,  0.2736, -

In [38]:
x = _swish(x)
x

tensor([[[[ 1.0262,  0.4179,  0.1691,  ...,  0.3270,  1.3533,  0.6606],
          [ 0.4042, -0.2280, -0.2160,  ...,  0.3028,  0.9722,  0.2475],
          [ 0.2122, -0.1171, -0.0798,  ...,  0.8938,  1.0033, -0.1320],
          ...,
          [ 0.8571, -0.2566, -0.2760,  ...,  0.1451,  0.2665,  0.4303],
          [-0.0237, -0.1315,  1.1638,  ..., -0.2128,  0.1234,  1.6614],
          [ 0.0000,  0.0000,  0.2764,  ...,  0.0000,  0.0000,  0.8506]],

         [[ 0.2185,  0.3093,  0.2805,  ..., -0.0349,  0.5062,  0.4640],
          [-0.0130, -0.2071, -0.2403,  ...,  0.6585, -0.0530, -0.1573],
          [ 0.4032, -0.0372, -0.0858,  ...,  0.9173,  0.0168,  0.2389],
          ...,
          [ 0.0101, -0.1590,  0.1153,  ...,  0.3713,  0.0592,  0.7654],
          [-0.2659, -0.2115,  0.7212,  ..., -0.2178, -0.0134,  1.1338],
          [ 0.0000,  0.0431,  0.1757,  ...,  0.0912,  0.0000,  0.0254]],

         [[ 0.0433,  0.2252,  0.1113,  ...,  0.0033,  0.0738,  0.4221],
          [ 0.5947,  0.1554, -