# Lab 2.1.1: Custom Module Lab - SOLUTIONS

This notebook contains complete solutions for the exercises in the Custom Module Lab.

---

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Type, Union, List

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

---

## Exercise 1 Solution: BasicBlock with Dropout

Adding dropout regularization to the BasicBlock can help prevent overfitting, especially on smaller datasets.

In [None]:
class BasicBlockWithDropout(nn.Module):
    """
    BasicBlock with optional dropout regularization.
    
    The dropout is applied after the first ReLU, which is a common position
    for regularization in residual networks.
    
    Args:
        in_channels: Number of input channels
        out_channels: Number of output channels
        stride: Stride for the first convolution (default: 1)
        dropout: Dropout probability (default: 0.0, meaning no dropout)
    """
    
    expansion = 1
    
    def __init__(
        self, 
        in_channels: int, 
        out_channels: int, 
        stride: int = 1,
        dropout: float = 0.0
    ):
        super().__init__()
        
        # First conv
        self.conv1 = nn.Conv2d(
            in_channels, out_channels,
            kernel_size=3, stride=stride, padding=1, bias=False
        )
        self.bn1 = nn.BatchNorm2d(out_channels)
        
        # Dropout layer - uses Dropout2d for spatial dropout
        # Dropout2d drops entire channels, which is more suitable for CNNs
        self.dropout = nn.Dropout2d(p=dropout) if dropout > 0 else nn.Identity()
        
        # Second conv
        self.conv2 = nn.Conv2d(
            out_channels, out_channels,
            kernel_size=3, stride=1, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        # Shortcut
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = x
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = F.relu(out, inplace=True)
        
        # Apply dropout after first ReLU
        out = self.dropout(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        
        out += self.shortcut(identity)
        out = F.relu(out, inplace=True)
        
        return out


# Test the implementation
print("=== Testing BasicBlockWithDropout ===")

# Test with dropout
block = BasicBlockWithDropout(64, 64, dropout=0.1)
print(f"Block structure:\n{block}")

# Test forward pass
x = torch.randn(2, 64, 32, 32)
block.train()  # Dropout active
y_train = block(x)
print(f"\nTraining mode - Output shape: {y_train.shape}")

block.eval()  # Dropout inactive
with torch.no_grad():
    y_eval = block(x)
print(f"Eval mode - Output shape: {y_eval.shape}")

# Verify dropout is being applied (outputs should differ in train mode)
block.train()
y1 = block(x)
y2 = block(x)
print(f"\nOutputs differ in train mode: {not torch.allclose(y1, y2)}")

### Why Use Dropout2d Instead of Dropout?

- **Dropout**: Drops individual values randomly. For images, this can create "holes" in feature maps.
- **Dropout2d**: Drops entire channels. This is better for CNNs because neighboring pixels are highly correlated.

By dropping whole channels, we force the network to learn redundant representations across channels, improving generalization.

---

## Challenge Solution: ResNet with SE Blocks

Squeeze-and-Excitation (SE) blocks add channel attention to improve accuracy.

In [None]:
class SEBlock(nn.Module):
    """
    Squeeze-and-Excitation Block.
    
    This block learns to weight channels based on their global importance.
    
    Architecture:
        1. Squeeze: Global average pooling to get channel-wise statistics
        2. Excitation: Two FC layers to learn channel weights
        3. Scale: Multiply original features by learned weights
    
    Args:
        channels: Number of input/output channels
        reduction: Reduction ratio for the bottleneck FC layer (default: 16)
    """
    
    def __init__(self, channels: int, reduction: int = 16):
        super().__init__()
        
        # Ensure we have at least 1 channel in the bottleneck
        reduced_channels = max(channels // reduction, 1)
        
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Linear(channels, reduced_channels, bias=False)
        self.fc2 = nn.Linear(reduced_channels, channels, bias=False)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size, channels, _, _ = x.shape
        
        # Squeeze: Global average pooling
        y = self.global_pool(x).view(batch_size, channels)
        
        # Excitation: FC -> ReLU -> FC -> Sigmoid
        y = F.relu(self.fc1(y), inplace=True)
        y = torch.sigmoid(self.fc2(y))
        
        # Reshape for broadcasting
        y = y.view(batch_size, channels, 1, 1)
        
        # Scale: Multiply input by channel weights
        return x * y


class SEBasicBlock(nn.Module):
    """
    BasicBlock with Squeeze-and-Excitation attention.
    
    The SE block is added after the second conv, before the residual addition.
    """
    
    expansion = 1
    
    def __init__(
        self, 
        in_channels: int, 
        out_channels: int, 
        stride: int = 1,
        reduction: int = 16
    ):
        super().__init__()
        
        self.conv1 = nn.Conv2d(
            in_channels, out_channels,
            kernel_size=3, stride=stride, padding=1, bias=False
        )
        self.bn1 = nn.BatchNorm2d(out_channels)
        
        self.conv2 = nn.Conv2d(
            out_channels, out_channels,
            kernel_size=3, stride=1, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        # SE attention
        self.se = SEBlock(out_channels, reduction)
        
        # Shortcut
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = x
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = F.relu(out, inplace=True)
        
        out = self.conv2(out)
        out = self.bn2(out)
        
        # Apply SE attention
        out = self.se(out)
        
        out += self.shortcut(identity)
        out = F.relu(out, inplace=True)
        
        return out


# Test SE Block
print("=== Testing SE Block ===")
se = SEBlock(64)
x = torch.randn(2, 64, 32, 32)
y = se(x)
print(f"SE Block: {x.shape} -> {y.shape}")
print(f"SE Block parameters: {sum(p.numel() for p in se.parameters()):,}")

# Test SE BasicBlock
print("\n=== Testing SE BasicBlock ===")
block = SEBasicBlock(64, 128, stride=2)
x = torch.randn(2, 64, 32, 32)
y = block(x)
print(f"SE BasicBlock: {x.shape} -> {y.shape}")
print(f"SE BasicBlock parameters: {sum(p.numel() for p in block.parameters()):,}")

In [None]:
class SEResNet(nn.Module):
    """
    SE-ResNet: ResNet with Squeeze-and-Excitation blocks.
    """
    
    def __init__(
        self,
        block: Type[SEBasicBlock],
        layers: List[int],
        num_classes: int = 10,
        reduction: int = 16
    ):
        super().__init__()
        
        self.in_channels = 64
        self.reduction = reduction
        
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        
        self.layer1 = self._make_layer(block, 64, layers[0], stride=1)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)
        
        self._initialize_weights()
    
    def _make_layer(self, block, channels, num_blocks, stride):
        layers = []
        layers.append(block(self.in_channels, channels, stride, self.reduction))
        self.in_channels = channels * block.expansion
        
        for _ in range(1, num_blocks):
            layers.append(block(self.in_channels, channels, 1, self.reduction))
        
        return nn.Sequential(*layers)
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x, inplace=True)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        
        return x


def se_resnet18(num_classes: int = 10, reduction: int = 16) -> SEResNet:
    """Create SE-ResNet-18 model."""
    return SEResNet(SEBasicBlock, [2, 2, 2, 2], num_classes, reduction)


# Compare parameters
print("=== Model Comparison ===")

# Standard ResNet-18
from typing import Union

class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_ch, out_ch, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, 1, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_ch)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_ch != out_ch:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 1, stride, bias=False),
                nn.BatchNorm2d(out_ch)
            )
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)), inplace=True)
        out = self.bn2(self.conv2(out))
        return F.relu(out + self.shortcut(x), inplace=True)

