# EfficientNet V2 in Tensorflow

In [1]:
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers as ly
import math
from functools import partial



In [2]:
### Obtained from Paper ###
# Configs taken from 
# https://github.com/leondgarse/keras_efficientnet_v2/blob/main/keras_efficientnet_v2/efficientnet_v2.py
# convs parameter is which type of block to use, maps to `layer_map`
# 6 models are supported, more are coming soon.
CONFIGS = {
    "b0": {
        "widths": [32, 16, 32, 48, 96, 112, 192],
        "depths": [1, 2, 2, 3, 5, 8],
        "strides": [1, 2, 2, 2, 1, 2],
        "convs": [0, 1, 1, 2, 3, 3],
        "output_conv_size": 1280,
    },
    "b1": {
        "widths": [32, 16, 32, 48, 96, 112, 192],
        "depths": [2, 3, 3, 4, 6, 9],
        "strides": [1, 2, 2, 2, 1, 2],
        "convs":[0, 1, 1, 2, 3, 3],
        "output_conv_size": 1280,
    },
    "b2": {
        "widths": [32, 16, 32, 56, 104, 120, 208],
        "depths": [2, 3, 3, 4, 6, 10],
        "strides": [1, 2, 2, 2, 1, 2],
        "convs": [0, 1, 1, 2, 3, 3],
        "output_conv_size": 1408,
    },
    "s": {
        "widths": [24, 24, 48, 64, 128, 160, 256],
        "depths": [2, 4, 4, 6, 9, 15],
        "strides": [1, 2, 2, 2, 1, 2],
        "convs": [0, 1, 1, 2, 3, 3],
        "output_conv_size": 1280,
    },
    "m": {
        "widths": [24, 24, 48, 80, 160, 176, 304, 512],
        "depths": [3, 5, 5, 7, 14, 18, 5],
        "strides": [1, 2, 2, 2, 1, 2, 1],
        "convs": [0, 1, 1, 2, 3, 3, 3],
        "output_conv_size": 1280,
    },
    "l": {
        "widths": [32, 32, 64, 96, 192, 224, 384, 640],
        "depths": [4, 7, 7, 10, 19, 25, 7],
        "strides": [1, 2, 2, 2, 1, 2, 1],
        "convs": [0, 1, 1, 2, 3, 3, 3],
        "output_conv_size": 1280,
    }
}

In [3]:
def bn_act(x, bn=True, act=True):
    if bn:
        x = ly.BatchNormalization(epsilon=1e-3,momentum=0.999)(x)
    if act:
        x = ly.Activation(tf.nn.swish)(x)
    return x

