In [1]:
!pip install pytorchvideo torch torchvision --quiet

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m132.7/132.7 kB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.2/50.2 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m38.7/38.7 MB[0m [31m45.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for pytorchvideo (setup.py) ... [?25l[?25hdone
  Building wheel for fvcore (setup.py) ... [?25l[?25hdone
  Building wheel for iopath (setup.py) ... [?25l[?25hdone


In [2]:
import tensorflow as tf
from tensorflow.keras import layers, models

###############################################################################
# 1. The Unit3D Block
#
# In the original Sonnet code, Unit3D applies a 3D convolution followed by
# optional batch normalization and an activation. In Keras we implement it as a
# custom layer that uses tf.keras.layers.Conv3D, BatchNormalization, and Activation.
###############################################################################

class Unit3D(layers.Layer):
    def __init__(self,
                 output_channels,
                 kernel_size=(1, 1, 1),
                 strides=(1, 1, 1),
                 activation='relu',
                 use_batch_norm=True,
                 use_bias=False,
                 is_training=False,
                 **kwargs):
        """
        Parameters:
          output_channels: number of filters for the convolution.
          kernel_size: size of the 3D convolution kernel.
          strides: convolution stride.
          activation: either a string (e.g. 'relu') or None.
          use_batch_norm: whether to apply batch normalization.
          use_bias: whether the convolution has a bias.
          is_training: not used explicitly (training is passed in call)
          kwargs: additional keyword arguments.
        """
        super(Unit3D, self).__init__(**kwargs)
        self.conv3d = layers.Conv3D(filters=output_channels,
                                    kernel_size=kernel_size,
                                    strides=strides,
                                    padding='same',
                                    use_bias=use_bias,
                                    name=kwargs.get("name", None))
        self.use_batch_norm = use_batch_norm
        if self.use_batch_norm:
            self.bn = layers.BatchNormalization()
        # If activation is None then no nonlinearity is applied.
        self.activation = layers.Activation(activation) if activation is not None else None

    def call(self, inputs, training=False):
        x = self.conv3d(inputs)
        if self.use_batch_norm:
            x = self.bn(x, training=training)
        if self.activation is not None:
            x = self.activation(x)
        return x

###############################################################################
# 2. The MixedLayerBranch
#
# Each branch in an Inception module applies (optionally) a convolution (using
# Unit3D) or a pooling operation followed by another convolution. In the Sonnet
# code a lambda was used for the “do nothing” branch. Here we define a custom
# layer that sets up branch A and branch B.
###############################################################################

class MixedLayerBranch(layers.Layer):
    def __init__(self,
                 a_output_channels,
                 a_kernel_size,
                 a_name,
                 b_output_channels,
                 b_kernel_size,
                 b_name,
                 is_training=False,
                 **kwargs):
        """
        If a_output_channels (and a_kernel_size) is None then branch A is implemented
        as a max pooling layer.
        """
        super(MixedLayerBranch, self).__init__(**kwargs)
        self.is_training = is_training

        if a_output_channels is not None and a_kernel_size is not None:
            self.branch_a = Unit3D(output_channels=a_output_channels,
                                   kernel_size=a_kernel_size,
                                   strides=(1, 1, 1),
                                   is_training=is_training,
                                   name=a_name)
        else:
            # If branch A is “empty” then we use a 3D max-pooling operation.
            self.branch_a = layers.MaxPooling3D(pool_size=(3, 3, 3),
                                                strides=(1, 1, 1),
                                                padding='same',
                                                name="MaxPool3d_0a_3x3")

        if b_output_channels is not None and b_kernel_size is not None:
            self.branch_b = Unit3D(output_channels=b_output_channels,
                                   kernel_size=b_kernel_size,
                                   strides=(1, 1, 1),
                                   is_training=is_training,
                                   name=b_name)
        else:
            # “Do nothing” if branch B is not defined.
            self.branch_b = lambda x, training=False: x

    def call(self, inputs, training=False):
        # First apply branch A, then branch B.
        x = self.branch_a(inputs, training=training) if hasattr(self.branch_a, 'call') else self.branch_a(inputs)
        x = self.branch_b(x, training=training) if hasattr(self.branch_b, 'call') else self.branch_b(x)
        return x

###############################################################################
# 3. The MixedLayer (Inception Module)
#
# The MixedLayer consists of four branches (branch 0, 1, 2, and 3). Each branch
# is built using a MixedLayerBranch. The outputs of the branches are concatenated
# along the channel dimension.
###############################################################################

class MixedLayer(layers.Layer):
    def __init__(self,
                 branch_0_a_output_channels,
                 branch_1_a_output_channels,
                 branch_1_b_output_channels,
                 branch_2_a_output_channels,
                 branch_2_b_output_channels,
                 branch_3_b_output_channels,
                 branch_0_a_name="Conv3d_0a_1x1",
                 branch_0_a_kernel_size=(1, 1, 1),
                 branch_1_a_name="Conv3d_0a_1x1",
                 branch_1_a_kernel_size=(1, 1, 1),
                 branch_1_b_name="Conv3d_0b_3x3",
                 branch_1_b_kernel_size=(3, 3, 3),
                 branch_2_a_name="Conv3d_0a_1x1",
                 branch_2_a_kernel_size=(1, 1, 1),
                 branch_2_b_name="Conv3d_0b_3x3",
                 branch_2_b_kernel_size=(3, 3, 3),
                 branch_3_b_name="Conv3d_0b_1x1",
                 branch_3_b_kernel_size=(1, 1, 1),
                 is_training=False,
                 **kwargs):
        super(MixedLayer, self).__init__(**kwargs)
        self.branch_0 = MixedLayerBranch(
            a_output_channels=branch_0_a_output_channels,
            a_kernel_size=branch_0_a_kernel_size,
            a_name=branch_0_a_name,
            b_output_channels=None,
            b_kernel_size=None,
            b_name=None,
            is_training=is_training)

        self.branch_1 = MixedLayerBranch(
            a_output_channels=branch_1_a_output_channels,
            a_kernel_size=branch_1_a_kernel_size,
            a_name=branch_1_a_name,
            b_output_channels=branch_1_b_output_channels,
            b_kernel_size=branch_1_b_kernel_size,
            b_name=branch_1_b_name,
            is_training=is_training)

        self.branch_2 = MixedLayerBranch(
            a_output_channels=branch_2_a_output_channels,
            a_kernel_size=branch_2_a_kernel_size,
            a_name=branch_2_a_name,
            b_output_channels=branch_2_b_output_channels,
            b_kernel_size=branch_2_b_kernel_size,
            b_name=branch_2_b_name,
            is_training=is_training)

        self.branch_3 = MixedLayerBranch(
            a_output_channels=None,
            a_kernel_size=None,
            a_name=None,
            b_output_channels=branch_3_b_output_channels,
            b_kernel_size=branch_3_b_kernel_size,
            b_name=branch_3_b_name,
            is_training=is_training)

    def call(self, inputs, training=False):
        branch_0 = self.branch_0(inputs, training=training)
        branch_1 = self.branch_1(inputs, training=training)
        branch_2 = self.branch_2(inputs, training=training)
        branch_3 = self.branch_3(inputs, training=training)
        # Concatenate along the channel axis (last dimension in channels_last)
        return tf.concat([branch_0, branch_1, branch_2, branch_3], axis=-1)

###############################################################################
# 4. The Logits Layer
#
# This layer applies average pooling, dropout, and a final 1x1x1 convolution
# (without batch norm) to produce the class logits. Then spatial dimensions are
# “squeezed” and an average is taken over the time dimension.
###############################################################################

class Logits(layers.Layer):
    def __init__(self,
                 num_classes,
                 spatial_squeeze=True,
                 dropout_keep_prob=1.0,
                 is_training=False,
                 **kwargs):
        super(Logits, self).__init__(**kwargs)
        self.spatial_squeeze = spatial_squeeze
        self.dropout_keep_prob = dropout_keep_prob
        # Average pooling: note that the original uses a window size [2, 7, 7]
        self.avg_pool = layers.AveragePooling3D(pool_size=(2, 7, 7),
                                                 strides=(1, 1, 1),
                                                 padding='valid')
        self.dropout = layers.Dropout(rate=1 - dropout_keep_prob)
        self.conv_logits = Unit3D(output_channels=num_classes,
                                  kernel_size=(1, 1, 1),
                                  activation=None,  # no activation
                                  use_batch_norm=False,
                                  use_bias=True,
                                  is_training=is_training,
                                  name="Conv3d_0c_1x1")

    def call(self, inputs, training=False):
        x = self.avg_pool(inputs)
        x = self.dropout(x, training=training)
        x = self.conv_logits(x, training=training)
        if self.spatial_squeeze:
            # Squeeze out the spatial dimensions (height and width)
            x = tf.squeeze(x, axis=[2, 3])
        # Finally, average over the time dimension (axis=1)
        return tf.reduce_mean(x, axis=1)

###############################################################################
# 5. The InceptionI3d Model
#
# This is the main I3D model. It sequentially “stacks” the various layers and
# modules defined above. (Because of the branching Inception modules, we cannot
# use a pure Sequential model here; instead we subclass tf.keras.Model and call
# each layer in order.) The model also collects “endpoints” (intermediate outputs)
# which is useful for inspection or multi-scale losses.
###############################################################################

class InceptionI3d(tf.keras.Model):
    def __init__(self,
                 num_classes=400,
                 spatial_squeeze=True,
                 is_training=False,
                 dropout_keep_prob=1.0,
                 final_endpoint="Logits",
                 **kwargs):
        """
        Parameters:
          num_classes: number of classes for the final logits.
          spatial_squeeze: whether to squeeze the spatial dimensions.
          is_training: if True, layers such as BatchNormalization and Dropout behave accordingly.
          dropout_keep_prob: probability of keeping a unit.
          final_endpoint: the last layer to build (useful for “partial” networks).
        """
        super(InceptionI3d, self).__init__(**kwargs)
        self.final_endpoint = final_endpoint

        # In the original code a fixed list of endpoint names is used.
        self.valid_endpoints = [
            "Conv3d_1a_7x7", "MaxPool3d_2a_3x3", "Conv3d_2b_1x1", "Conv3d_2c_3x3",
            "MaxPool3d_3a_3x3", "Mixed_3b", "Mixed_3c", "MaxPool3d_4a_3x3",
            "Mixed_4b", "Mixed_4c", "Mixed_4d", "Mixed_4e", "Mixed_4f",
            "MaxPool3d_5a_2x2", "Mixed_5b", "Mixed_5c", "Logits", "Predictions"
        ]

        # Define each layer as in the original model:
        self.Conv3d_1a_7x7 = Unit3D(64,
                                    kernel_size=(7, 7, 7),
                                    strides=(2, 2, 2),
                                    is_training=is_training,
                                    name="Conv3d_1a_7x7")

        self.MaxPool3d_2a_3x3 = layers.MaxPooling3D(pool_size=(1, 3, 3),
                                                    strides=(1, 2, 2),
                                                    padding='same',
                                                    name="MaxPool3d_2a_3x3")

        self.Conv3d_2b_1x1 = Unit3D(64,
                                    kernel_size=(1, 1, 1),
                                    is_training=is_training,
                                    name="Conv3d_2b_1x1")

        self.Conv3d_2c_3x3 = Unit3D(192,
                                    kernel_size=(3, 3, 3),
                                    is_training=is_training,
                                    name="Conv3d_2c_3x3")

        self.MaxPool3d_3a_3x3 = layers.MaxPooling3D(pool_size=(1, 3, 3),
                                                    strides=(1, 2, 2),
                                                    padding='same',
                                                    name="MaxPool3d_3a_3x3")

        self.Mixed_3b = MixedLayer(64, 96, 128, 16, 32, 32,
                                   is_training=is_training,
                                   name="Mixed_3b")

        self.Mixed_3c = MixedLayer(128, 128, 192, 32, 96, 64,
                                   is_training=is_training,
                                   name="Mixed_3c")

        self.MaxPool3d_4a_3x3 = layers.MaxPooling3D(pool_size=(3, 3, 3),
                                                    strides=(2, 2, 2),
                                                    padding='same',
                                                    name="MaxPool3d_4a_3x3")

        self.Mixed_4b = MixedLayer(192, 96, 208, 16, 48, 64,
                                   is_training=is_training,
                                   name="Mixed_4b")

        self.Mixed_4c = MixedLayer(160, 112, 224, 24, 64, 64,
                                   is_training=is_training,
                                   name="Mixed_4c")

        self.Mixed_4d = MixedLayer(128, 128, 256, 24, 64, 64,
                                   is_training=is_training,
                                   name="Mixed_4d")

        self.Mixed_4e = MixedLayer(112, 144, 288, 32, 64, 64,
                                   is_training=is_training,
                                   name="Mixed_4e")

        self.Mixed_4f = MixedLayer(256, 160, 320, 32, 128, 128,
                                   is_training=is_training,
                                   name="Mixed_4f")

        self.MaxPool3d_5a_2x2 = layers.MaxPooling3D(pool_size=(2, 2, 2),
                                                    strides=(2, 2, 2),
                                                    padding='same',
                                                    name="MaxPool3d_5a_2x2")

        self.Mixed_5b = MixedLayer(256, 160, 320, 32, 128, 128,
                                   is_training=is_training,
                                   name="Mixed_5b")

        self.Mixed_5c = MixedLayer(384, 192, 384, 48, 128, 128,
                                   is_training=is_training,
                                   name="Mixed_5c")

        self.Logits = Logits(num_classes,
                             spatial_squeeze=spatial_squeeze,
                             dropout_keep_prob=dropout_keep_prob,
                             is_training=is_training,
                             name="Logits")
        # The Predictions endpoint applies softmax.
        self.Predictions = layers.Softmax(name="Predictions")

    def call(self, inputs, training=False):
        # Check that input has shape [batch, num_frames, 224, 224, channels]
        if inputs.shape.ndims != 5 or inputs.shape[2] != 224 or inputs.shape[3] != 224:
            raise ValueError("Input tensor shape must be [batch, num_frames, 224, 224, channels]")

        endpoints = {}
        x = inputs

        # For each valid endpoint, call the corresponding layer.
        # (In the original implementation, all endpoints up to final_endpoint are returned.)
        for endpoint in self.valid_endpoints:
            layer_or_block = getattr(self, endpoint)
            # Most of our layers accept a training flag.
            # (Note: the Softmax layer does not use training, but that’s harmless.)
            x = layer_or_block(x, training=training)
            endpoints[endpoint] = x
            if endpoint == self.final_endpoint:
                break

        return x, endpoints

In [3]:
# Create an instance of the I3D model.
# (For example, for Kinetics you might have num_classes=400.)
model = InceptionI3d(num_classes=3, is_training=True, dropout_keep_prob=0.5)

# Create a dummy input: batch of 2 videos, each with 64 frames of 224x224 RGB images.
dummy_input = tf.random.uniform(shape=(2, 16, 224, 224, 2))
    
# Run a forward pass.
logits, endpoints = model(dummy_input, training=True)
    
# Print the final output shape and the keys for all endpoints.
print("Logits shape:", logits.shape)
print("Endpoints:", list(endpoints.keys()))

Logits shape: (2, 3)
Endpoints: ['Conv3d_1a_7x7', 'MaxPool3d_2a_3x3', 'Conv3d_2b_1x1', 'Conv3d_2c_3x3', 'MaxPool3d_3a_3x3', 'Mixed_3b', 'Mixed_3c', 'MaxPool3d_4a_3x3', 'Mixed_4b', 'Mixed_4c', 'Mixed_4d', 'Mixed_4e', 'Mixed_4f', 'MaxPool3d_5a_2x2', 'Mixed_5b', 'Mixed_5c', 'Logits']


In [4]:
model.summary()