In [5]:
import math
import torch
import torch.nn as nn

Network Architecture

In [3]:

# Initial part of network, whihc is fixed
class initial_part_network(nn.Module):
    def __init__(self):
        super(initial_part_network, self).__init__()
        self.conv = nn.Conv2d(in_channels=3, out_channels=8, kernel_size=5, stride=1, padding=2)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x = self.conv(x)
        x = self.maxpool(x)
        return x

# Custom part of network, which can be added many time as per requirement(0,1,2,3,...times)
class custom_part(nn.Module):
    def __init__(self, in_channels):
        super(custom_part, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=2 * in_channels, kernel_size=7, stride=1, padding=3)
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(in_channels=2 * in_channels, out_channels=4 * in_channels, kernel_size=5, stride=1, padding=2)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.maxpool2(x)
        return x

# Final network is which formed with intial part network first then attached custom blocks as per given condition
class final_network(nn.Module):
    def __init__(self, num_custom_block=1):
        super(final_network, self).__init__()
        self.initial_part = initial_part_network()
        
        # Stacking custom blocks based on given num custom block value
        custom_blocks = []
        in_channels = 8
        for _ in range(num_custom_block):
            custom_blocks.append(custom_part(in_channels))
            in_channels *= 4  
        self.custom_blocks = nn.Sequential(*custom_blocks)
        
        # Last part of network
        self.remaining_network = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_channels, 3)  # 3 is arbitrary number, it modify as per requirement
        )

    def forward(self, x):
        x = self.initial_part(x)
        x = self.custom_blocks(x)
        x = self.remaining_network(x)
        return x


Model (for demonstration, usinng 3 custom blocks)

In [4]:
model = final_network(num_custom_block=3)  
print(model)


final_network(
  (initial_part): initial_part_network(
    (conv): Conv2d(3, 8, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (custom_blocks): Sequential(
    (0): custom_part(
      (conv1): Conv2d(8, 16, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
      (maxpool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (conv2): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (maxpool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (1): custom_part(
      (conv1): Conv2d(32, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
      (maxpool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (conv2): Conv2d(64, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (maxpool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False

Receptive Field Calculation

In [None]:
# Calculating output-input parameters
def outFromIn(conv_params, layer_in):
    n_in_h, n_in_w, j_in, r_in, (start_in_h, start_in_w) = layer_in
    k, s, p = conv_params

    # Calculate output dimensions for height and width
    n_out_h = math.floor((n_in_h - k + 2 * p) / s) + 1
    n_out_w = math.floor((n_in_w - k + 2 * p) / s) + 1

    # Output jump and receptive field calculation
    j_out = j_in * s
    r_out = r_in + (k - 1) * (j_in)
  
    return (n_out_h, n_out_w, j_out, r_out)


In [None]:
# Function to print layer information
def printLayer(layer, layer_name):
    print(f"{layer_name}:")
    print(f"\t n (H, W): ({layer[0]}, {layer[1]})")
    print(f"\t jump: {layer[2]}")
    print(f"\t Receptive fied: {layer[3]}")
