<a href="https://colab.research.google.com/github/ShantanuKadam3115/MachineLearningBasics/blob/ML_implementations/ResidualBlock.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
import torch
import torch.nn as nn

class VerboseResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(VerboseResBlock, self).__init__()

        print(f"\n--- Initializing ResBlock: In={in_channels}, Out={out_channels}, Stride={stride} ---")

        # Main Path
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        # Shortcut Path
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            print("   -> Creating 1x1 Conv Shortcut (Shape Mismatch detected)")
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
        else:
            print("   -> Identity Shortcut (Direct Connection)")

    def forward(self, x):
        print(f"\n[Forward Pass] Input Shape: {list(x.shape)}")

        # 1. Main Path Step 1
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        print(f"   Main Path (Conv1+BN+ReLU): {list(out.shape)}")

        # 2. Main Path Step 2
        out = self.conv2(out)
        out = self.bn2(out)
        print(f"   Main Path (Conv2+BN):      {list(out.shape)}")

        # 3. Shortcut Path
        shortcut_val = self.shortcut(x)
        print(f"   Shortcut Path result:      {list(shortcut_val.shape)}")
        # print("out", out,"input", shortcut_val)
        # 4. Addition
        # CRITICAL: These two shapes MUST match
        out += shortcut_val
        print(f"   Add (Main + Shortcut):     {list(out.shape)}")
        # print("after",out)
        # 5. Final Activation
        out = self.relu(out)
        print(f"   Final ReLU Output:         {list(out.shape)}")

        return out

# --- TEST CASE 1: The Identity Block (Easy) ---
# Same channels (64->64), Same size (32x32)
print("\n=== TEST CASE 1: Identity Block ===")
x1 = torch.randn(1, 64, 32, 32)
block1 = VerboseResBlock(in_channels=64, out_channels=64, stride=1)
y1 = block1(x1)

# --- TEST CASE 2: The Downsample Block (Hard) ---
# Double channels (64->128), Half size (32 -> 16)
print("\n=== TEST CASE 2: Downsample Block ===")
x2 = torch.randn(1, 64, 32, 32)
block2 = VerboseResBlock(in_channels=64, out_channels=128, stride=2)
y2 = block2.forward(x2)


=== TEST CASE 1: Identity Block ===

--- Initializing ResBlock: In=64, Out=64, Stride=1 ---
   -> Identity Shortcut (Direct Connection)

[Forward Pass] Input Shape: [1, 64, 32, 32]
   Main Path (Conv1+BN+ReLU): [1, 64, 32, 32]
   Main Path (Conv2+BN):      [1, 64, 32, 32]
   Shortcut Path result:      [1, 64, 32, 32]
   Add (Main + Shortcut):     [1, 64, 32, 32]
   Final ReLU Output:         [1, 64, 32, 32]

=== TEST CASE 2: Downsample Block ===

--- Initializing ResBlock: In=64, Out=128, Stride=2 ---
   -> Creating 1x1 Conv Shortcut (Shape Mismatch detected)

[Forward Pass] Input Shape: [1, 64, 32, 32]
   Main Path (Conv1+BN+ReLU): [1, 128, 16, 16]
   Main Path (Conv2+BN):      [1, 128, 16, 16]
   Shortcut Path result:      [1, 128, 16, 16]
   Add (Main + Shortcut):     [1, 128, 16, 16]
   Final ReLU Output:         [1, 128, 16, 16]
