In [2]:
# 使用sys.path添加上级目录
import sys
import os
package_path = os.path.dirname(os.path.dirname(os.getcwd()))
file_path = os.path.join(package_path, "ch05", "05_bonus_hparam_tuning")
print(file_path)
sys.path.append(file_path)

/Users/young/project/llmProject/LLMs-from-scratch-CN/ch05/05_bonus_hparam_tuning


In [3]:
import itertools
import math
import os
import tiktoken
import torch
from previous_chapters import GPTModel, create_dataloader_v1

In [4]:
# 待搜索的超参数
HPARAM_GRID = {
    "batch_size": [2, 4, 8, 16],
    "drop_rate": [0.0, 0.1, 0.2],
    "warmup_iters": [10, 20, 30],
    "weight_decay": [0.1, 0.01, 0.0],
    "peak_lr": [0.0001, 0.0005, 0.001, 0.005],
    "initial_lr": [0.00005, 0.0001],
    "min_lr": [0.00005, 0.00001, 0.0001],
    "n_epochs": [5, 10, 15, 20, 25],
}

In [5]:
def calc_loss_loader(data_loader, model, device, num_batches=None):
    total_loss = 0
    if len(data_loader) == 0:
        return float("nan")
    elif num_batches is None:
        num_batches = len(data_loader)
    else:
        num_batches = min(num_batches, len(data_loader))
    for i, (input_batch, target_batch) in enumerate(data_loader):
        if i < num_batches:
            loss = calc_loss_batch(input_batch, target_batch, model, device)
            total_loss += loss.item()
        else:
            break
    return total_loss / num_batches

def calc_loss_batch(input_batch, target_batch, model, device):
    input_batch, target_batch = input_batch.to(device), target_batch.to(device)

    logits = model(input_batch)
    logits = logits.view(-1, logits.size(-1))
    loss = torch.nn.functional.cross_entropy(logits, target_batch.view(-1))
    return loss

def evaluate_model(model, train_loader, val_loader, device, eval_iter):
    model.eval()
    with torch.no_grad():
        train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter)
        val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter)
    model.train()
    return train_loss, val_loss

def train_model(model, train_loader, val_loader, optimizer, device,
                n_epochs, eval_freq, eval_iter,
                encoded_start_context, tokenizer, warmup_iters=10,
                initial_lr=3e-05, min_lr=1e-6):
    global_step = 0

    # 计算学习率预热的增量
    max_lr = optimizer.param_groups[0]["lr"]
    total_training_iters = len(train_loader) * n_epochs
    lr_increment = (optimizer.param_groups[0]["lr"] - initial_lr) / warmup_iters

    for epoch in range(n_epochs):
        model.train()
        for input_batch, target_batch in train_loader:
            optimizer.zero_grad()

            global_step += 1

            # 学习率预热阶段
            if global_step <= warmup_iters:
                lr = initial_lr + global_step * lr_increment
            # 余弦退火阶段 / 余弦衰减
            else:
                progress = (global_step - warmup_iters) / (total_training_iters - warmup_iters)
                lr = min_lr + (max_lr - min_lr) * 0.5 * (1 + math.cos(math.pi * progress))
            
            # 应用计算出的学习率
            for param_group in optimizer.param_groups:
                param_group["lr"] = lr
            
            loss = calc_loss_batch(input_batch, target_batch, model, device)
            loss.backward()

            if global_step >= warmup_iters:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
    
    train_loss, val_loss = evaluate_model(model, train_loader, val_loader, device, eval_iter)

    return train_loss, val_loss

def get_device():
    if torch.cuda.is_available():
        device = torch.device("cuda")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
    else:
        device = torch.device("cpu")
    return device

