## Imports

In [1]:
from typing import Callable, Tuple

import torch
import torch.nn as nn
from pytorchvideo.models.head import create_res_basic_head
from pytorchvideo.models.resnet import Net, create_bottleneck_block, create_res_stage
from pytorchvideo.models.stem import create_res_basic_stem

## Values

In [83]:
# Input clip configs.
input_channel: int = 3
# Model configs.
model_depth: int = 50
model_num_class: int = 400
dropout_rate: float = 0.5
# Normalization configs.
norm: Callable = nn.BatchNorm3d
# Activation configs.
activation: Callable = nn.ReLU
# Stem configs.
stem_dim_out: int = 64
stem_conv_kernel_size: Tuple[int] = (3, 7, 7)
stem_conv_stride: Tuple[int] = (1, 2, 2)
stem_pool: Callable = None
stem_pool_kernel_size: Tuple[int] = (1, 3, 3)
stem_pool_stride: Tuple[int] = (1, 2, 2)
# Stage configs.
stage_conv_a_kernel_size: Tuple[int] = (1, 1, 1)
stage_conv_b_kernel_size: Tuple[int] = (3, 3, 3)
stage_conv_b_width_per_group: int = 1
stage_spatial_stride: Tuple[int] = (1, 2, 2, 2)
stage_temporal_stride: Tuple[int] = (1, 2, 2, 2)
bottleneck: Callable = create_bottleneck_block
bottleneck_ratio: int = 4
# Head configs.
head_pool: Callable = nn.AvgPool3d
head_pool_kernel_size: Tuple[int] = (1, 7, 7)
head_output_size: Tuple[int] = (1, 1, 1)
head_activation: Callable = None
head_output_with_global_average: bool = True

In [16]:
# Number of blocks for different stages given the model depth.
_MODEL_STAGE_DEPTH = {50: (3, 4, 6, 3), 101: (3, 4, 23, 3), 152: (3, 8, 36, 3)}

In [17]:
# Given a model depth, get the number of blocks for each stage.
assert (
    model_depth in _MODEL_STAGE_DEPTH.keys()
), f"{model_depth} is not in {_MODEL_STAGE_DEPTH.keys()}"

In [19]:
stage_depths = _MODEL_STAGE_DEPTH[model_depth]
stage_depths

(3, 4, 6, 3)

## Create the model

In [35]:
blocks = []

## Stem

In [22]:
stem = create_res_basic_stem(
    in_channels=input_channel,
    out_channels=stem_dim_out,
    conv_kernel_size=stem_conv_kernel_size,
    conv_stride=stem_conv_stride,
    conv_padding=[size // 2 for size in stem_conv_kernel_size],
    pool=stem_pool,
    pool_kernel_size=stem_pool_kernel_size,
    pool_stride=stem_pool_stride,
    pool_padding=[size // 2 for size in stem_pool_kernel_size],
    norm=norm,
    activation=activation,
)
stem

ResNetBasicStem(
  (conv): Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), padding=(1, 3, 3), bias=False)
  (norm): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (activation): ReLU()
)

### Check stem size

In [33]:
module = stem

torch.save(module.state_dict(), 'module.pth')
module_checkpoint = torch.load('module.pth')

checkpoint_iter = iter(module_checkpoint.keys())
for i in checkpoint_iter:
    print(f'{i}: {module_checkpoint[i].shape}')

conv.weight: torch.Size([64, 3, 3, 7, 7])
norm.weight: torch.Size([64])
norm.bias: torch.Size([64])
norm.running_mean: torch.Size([64])
norm.running_var: torch.Size([64])
norm.num_batches_tracked: torch.Size([])


### Add stem

In [36]:
blocks.append(stem)

In [84]:
stage_dim_out

256

In [85]:
stage_dim_in = stem_dim_out
stage_dim_out = stage_dim_in * 4
stage_dim_out

256

In [86]:
stage_dim_in

64

## Create each stage for CSN

In [60]:
for idx in range(len(stage_depths)):
    print(idx)

