# MNIST

## Importing libraries

In [None]:
import os
import random
import umap
import numpy as np
import matplotlib.pyplot as plt
from dataclasses import dataclass
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import wandb

## Configuration

In [None]:
@dataclass
class Config:
    debug: bool = True
    dataset_path: str = "data/"
    device: torch.device = torch.device('cpu')  # Automatic device configuration
    
    # Model
    d_embed: int = 16
    
    # Training
    batch_size: int = 100
    max_steps: int = 600  # Total number of training samples = max_steps * batch_size = 60,000
    lr: float = 0.01

    seed: int = 101
    
config = Config()

## Weights & Biases

In [None]:
if not args.debug:
    wandb.login(key=os.environ.get("WANDB_API_KEY"))

## Reproducibility

In [None]:
def set_seed(seed: int):
    """
    Set the random seed for reproducibility.

    Args:
        seed (int): The seed value to set.
    """
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print(f"Random seed set to {seed}")
    
set_seed(config.seed)

## Device

In [None]:
def configure_device() -> torch.device:
    """
    Configure the device for training.

    Returns:
        torch.device: The device to use for training.
    """
    if torch.cuda.is_available():
        device = torch.device("cuda")
        num_gpu = torch.cuda.device_count()
        print(f"Running on {num_gpu} {torch.cuda.get_device_name()} GPU(s)")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
        print(f"Running on {device}")
    else:
        device = torch.device("cpu")
        print(f"Running on {device}")
    return device

config.device = configure_device()

## Dataset

In [None]:
mnist_train = datasets.MNIST(
    root=config.dataset_path,
    train=True,
    transform=transforms.ToTensor(),
    download=True,
)
mnist_test = datasets.MNIST(
    root=config.dataset_path,
    train=False,
    transform=transforms.ToTensor(),
    download=True,
)
train_loader = DataLoader(
    dataset=mnist_train,
    batch_size=config.batch_size,
    shuffle=True,
    drop_last=True
)
test_loader = DataLoader(
    dataset=mnist_test,
    batch_size=10000,
    shuffle=False,
    drop_last=False
)

## Model

In [None]:
class MLP(nn.Module):
    def __init__(self, d_embed):
        super(MLP, self).__init__()
        self.linear1 = nn.Linear(784, d_embed)
        self.linear2 = nn.Linear(d_embed, 10)

    def forward(self, x):  # x: (batch_size, 1, 28, 28)
        x = x.view(x.size(0), -1)  # (batch_size, 784)
        x = self.linear1(x)  # (batch_size, d_embed)
        x = F.relu(x)
        x = self.linear2(x)  # (batch_size, 10)
        return x

## Training

In [None]:
def train_epoch(model: nn.Module, dataloader: DataLoader, optimizer: Optimizer, scheduler: LambdaLR, current_epoch: int, total_epochs: int, grad_clip: float, device: torch.device, wandb_run: wandb.sdk.wandb_run.Run) -> nn.Module:
    """
    Train the model for one epoch.

    Args:
        model (nn.Module): The model to train.
        dataloader (DataLoader): DataLoader for the training data.
        optimizer (Optimizer): Optimizer to use.
        scheduler (lr_scheduler): Learning rate scheduler.
        current_epoch (int): Current epoch number.
        total_epochs (int): Total number of epochs.
        grad_clip (float): Gradient clipping value.
        device (torch.device): Device to run the model on.
        wandb_run (wandb.sdk.wandb_run.Run): Wandb run for logging.

    Returns:
        model (nn.Module): The trained model.
    """
    model.train()
    running_loss = 0.0
    progress_bar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch {current_epoch+1}/{total_epochs}")
    wandb_run.watch(model, log="all", log_freq=len(dataloader))

    for batch_idx, (inputs, targets) in progress_bar:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        logits = model(inputs)
        loss = model.loss(logits, targets)
        loss.backward()
        clip_grad_norm_(model.parameters(), grad_clip)
        optimizer.step()
        scheduler.step()
        running_loss += loss.item()
        progress_bar.set_postfix(loss=f"{running_loss / (batch_idx + 1):.4f}")

        if wandb_run is not None:
            wandb_run.log(
                {"Train Loss": loss.item(),
                "Learning Rate": optimizer.param_groups[0]['lr']},
                step=current_epoch * len(dataloader) + batch_idx
            )

    progress_bar.close()
    print(f"Epoch {current_epoch+1}/{total_epochs} Loss: {running_loss / len(dataloader):.4f}")

    return model

## Evaluation

