# Scaling Laws

## Importing libraries

In [4]:
import os
from dataclasses import dataclass
import torch
from torch.utils.data import DataLoader
from models.mlp.mlp import MLP, MLPConfig
from models.gpt.gpt import GPT, GPTConfig
from src.utils import load_text, set_seed, configure_device
from src.tokenizer import CharTokenizer, BPETokenizer
from src.train import split_text, TextDataset, setup_optimizer, setup_scheduler, train_epoch, evaluate

## Configuration

In [None]:
@dataclass
class CONFIG:
    debug: bool = False
    root_dir: str = os.getcwd() + "/../"
    dataset_path: str = 'data/raw/shakespeare.txt'
    device: torch.device = torch.device('cpu')  # Automatic device configuration

    # wandb
    project: str = "LLM101-Scaling-Laws"

    # Tokenizer
    tokenizer: str = "char"  # char or bpe

    # Model
    model: str = "gpt"  # gpt or mlp
    if model == "mlp":
        context_size: int = 16
        d_embed: int = 256
        d_ff: int = 1024
    elif model == "gpt":
        context_size: int = 4
        n_layer: int = 2
        n_head: int = 2
        d_embed: int = 128
        d_ff: int = 512
        dropout: float = 0.2
        flash_attention: bool = False
    elif model == "megabyte":
        pass

    # Training
    val_size: float = 0.05
    epochs: int = 1
    batch_size: int = 64
    optimizer: str = "AdamW"  # AdamW or SGD
    learning_rate: float = 0.001
    weight_decay: float = 0.01
    scheduler: str = "cosine"  # cosine or linear
    warmup_ratio: float = 0.1
    grad_clip: float = 1.0
    mixed_precision: bool = False
    seed: int = 101

## Weights & Biases

In [None]:
if not CONFIG.debug:
    import wandb
    wandb.login(key=os.environ.get("WANDB_API_KEY"))
    wandb_run = wandb.init(
        project=CONFIG.project,
        config=CONFIG.__dict__,
        dir=CONFIG.root_dir
    )
    print(f"Wandb run initialized: {wandb_run.id}")
else:
    wandb_run = None
    print("Debug mode enabled.")

## Reproducibility

In [None]:
set_seed(CONFIG.seed)

## Device

In [None]:
CONFIG.device = configure_device()

## Tokenizer

In [None]:
# Initialize tokenizer
if CONFIG.tokenizer == "char":
    tokenizer = CharTokenizer()
elif CONFIG.tokenizer == "bpe":
    tokenizer = BPETokenizer()
else:
    raise ValueError("Invalid tokenizer type. Choose 'char' or 'bpe'.")

## Text to build vocabulary
vocab_text = load_text(CONFIG.root_dir + CONFIG.dataset_path)

## Build vocabulary
tokenizer.build_vocab(vocab_text)

## Model

In [None]:
# Initialize model
if CONFIG.model == "mlp":
    model = MLP(MLPConfig(
        vocab_size=tokenizer.vocab_size,
        context_size=CONFIG.context_size,
        d_embed=CONFIG.d_embed,
        d_ff=CONFIG.d_ff
    ))
elif CONFIG.model == "gpt":
    model = GPT(GPTConfig(
        vocab_size=tokenizer.vocab_size,
        context_size=CONFIG.context_size,
        n_layer=CONFIG.n_layer,
        n_head=CONFIG.n_head,
        d_embed=CONFIG.d_embed,
        d_ff=CONFIG.d_ff,
        dropout=CONFIG.dropout
    ))
else:
    raise ValueError("Invalid model type. Choose 'mlp' or 'gpt'.")

model.to(CONFIG.device)
print(model)

## Dataset

In [None]:
# Load text dataset to train the model
text = load_text(CONFIG.root_dir + CONFIG.dataset_path)

# Split text into training and validation sets
train_text, val_text = split_text(text, CONFIG.val_size)

# Create datasets and dataloaders
train_dataset = TextDataset(tokenizer, train_text, CONFIG.context_size)
val_dataset = TextDataset(tokenizer, val_text, CONFIG.context_size)
train_loader = DataLoader(train_dataset, batch_size=CONFIG.batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=CONFIG.batch_size, shuffle=False)

## Training

In [None]:
# Setup optimizer and scheduler
optimizer = setup_optimizer(model, CONFIG.optimizer, CONFIG.learning_rate, CONFIG.weight_decay)
scheduler = setup_scheduler(optimizer, CONFIG.scheduler, CONFIG.warmup_ratio, len(train_loader) * CONFIG.epochs)

# Train model
for epoch in range(CONFIG.epochs):
    train_epoch(model, train_loader, optimizer, scheduler, epoch, CONFIG.epochs, CONFIG.grad_clip, CONFIG.device, wandb_run)
    evaluate(model, val_loader, epoch, wandb_run)

# Finish wandb run
if wandb_run is not None:
    wandb_run.finish()