<a href="https://colab.research.google.com/github/BDH-teacher/Deep_Learning_Audit_code/blob/main/Normalization%2C_Optimization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Normalization + Optimization (RNN) Colab runnable demo
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# Reproducibility
def set_seed(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Toy text dataset (binary classification)
# - label depends on the LAST token (so learning is visible)
def build_vocab(vocab_size=200):
    vocab_to_int = {"<pad>": 0}
    for i in range(1, vocab_size):
        vocab_to_int[f"tok{i}"] = i
    return vocab_to_int

vocab_to_int = build_vocab(vocab_size=200)

def make_toy_sequences(n=800, seq_len=15, vocab_size=200, trigger_id=42, p_trigger=0.5, seed=0):
    rng = np.random.default_rng(seed)
    X = rng.integers(1, vocab_size, size=(n, seq_len), dtype=np.int64)
    y = np.zeros(n, dtype=np.int64)

    mask = rng.random(n) < p_trigger
    X[mask, -1] = trigger_id
    y[mask] = 1

    # make sure non-trigger samples don't accidentally end with trigger_id
    if (~mask).sum() > 0:
        last = rng.integers(1, vocab_size, size=((~mask).sum(),), dtype=np.int64)
        last = np.where(last == trigger_id, trigger_id + 1, last)
        X[~mask, -1] = last

    return X, y


X, y = make_toy_sequences(n=800, seq_len=15, vocab_size=len(vocab_to_int), seed=0)

# split: train/dev/test
n = len(X)
n_train = int(n * 0.7)
n_dev = int(n * 0.15)

X_train, y_train = X[:n_train], y[:n_train]
X_dev,   y_dev   = X[n_train:n_train + n_dev], y[n_train:n_train + n_dev]
X_test,  y_test  = X[n_train + n_dev:], y[n_train + n_dev:]

class TextDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.long)
        self.y = torch.tensor(y, dtype=torch.long)

    def __len__(self):
        return self.X.size(0)

    def __getitem__(self, idx):
        input_ids = self.X[idx]
        attention_masks = (input_ids != 0).long()
        segment_ids = torch.zeros_like(input_ids)
        labels = self.y[idx]
        return input_ids, attention_masks, segment_ids, labels

train_loader = DataLoader(TextDataset(X_train, y_train), batch_size=64, shuffle=True)
dev_loader   = DataLoader(TextDataset(X_dev, y_dev),     batch_size=128, shuffle=False)
test_loader  = DataLoader(TextDataset(X_test, y_test),   batch_size=128, shuffle=False)

In [2]:
# RNN + (BatchNorm or LayerNorm) before RNN forward

class RNNModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super(RNNModel, self).__init__()

        # Define Embedding
        self.embedding = nn.Embedding(len(vocab_to_int), input_size, padding_idx=vocab_to_int['<pad>'])

        # Define batch normalization
        self.bn = nn.BatchNorm1d(input_size)

        # Define layer normalization
        self.ln = nn.LayerNorm(input_size)

        # Define the RNN layer
        self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)

        # Define the fully connected layer to produce outputs
        self.fc = nn.Linear(hidden_size, output_size)

        # choose normalization mode: "bn" / "ln" / None
        self.norm_mode = "ln"

    def forward(self, x):
        embedding = self.embedding(x)

        ## Batch Normalization
        if self.norm_mode == "bn":
            embedding = embedding.permute(0, 2, 1)
            embedding = self.bn(embedding).permute(0, 2, 1)

        ## Layer Normalization
        if self.norm_mode == "ln":
            embedding = self.ln(embedding)

        # Forward propagate the RNN
        out, hidden = self.rnn(embedding)

        # Take the output from the last time step
        out = self.fc(out[:, -1, :])
        return out


# Train / Eval (Optimization + Monitoring logs)
def accuracy_percent(logits, labels):
    preds = logits.argmax(dim=1)
    return (preds == labels).float().mean().item() * 100.0

@torch.no_grad()
def evaluate(loader, model, criterion):
    model.eval()
    losses = []
    accs = []
    for data in loader:
        input_ids, attention_masks, segment_ids, labels = data
        input_ids = input_ids.to(device)
        labels = labels.to(device)

        out = model(input_ids)
        loss = criterion(out, labels)

        losses.append(loss.item())
        accs.append(accuracy_percent(out, labels))

    return float(np.mean(losses)), float(np.mean(accs))

In [3]:
def train_one_experiment(
    title: str,
    lr: float,
    epochs: int,
    norm_mode: str = "ln",
    max_grad_norm: float = 0.0,
    stop_if_loss_gt: float = 1e8):

    print("\n" + "=" * 80)
    print(f"{title} | lr={lr:g} | norm={norm_mode} | max_grad_norm={max_grad_norm}")
    print("=" * 80)

    set_seed(0)
    model = RNNModel(input_size=10, hidden_size=20, num_layers=2, output_size=2).to(device)
    model.norm_mode = norm_mode

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)

    for epoch in range(1, epochs + 1):
        model.train()
        train_losses = []
        train_accs = []

        # for data in loader: ...
        for data in train_loader:
            input_ids, attention_masks, segment_ids, labels = data
            input_ids = input_ids.to(device)
            labels = labels.to(device)

            out = model(input_ids)
            loss = criterion(out, labels)

            if (not torch.isfinite(loss)) or (loss.item() > stop_if_loss_gt):
                print(f"{epoch} epoch, train_loss = {loss.item():.6f}  -> (stopped early: non-finite or exploding)")
                return

            optimizer.zero_grad()
            loss.backward()

            # Gradient clipping
            if max_grad_norm > 0.:
                nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

            optimizer.step()

            train_losses.append(loss.item())
            train_accs.append(accuracy_percent(out, labels))

        train_loss = float(np.mean(train_losses))
        train_acc = float(np.mean(train_accs))
        eval_loss, eval_acc = evaluate(dev_loader, model, criterion)

        print(f"{epoch} epoch, train_loss = {train_loss:.6f}, train_acc:{train_acc:.1f}, eval_loss:{eval_loss:.6f}, eval_acc:{eval_acc:.1f}")

        if (not np.isfinite(train_loss)) or (train_loss > stop_if_loss_gt):
            print("(stopped early: average loss exploded)")
            return

    test_loss, test_acc = evaluate(test_loader, model, criterion)
    print(f"[FINAL TEST] loss:{test_loss:.6f}, acc:{test_acc:.1f}")

