In [None]:
import torch
import torch.nn as nn
import numpy as np

In [None]:
class OriginDownSample(nn.Module):
    def __init__(self, planes):
        super(OriginDownSample, self).__init__()
        self.conv1 = nn.Conv2d(planes, planes * 2, kernel_size=3, stride=2, padding=1, bias=False)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(planes * 2, planes * 2, kernel_size=3, stride=1, padding=1, bias=False)
        self.down_sample = nn.Conv2d(planes, planes * 2, kernel_size=1, stride=2, bias=False)

    def forward(self, x):
        y = self.conv1(x)
        y = self.relu1(y)
        y = self.conv2(y)
        return y + self.down_sample(x)

In [None]:
class RMDownSampleV2Stage1(nn.Module):
    def __init__(self, planes):
        super(RMDownSampleV2Stage1, self).__init__()
        self.conv1 = nn.Conv2d(planes, planes * 2, kernel_size=3, stride=2, padding=1, bias=False)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(planes * 2, planes * 2, kernel_size=3, stride=1, padding=1, bias=False)
        self.identity_down_sample = nn.Conv2d(planes, planes, kernel_size=3, stride=2, padding=1, bias=False)
        nn.init.dirac_(self.identity_down_sample.weight)
        self.relu2 = nn.ReLU()
        self.down_sample2 = nn.Conv2d(planes, planes * 2, kernel_size=3, stride=1, padding=1, bias=False)

    def forward(self, x):
        branchA = self.conv1(x)
        branchA = self.relu1(branchA)
        branchA = self.conv2(branchA)

        branchB = self.identity_down_sample(x)
        branchB = self.relu2(branchB)
        branchB = self.down_sample2(branchB)

        return branchA + branchB

In [None]:
class RMDownSampleV2Stage2(nn.Module):
    def __init__(self, planes):
        super(RMDownSampleV2Stage2, self).__init__()
        self.conv1 = nn.Conv2d(planes, planes * 3, kernel_size=3, stride=2, padding=1, bias=False)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(planes * 3, planes * 2, kernel_size=3, stride=1, padding=1, bias=False)

    def forward(self, x):
        y = self.conv1(x)
        y = self.relu1(y)
        y = self.conv2(y)
        return y

In [None]:
planes = 2
OriginResDownSample = OriginDownSample(planes)
RMDownSample1 = RMDownSampleV2Stage1(planes)
RMDownSample2 = RMDownSampleV2Stage2(planes)

In [None]:
"""
Do Some Initialization
"""
RMDownSample1.conv1.weight = OriginResDownSample.conv1.weight
RMDownSample1.conv2.weight = OriginResDownSample.conv2.weight
# Padding from zero value to convert 1x1 to 3x3 kernel
RMDownSample1.down_sample2.weight = torch.nn.Parameter(
                                        torch.nn.functional.pad(
                                            OriginResDownSample.down_sample.weight.data, [1, 1, 1, 1], value=0.0)
                                    )

In [None]:
x = torch.Tensor(np.random.uniform(low=0.0, high=1, size=(1, planes, 4, 4)))
original_res_output = OriginResDownSample(x)
rmblock_output = RMDownSample1(x)
print("RM output is equal?: ", np.allclose(original_res_output.detach().numpy(),
                                           rmblock_output.detach().numpy(),
                                           atol=1e-4))

In [None]:
"""
Do Some Initialization
"""
RMDownSample2.conv1.weight = torch.nn.Parameter(
                                torch.cat(
                                    [RMDownSample1.conv1.weight, RMDownSample1.identity_down_sample.weight], dim=0)
                            )
RMDownSample2.conv2.weight = torch.nn.Parameter(
                                torch.cat(
                                    [RMDownSample1.conv2.weight, RMDownSample1.down_sample2.weight], dim=1)
                            )

In [None]:
rmblock_outputv2 = RMDownSample2(x)
print("RM output is equal?: ", np.allclose(rmblock_output.detach().numpy(),
                                           rmblock_outputv2.detach().numpy(),
                                           atol=1e-4))