class ResNet18(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.in_ch = 64
        self.conv1 = nn.Conv2d(3, 64, 3, 1, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(64, 2, 1)
        self.layer2 = self._make_layer(128, 2, 2)
        self.layer3 = self._make_layer(256, 2, 2)
        self.layer4 = self._make_layer(512, 2, 2)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(512, num_classes)
    
    def _make_layer(self, ch, blocks, stride):
        layers = [BasicBlock(self.in_ch, ch, stride)]
        self.in_ch = ch
        for _ in range(1, blocks):
            layers.append(BasicBlock(self.in_ch, ch))
        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)), inplace=True)
        x = self.layer4(self.layer3(self.layer2(self.layer1(x))))
        x = self.avgpool(x)
        return self.fc(x.flatten(1))

resnet = ResNet18(10)
se_resnet = se_resnet18(10)

resnet_params = sum(p.numel() for p in resnet.parameters())
se_resnet_params = sum(p.numel() for p in se_resnet.parameters())

print(f"ResNet-18: {resnet_params:,} parameters")
print(f"SE-ResNet-18: {se_resnet_params:,} parameters")
print(f"SE overhead: {(se_resnet_params - resnet_params):,} parameters ({100*(se_resnet_params/resnet_params - 1):.2f}%)")

### Key Insights

1. **SE blocks add minimal overhead** (~2% more parameters)
2. **But provide significant accuracy gains** (~1% on ImageNet)
3. **The attention mechanism** learns which channels are most important for the task
4. **Position matters** - SE is applied after convs but before residual addition

---

## Alternative Implementation: Using nn.Sequential

For simpler blocks, you can use `nn.Sequential` for cleaner code:

In [None]:
def make_basic_block(in_channels: int, out_channels: int, stride: int = 1):
    """
    Alternative BasicBlock implementation using nn.Sequential.
    
    This is more concise but less flexible than the class-based approach.
    """
    
    class BasicBlockSequential(nn.Module):
        expansion = 1
        
        def __init__(self):
            super().__init__()
            
            self.main = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 3, stride, 1, bias=False),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
                nn.BatchNorm2d(out_channels),
            )
            
            self.shortcut = nn.Identity()
            if stride != 1 or in_channels != out_channels:
                self.shortcut = nn.Sequential(
                    nn.Conv2d(in_channels, out_channels, 1, stride, bias=False),
                    nn.BatchNorm2d(out_channels),
                )
        
        def forward(self, x):
            return F.relu(self.main(x) + self.shortcut(x), inplace=True)
    
    return BasicBlockSequential()


# Test
block = make_basic_block(64, 128, stride=2)
x = torch.randn(2, 64, 32, 32)
y = block(x)
print(f"Sequential BasicBlock: {x.shape} -> {y.shape}")

---

## Performance Comparison on CIFAR-10

Here's a summary of expected performance for different models:

| Model | Parameters | CIFAR-10 Accuracy (100 epochs) |
|-------|------------|--------------------------------|
| ResNet-18 | 11.2M | ~93-94% |
| SE-ResNet-18 | 11.4M | ~94-95% |
| ResNet-34 | 21.3M | ~94-95% |
| ResNet-50 | 23.5M | ~94-95% |

**Note:** The small image size (32×32) of CIFAR-10 means deeper networks don't provide as much benefit as they would on ImageNet (224×224).

In [None]:
# Cleanup
import gc
torch.cuda.empty_cache()
gc.collect()
print("Cleanup complete!")