# Lecture 6: Batch Normalization and Residual Streams

In this lecture, we will discuss two important techniques that have been shown to be very effective in training deep neural networks: Batch Normalization and Residual Streams. We will discuss both of these techniques in detail and show how they can be used to improve the performance of deep neural networks.

### Importing libraries

In [1]:
import os
import matplotlib.pyplot as plt
import itertools
from dataclasses import dataclass
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F
from src.utils import load_text, set_seed, configure_device

### Configuration

In [2]:
@dataclass
class MLPConfig:
    root_dir: str = os.getcwd() + "/../../"
    dataset_path: str = "data/names.txt"
    device: torch.device = torch.device('cpu')  # Automatic device configuration

    # Tokenizer
    vocab_size: int = 0  # Set later

    # Model
    context_size: int = 3
    d_embed: int = 16
    d_hidden: int = 256

    # Training
    val_size: float = 0.1
    batch_size: int = 32
    max_steps: int = 1000
    lr: float = 0.01
    val_interval: int = 100
    log_interval: int = 100

    seed: int = 101

## Reproducibility

In [4]:
set_seed(MLPConfig.seed)

Random seed set to 101


## Device

In [5]:
MLPConfig.device = configure_device()

Running on mps


## Dataset

In [6]:
# Load text and split by lines
names = load_text(MLPConfig.root_dir + MLPConfig.dataset_path).splitlines()

Loaded text data from /Users/pathfinder/Documents/GitHub/LLM101/notebooks/Lectures/../../data/names.txt (length: 228145 characters).


## Tokenizer

In [7]:
chars = [chr(i) for i in range(97, 123)]  # all alphabet characters
chars.insert(0, ".")  # Add special token
MLPConfig.vocab_size = len(chars)
str2idx = {char: idx for idx, char in enumerate(chars)}
idx2str = {idx: char for char, idx in str2idx.items()}

## Preprocessing

In [8]:
# Train-Val Split
train_names, val_names = train_test_split(names, test_size=MLPConfig.val_size, random_state=MLPConfig.seed)

In [9]:
class NamesDataset(Dataset):
    def __init__(self, _names, context_size):
        self.inputs, self.targets = [], []

        for name in _names:
            context = [0] * context_size

            for char in name + ".":
                idx = str2idx[char]
                self.inputs.append(context)
                self.targets.append(idx)
                context = context[1:] + [idx]

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        input_ids = torch.tensor(self.inputs[idx])
        target_id = torch.tensor(self.targets[idx])
        return input_ids, target_id

train_dataset = NamesDataset(train_names, MLPConfig.context_size)
val_dataset = NamesDataset(val_names, MLPConfig.context_size)
train_loader = DataLoader(train_dataset, batch_size=MLPConfig.batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=MLPConfig.batch_size, shuffle=False)

## Model

In [10]:
class MLP(nn.Module):
    def __init__(self, vocab_size, context_size, d_embed, d_hidden):
        super().__init__()
        self.C = nn.Parameter(torch.randn(vocab_size, d_embed))
        self.W1 = nn.Parameter(torch.randn(context_size * d_embed, d_hidden))
        self.b1 = nn.Parameter(torch.randn(d_hidden))
        self.W2 = nn.Parameter(torch.randn(d_hidden, vocab_size))
        self.b2 = nn.Parameter(torch.randn(vocab_size))

    def forward(self, x):  # x: (batch_size, context_size)
        # Embedding
        x_embed = self.C[x]  # (batch_size, context_size, d_embed)
        x = x_embed.view(x.size(0), -1)  # (batch_size, context_size * d_embed)

        # Hidden layer
        h = F.tanh(x @ self.W1 + self.b1)  # (batch_size, d_hidden)

        # Output layer
        logits = torch.matmul(h, self.W2) + self.b2  # (batch_size, vocab_size)
        return logits

