# Composing Models with `nn.Module`

Data pipelines are ready—now we architect models. This notebook takes a top-down view: start with design principles, then dive into reusable components that plug into the broader workflow.

_Environment note:_ Recommendations mirror best practices through October 2024.

## Learning Objectives

- Implement custom `nn.Module` classes and combine them with `nn.Sequential`.
- Apply normalization, activation, and dropout layers thoughtfully.
- Introduce residual connections to stabilize deeper networks.
- Prepare more advanced modules (for CNNs, transformers, and MoE models) without rewriting boilerplate.

## Modular Design Principles

- **Encoders** convert raw inputs into speaker representations.
- **Heads** map representations to task-specific outputs.
- **Utilities** such as normalization and residual pathways keep activations healthy.

Organizing models into these roles makes debugging easier and encourages reuse.

In [None]:
import torch
import torch.nn as nn

class SimpleMLP(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
        )

    def forward(self, x):
        return self.net(x)

model = SimpleMLP(4, 32, 1)
dummy = torch.randn(8, 4)
print(model(dummy).shape)


### Initialization & Stability

- Use Kaiming initialization for ReLU-heavy stacks, Xavier for near-linear transitions.
- Normalization layers help manage activation scales.
- Residual connections (`x + f(x)`) make optimization of deep networks tractable.

In [None]:
def kaiming_init(module: nn.Module):
    if isinstance(module, nn.Linear):
        nn.init.kaiming_uniform_(module.weight, nonlinearity="relu")
        nn.init.zeros_(module.bias)

model.apply(kaiming_init);


In [None]:
class ResidualMLPBlock(nn.Module):
    def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.1):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.ff = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
        )

    def forward(self, x):
        return x + self.ff(self.norm(x))

block = ResidualMLPBlock(32, 64)
print(block(torch.randn(4, 32)).shape)


## Mini Task – Linear-Norm-Activation Block

Create a module that applies Linear → BatchNorm1d → GELU and returns both the activated output and the pre-activation values. This pattern reappears in attention feed-forward layers.

Work through the starter cell before expanding the hidden solution.

In [None]:
class LinearNormActivation(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        # TODO: define linear, batchnorm, activation layers

    def forward(self, x):
        # TODO: return tuple (post_activation, pre_activation)
        raise NotImplementedError


In [None]:
class LinearNormActivation(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.linear = nn.Linear(in_dim, out_dim)
        self.norm = nn.BatchNorm1d(out_dim)
        self.act = nn.GELU()

    def forward(self, x):
        pre = self.linear(x)
        normed = self.norm(pre)
        out = self.act(normed)
        return out, pre

block = LinearNormActivation(4, 8)
out, pre = block(torch.randn(3, 4))
print(out.shape, pre.shape)


## From Blocks to Networks

- Stack residual MLP blocks to prototype transformer feed-forward sublayers.
- Combine CNN backbones with task-specific heads using the same module patterns.
- Keep modules decoupled so you can swap them during experimentation without rewriting glue code.

## Comprehensive Exercise – Configurable Feedforward Network

Create an MLP that accepts a list of hidden dimensions, optional dropout, and residual connections when consecutive layer widths match. Provide a `forward_with_intermediates` method for debugging.

In [None]:
class ConfigurableMLP(nn.Module):
    def __init__(self, input_dim, hidden_dims, output_dim, dropout=0.0):
        super().__init__()
        # TODO: construct layers and residual bookkeeping

    def forward(self, x):
        raise NotImplementedError

    def forward_with_intermediates(self, x):
        # TODO: return (output, activations)
        raise NotImplementedError


In [None]:
class ConfigurableMLP(nn.Module):
    def __init__(self, input_dim, hidden_dims, output_dim, dropout=0.0):
        super().__init__()
        dims = [input_dim] + list(hidden_dims) + [output_dim]
        layers = []
        self.residual_flags = []
        for idx in range(len(dims) - 1):
            in_dim, out_dim = dims[idx], dims[idx + 1]
            layers.append(nn.Linear(in_dim, out_dim))
            if idx < len(dims) - 2:
                layers.append(nn.LayerNorm(out_dim))
                layers.append(nn.GELU())
                if dropout > 0:
                    layers.append(nn.Dropout(dropout))
                self.residual_flags.append(in_dim == out_dim)
        self.residual_flags.append(False)
        self.layers = nn.ModuleList(layers)

    def forward(self, x):
        out, _ = self.forward_with_intermediates(x)
        return out

    def forward_with_intermediates(self, x):
        activations = []
        residual = x
        idx = 0
        flag_idx = 0
        while idx < len(self.layers):
            layer = self.layers[idx]
            x = layer(x)
            idx += 1
            if idx < len(self.layers) and isinstance(self.layers[idx], nn.LayerNorm):
                norm = self.layers[idx]
                act = self.layers[idx + 1]
                x = act(norm(x))
                idx += 2
                if idx < len(self.layers) and isinstance(self.layers[idx], nn.Dropout):
                    x = self.layers[idx](x)
                    idx += 1
                if self.residual_flags[flag_idx]:
                    x = x + residual
                residual = x
                flag_idx += 1
            activations.append(x)
        return x, activations

mlp = ConfigurableMLP(10, [32, 32, 16], 1, dropout=0.1)
out, acts = mlp.forward_with_intermediates(torch.randn(4, 10))
print(out.shape, len(acts))


## Further Reading

- PyTorch `nn` Module Reference: https://pytorch.org/docs/stable/nn.html
- He et al. (2015) – Deep Residual Learning for Image Recognition
- FastAI lessons on reusable model blocks