
# VGG-16/VGG-19 Training Notebook

本 Notebook 演示如何使用 `train/train_vgg.py` 中的函数在 CIFAR-10 上训练 VGG-16 或 VGG-19，并进行结果可视化与超参数实验。


In [None]:

# 如果你的 notebook 已经在 PyTorch_Learning/notebooks 里启动
import os
import sys

from utils import PROJECT_ROOT

# 确保能 import train 和 utils
sys.path.append(os.path.abspath(".."))
print("项目根目录：", os.path.abspath(".."))

# 输出目录
outputs_dir = os.path.join(PROJECT_ROOT, "outputs")
print("Outputs 目录：", outputs_dir)

# 导入数据加载器和训练函数
from train.train_vgg import get_cifar10_loader, run_vgg_training
from utils.visualization import plot_training_metrics
from utils.experiment import save_experiment_results, load_experiment_results


## 1. 构建 DataLoader

In [None]:

# 构造用于 VGG 的 CIFAR-10 数据加载器
train_loader = get_cifar10_loader(batch_size=64, train=True, img_size=224)
test_loader  = get_cifar10_loader(batch_size=256, train=False, img_size=224)

print("训练集 batch 数：", len(train_loader))
print("测试集 batch 数：", len(test_loader))


## 2. 快速训练并可视化指标

In [None]:

# 运行 3 个 epoch 的快速训练
metrics = run_vgg_training(
    model_type='vgg16',
    epochs=3,
    train_batch_size=64,
    test_batch_size=256,
    lr=0.01,
    momentum=0.9,
    weight_decay=5e-4,
    patience=2,
    log_interval=200,
    output_dir=os.path.join(outputs_dir, "vgg_quick")
)
print(metrics)

# 可视化训练/验证损失与准确率
metrics_dir = os.path.join(outputs_dir, "vgg_quick_metrics")
os.makedirs(metrics_dir, exist_ok=True)
plot_training_metrics(metrics, save_path=metrics_dir)


## 3. 超参数网格实验

In [None]:

# 定义超参数组合
param_grid = [
    {"model_type": "vgg16", "lr": 0.01, "momentum": 0.9,  "batch_size": 64},
    {"model_type": "vgg16", "lr": 0.01, "momentum": 0.95, "batch_size": 128},
    {"model_type": "vgg19", "lr": 0.01, "momentum": 0.9,  "batch_size": 64},
    {"model_type": "vgg19", "lr": 0.001, "momentum": 0.9, "batch_size": 256},
]

all_results = []
for cfg in param_grid:
    metrics = run_vgg_training(
        model_type=cfg["model_type"],
        epochs=10,
        train_batch_size=cfg["batch_size"],
        test_batch_size=256,
        lr=cfg["lr"],
        momentum=cfg["momentum"],
        weight_decay=5e-4,
        patience=3,
        log_interval=200,
        output_dir=os.path.join(
            outputs_dir, "vgg_experiments",
            f"{cfg['model_type']}_lr{cfg['lr']}_mom{cfg['momentum']}_bs{cfg['batch_size']}"
        )
    )
    all_results.append({"config": cfg, "metrics": metrics})

# 保存实验结果
exp_path = os.path.join(outputs_dir, "vgg_experiment_results.json")
save_experiment_results(all_results, exp_path)


## 4. 超参数并行坐标可视化

In [None]:

# 加载实验结果并绘制并行坐标图
results = load_experiment_results(exp_path)
from utils.visualization import plot_experiment_comparison, plot_hyperparam_parallel

plot_experiment_comparison(results, metric="accuracy")
plot_hyperparam_parallel(results)