In [11]:
# Initialize the model
mlp = MLP(MLPConfig.vocab_size, MLPConfig.context_size, MLPConfig.d_embed, MLPConfig.d_hidden)
mlp.to(MLPConfig.device) # Move the model to the device
print(mlp)
print("Number of parameters:", sum(p.numel() for p in mlp.parameters()))

MLP()
Number of parameters: 19915


In [12]:
# Training
def train(
        model: nn.Module,
        train_loader: DataLoader,
        val_loader: DataLoader,
        max_steps: int,
        lr: float,
        val_interval: int,
        log_interval: int,
        device: torch.device
):
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    steps = []
    train_losses = []
    val_losses = []
    train_iter = itertools.cycle(train_loader)  # Infinite dataloader

    for step in range(1, max_steps + 1):
        model.train()
        train_inputs, train_targets = next(train_iter)
        train_inputs, train_targets = train_inputs.to(device), train_targets.to(device)
        optimizer.zero_grad()
        logits = model(train_inputs)
        loss = F.cross_entropy(logits, train_targets)
        loss.backward()
        optimizer.step()

        if step % val_interval == 0:
            model.eval()
            val_loss = 0.0
            total_samples = 0
            with torch.no_grad():
                for val_inputs, val_targets in val_loader:
                    val_inputs, val_targets = val_inputs.to(device), val_targets.to(device)
                    val_logits = model(val_inputs)
                    batch_loss = F.cross_entropy(val_logits, val_targets)
                    val_loss += batch_loss.item() * val_inputs.size(0)
                    total_samples += val_inputs.size(0)
            val_loss /= total_samples
            val_losses.append(val_loss)

        if step % log_interval == 0:
            steps.append(step)
            train_losses.append(loss.item())
            if step == 1:
                print(f"Initial Train Loss: {loss.item():.4f}")
            print(f"Step {step}/{max_steps} | Train Loss: {running_loss / step:.4f} | Validation Loss: {val_loss:.4f}")

        step += 1

    final_val_logits = model(val_inputs.to(device))
    final_val_loss = F.cross_entropy(final_val_logits, val_targets.to(device)).item()
    print(f"Final Validation Loss: {final_val_loss:.4f}")
    plt.figure()
    plt.plot(steps, train_losses, label="Train")
    plt.plot(steps[::val_interval], val_losses, label="Validation")
    plt.xlabel("Steps")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()

In [13]:
train(
    model=mlp,
    train_loader=train_loader,
    val_loader=val_loader,
    max_steps=MLPConfig.max_steps,
    lr=MLPConfig.lr,
    val_interval=MLPConfig.val_interval,
    log_interval=MLPConfig.log_interval,
    device=MLPConfig.device
)

Training: 100%|██████████| 1000/1000 [00:15<00:00, 65.26it/s, loss=13.9290]


0,1
Train Loss,█▆▅▄▃▃▂▂▁▁
Val Loss,█▆▄▃▃▂▂▂▁▁

0,1
Train Loss,13.92902
Val Loss,9.56011


### Part 1: Batch Normalization

Recall what we do for initializing the weights of a neural network.

1. We don't want the logits to be too big because the softmax might explode.
- Initialize the final layer with small values.
2. We don't want the gradients to vanish.
- Initialize the inner layer with small values.
- Use different activation functions.

Eventually, what we want is to preserve the same gaussian distribution of the activations.

