In [None]:
import os
import torchvision
import torchvision.transforms as transforms

if not os.path.exists("../data"):
    os.makedirs("../data")

cifar10_train = torchvision.datasets.CIFAR10(
    root="../data",
    train=True,
    download=True,
    transform=transforms.ToTensor()
)
cifar10_test = torchvision.datasets.CIFAR10(
    root="../data",
    train=False,
    download=True,
    transform=transforms.ToTensor()
)
print("✅ CIFAR10 disponible en carpeta '../data/'")

# VGG16 - Entrenamiento y Evaluación

Entrenamiento de VGG16 base y sus variantes con BatchNorm y Global Average Pooling.

In [None]:
import sys
sys.path.append('../')

import torch
import torch.nn as nn
import os
import numpy as np
import random
import pandas as pd
import pickle

from models.vgg16 import VGG16
from models.vgg16_variants import VGG16_BN, VGG16_GAP
from utils.dataset import create_data_loaders
from utils.training import train_model, test_model, count_parameters
from utils.visualization import plot_training_curves, plot_confusion_matrix

torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Dispositivo: {device}")

# Cargar datos
data_dir = '../data'
train_loader, val_loader, test_loader, class_names = create_data_loaders(
    data_dir, batch_size=64, input_size=224, seed=42
)
nChannels, nClasses = 3, len(class_names)

In [None]:
# VGG16 Base
vgg16_base = VGG16(nChannels, nClasses).to(device)
params_base = count_parameters(vgg16_base)
print(f"VGG16 Base - Parámetros: {params_base:,} ({params_base/1e6:.2f}M)")

history_base, avg_time_base = train_model(
    vgg16_base, train_loader, val_loader, epochs=60, 
    optimizer_name='SGD', lr=0.001, device=device, 
    save_path='../checkpoints/vgg16_base_best.pth'
)

vgg16_base.load_state_dict(torch.load('../checkpoints/vgg16_base_best.pth'))
test_acc_base, test_f1_base, conf_matrix_base, _ = test_model(
    vgg16_base, test_loader, device, class_names
)
print(f"VGG16 Base - Test Acc: {test_acc_base:.2f}%, Test F1: {test_f1_base:.2f}%")

In [None]:
# VGG16 + BatchNorm
vgg16_bn = VGG16_BN(nChannels, nClasses).to(device)
params_bn = count_parameters(vgg16_bn)
print(f"VGG16 + BN - Parámetros: {params_bn:,} ({params_bn/1e6:.2f}M)")

history_bn, avg_time_bn = train_model(
    vgg16_bn, train_loader, val_loader, epochs=60, 
    optimizer_name='SGD', lr=0.001, device=device, 
    save_path='../checkpoints/vgg16_bn_best.pth'
)

vgg16_bn.load_state_dict(torch.load('../checkpoints/vgg16_bn_best.pth'))
test_acc_bn, test_f1_bn, conf_matrix_bn, _ = test_model(
    vgg16_bn, test_loader, device, class_names
)
print(f"VGG16 + BN - Test Acc: {test_acc_bn:.2f}%, Test F1: {test_f1_bn:.2f}%")

In [None]:
# VGG16 + GAP
vgg16_gap = VGG16_GAP(nChannels, nClasses).to(device)
params_gap = count_parameters(vgg16_gap)
print(f"VGG16 + GAP - Parámetros: {params_gap:,} ({params_gap/1e6:.2f}M)")

history_gap, avg_time_gap = train_model(
    vgg16_gap, train_loader, val_loader, epochs=60, 
    optimizer_name='SGD', lr=0.001, device=device, 
    save_path='../checkpoints/vgg16_gap_best.pth'
)

vgg16_gap.load_state_dict(torch.load('../checkpoints/vgg16_gap_best.pth'))
test_acc_gap, test_f1_gap, conf_matrix_gap, _ = test_model(
    vgg16_gap, test_loader, device, class_names
)
print(f"VGG16 + GAP - Test Acc: {test_acc_gap:.2f}%, Test F1: {test_f1_gap:.2f}%")

In [None]:
# Visualización y resultados
plot_training_curves(history_base, "VGG16 Base", '../results/vgg16_base_curves.png')
plot_training_curves(history_bn, "VGG16 + BatchNorm", '../results/vgg16_bn_curves.png')
plot_training_curves(history_gap, "VGG16 + GAP", '../results/vgg16_gap_curves.png')

plot_confusion_matrix(conf_matrix_base, class_names, "VGG16 Base", '../results/vgg16_base_confusion.png')
plot_confusion_matrix(conf_matrix_bn, class_names, "VGG16 + BN", '../results/vgg16_bn_confusion.png')
plot_confusion_matrix(conf_matrix_gap, class_names, "VGG16 + GAP", '../results/vgg16_gap_confusion.png')

results_vgg16 = {
    'Modelo': ['VGG16 Base', 'VGG16 + BN', 'VGG16 + GAP'],
    'Params (M)': [params_base/1e6, params_bn/1e6, params_gap/1e6],
    't/epoca (s)': [avg_time_base, avg_time_bn, avg_time_gap],
    'Val Acc': [max(history_base['val_acc']), max(history_bn['val_acc']), max(history_gap['val_acc'])],
    'Val F1': [max(history_base['val_f1']), max(history_bn['val_f1']), max(history_gap['val_f1'])],
    'Test Acc': [test_acc_base, test_acc_bn, test_acc_gap],
    'Test F1': [test_f1_base, test_f1_bn, test_f1_gap]
}

df_vgg16 = pd.DataFrame(results_vgg16)
print("Resultados VGG16 y Variantes:")
print(df_vgg16.round(2))

df_vgg16.to_csv('../results/vgg16_results.csv', index=False)

with open('../results/vgg16_histories.pkl', 'wb') as f:
    pickle.dump({
        'base': history_base,
        'bn': history_bn,
        'gap': history_gap
    }, f)