In [None]:
def evaluate(model: nn.Module, dataloader: DataLoader, device: torch.device, wandb_run: wandb.sdk.wandb_run.Run) -> float:
    """
    Evaluate the model on the validation set.

    Args:
        model (nn.Module): The model to evaluate.
        dataloader (DataLoader): DataLoader for the validation data.
        device (torch.device): Device to run the model on.
        wandb_run (wandb.sdk.wandb_run.Run): Wandb run for logging.

    Returns:
        loss (float): The average loss on the validation set.
    """
    model.eval()
    running_loss = 0.0
    progress_bar = tqdm(enumerate(dataloader), total=len(dataloader), desc="Validation")

    with torch.no_grad():
        for batch_idx, (inputs, targets) in progress_bar:
            inputs, targets = inputs.to(device), targets.to(device)
            logits = model(inputs)
            loss = model.loss(logits, targets)
            running_loss += loss.item()
            progress_bar.set_postfix(loss=f"{running_loss / (batch_idx + 1):.4f}")

    progress_bar.close()

    avg_loss = running_loss / len(dataloader)

    if wandb_run is not None:
        wandb_run.log({
            "Validation Loss": avg_loss,
        })

    print(f"Validation Loss: {avg_loss:.4f}")

    return avg_loss

## Scaling laws

In [None]:
def plot_scaling_laws(x, y, x_label, y_label, title, wandb_run):
    """
    Plot the given data in log-log scale and log the figure to wandb.
    """
    from scipy.stats import linregress

    # Convert to log-space
    x_log = np.log10(x)
    y_log = np.log10(y)
    slope, intercept, r_value, p_value, std_err = linregress(x_log, y_log)

    # Create figure and axis
    fig, ax = plt.subplots(figsize=(6, 5))
    ax.scatter(x, y, label='Data', alpha=0.7)

    # Best fit line in log space
    fit_line_x = np.linspace(x_log.min(), x_log.max(), 100)
    fit_line_y = slope * fit_line_x + intercept
    ax.plot(10 ** fit_line_x, 10 ** fit_line_y, 'r--',
            label=f'Fit: y = {10 ** intercept:.2f} * x^{slope:.2f}')
    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.set_xlabel(x_label)
    ax.set_ylabel(y_label)
    ax.set_title(title)
    ax.legend()

    # Log the figure to wandb
    if wandb_run is not None:
        wandb.log({title: wandb.Image(fig)})
    plt.close(fig)

### Compute

In [None]:
def compute_experiment(
        model_sizes: dict, dataset_sizes: dict,
        train_text: str, val_text: str, tokenizer: CharTokenizer | BPETokenizer,
        optimizer_name: str, lr: float, weight_decay: float, scheduler_type: str, warmup_ratio: float,
        grad_clip: float, device: torch.device, project: str, root_dir: str):
    """
    Compute vs test loss scaling laws.

    Args:
        model_sizes (dict): Dictionary with the model sizes.
        dataset_sizes (dict): Dictionary with the dataset sizes.
        train_text (str): Text data for training.
        val_text (str): Text data for validation.
        tokenizer (CharTokenizer | BPETokenizer): Tokenizer instance.
        optimizer_name (str): Name of the optimizer.
        lr (float): Learning rate.
        weight_decay (float): Weight decay.
        scheduler_type (str): Type of the scheduler.
        warmup_ratio (float): Ratio of the warmup steps.
        grad_clip (float): Gradient clipping value.
        device (torch.device): Device for training.
        project (str): Name of the project.
        root_dir (str): Root directory of the project.
    """
    compute_values = []
    test_losses = []
    exp = 1

    for model_size in model_sizes:
        for dataset_size in dataset_sizes:
            wandb_run = wandb.init(
                project=project,
                name=f"Compute vs Test Loss - exp {exp}",
                dir=root_dir
            )
            print(f"Wandb run initialized: {wandb_run.id}")

            # Subset the training data
            subset_train_text = train_text[:int(len(train_text) * dataset_sizes[dataset_size])]
            train_dataset = TextDataset(text=subset_train_text, tokenizer=tokenizer, context_size=model_sizes[model_size]["context_size"])
            val_dataset = TextDataset(text=val_text, tokenizer=tokenizer, context_size=model_sizes[model_size]["context_size"])
            if model_size == "small":
                batch_size = 512
            elif model_size == "medium":
                batch_size = 128
            elif model_size == "large":
                batch_size = 64
            elif model_size == "xl":
                batch_size = 32
            else:
                batch_size = 128
            train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
            val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
            print(f"Number of tokens: {len(train_dataset)}")

            # Initialize the model
            model = GPT(GPTConfig(
                vocab_size=tokenizer.vocab_size,
                context_size=model_sizes[model_size]["context_size"],
                n_layer=model_sizes[model_size]["n_layer"],
                n_head=model_sizes[model_size]["n_head"],
                d_embed=model_sizes[model_size]["d_embed"],
                d_ff=model_sizes[model_size]["d_ff"],
                dropout=model_sizes[model_size]["dropout"]
            )).to(device)
            num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

            # FLOPs
            flops = 2 * num_params * len(train_dataset)
            print(f"FLOPs: {flops}")

            # Initialize the optimizer and scheduler
            optimizer = setup_optimizer(
                model=model,
                optimizer_name=optimizer_name,
                lr=lr,
                weight_decay=weight_decay
            )
            scheduler = setup_scheduler(
                optimizer=optimizer,
                scheduler_type=scheduler_type,
                warmup_ratio=warmup_ratio,
                total_steps=len(train_loader) * 1
            )

            # Train the model for one epoch
            model = train_epoch(
                model=model,
                dataloader=train_loader,
                optimizer=optimizer,
                scheduler=scheduler,
                current_epoch=0,
                total_epochs=1,
                grad_clip=grad_clip,
                device=device,
                wandb_run=wandb_run
            )
            test_loss = evaluate(
                model=model,
                dataloader=val_loader,
                device=device,
                wandb_run=wandb_run
            )

            compute_values.append(flops)
            test_losses.append(test_loss)
            wandb_run.finish()
            exp += 1

    plot_scaling_laws(
        x=compute_values,
        y=test_losses,
        x_label="Compute (FLOPs)",
        y_label="Test Loss",
        title="Compute vs Test Loss",
        wandb_run=wandb_run
    )