**Why not just normalize the activations?** -> Key idea to [Batch Normalization](https://arxiv.org/pdf/1502.03167)

In [14]:
# MLP Model with Batch Normalization
class MLPv2(nn.Module):
    def __init__(self, vocab_size, context_size, d_embed, d_hidden):
        super().__init__()
        self.C = nn.Parameter(torch.randn(vocab_size, d_embed))
        self.W1 = nn.Parameter(torch.randn(context_size * d_embed, d_hidden))
        self.b1 = nn.Parameter(torch.randn(d_hidden))
        self.W2 = nn.Parameter(torch.randn(d_hidden, vocab_size))
        self.b2 = nn.Parameter(torch.randn(vocab_size))

        # Batch Normalization
        self.gamma = nn.Parameter(torch.ones(1, d_hidden))
        self.beta = nn.Parameter(torch.zeros(1, d_hidden))

    def forward(self, x):  # x: (batch_size, context_size)
        # Embedding
        x_embed = self.C[x]  # (batch_size, context_size, d_embed)
        x = x_embed.view(x.size(0), -1)  # (batch_size, context_size * d_embed)

        # Hidden layer
        x = x @ self.W1 + self.b1  # (batch_size, d_hidden)

        # Batch Normalization
        ################################################################################
        # TODO:                                                                        #
        # Implement the MLP model.                                                     #
        ################################################################################
        # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
        mean = x.mean(dim=0, keepdim=True)  # (1, d_hidden)
        std = x.std(dim=0, keepdim=True)  # (1, d_hidden)
        #x = self.gamma * (x - mean) / (std + 1e-6) + self.beta  # (batch_size, d_hidden)
        # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

        h = F.tanh(x)  # (batch_size, d_hidden)

        # Output layer
        logits = torch.matmul(h, self.W2) + self.b2  # (batch_size, vocab_size)
        return logits

In [15]:
# Initialize the model
mlpV2 = MLPv2(MLPConfig.vocab_size, MLPConfig.context_size, MLPConfig.d_embed, MLPConfig.d_hidden)
mlpV2.to(MLPConfig.device) # Move the model to the device
print(mlpV2)
print("Number of parameters:", sum(p.numel() for p in mlpV2.parameters()))

MLPv2()
Number of parameters: 20427


In [16]:
train(
    model=mlpV2,
    train_loader=train_loader,
    val_loader=val_loader,
    max_steps=MLPConfig.max_steps,
    lr=MLPConfig.lr,
    val_interval=MLPConfig.val_interval,
    log_interval=MLPConfig.log_interval,
    device=MLPConfig.device
)

Training:   0%|          | 0/1000 [00:00<?, ?it/s]

Error: You must call wandb.init() before wandb.watch()

In [None]:
ASDF

In [None]:
# MLP Model with more layers
class MLPv3(nn.Module):
    def __init__(self, vocab_size, context_size, d_embed, d_hidden, batch_norm=False):
        super().__init__()
        self.C = nn.Parameter(torch.randn(vocab_size, d_embed))
        self.W1 = nn.Parameter(torch.randn(context_size * d_embed, d_hidden))
        self.b1 = nn.Parameter(torch.randn(d_hidden))
        self.W2 = nn.Parameter(torch.randn(d_hidden, d_hidden))
        self.b2 = nn.Parameter(torch.randn(d_hidden))
        self.W3 = nn.Parameter(torch.randn(d_hidden, vocab_size))
        self.b3 = nn.Parameter(torch.randn(vocab_size))
        self.W4 = nn.Parameter(torch.randn(vocab_size, vocab_size))
        self.b4 = nn.Parameter(torch.randn(vocab_size))

        # Batch Normalization
        self.batch_norm = batch_norm
        self.gamma1 = nn.Parameter(torch.ones(1, d_hidden))
        self.beta1 = nn.Parameter(torch.zeros(1, d_hidden))
        self.gamma2 = nn.Parameter(torch.ones(1, d_hidden))
        self.beta2 = nn.Parameter(torch.zeros(1, d_hidden))

    def forward(self, x):  # x: (batch_size, context_size)
        # Embedding
        x_embed = self.C[x]  # (batch_size, context_size, d_embed)
        x = x_embed.view(x.size(0), -1)  # (batch_size, context_size * d_embed)

        # Hidden layer 1
        x = x @ self.W1 + self.b1  # (batch_size, d_hidden)

        # Batch Normalization 1
        ################################################################################

## Inference