In [4]:
def SEBlock(x, c, r=24):
    squeeze = ly.GlobalAveragePooling2D()(x)
    squeeze = ly.Reshape((1, 1, c))(squeeze)
    ex = ly.Conv2D(c // r, (1, 1), padding='same')(squeeze)
    ex = ly.Activation(tf.nn.swish)(ex)
    ex = ly.Conv2D(c, (1, 1), padding='same')(ex)
    ex = ly.Activation(tf.nn.sigmoid)(ex)
    x = ly.multiply([x, ex])
    return x

In [59]:
def MBConv(x, n_in, n_out, expansion, ks=3, strides=1, dropout=0.1, r=24):
    residual = x
    skip_connection = (strides == 1) and (n_in == n_out)
    padding = (ks-1)//2
    if expansion != 1:
        # Expand Pointwise
        x = ly.Conv2D(expansion * n_in, kernel_size=1, padding='same', use_bias=False,
                      activation=None)(x)
        x = bn_act(x)
    ## Depthwise
    x = ly.DepthwiseConv2D(kernel_size=ks, strides=strides, activation=None, use_bias=False, 
                          padding='same')(x)
    x = bn_act(x)
    x = SEBlock(x, expansion * n_in, r=r)
    x = ly.Conv2D(n_out, (1, 1), padding='same', activation=None, use_bias=False)(x)
    x = bn_act(x, act=False)
    if skip_connection:
        x = ly.Dropout(0.2)(x)
        x = ly.add([x, residual])
    return x

In [65]:
def FusedMBConv(x, n_in, n_out, expansion, ks=3, strides=1, dropout=0.1, r=24):
    """
    Implementation of the FusedMBConv Block from EfficientNet V2 Paper
    """
    residual = x
    skip_connection = (strides == 1) and (n_in == n_out)
    padding = (ks-1)//2
    if expansion != 1:
        # Expand
        x = ly.Conv2D(expansion * n_in, kernel_size=(3, 3), strides=strides, padding='same', use_bias=False,
                      activation=None)(x)
        x = bn_act(x)
        # Reduce Pointwise
        x = ly.Conv2D(n_out, (1, 1), padding='same', activation=None, use_bias=False)(x)
        x = bn_act(x, act=False)
    else:
        # Reduce Pointwise
        x = ly.Conv2D(n_out, (3, 3), padding='same', strides=strides, activation=None, use_bias=False)(x)
        x = bn_act(x, act=False)
    if skip_connection:
        x = ly.Dropout(0.2)(x)
        x = ly.add([x, residual])
    return x

In [66]:
layer_map = [
    partial(FusedMBConv, expansion=1),
    partial(FusedMBConv, expansion=4),
    partial(MBConv, expansion=4),
    partial(MBConv, expansion=6),
]

In [67]:
def EfficientNetV2(cfg, n_classes=1000):
    widths, depths, strides, convs = cfg['widths'],cfg['depths'],cfg['strides'],cfg['convs']
    outconv_size = cfg['output_conv_size']
    inputs = keras.Input(shape=(224, 224, 3))
    x = ly.ZeroPadding2D(
      padding=1)(inputs)
    x = ly.Conv2D(widths[0], (3, 3), strides=(2, 2), padding='valid', use_bias=False)(x)
    x = bn_act(x)
    
    for i in range(len(depths)):
        depth = depths[i]
        stride = strides[i]
        w_in = widths[i]
        w_out = widths[i + 1]
        layer = layer_map[convs[i]]
        x = layer(x, w_in, w_out, ks=3, strides=stride, r= 4 if i==0 else 24)
        for j in range(1, depth):
            x = layer(x, w_out, w_out, ks=3, r= 4 if i==0 else 24)
    
    x = ly.Conv2D(outconv_size, kernel_size=1, use_bias=False)(x)
    x = bn_act(x)
    x = ly.GlobalAveragePooling2D()(x)
    x = ly.Dense(n_classes)(x)
    return keras.Model(inputs=inputs, outputs=x, name="efficientnetv2")

In [68]:
def efficientnetv2_b0(n_classes=1000):
    return EfficientNetV2(CONFIGS['b0'], n_classes=n_classes)
def efficientnetv2_b1(n_classes=1000):
    return EfficientNetV2(CONFIGS['b1'], n_classes=n_classes)
def efficientnetv2_b2(n_classes=1000):
    return EfficientNetV2(CONFIGS['b2'], n_classes=n_classes)
def efficientnetv2_s(n_classes=1000):
    return EfficientNetV2(CONFIGS['s'], n_classes=n_classes)
def efficientnetv2_m(n_classes=1000):
    return EfficientNetV2(CONFIGS['m'], n_classes=n_classes)
def efficientnetv2_l(n_classes=1000):
    return EfficientNetV2(CONFIGS['l'], n_classes=n_classes)

In [69]:
v2_b0 = efficientnetv2_b0()
v2_b1 = efficientnetv2_b1()
v2_b2 = efficientnetv2_b2()
v2_s = efficientnetv2_s()
v2_m = efficientnetv2_m()
v2_l = efficientnetv2_l()

In [73]:
models = [v2_b0, v2_b1, v2_b2, v2_s, v2_m, v2_l]

In [71]:
def fmat(n):
    return "{:.2f}M".format(n / 1e6)

In [72]:
def params(model, f = True):
    count = int(np.sum([np.prod(p.shape) for p in model.variables]))
    return fmat(count) if f else count

In [74]:
for m in models:
    print(params(m))

7.19M
8.19M
10.15M
21.55M
54.32M
118.80M