### Dataset size

In [None]:
def dataset_size_experiment(
        dataset_sizes: dict, model_size: dict,
        train_text: str, val_text: str, tokenizer: CharTokenizer | BPETokenizer, batch_size: int,
        optimizer_name: str, lr: float, weight_decay: float, scheduler_type: str, warmup_ratio: float,
        grad_clip: float, device: torch.device, project: str, root_dir: str):
    """
    Dataset size vs test loss scaling laws.

    Args:
        dataset_sizes (dict): Dictionary with the dataset sizes.
        model_size (dict): Dictionary with the model size.
        train_text (str): Text data for training.
        val_text (str): Text data for validation.
        tokenizer (CharTokenizer | BPETokenizer): Tokenizer instance.
        batch_size (int): Batch size the DataLoaders.
        optimizer_name (str): Name of the optimizer.
        lr (float): Learning rate.
        weight_decay (float): Weight decay.
        scheduler_type (str): Type of the scheduler.
        warmup_ratio (float): Ratio of the warmup steps.
        grad_clip (float): Gradient clipping value.
        device (torch.device): Device for training.
        project (str): Name of the project.
        root_dir (str): Root directory of the project.
    """
    num_tokens = []
    test_losses = []
    exp = 1

    for dataset_size in dataset_sizes:
        wandb_run = wandb.init(
            project=project,
            name=f"Dataset Size vs Test Loss - exp {exp}",
            dir=root_dir
        )
        print(f"Wandb run initialized: {wandb_run.id}")

        subset_train_text = train_text[:int(len(train_text) * dataset_sizes[dataset_size])]
        train_dataset = TextDataset(text=subset_train_text, tokenizer=tokenizer, context_size=model_size["context_size"])
        val_dataset = TextDataset(text=val_text, tokenizer=tokenizer, context_size=model_size["context_size"])
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
        print(f"Number of tokens: {len(train_dataset)}")

        # Initialize the model
        model = GPT(GPTConfig(
            vocab_size=tokenizer.vocab_size,
            context_size=model_size["context_size"],
            n_layer=model_size["n_layer"],
            n_head=model_size["n_head"],
            d_embed=model_size["d_embed"],
            d_ff=model_size["d_ff"],
            dropout=model_size["dropout"]
        )).to(device)

        # Initialize the optimizer and scheduler
        optimizer = setup_optimizer(
            model=model,
            optimizer_name=optimizer_name,
            lr=lr,
            weight_decay=weight_decay
        )
        scheduler = setup_scheduler(
            optimizer=optimizer,
            scheduler_type=scheduler_type,
            warmup_ratio=warmup_ratio,
            total_steps=len(train_loader) * 1
        )

        # Train the model for one epoch
        model = train_epoch(
            model=model,
            dataloader=train_loader,
            optimizer=optimizer,
            scheduler=scheduler,
            current_epoch=0,
            total_epochs=1,
            grad_clip=grad_clip,
            device=device,
            wandb_run=wandb_run
        )
        test_loss = evaluate(
            model=model,
            dataloader=val_loader,
            device=device,
            wandb_run=wandb_run
        )

        num_tokens.append(len(train_dataset))
        test_losses.append(test_loss)
        wandb_run.finish()
        exp += 1

    plot_scaling_laws(
        x=num_tokens,
        y=test_losses,
        x_label="Dataset Size",
        y_label="Test Loss",
        title="Dataset Size vs Test Loss",
        wandb_run=wandb_run
    )

