In [1]:
import torch 
from torch import nn 
import torch.nn.functional as F

In [3]:
from platform import python_version 
python_version()

'3.9.6'

In [22]:
class SkipBlock(nn.Module):
    def __init__(self, in_channels, out_channels, pool = True):
        super(SkipBlock, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels = out_channels, out_channels = out_channels, kernel_size=3, padding=1)
        
        if pool == True:
            self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        else:
            self.pool = None
            
        
        self.skip_conv = nn.Sequential()
        
        if in_channels != out_channels:
            self.skip_conv = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size=3, padding=1)
            
    
    def forward(self, x):
        out = self.conv1(x)
        add_out = self.skip_conv(x)
        
        out = F.relu(out)
        
        out = self.conv2(out)
        
        out += add_out
        
        if self.pool:
            out = self.pool(out)
        
        out = F.relu(out)

        return out        

In [23]:
tensor = torch.randn(3, 224, 224)
tensor.shape

torch.Size([3, 224, 224])

In [24]:
block_1 = SkipBlock(in_channels = 3, out_channels = 32, pool = False)
block_2 = SkipBlock(in_channels = 32, out_channels = 64, pool = True)

In [25]:
out = block_1(tensor)
out = block_2(out)
out.shape

torch.Size([64, 112, 112])

In [47]:
class LargeSkipBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(LargeSkipBlock, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels = out_channels, out_channels = out_channels, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(in_channels = out_channels, out_channels = out_channels, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(in_channels = out_channels, out_channels = out_channels, kernel_size=3, padding=1)
        
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
       
        self.add_conv = nn.Sequential()
        
        if in_channels != out_channels:
            self.add_conv = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size=3, padding=1)
            
    def forward(self, x):
        out = self.conv1(x)
        add_out = self.add_conv(x)
        
        out = F.relu(out)
        
        add_out = self.pool(add_out)
        out = self.pool(out)
        
        out = self.conv2(out)
        out = F.relu(out)
        
        out = self.conv3(out)
        out = F.relu(out)
        
        add_out = self.pool(add_out)
        out = self.pool(out)
        
        out = self.conv4(out)
        out += add_out
        
        return out
    

torch.Size([32, 56, 56])

In [48]:
tensor = torch.randn(3, 224, 224)

In [51]:
block = LargeSkipBlock(in_channels = 3, out_channels = 32)

In [52]:
out = block(tensor)
out.shape

torch.Size([32, 56, 56])