In [7]:
if __name__ == "__main__":
    # 生成搜索空间
    hyperparameter_combinations = list(itertools.product(*HPARAM_GRID.values()))
    total_combinations = len(hyperparameter_combinations)
    print(f"Total hyperparameter configurations: {total_combinations}")

    # 记录最优的验证集loss
    best_val_loss = float("inf")
    best_hparams = {}

    script_path = "./"  # os.path.abspath(__file__)
    script_dir = os.path.dirname(script_path)
    with open(os.path.join(script_dir, "the-verdict.txt"), "r", encoding="utf-8") as file:
        text_data = file.read()
    
    tokenizer = tiktoken.get_encoding("gpt2")
    device = get_device()

    train_ratio = 0.95
    split_idx = int(train_ratio * len(text_data))

    torch.manual_seed(123)

    interrupted = False
    current_config = 0
    for combination in hyperparameter_combinations:
        try:
            current_config += 1
            print(f"Evaluating configuration {current_config} of {total_combinations}")

            HPARAM_CONFIG = dict(zip(HPARAM_GRID.keys(), combination))

            GPT_CONFIG_124M = {
                "vocab_size": 50257,    # Vocabulary size
                "context_length": 256,  # Context length -- shortened from original 1024 tokens
                "emb_dim": 768,         # Embedding dimension
                "n_heads": 12,          # Number of attention heads
                "n_layers": 12,         # Number of layers
                "drop_rate": HPARAM_CONFIG["drop_rate"],
                "qkv_bias": False,     # Query-Key-Value bias
            }

            torch.manual_seed(123)
            train_loader = create_dataloader_v1(
                text_data[:split_idx],
                batch_size=HPARAM_CONFIG["batch_size"],
                max_length=GPT_CONFIG_124M["context_length"],
                stride=GPT_CONFIG_124M["context_length"],
                drop_last=True,
                shuffle=True,
                num_workers=0
            )

            val_loader = create_dataloader_v1(
                text_data[split_idx:],
                batch_size=HPARAM_CONFIG["batch_size"],
                max_length=GPT_CONFIG_124M["context_length"],
                stride=GPT_CONFIG_124M["context_length"],
                drop_last=False,
                shuffle=False,
                num_workers=0
            )

            model = GPTModel(GPT_CONFIG_124M)
            model.to(device)

            optimizer = torch.optim.AdamW(
                model.parameters(),
                lr=HPARAM_CONFIG["peak_lr"],
                weight_decay=HPARAM_CONFIG["weight_decay"]
            )

            encoded_start_context = tokenizer.encode("Nevertheless")
            encoded_tensor = torch.tensor(encoded_start_context).unsqueeze(0)

            train_loss, val_loss = train_model(
                model, train_loader, val_loader, optimizer, device,
                n_epochs=HPARAM_CONFIG["n_epochs"],
                eval_freq=5, eval_iter=1,
                encoded_start_context=encoded_tensor,
                tokenizer=tokenizer,
                warmup_iters=HPARAM_CONFIG["warmup_iters"],
                initial_lr=HPARAM_CONFIG["initial_lr"],
                min_lr=HPARAM_CONFIG["min_lr"]
            )

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_train_loss = train_loss
                best_hparams = HPARAM_CONFIG

        except KeyboardInterrupt:
            print("Hyperparameter search completed.")
            print(f"Best hyperparameters: {best_hparams}")
            print(f"Best Val loss: {best_val_loss} | Training loss {train_loss}")
            interrupted = True
            break
    
    if not interrupted:
        print("Hyperparameter search completed.")
        print(f"Best hyperparameters: {best_hparams}")
        print(f"Best Val loss: {best_val_loss} | Training loss {train_loss}")

Total hyperparameter configurations: 12960
Evaluating configuration 1 of 12960
Evaluating configuration 2 of 12960
Evaluating configuration 3 of 12960
Evaluating configuration 4 of 12960
Evaluating configuration 5 of 12960
Evaluating configuration 6 of 12960
Evaluating configuration 7 of 12960
Hyperparameter search completed.
Best hyperparameters: {'batch_size': 2, 'drop_rate': 0.0, 'warmup_iters': 10, 'weight_decay': 0.1, 'peak_lr': 0.0001, 'initial_lr': 5e-05, 'min_lr': 5e-05, 'n_epochs': 15}
Best Val loss: 6.346194267272949 | Training loss 7.388454914093018
