<a href="https://colab.research.google.com/github/IANGECHUKI176/deeplearning/blob/main/pytorch/convnets/rir.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# resnet in resnet in pytorch

Sasha Targ, Diogo Almeida, Kevin Lyman.

Resnet in Resnet: Generalizing Residual Architectures

> https://arxiv.org/abs/1603.08029v1

In [None]:
import torch
import torch.nn as nn
from torchsummary import summary

In [None]:
class ResnetInit(nn.Module):
    def __init__(self,in_channels,out_channels,stride):
        super(ResnetInit,self).__init__()

        #"""The modular unit of the generalized residual network architecture is a
        #generalized residual block consisting of parallel states for a residual stream,
        #r, which contains identity shortcut connections and is similar to the structure
        #of a residual block from the original ResNet with a single convolutional layer
        #(parameters W l,r→r )
        self.residual_stream_conv = nn.Conv2d(in_channels,out_channels,3,padding = 1,stride = stride)
        #"""and a transient stream, t, which is a standard convolutional layer
        #(W l,t→t )."""
        self.transient_stream_conv = nn.Conv2d(in_channels,out_channels,3,padding = 1,stride = stride)
        #"""Two additional sets of convolutional filters in each block (W l,r→t , W l,t→r )
        #also transfer information across streams."""

        self.residual_stream_conv_across = nn.Conv2d(in_channels,out_channels,3,padding = 1,stride = stride)
        #"""We use equal numbers of filters for the residual and transient streams of the
        #generalized residual network, but optimizing this hyperparameter could lead to
        #further potential improvements."""
        self.transient_stream_conv_across = nn.Conv2d(in_channels,out_channels,3,padding = 1,stride = stride)

        self.residual_bn_relu = nn.Sequential(
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace = True)
        )
        self.transient_bn_relu = nn.Sequential(
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace = True)
        )
        #"""The form of the shortcut connection can be an identity function with
        #the appropriate padding or a projection as in He et al. (2015b).
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels,out_channels,1,stride = stride)

    def forward(self,x):
        x_residual,x_transient = x
        residual_r_r = self.residual_stream_conv(x_residual)
        residual_r_t = self.residual_stream_conv_across(x_residual)
        residual_shortcut = self.shortcut(x_residual)

        transient_t_t = self.transient_stream_conv(x_transient)
        transient_t_r = self.transient_stream_conv_across(x_transient)

        x_residual = self.residual_bn_relu(residual_r_r + transient_t_r + residual_shortcut)
        x_transient = self.transient_bn_relu(transient_t_t + residual_r_t)
        #"""Same-stream and cross-stream activations are summed (along with the
        #shortcut connection for the residual stream) before applying batch
        #normalization and ReLU nonlinearities (together σ) to get the output
        #states of the block (Equation 1) (Ioffe & Szegedy, 2015)."""
        return x_residual, x_transient

In [None]:
class RiRBlock(nn.Module):
    def __init__(self,in_channels,out_channel,layer_num,stride,layer = ResnetInit):
        super(RiRBlock,self).__init__()
        self.resnetinit = self._make_layers(in_channels,out_channel,layer_num,stride)

    def forward(self,x):
        x_residual, x_transient = self.resnetinit(x)
        return (x_residual, x_transient)
    #"""Replacing each of the convolutional layers within a residual
    #block from the original ResNet (Figure 1a) with a generalized residual block
    #(Figure 1b) leads us to a new architecture we call ResNet in ResNet (RiR)
    #(Figure 1d)."""
    def _make_layers(self, in_channel, out_channel, layer_num, stride, layer=ResnetInit):
        strides = [stride] + [1] * (layer_num - 1)
        layers = nn.Sequential()
        for index, s in enumerate(strides):
            layers.add_module("generalized layers {}".format(index), layer(in_channel, out_channel, s))
            in_channel = out_channel

        return layers

In [None]:
blk2 = RiRBlock(32, 64, 2, 1)

# Create a random input tensor of the desired shape (batch_size, channels, height, width)
input_tensor = torch.randn(1, 32, 224, 224)

# Print the summary of the model using the print function
x_residual = torch.randn(1, 32, 224, 224)
x_transient = torch.randn(1, 32, 224, 224)

# Pass the inputs as a tuple to the model's forward method
output_residual, output_transient = blk2((x_residual, x_transient))

In [None]:
class ResnetInResnet(nn.Module):
    def __init__(self,n_classes = 10):
        super(ResnetInResnet,self).__init__()
        base = int(96/ 2)

        self.residual_pre_conv = nn.Sequential(
            nn.Conv2d(3,base,3,padding = 1,bias = False),
            nn.BatchNorm2d(base),
            nn.ReLU(inplace = True)
        )
        self.transient_pre_conv = nn.Sequential(
            nn.Conv2d(3,base,3,padding = 1,bias = False),
            nn.BatchNorm2d(base),
            nn.ReLU(inplace = True)
        )
        self.rir1 = RiRBlock(base, base, 2, 1)
        self.rir2 = RiRBlock(base, base, 2, 1)
        self.rir3 = RiRBlock(base, base*2, 2, 2)
        self.rir4 = RiRBlock(base*2, base*2, 2, 1)
        self.rir5 = RiRBlock(base*2, base*2, 2, 1)
        self.rir6 = RiRBlock(base*2, base*4, 2, 2)
        self.rir7 = RiRBlock(base*4, base*4, 2, 1)
        self.rir8 = RiRBlock(base*4, base*4, 2, 1)
        self.conv1 = nn.Sequential(
            nn.Conv2d(384,n_classes,3,stride = 2),
            nn.BatchNorm2d(n_classes),
            nn.ReLU(inplace  = True)
        )
        self.classifier = nn.Sequential(
            nn.Linear(7290,450),
            nn.ReLU(inplace = True),
            nn.Dropout(),
            nn.Linear(450,10)
        )
        self._weight_init()
    def forward(self,x):
        x_residual = self.residual_pre_conv(x)
        x_transient = self.transient_pre_conv(x)
        x_residual,x_transient = self.rir1((x_residual,x_transient))
        x_residual,x_transient = self.rir2((x_residual,x_transient))
        x_residual,x_transient = self.rir3((x_residual,x_transient))
        x_residual,x_transient = self.rir4((x_residual,x_transient))
        x_residual,x_transient = self.rir5((x_residual,x_transient))
        x_residual,x_transient = self.rir6((x_residual,x_transient))
        x_residual,x_transient = self.rir7((x_residual,x_transient))
        x_residual,x_transient = self.rir8((x_residual,x_transient))
        h = torch.cat([x_residual,x_transient],1)
        h = self.conv1(h)
        h = h.view(h.size(0),-1)
        h = self.classifier(h)
        return h
    def _weight_init(self):
        for m in self.modules():
            if isinstance(m,nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:  # Check if the layer has biases before filling them
                    m.bias.data.fill_(0.01)

In [None]:
def resnet_in_resnet():
    return ResnetInResnet()

In [None]:
blk4 = resnet_in_resnet()
blk4(torch.randn(1, 3, 224, 224))

tensor([[-0.3725, -0.1105, -0.3854, -0.2425, -0.4050,  0.2341,  0.2290, -0.4518,
          0.1459,  0.3158]], grad_fn=<AddmmBackward0>)