## Imports

In [36]:
import torch
import torch.nn as nn
import torch.optim as optim
import os
from torchvision import datasets
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np
from torchvision import transforms
from video_dataset import VideoFrameDataset, ImglistToTensor

In [103]:
model_checkpoint = torch.load('ircsn_from_scratch_r50_ig65m_20210617-ce545a37.pth')

In [104]:
checkpoint_keys = list(model_checkpoint.keys())
checkpoint_keys

['conv1.conv.weight',
 'conv1.bn.weight',
 'conv1.bn.bias',
 'conv1.bn.running_mean',
 'conv1.bn.running_var',
 'conv1.bn.num_batches_tracked',
 'layer1.0.conv1.conv.weight',
 'layer1.0.conv1.bn.weight',
 'layer1.0.conv1.bn.bias',
 'layer1.0.conv1.bn.running_mean',
 'layer1.0.conv1.bn.running_var',
 'layer1.0.conv1.bn.num_batches_tracked',
 'layer1.0.conv2.0.conv.weight',
 'layer1.0.conv2.0.bn.weight',
 'layer1.0.conv2.0.bn.bias',
 'layer1.0.conv2.0.bn.running_mean',
 'layer1.0.conv2.0.bn.running_var',
 'layer1.0.conv2.0.bn.num_batches_tracked',
 'layer1.0.conv3.conv.weight',
 'layer1.0.conv3.bn.weight',
 'layer1.0.conv3.bn.bias',
 'layer1.0.conv3.bn.running_mean',
 'layer1.0.conv3.bn.running_var',
 'layer1.0.conv3.bn.num_batches_tracked',
 'layer1.0.downsample.conv.weight',
 'layer1.0.downsample.bn.weight',
 'layer1.0.downsample.bn.bias',
 'layer1.0.downsample.bn.running_mean',
 'layer1.0.downsample.bn.running_var',
 'layer1.0.downsample.bn.num_batches_tracked',
 'layer1.1.conv1.conv.

In [105]:
key = 'layer1.0.conv1.bn.weight'
our_key = 'layer1.0.bn1.weight'
values = checkpoint[key]
our_value = ours[our_key]
print(values.shape)
print(our_value.shape)

KeyError: 'layer1.0.bn1.weight'

In [66]:
replacements[our_key] = key

In [34]:
checkpoint[our_key] = our_value
del checkpoint[key]

## Pytorch CSN Implementation

In [125]:
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

from typing import Callable, Tuple

import torch
import torch.nn as nn
from pytorchvideo.models.head import create_res_basic_head
from pytorchvideo.models.resnet import Net, create_bottleneck_block, create_res_stage
from pytorchvideo.models.stem import create_res_basic_stem


def create_csn(
    *,
    # Input clip configs.
    input_channel: int = 3,
    # Model configs.
    model_depth: int = 50,
    model_num_class: int = 400,
    dropout_rate: float = 0.5,
    # Normalization configs.
    norm: Callable = nn.BatchNorm3d,
    # Activation configs.
    activation: Callable = nn.ReLU,
    # Stem configs.
    stem_dim_out: int = 64,
    stem_conv_kernel_size: Tuple[int] = (3, 7, 7),
    stem_conv_stride: Tuple[int] = (1, 2, 2),
    stem_pool: Callable = None,
    stem_pool_kernel_size: Tuple[int] = (1, 3, 3),
    stem_pool_stride: Tuple[int] = (1, 2, 2),
    # Stage configs.
    stage_conv_a_kernel_size: Tuple[int] = (1, 1, 1),
    stage_conv_b_kernel_size: Tuple[int] = (3, 3, 3),
    stage_conv_b_width_per_group: int = 1,
    stage_spatial_stride: Tuple[int] = (1, 2, 2, 2),
    stage_temporal_stride: Tuple[int] = (1, 2, 2, 2),
    bottleneck: Callable = create_bottleneck_block,
    bottleneck_ratio: int = 4,
    # Head configs.
    head_pool: Callable = nn.AvgPool3d,
    head_pool_kernel_size: Tuple[int] = (1, 7, 7),
    head_output_size: Tuple[int] = (1, 1, 1),
    head_activation: Callable = None,
    head_output_with_global_average: bool = True,
) -> nn.Module:
    """
    Build Channel-Separated Convolutional Networks (CSN):
    Video classification with channel-separated convolutional networks.
    Du Tran, Heng Wang, Lorenzo Torresani, Matt Feiszli. ICCV 2019.

    CSN follows the ResNet style architecture including three parts: Stem,
    Stages and Head. The three parts are assembled in the following order:

    ::

                                         Input
                                           ↓
                                         Stem
                                           ↓
                                         Stage 1
                                           ↓
                                           .
                                           .
                                           .
                                           ↓
                                         Stage N
                                           ↓
                                         Head

    CSN uses depthwise convolution. To further reduce the computational cost, it uses
    low resolution (112x112), short clips (4 frames), different striding and kernel
    size, etc.

    Args:

        input_channel (int): number of channels for the input video clip.

        model_depth (int): the depth of the resnet. Options include: 50, 101, 152.
            model_num_class (int): the number of classes for the video dataset.
            dropout_rate (float): dropout rate.

        norm (callable): a callable that constructs normalization layer.

        activation (callable): a callable that constructs activation layer.

        stem_dim_out (int): output channel size to stem.
        stem_conv_kernel_size (tuple): convolutional kernel size(s) of stem.
        stem_conv_stride (tuple): convolutional stride size(s) of stem.
        stem_pool (callable): a callable that constructs resnet head pooling layer.
        stem_pool_kernel_size (tuple): pooling kernel size(s).
        stem_pool_stride (tuple): pooling stride size(s).

        stage_conv_a_kernel_size (tuple): convolutional kernel size(s) for conv_a.
        stage_conv_b_kernel_size (tuple): convolutional kernel size(s) for conv_b.
        stage_conv_b_width_per_group(int): the width of each group for conv_b. Set
            it to 1 for depthwise convolution.
        stage_spatial_stride (tuple): the spatial stride for each stage.
        stage_temporal_stride (tuple): the temporal stride for each stage.
        bottleneck (callable): a callable that constructs bottleneck block layer.
            Examples include: create_bottleneck_block.
        bottleneck_ratio (int): the ratio between inner and outer dimensions for
            the bottleneck block.

        head_pool (callable): a callable that constructs resnet head pooling layer.
        head_pool_kernel_size (tuple): the pooling kernel size.
        head_output_size (tuple): the size of output tensor for head.
        head_activation (callable): a callable that constructs activation layer.
        head_output_with_global_average (bool): if True, perform global averaging on
            the head output.

    Returns:
        (nn.Module): the csn model.
    """

    torch._C._log_api_usage_once("PYTORCHVIDEO.model.create_csn")

    # Number of blocks for different stages given the model depth.
    _MODEL_STAGE_DEPTH = {50: (3, 4, 6, 3), 101: (3, 4, 23, 3), 152: (3, 8, 36, 3)}

    # Given a model depth, get the number of blocks for each stage.
    assert (
        model_depth in _MODEL_STAGE_DEPTH.keys()
    ), f"{model_depth} is not in {_MODEL_STAGE_DEPTH.keys()}"
    stage_depths = _MODEL_STAGE_DEPTH[model_depth]

    blocks = []
    # Create stem for CSN.
    stem = create_res_basic_stem(
        in_channels=input_channel,
        out_channels=stem_dim_out,
        conv_kernel_size=stem_conv_kernel_size,
        conv_stride=stem_conv_stride,
        conv_padding=[size // 2 for size in stem_conv_kernel_size],
        pool=stem_pool,
        pool_kernel_size=stem_pool_kernel_size,
        pool_stride=stem_pool_stride,
        pool_padding=[size // 2 for size in stem_pool_kernel_size],
        norm=norm,
        activation=activation,
    )
    blocks.append(stem)

    stage_dim_in = stem_dim_out
    stage_dim_out = stage_dim_in * 4

    # Create each stage for CSN.
    for idx in range(len(stage_depths)):
        stage_dim_inner = stage_dim_out // bottleneck_ratio
        depth = stage_depths[idx]

        stage_conv_b_stride = (
            stage_temporal_stride[idx],
            stage_spatial_stride[idx],
            stage_spatial_stride[idx],
        )

        stage = create_res_stage(
            depth=depth,
            dim_in=stage_dim_in,
            dim_inner=stage_dim_inner,
            dim_out=stage_dim_out,
            bottleneck=bottleneck,
            conv_a_kernel_size=stage_conv_a_kernel_size,
            conv_a_stride=(1, 1, 1),
            conv_a_padding=[size // 2 for size in stage_conv_a_kernel_size],
            conv_b_kernel_size=stage_conv_b_kernel_size,
            conv_b_stride=stage_conv_b_stride,
            conv_b_padding=[size // 2 for size in stage_conv_b_kernel_size],
            conv_b_num_groups=(stage_dim_inner // stage_conv_b_width_per_group),
            conv_b_dilation=(1, 1, 1),
            norm=norm,
            activation=activation,
        )

        blocks.append(stage)
        stage_dim_in = stage_dim_out
        stage_dim_out = stage_dim_out * 2

    # Create head for CSN.
    head = create_res_basic_head(
        in_features=stage_dim_in,
        out_features=model_num_class,
        pool=head_pool,
        output_size=head_output_size,
        pool_kernel_size=head_pool_kernel_size,
        dropout_rate=dropout_rate,
        activation=head_activation,
        output_with_global_average=head_output_with_global_average,
    )
#     blocks.append(head)
    return Net(blocks=nn.ModuleList(blocks))

## Check if the sizes are the same

In [126]:
# Create a CSN model
csn = create_csn()
torch.save(csn.state_dict(), 'csn_pytorch.pth')
csn_checkpoint = torch.load('csn_pytorch.pth')
list(csn_checkpoint.keys())

['blocks.0.conv.weight',
 'blocks.0.norm.weight',
 'blocks.0.norm.bias',
 'blocks.0.norm.running_mean',
 'blocks.0.norm.running_var',
 'blocks.0.norm.num_batches_tracked',
 'blocks.1.res_blocks.0.branch1_conv.weight',
 'blocks.1.res_blocks.0.branch1_norm.weight',
 'blocks.1.res_blocks.0.branch1_norm.bias',
 'blocks.1.res_blocks.0.branch1_norm.running_mean',
 'blocks.1.res_blocks.0.branch1_norm.running_var',
 'blocks.1.res_blocks.0.branch1_norm.num_batches_tracked',
 'blocks.1.res_blocks.0.branch2.conv_a.weight',
 'blocks.1.res_blocks.0.branch2.norm_a.weight',
 'blocks.1.res_blocks.0.branch2.norm_a.bias',
 'blocks.1.res_blocks.0.branch2.norm_a.running_mean',
 'blocks.1.res_blocks.0.branch2.norm_a.running_var',
 'blocks.1.res_blocks.0.branch2.norm_a.num_batches_tracked',
 'blocks.1.res_blocks.0.branch2.conv_b.weight',
 'blocks.1.res_blocks.0.branch2.norm_b.weight',
 'blocks.1.res_blocks.0.branch2.norm_b.bias',
 'blocks.1.res_blocks.0.branch2.norm_b.running_mean',
 'blocks.1.res_blocks.0.

In [122]:
dict_keys

['conv1.conv.weight',
 'conv1.bn.weight',
 'conv1.bn.bias',
 'conv1.bn.running_mean',
 'conv1.bn.running_var',
 'conv1.bn.num_batches_tracked',
 'layer1.0.conv1.conv.weight',
 'layer1.0.conv1.bn.weight',
 'layer1.0.conv1.bn.bias',
 'layer1.0.conv1.bn.running_mean',
 'layer1.0.conv1.bn.running_var',
 'layer1.0.conv1.bn.num_batches_tracked',
 'layer1.0.conv2.0.conv.weight',
 'layer1.0.conv2.0.bn.weight',
 'layer1.0.conv2.0.bn.bias',
 'layer1.0.conv2.0.bn.running_mean',
 'layer1.0.conv2.0.bn.running_var',
 'layer1.0.conv2.0.bn.num_batches_tracked',
 'layer1.0.conv3.conv.weight',
 'layer1.0.conv3.bn.weight',
 'layer1.0.conv3.bn.bias',
 'layer1.0.conv3.bn.running_mean',
 'layer1.0.conv3.bn.running_var',
 'layer1.0.conv3.bn.num_batches_tracked',
 'layer1.0.downsample.conv.weight',
 'layer1.0.downsample.bn.weight',
 'layer1.0.downsample.bn.bias',
 'layer1.0.downsample.bn.running_mean',
 'layer1.0.downsample.bn.running_var',
 'layer1.0.downsample.bn.num_batches_tracked',
 'layer1.1.conv1.conv.

In [123]:
len(model_checkpoint.keys())

318

In [91]:
len(csn_checkpoint.keys())

318

## Check layer sizes


In [107]:
csn_iter = iter(csn_checkpoint.keys())
checkpoint_iter = iter(checkpoint.keys())
model_key = next(csn_iter)
checkpoint_key = next(checkpoint_iter)

In [109]:
csn_checkpoint[model_key].shape

torch.Size([64, 3, 3, 7, 7])

In [111]:
model_checkpoint[checkpoint_key].shape

torch.Size([64, 3, 3, 7, 7])

In [112]:
csn_checkpoint[model_key].shape == model_checkpoint[checkpoint_key].shape

True

In [128]:
csn_iter = iter(csn_checkpoint.keys())
checkpoint_iter = iter(checkpoint.keys())
for i in range(318):
    model_key = next(csn_iter)
    checkpoint_key = next(checkpoint_iter)
#     if not csn_checkpoint[model_key].shape == model_checkpoint[checkpoint_key].shape:
    print(f'{checkpoint_key}: {model_checkpoint[checkpoint_key].shape}; {model_key}: {csn_checkpoint[model_key].shape}')

conv1.conv.weight: torch.Size([64, 3, 3, 7, 7]); blocks.0.conv.weight: torch.Size([64, 3, 3, 7, 7])
conv1.bn.weight: torch.Size([64]); blocks.0.norm.weight: torch.Size([64])
conv1.bn.bias: torch.Size([64]); blocks.0.norm.bias: torch.Size([64])
conv1.bn.running_mean: torch.Size([64]); blocks.0.norm.running_mean: torch.Size([64])
conv1.bn.running_var: torch.Size([64]); blocks.0.norm.running_var: torch.Size([64])
conv1.bn.num_batches_tracked: torch.Size([]); blocks.0.norm.num_batches_tracked: torch.Size([])
layer1.0.conv1.conv.weight: torch.Size([64, 64, 1, 1, 1]); blocks.1.res_blocks.0.branch1_conv.weight: torch.Size([256, 64, 1, 1, 1])
layer1.0.conv1.bn.weight: torch.Size([64]); blocks.1.res_blocks.0.branch1_norm.weight: torch.Size([256])
layer1.0.conv1.bn.bias: torch.Size([64]); blocks.1.res_blocks.0.branch1_norm.bias: torch.Size([256])
layer1.0.conv1.bn.running_mean: torch.Size([64]); blocks.1.res_blocks.0.branch1_norm.running_mean: torch.Size([256])
layer1.0.conv1.bn.running_var: tor

## Check the model

In [None]:
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

from typing import Callable, Tuple

import torch
import torch.nn as nn
from pytorchvideo.models.head import create_res_basic_head
from pytorchvideo.models.resnet import Net, create_bottleneck_block, create_res_stage
from pytorchvideo.models.stem import create_res_basic_stem


def create_csn(
    *,
    # Input clip configs.
    input_channel: int = 3,
    # Model configs.
    model_depth: int = 50,
    model_num_class: int = 400,
    dropout_rate: float = 0.5,
    # Normalization configs.
    norm: Callable = nn.BatchNorm3d,
    # Activation configs.
    activation: Callable = nn.ReLU,
    # Stem configs.
    stem_dim_out: int = 64,
    stem_conv_kernel_size: Tuple[int] = (3, 7, 7),
    stem_conv_stride: Tuple[int] = (1, 2, 2),
    stem_pool: Callable = None,
    stem_pool_kernel_size: Tuple[int] = (1, 3, 3),
    stem_pool_stride: Tuple[int] = (1, 2, 2),
    # Stage configs.
    stage_conv_a_kernel_size: Tuple[int] = (1, 1, 1),
    stage_conv_b_kernel_size: Tuple[int] = (3, 3, 3),
    stage_conv_b_width_per_group: int = 1,
    stage_spatial_stride: Tuple[int] = (1, 2, 2, 2),
    stage_temporal_stride: Tuple[int] = (1, 2, 2, 2),
    bottleneck: Callable = create_bottleneck_block,
    bottleneck_ratio: int = 4,
    # Head configs.
    head_pool: Callable = nn.AvgPool3d,
    head_pool_kernel_size: Tuple[int] = (1, 7, 7),
    head_output_size: Tuple[int] = (1, 1, 1),
    head_activation: Callable = None,
    head_output_with_global_average: bool = True,
) -> nn.Module:
    """
    Build Channel-Separated Convolutional Networks (CSN):
    Video classification with channel-separated convolutional networks.
    Du Tran, Heng Wang, Lorenzo Torresani, Matt Feiszli. ICCV 2019.

    CSN follows the ResNet style architecture including three parts: Stem,
    Stages and Head. The three parts are assembled in the following order:

    ::

                                         Input
                                           ↓
                                         Stem
                                           ↓
                                         Stage 1
                                           ↓
                                           .
                                           .
                                           .
                                           ↓
                                         Stage N
                                           ↓
                                         Head

    CSN uses depthwise convolution. To further reduce the computational cost, it uses
    low resolution (112x112), short clips (4 frames), different striding and kernel
    size, etc.

    Args:

        input_channel (int): number of channels for the input video clip.

        model_depth (int): the depth of the resnet. Options include: 50, 101, 152.
            model_num_class (int): the number of classes for the video dataset.
            dropout_rate (float): dropout rate.

        norm (callable): a callable that constructs normalization layer.

        activation (callable): a callable that constructs activation layer.

        stem_dim_out (int): output channel size to stem.
        stem_conv_kernel_size (tuple): convolutional kernel size(s) of stem.
        stem_conv_stride (tuple): convolutional stride size(s) of stem.
        stem_pool (callable): a callable that constructs resnet head pooling layer.
        stem_pool_kernel_size (tuple): pooling kernel size(s).
        stem_pool_stride (tuple): pooling stride size(s).

        stage_conv_a_kernel_size (tuple): convolutional kernel size(s) for conv_a.
        stage_conv_b_kernel_size (tuple): convolutional kernel size(s) for conv_b.
        stage_conv_b_width_per_group(int): the width of each group for conv_b. Set
            it to 1 for depthwise convolution.
        stage_spatial_stride (tuple): the spatial stride for each stage.
        stage_temporal_stride (tuple): the temporal stride for each stage.
        bottleneck (callable): a callable that constructs bottleneck block layer.
            Examples include: create_bottleneck_block.
        bottleneck_ratio (int): the ratio between inner and outer dimensions for
            the bottleneck block.

        head_pool (callable): a callable that constructs resnet head pooling layer.
        head_pool_kernel_size (tuple): the pooling kernel size.
        head_output_size (tuple): the size of output tensor for head.
        head_activation (callable): a callable that constructs activation layer.
        head_output_with_global_average (bool): if True, perform global averaging on
            the head output.

    Returns:
        (nn.Module): the csn model.
    """

    torch._C._log_api_usage_once("PYTORCHVIDEO.model.create_csn")

    # Number of blocks for different stages given the model depth.
    _MODEL_STAGE_DEPTH = {50: (3, 4, 6, 3), 101: (3, 4, 23, 3), 152: (3, 8, 36, 3)}

    # Given a model depth, get the number of blocks for each stage.
    assert (
        model_depth in _MODEL_STAGE_DEPTH.keys()
    ), f"{model_depth} is not in {_MODEL_STAGE_DEPTH.keys()}"
    stage_depths = _MODEL_STAGE_DEPTH[model_depth]

    blocks = []
    # Create stem for CSN.
    stem = create_res_basic_stem(
        in_channels=input_channel,
        out_channels=stem_dim_out,
        conv_kernel_size=stem_conv_kernel_size,
        conv_stride=stem_conv_stride,
        conv_padding=[size // 2 for size in stem_conv_kernel_size],
        pool=stem_pool,
        pool_kernel_size=stem_pool_kernel_size,
        pool_stride=stem_pool_stride,
        pool_padding=[size // 2 for size in stem_pool_kernel_size],
        norm=norm,
        activation=activation,
    )
    blocks.append(stem)

    stage_dim_in = stem_dim_out
    stage_dim_out = stage_dim_in * 4

    # Create each stage for CSN.
    for idx in range(len(stage_depths)):
        stage_dim_inner = stage_dim_out // bottleneck_ratio
        depth = stage_depths[idx]

        stage_conv_b_stride = (
            stage_temporal_stride[idx],
            stage_spatial_stride[idx],
            stage_spatial_stride[idx],
        )

        stage = create_res_stage(
            depth=depth,
            dim_in=stage_dim_in,
            dim_inner=stage_dim_inner,
            dim_out=stage_dim_out,
            bottleneck=bottleneck,
            conv_a_kernel_size=stage_conv_a_kernel_size,
            conv_a_stride=(1, 1, 1),
            conv_a_padding=[size // 2 for size in stage_conv_a_kernel_size],
            conv_b_kernel_size=stage_conv_b_kernel_size,
            conv_b_stride=stage_conv_b_stride,
            conv_b_padding=[size // 2 for size in stage_conv_b_kernel_size],
            conv_b_num_groups=(stage_dim_inner // stage_conv_b_width_per_group),
  s          conv_b_dilation=(1, 1, 1),
            norm=norm,
            activation=activation,
        )

        blocks.append(stage)
        stage_dim_in = stage_dim_out
        stage_dim_out = stage_dim_out * 2

    # Create head for CSN.
    head = create_res_basic_head(
        in_features=stage_dim_in,
        out_features=model_num_class,
        pool=head_pool,
        output_size=head_output_size,
        pool_kernel_size=head_pool_kernel_size,
        dropout_rate=dropout_rate,
        activation=head_activation,
        output_with_global_average=head_output_with_global_average,
    )
#     blocks.append(head)
    return Net(blocks=nn.ModuleList(blocks))