# Training Together 

In [None]:
%load_ext autoreload
%autoreload 2
from functools import partial
from pathlib import Path
import pickle

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import wandb
import os

from ece496b_basics.adapters import *
from ece496b_basics.model import CustomModule

DATA_PATH = Path("../data").resolve()
OUTPUT_PATH = Path("outputs").resolve()
ts_tokenized_path = OUTPUT_PATH / "tinystories_encoded.npy"

device = "cuda"

### Parameter Sweep

In [None]:
# Initialize Weights & Biases
config = {
    "vocab_size": 10_000,
    "context_length": 128,
    "d_model": 512,
    "num_layers": 4,
    "num_heads": 16,
    "d_ff": 2048,
    "attn_pdrop": 0.1,
    "residual_pdrop": 0.1,
    "num_steps":  2500,
    "epochs": 1,
    "epochs_per_checkpoint": 1,
    "learning_rate": 0.001,
}
sweep_config = {
    "method": "grid",
    "parameters": {
        "batch_size": {
            "values": [32, 64, 128, 192]
        },
    }
}
sweep_id = wandb.sweep(sweep_config, project="training_together")

### Sweep Batch Size

In [None]:
# Training Loop
def train(config):
    run = wandb.init(config=config)
    config = wandb.config
    run.name = f"batch_size_{config.batch_size}"
    dataset = np.load(ts_tokenized_path, mmap_mode="r")
    model = CustomModule(
        vocab_size=config.vocab_size,
        context_length=config.context_length,
        d_model=config.d_model,
        num_layers=config.num_layers,
        num_heads=config.num_heads,
        d_ff=config.d_ff,
        device=device,
    )
    optimizer = get_adamw_cls()(model.parameters(), lr=config.learning_rate)
    # scheduler = torch.optim.lr_scheduler.LambdaLR(
    # optimizer, lr_lambda=lambda it: run_get_lr_cosine_schedule(
    #     it, learning_rate, learning_rate * 0.1, 1000, 10000)
    # )
    total_loss = 0
    for step in range(config.num_steps):
        # Get batch
        inputs, targets = run_get_batch(dataset, config.batch_size, config.context_length, device)
        # Forward pass
        outputs = model(inputs)
        
        # Compute loss
        loss = run_cross_entropy(outputs.view(-1, config.vocab_size), targets.view(-1))
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        run_gradient_clipping(model.parameters(), 1.0)
        optimizer.step()

        total_loss += loss.item()

        # Log to wandb
        wandb.log({f"Loss:": loss.item(), "batch_size": config.batch_size})

    # Finish wandb run
    wandb.finish()

In [None]:
num_combinations = np.prod([len(v["values"]) for v in sweep_config["parameters"].values()])
wandb.agent(sweep_id, function=partial(train, config), count=int(num_combinations))