In [4]:
# Loss barely changing: lr too small
train_one_experiment(
    title="CASE 1) LR too small -> loss barely changes",
    lr=1e-5,
    epochs=35,
    norm_mode="ln",
    max_grad_norm=0.0,
)


CASE 1) LR too small -> loss barely changes | lr=1e-05 | norm=ln | max_grad_norm=0.0
1 epoch, train_loss = 0.810851, train_acc:38.7, eval_loss:0.798933, eval_acc:40.8
2 epoch, train_loss = 0.809332, train_acc:38.9, eval_loss:0.798894, eval_acc:40.8
3 epoch, train_loss = 0.809024, train_acc:39.1, eval_loss:0.798854, eval_acc:40.8
4 epoch, train_loss = 0.810841, train_acc:38.7, eval_loss:0.798814, eval_acc:40.8
5 epoch, train_loss = 0.810063, train_acc:38.7, eval_loss:0.798774, eval_acc:40.8
6 epoch, train_loss = 0.810356, train_acc:38.7, eval_loss:0.798735, eval_acc:40.8
7 epoch, train_loss = 0.808025, train_acc:39.2, eval_loss:0.798696, eval_acc:40.8
8 epoch, train_loss = 0.808075, train_acc:39.2, eval_loss:0.798656, eval_acc:40.8
9 epoch, train_loss = 0.809855, train_acc:38.8, eval_loss:0.798617, eval_acc:40.8
10 epoch, train_loss = 0.808708, train_acc:39.1, eval_loss:0.798577, eval_acc:40.8
11 epoch, train_loss = 0.808219, train_acc:39.2, eval_loss:0.798538, eval_acc:40.8
12 epoch, 

In [5]:
# LR too large -> loss explodes
train_one_experiment(
    title="CASE 2) LR too large -> loss explodes (stop before NaN)",
    lr=1e3,
    epochs=15,
    norm_mode="ln",
    max_grad_norm=0.0,
    stop_if_loss_gt=1e6,
)


CASE 2) LR too large -> loss explodes (stop before NaN) | lr=1000 | norm=ln | max_grad_norm=0.0
1 epoch, train_loss = 3492.913072, train_acc:51.7, eval_loss:7009.353027, eval_acc:50.8
2 epoch, train_loss = 4956.795519, train_acc:50.8, eval_loss:9199.971680, eval_acc:50.0
3 epoch, train_loss = 5993.430664, train_acc:48.1, eval_loss:4981.220215, eval_acc:50.0
4 epoch, train_loss = 5732.435791, train_acc:48.8, eval_loss:1449.972412, eval_acc:50.0
5 epoch, train_loss = 4616.103638, train_acc:53.2, eval_loss:2200.235352, eval_acc:50.8
6 epoch, train_loss = 4872.315111, train_acc:51.6, eval_loss:5641.900879, eval_acc:50.8
7 epoch, train_loss = 4441.189134, train_acc:52.4, eval_loss:10206.578125, eval_acc:50.0
8 epoch, train_loss = 5686.153388, train_acc:50.2, eval_loss:5731.222168, eval_acc:50.0
9 epoch, train_loss = 5239.835626, train_acc:51.0, eval_loss:1762.473755, eval_acc:50.0
10 epoch, train_loss = 6167.294908, train_acc:45.9, eval_loss:2676.535645, eval_acc:50.8
11 epoch, train_loss 

In [6]:
# Reasonable lr -> learning works
train_one_experiment(
    title="CASE 3) LR=1e-3 -> learning works",
    lr=1e-3,
    epochs=100,
    norm_mode="ln",
    max_grad_norm=0.0,
)


CASE 3) LR=1e-3 -> learning works | lr=0.001 | norm=ln | max_grad_norm=0.0
1 epoch, train_loss = 0.808923, train_acc:38.7, eval_loss:0.795032, eval_acc:40.0
2 epoch, train_loss = 0.803179, train_acc:38.6, eval_loss:0.791240, eval_acc:40.0
3 epoch, train_loss = 0.798767, train_acc:38.5, eval_loss:0.787550, eval_acc:40.0
4 epoch, train_loss = 0.796386, train_acc:38.0, eval_loss:0.783918, eval_acc:40.0
5 epoch, train_loss = 0.791721, train_acc:38.0, eval_loss:0.780402, eval_acc:40.0
6 epoch, train_loss = 0.788136, train_acc:37.7, eval_loss:0.776972, eval_acc:40.0
7 epoch, train_loss = 0.782601, train_acc:38.1, eval_loss:0.773693, eval_acc:40.0
8 epoch, train_loss = 0.779080, train_acc:37.7, eval_loss:0.770491, eval_acc:40.0
9 epoch, train_loss = 0.776995, train_acc:36.9, eval_loss:0.767334, eval_acc:40.0
10 epoch, train_loss = 0.772754, train_acc:37.0, eval_loss:0.764276, eval_acc:40.0
11 epoch, train_loss = 0.769076, train_acc:36.9, eval_loss:0.761304, eval_acc:40.8
12 epoch, train_loss