### Model size

In [None]:
def model_size_experiment(
        model_sizes: dict, dataset_size: float,
        train_text: str, val_text: str, tokenizer: CharTokenizer | BPETokenizer,
        optimizer_name: str, lr: float, weight_decay: float, scheduler_type: str, warmup_ratio: float,
        grad_clip: float, device: torch.device, project: str, root_dir: str):
    """
    Model size vs test loss scaling laws.

    Args:
        model_sizes (dict): Dictionary with the model sizes.
        dataset_size (float): Dictionary with the dataset sizes.
        train_text (str): Text data for training.
        val_text (str): Text data for validation.
        tokenizer (CharTokenizer | BPETokenizer): Tokenizer instance.
        optimizer_name (str): Name of the optimizer.
        lr (float): Learning rate.
        weight_decay (float): Weight decay.
        scheduler_type (str): Type of the scheduler.
        warmup_ratio (float): Ratio of the warmup steps.
        grad_clip (float): Gradient clipping value.
        device (torch.device): Device for training.
        project (str): Name of the project.
        root_dir (str): Root directory of the project.
    """
    parameters = []
    test_losses = []
    exp = 1

    for model_size in model_sizes:
        wandb_run = wandb.init(
            project=project,
            name=f"Parameters vs Test Loss - exp {exp}",
            dir=root_dir
        )
        print(f"Wandb run initialized: {wandb_run.id}")

        subset_train_text = train_text[:int(len(train_text) * dataset_size)]
        train_dataset = TextDataset(text=subset_train_text, tokenizer=tokenizer, context_size=model_sizes[model_size]["context_size"])
        val_dataset = TextDataset(text=val_text, tokenizer=tokenizer, context_size=model_sizes[model_size]["context_size"])
        if model_size == "small":
            batch_size = 512
        elif model_size == "medium":
            batch_size = 128
        elif model_size == "large":
            batch_size = 64
        elif model_size == "xl":
            batch_size = 32
        else:
            batch_size = 128
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

        # Initialize the model
        model = GPT(GPTConfig(
            vocab_size=tokenizer.vocab_size,
            context_size=model_sizes[model_size]["context_size"],
            n_layer=model_sizes[model_size]["n_layer"],
            n_head=model_sizes[model_size]["n_head"],
            d_embed=model_sizes[model_size]["d_embed"],
            d_ff=model_sizes[model_size]["d_ff"],
            dropout=model_sizes[model_size]["dropout"]
        )).to(device)
        num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print(f"Number of parameters: {num_params}")

        # Initialize the optimizer and scheduler
        optimizer = setup_optimizer(
            model=model,
            optimizer_name=optimizer_name,
            lr=lr,
            weight_decay=weight_decay
        )
        scheduler = setup_scheduler(
            optimizer=optimizer,
            scheduler_type=scheduler_type,
            warmup_ratio=warmup_ratio,
            total_steps=len(train_loader) * 1
        )


        # Train the model for one epoch
        model = train_epoch(
            model=model,
            dataloader=train_loader,
            optimizer=optimizer,
            scheduler=scheduler,
            current_epoch=0,
            total_epochs=1,
            grad_clip=grad_clip,
            device=device,
            wandb_run=wandb_run
        )
        test_loss = evaluate(
            model=model,
            dataloader=val_loader,
            device=device,
            wandb_run=wandb_run
        )

        parameters.append(num_params)
        test_losses.append(test_loss)
        wandb_run.finish()
        exp += 1

    plot_scaling_laws(
        x=parameters,
        y=test_losses,
        x_label="Number of Layers",
        y_label="Test Loss",
        title="Model Size vs Test Loss",
        wandb_run=wandb_run
    )

### UMAP