0
1
2
3


In [63]:
idx = 0

In [62]:
bottleneck_ratio

4

In [88]:
stage_dim_inner = stage_dim_out // bottleneck_ratio
stage_dim_inner

64

In [66]:
stage_depths

(3, 4, 6, 3)

In [89]:
depth = stage_depths[idx]
depth

3

In [90]:
stage_conv_b_stride = (
        stage_temporal_stride[idx],
        stage_spatial_stride[idx],
        stage_spatial_stride[idx],
    )

stage_conv_b_stride

(1, 1, 1)

In [108]:
stage_dim_in = 64
stage_dim_out = 256
stage_dim_inner = 64
stage = create_res_stage(
            depth=depth,
            dim_in=stage_dim_in,
            dim_inner=stage_dim_inner,
            dim_out=stage_dim_out,
            bottleneck=bottleneck,
            conv_a_kernel_size=stage_conv_a_kernel_size,
            conv_a_stride=(1, 1, 1),
            conv_a_padding=[size // 2 for size in stage_conv_a_kernel_size],
            conv_b_kernel_size=stage_conv_b_kernel_size,
            conv_b_stride=stage_conv_b_stride,
            conv_b_padding=[size // 2 for size in stage_conv_b_kernel_size],
            conv_b_num_groups=(stage_dim_inner // stage_conv_b_width_per_group),
            conv_b_dilation=(1, 1, 1),
            norm=norm,
            activation=activation,
        )
stage

ResStage(
  (res_blocks): ModuleList(
    (0): ResBlock(
      (branch1_conv): Conv3d(64, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
      (branch1_norm): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (branch2): BottleneckBlock(
        (conv_a): Conv3d(64, 64, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
        (norm_a): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act_a): ReLU()
        (conv_b): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), groups=64, bias=False)
        (norm_b): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act_b): ReLU()
        (conv_c): Conv3d(64, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
        (norm_c): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (activation): ReLU()
    )
    (1): ResBlock(
      (branch2): Bottle

## Check stage 0 size

In [109]:
module = stage

torch.save(module.state_dict(), 'module.pth')
module_checkpoint = torch.load('module.pth')

checkpoint_iter = iter(module_checkpoint.keys())
for i in checkpoint_iter:
    print(f'{i}: {module_checkpoint[i].shape}')

res_blocks.0.branch1_conv.weight: torch.Size([256, 64, 1, 1, 1])
res_blocks.0.branch1_norm.weight: torch.Size([256])
res_blocks.0.branch1_norm.bias: torch.Size([256])
res_blocks.0.branch1_norm.running_mean: torch.Size([256])
res_blocks.0.branch1_norm.running_var: torch.Size([256])
res_blocks.0.branch1_norm.num_batches_tracked: torch.Size([])
res_blocks.0.branch2.conv_a.weight: torch.Size([64, 64, 1, 1, 1])
res_blocks.0.branch2.norm_a.weight: torch.Size([64])
res_blocks.0.branch2.norm_a.bias: torch.Size([64])
res_blocks.0.branch2.norm_a.running_mean: torch.Size([64])
res_blocks.0.branch2.norm_a.running_var: torch.Size([64])
res_blocks.0.branch2.norm_a.num_batches_tracked: torch.Size([])
res_blocks.0.branch2.conv_b.weight: torch.Size([64, 1, 3, 3, 3])
res_blocks.0.branch2.norm_b.weight: torch.Size([64])
res_blocks.0.branch2.norm_b.bias: torch.Size([64])
res_blocks.0.branch2.norm_b.running_mean: torch.Size([64])
res_blocks.0.branch2.norm_b.running_var: torch.Size([64])
res_blocks.0.branch

Need to remove branch 1

In [112]:
def create_res_stage(
    *,
    # Stage configs.
    depth: int,
    # Bottleneck Block configs.
    dim_in: int,
    dim_inner: int,
    dim_out: int,
    bottleneck: Callable,
    # Conv configs.
    conv_a_kernel_size = (3, 1, 1),
    conv_a_stride: Tuple[int] = (2, 1, 1),
    conv_a_padding = (1, 0, 0),
    conv_a: Callable = nn.Conv3d,
    conv_b_kernel_size: Tuple[int] = (1, 3, 3),
    conv_b_stride: Tuple[int] = (1, 2, 2),
    conv_b_padding: Tuple[int] = (0, 1, 1),
    conv_b_num_groups: int = 1,
    conv_b_dilation: Tuple[int] = (1, 1, 1),
    conv_b: Callable = nn.Conv3d,
    conv_c: Callable = nn.Conv3d,
    # Norm configs.
    norm: Callable = nn.BatchNorm3d,
    norm_eps: float = 1e-5,
    norm_momentum: float = 0.1,
    # Activation configs.
    activation: Callable = nn.ReLU,
) -> nn.Module:
    """
    Create Residual Stage, which composes sequential blocks that make up a ResNet. These
    blocks could be, for example, Residual blocks, Non-Local layers, or
    Squeeze-Excitation layers.

    ::


                                        Input
                                           ↓
                                       ResBlock
                                           ↓
                                           .
                                           .
                                           .
                                           ↓
                                       ResBlock

    Normalization examples include: BatchNorm3d and None (no normalization).
    Activation examples include: ReLU, Softmax, Sigmoid, and None (no activation).
    Bottleneck examples include: create_bottleneck_block.

    Args:
        depth (init): number of blocks to create.

        dim_in (int): input channel size to the bottleneck block.
        dim_inner (int): intermediate channel size of the bottleneck.
        dim_out (int): output channel size of the bottleneck.
        bottleneck (callable): a callable that constructs bottleneck block layer.
            Examples include: create_bottleneck_block.

        conv_a_kernel_size (tuple or list of tuple): convolutional kernel size(s)
            for conv_a. If conv_a_kernel_size is a tuple, use it for all blocks in
            the stage. If conv_a_kernel_size is a list of tuple, the kernel sizes
            will be repeated until having same length of depth in the stage. For
            example, for conv_a_kernel_size = [(3, 1, 1), (1, 1, 1)], the kernel
            size for the first 6 blocks would be [(3, 1, 1), (1, 1, 1), (3, 1, 1),
            (1, 1, 1), (3, 1, 1)].
        conv_a_stride (tuple): convolutional stride size(s) for conv_a.
        conv_a_padding (tuple or list of tuple): convolutional padding(s) for
            conv_a. If conv_a_padding is a tuple, use it for all blocks in
            the stage. If conv_a_padding is a list of tuple, the padding sizes
            will be repeated until having same length of depth in the stage.
        conv_a (callable): a callable that constructs the conv_a conv layer, examples
            include nn.Conv3d, OctaveConv, etc
        conv_b_kernel_size (tuple): convolutional kernel size(s) for conv_b.
        conv_b_stride (tuple): convolutional stride size(s) for conv_b.
        conv_b_padding (tuple): convolutional padding(s) for conv_b.
        conv_b_num_groups (int): number of groups for groupwise convolution for
            conv_b.
        conv_b_dilation (tuple): dilation for 3D convolution for conv_b.
        conv_b (callable): a callable that constructs the conv_b conv layer, examples
            include nn.Conv3d, OctaveConv, etc
        conv_c (callable): a callable that constructs the conv_c conv layer, examples
            include nn.Conv3d, OctaveConv, etc

        norm (callable): a callable that constructs normalization layer. Examples
            include nn.BatchNorm3d, and None (not performing normalization).
        norm_eps (float): normalization epsilon.
        norm_momentum (float): normalization momentum.

        activation (callable): a callable that constructs activation layer. Examples
            include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None (not performing
            activation).

    Returns:
        (nn.Module): resnet basic stage layer.
    """
    res_blocks = []
    if isinstance(conv_a_kernel_size[0], int):
        conv_a_kernel_size = [conv_a_kernel_size]
    if isinstance(conv_a_padding[0], int):
        conv_a_padding = [conv_a_padding]
    # Repeat conv_a kernels until having same length of depth in the stage.
    conv_a_kernel_size = (conv_a_kernel_size * depth)[:depth]
    conv_a_padding = (conv_a_padding * depth)[:depth]

    for ind in range(depth):
        block = create_res_block(
            dim_in=dim_in if ind == 0 else dim_out,
            dim_inner=dim_inner,
            dim_out=dim_out,
            bottleneck=bottleneck,
            conv_a_kernel_size=conv_a_kernel_size[ind],
            conv_a_stride=conv_a_stride if ind == 0 else (1, 1, 1),
            conv_a_padding=conv_a_padding[ind],
            conv_a=conv_a,
            conv_b_kernel_size=conv_b_kernel_size,
            conv_b_stride=conv_b_stride if ind == 0 else (1, 1, 1),
            conv_b_padding=conv_b_padding,
            conv_b_num_groups=conv_b_num_groups,
            conv_b_dilation=conv_b_dilation,
            conv_b=conv_b,
            conv_c=conv_c,
            norm=norm,
            norm_eps=norm_eps,
            norm_momentum=norm_momentum,
            activation_bottleneck=activation,
            activation_block=activation,
        )
        res_blocks.append(block)
    return ResStage(res_blocks=nn.ModuleList(res_blocks))



# Number of blocks for different stages given the model depth.
_MODEL_STAGE_DEPTH = {50: (3, 4, 6, 3), 101: (3, 4, 23, 3), 152: (3, 8, 36, 3)}

In [113]:
# Stage configs.
depth: int = 3
# Bottleneck Block configs.
dim_in: int = 64
dim_inner: int = 64
dim_out: int = 256
bottleneck: Callable
# Conv configs.
conv_a_kernel_size = (3, 1, 1)
conv_a_stride: Tuple[int] = (2, 1, 1)
conv_a_padding = (1, 0, 0)
conv_a: Callable = nn.Conv3d,
conv_b_kernel_size: Tuple[int] = (1, 3, 3)
conv_b_stride: Tuple[int] = (1, 2, 2)
conv_b_padding: Tuple[int] = (0, 1, 1)
conv_b_num_groups: int = 1
conv_b_dilation: Tuple[int] = (1, 1, 1)
conv_b: Callable = nn.Conv3d
conv_c: Callable = nn.Conv3d
# Norm configs.
norm: Callable = nn.BatchNorm3d
norm_eps: float = 1e-5
norm_momentum: float = 0.1
# Activation configs.
activation: Callable = nn.ReLU

In [115]:
res_blocks = []
if isinstance(conv_a_kernel_size[0], int):
    conv_a_kernel_size = [conv_a_kernel_size]
if isinstance(conv_a_padding[0], int):
    conv_a_padding = [conv_a_padding]
# Repeat conv_a kernels until having same length of depth in the stage.
conv_a_kernel_size = (conv_a_kernel_size * depth)[:depth]
conv_a_padding = (conv_a_padding * depth)[:depth]

In [118]:
def create_res_block(
    *,
    # Bottleneck Block configs.
    dim_in: int,
    dim_inner: int,
    dim_out: int,
    bottleneck: Callable,
    use_shortcut: bool = False,
    branch_fusion: Callable = lambda x, y: x + y,
    # Conv configs.
    conv_a_kernel_size: Tuple[int] = (3, 1, 1),
    conv_a_stride: Tuple[int] = (2, 1, 1),
    conv_a_padding: Tuple[int] = (1, 0, 0),
    conv_a: Callable = nn.Conv3d,
    conv_b_kernel_size: Tuple[int] = (1, 3, 3),
    conv_b_stride: Tuple[int] = (1, 2, 2),
    conv_b_padding: Tuple[int] = (0, 1, 1),
    conv_b_num_groups: int = 1,
    conv_b_dilation: Tuple[int] = (1, 1, 1),
    conv_b: Callable = nn.Conv3d,
    conv_c: Callable = nn.Conv3d,
    conv_skip: Callable = nn.Conv3d,
    # Norm configs.
    norm: Callable = nn.BatchNorm3d,
    norm_eps: float = 1e-5,
    norm_momentum: float = 0.1,
    # Activation configs.
    activation_bottleneck: Callable = nn.ReLU,
    activation_block: Callable = nn.ReLU,
) -> nn.Module:
    """
    Residual block. Performs a summation between an identity shortcut in branch1 and a
    main block in branch2. When the input and output dimensions are different, a
    convolution followed by a normalization will be performed.

    ::


                                         Input
                                           |-------+
                                           ↓       |
                                         Block     |
                                           ↓       |
                                       Summation ←-+
                                           ↓
                                       Activation

    Normalization examples include: BatchNorm3d and None (no normalization).
    Activation examples include: ReLU, Softmax, Sigmoid, and None (no activation).
    Transform examples include: BottleneckBlock.

    Args:
        dim_in (int): input channel size to the bottleneck block.
        dim_inner (int): intermediate channel size of the bottleneck.
        dim_out (int): output channel size of the bottleneck.
        bottleneck (callable): a callable that constructs bottleneck block layer.
            Examples include: create_bottleneck_block.
        use_shortcut (bool): If true, use conv and norm layers in skip connection.
        branch_fusion (callable): a callable that constructs summation layer.
            Examples include: lambda x, y: x + y, OctaveSum.

        conv_a_kernel_size (tuple): convolutional kernel size(s) for conv_a.
        conv_a_stride (tuple): convolutional stride size(s) for conv_a.
        conv_a_padding (tuple): convolutional padding(s) for conv_a.
        conv_a (callable): a callable that constructs the conv_a conv layer, examples
            include nn.Conv3d, OctaveConv, etc
        conv_b_kernel_size (tuple): convolutional kernel size(s) for conv_b.
        conv_b_stride (tuple): convolutional stride size(s) for conv_b.
        conv_b_padding (tuple): convolutional padding(s) for conv_b.
        conv_b_num_groups (int): number of groups for groupwise convolution for
            conv_b.
        conv_b_dilation (tuple): dilation for 3D convolution for conv_b.
        conv_b (callable): a callable that constructs the conv_b conv layer, examples
            include nn.Conv3d, OctaveConv, etc
        conv_c (callable): a callable that constructs the conv_c conv layer, examples
            include nn.Conv3d, OctaveConv, etc
        conv_skip (callable): a callable that constructs the conv_skip conv layer,
        examples include nn.Conv3d, OctaveConv, etc

        norm (callable): a callable that constructs normalization layer. Examples
            include nn.BatchNorm3d, None (not performing normalization).
        norm_eps (float): normalization epsilon.
        norm_momentum (float): normalization momentum.

        activation_bottleneck (callable): a callable that constructs activation layer in
            bottleneck. Examples include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None
            (not performing activation).
        activation_block (callable): a callable that constructs activation layer used
            at the end of the block. Examples include: nn.ReLU, nn.Softmax, nn.Sigmoid,
            and None (not performing activation).

    Returns:
        (nn.Module): resnet basic block layer.
    """
    branch1_conv_stride = tuple(map(np.prod, zip(conv_a_stride, conv_b_stride)))
    norm_model = None
    if use_shortcut or (
        norm is not None and (dim_in != dim_out or np.prod(branch1_conv_stride) != 1)
    ):
        norm_model = norm(num_features=dim_out, eps=norm_eps, momentum=norm_momentum)

    return ResBlock(
        branch1_conv=conv_skip(
            dim_in,
            dim_out,
            kernel_size=(1, 1, 1),
            stride=branch1_conv_stride,
            bias=False,
        )
        if (dim_in != dim_out or np.prod(branch1_conv_stride) != 1) or use_shortcut
        else None,
        branch1_norm=norm_model,
        branch2=bottleneck(
            dim_in=dim_in,
            dim_inner=dim_inner,
            dim_out=dim_out,
            conv_a_kernel_size=conv_a_kernel_size,
            conv_a_stride=conv_a_stride,
            conv_a_padding=conv_a_padding,
            conv_a=conv_a,
            conv_b_kernel_size=conv_b_kernel_size,
            conv_b_stride=conv_b_stride,
            conv_b_padding=conv_b_padding,
            conv_b_num_groups=conv_b_num_groups,
            conv_b_dilation=conv_b_dilation,
            conv_b=conv_b,
            conv_c=conv_c,
            norm=norm,
            norm_eps=norm_eps,
            norm_momentum=norm_momentum,
            activation=activation_bottleneck,
        ),
        activation=None if activation_block is None else activation_block(),
        branch_fusion=branch_fusion,
    )


In [120]:
depth=depth
dim_in=stage_dim_in
dim_inner=stage_dim_inner
dim_out=stage_dim_out
bottleneck=bottleneck
conv_a_kernel_size=stage_conv_a_kernel_size
conv_a_stride=(1, 1, 1)
conv_a_padding=[size // 2 for size in stage_conv_a_kernel_size]
conv_b_kernel_size=stage_conv_b_kernel_size
conv_b_stride=stage_conv_b_stride
conv_b_padding=[size // 2 for size in stage_conv_b_kernel_size]
conv_b_num_groups=(stage_dim_inner // stage_conv_b_width_per_group)
conv_b_dilation=(1, 1, 1)
norm=norm
activation=activation

In [122]:
import numpy as np

In [126]:
class ResBlock(nn.Module):
    """
    Residual block. Performs a summation between an identity shortcut in branch1 and a
    main block in branch2. When the input and output dimensions are different, a
    convolution followed by a normalization will be performed.

    ::


                                         Input
                                           |-------+
                                           ↓       |
                                         Block     |
                                           ↓       |
                                       Summation ←-+
                                           ↓
                                       Activation

    The builder can be found in `create_res_block`.
    """

    def __init__(
        self,
        branch1_conv: nn.Module = None,
        branch1_norm: nn.Module = None,
        branch2: nn.Module = None,
        activation: nn.Module = None,
        branch_fusion: Callable = None,
    ) -> nn.Module:
        """
        Args:
            branch1_conv (torch.nn.modules): convolutional module in branch1.
            branch1_norm (torch.nn.modules): normalization module in branch1.
            branch2 (torch.nn.modules): bottleneck block module in branch2.
            activation (torch.nn.modules): activation module.
            branch_fusion: (Callable): A callable or layer that combines branch1
                and branch2.
        """
        super().__init__()
        set_attributes(self, locals())
        assert self.branch2 is not None


    def forward(self, x) -> torch.Tensor:
        if self.branch1_conv is None:
            x = self.branch_fusion(x, self.branch2(x))
        else:
            shortcut = self.branch1_conv(x)
            if self.branch1_norm is not None:
                shortcut = self.branch1_norm(shortcut)
            x = self.branch_fusion(shortcut, self.branch2(x))
        if self.activation is not None:
            x = self.activation(x)
        return x

In [127]:
ind = 0
block = create_res_block(
    dim_in=dim_in if ind == 0 else dim_out,
    dim_inner=dim_inner,
    dim_out=dim_out,
    bottleneck=bottleneck,
    conv_a_kernel_size=conv_a_kernel_size[ind],
    conv_a_stride=conv_a_stride if ind == 0 else (1, 1, 1),
    conv_a_padding=conv_a_padding[ind],
    conv_a=conv_a,
    conv_b_kernel_size=conv_b_kernel_size,
    conv_b_stride=conv_b_stride if ind == 0 else (1, 1, 1),
    conv_b_padding=conv_b_padding,
    conv_b_num_groups=conv_b_num_groups,
    conv_b_dilation=conv_b_dilation,
    conv_b=conv_b,
    conv_c=conv_c,
    norm=norm,
    norm_eps=norm_eps,
    norm_momentum=norm_momentum,
    activation_bottleneck=activation,
    activation_block=activation,
)

TypeError: 'tuple' object is not callable