In [1]:
import tensorflow as tf

In [2]:
from tensorflow.keras import layers

In [4]:
class ResNeXtBlock(layers.Layer):
    def __init__(self, num_channels, groups, bot_mul, strides=1):
        super(ResNeXtBlock, self).__init__()
        bot_channels = int(round(num_channels * bot_mul))
        self.conv1 = layers.Conv2D(bot_channels, kernel_size=1, strides=1)
        self.conv2 = layers.Conv2D(bot_channels, kernel_size=3, strides=strides, padding=1, groups=bot_channels//groups)
        self.conv3 = layers.Conv2D(num_channels, kernel_size=1, strides=1)
        self.bn1 = layers.BatchNormalization()
        self.bn2 = layers.BatchNormalization()
        self.bn3 = layers.BatchNormalization()


    def call(self, inp):
        x = self.conv1(inp)
        x = self.bn1(x)
        x = tf.nn.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = tf.nn.relu(x)
        x = self.conv3(x)
        x = self.bn3(x)
        x += inp
        x = tf.nn.relu(x)
        return x