In [1]:
import tensorflow as tf

In [2]:
from tensorflow.keras import layers

In [22]:
class ResNeXtBlock(layers.Layer):
    def __init__(self, num_channels, groups, bot_mul, use_1x1conv=False, strides=1):
        super(ResNeXtBlock, self).__init__()
        bot_channels = int(round(num_channels * bot_mul))
        self.conv1 = layers.Conv2D(bot_channels, kernel_size=1, strides=1)
        self.conv2 = layers.Conv2D(bot_channels, kernel_size=3, strides=strides, padding='same', groups=bot_channels//groups)
        self.conv3 = layers.Conv2D(num_channels, kernel_size=1, strides=1)
        self.bn1 = layers.BatchNormalization()
        self.bn2 = layers.BatchNormalization()
        self.bn3 = layers.BatchNormalization()

        if use_1x1conv:
            self.conv4 = layers.Conv2D(num_channels, kernel_size=1, strides=strides)
            self.bn4 = layers.BatchNormalization()
        else:
            self.conv4 = None        

    def call(self, inp):
        x = self.conv1(inp)
        x = self.bn1(x)
        x = tf.nn.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = tf.nn.relu(x)
        x = self.conv3(x)
        x = self.bn3(x)
        if self.conv4:
            x = self.bn4(self.conv4(x))
        x = x + inp
        out = tf.nn.relu(x)
        return out

In [23]:
blk = ResNeXtBlock(32, 16, 1, use_1x1conv=True)
X = tf.random.normal((4, 96, 96, 32))
result = blk(X)
print(result.shape)

(4, 96, 96, 32)
