# AlexNet - Entrenamiento y Evaluación

Este notebook implementa el entrenamiento de AlexNet base y sus variantes.

## Configuración de Entrenamiento
- **Épocas**: 60
- **Batch size**: 64  
- **Optimizador**: SGD (momentum=0.9, weight_decay=5e-4)
- **Learning rate**: 0.001
- **Criterio**: CrossEntropyLoss
- **Semilla fija**: 42 para reproducibilidad

## Variantes a Evaluar
- **AlexNet Base**: Arquitectura original adaptada
- **AlexNet + BatchNorm**: Con normalización por lotes
- **AlexNet Reduced**: Con menos capas fully connected

In [None]:
import sys
sys.path.append('../')

import torch
import torch.nn as nn
import os
import numpy as np
import random
import pandas as pd
import pickle

from models.alexnet import AlexNet
from models.alexnet_variants import AlexNet_BN, AlexNet_Reduced
from utils.dataset import create_data_loaders, get_dataset_info
from utils.training import train_model, test_model, count_parameters
from utils.visualization import plot_training_curves, plot_confusion_matrix

# Configurar semillas
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Dispositivo: {device}")

# Cargar datos (reutilizar configuración)
data_dir = '../data'
batch_size = 64
input_size = 224

train_loader, val_loader, test_loader, class_names = create_data_loaders(
    data_dir, batch_size=batch_size, input_size=input_size, seed=42
)

nChannels = 3
nClasses = len(class_names)
print(f"Canales: {nChannels}, Clases: {nClasses}")

## 1. AlexNet Base

In [None]:
alexnet_base = AlexNet(nChannels, nClasses).to(device)
params_base = count_parameters(alexnet_base)
print(f"AlexNet Base - Parámetros: {params_base:,} ({params_base/1e6:.2f}M)")

save_path = '../checkpoints/alexnet_base_best.pth'
print("\nEntrenando AlexNet Base...")
history_base, avg_time_base = train_model(
    alexnet_base, train_loader, val_loader, 
    epochs=60, optimizer_name='SGD', lr=0.001, 
    device=device, save_path=save_path
)

alexnet_base.load_state_dict(torch.load(save_path))
test_acc_base, test_f1_base, conf_matrix_base, _ = test_model(
    alexnet_base, test_loader, device, class_names
)

print(f"AlexNet Base - Test Acc: {test_acc_base:.2f}%, Test F1: {test_f1_base:.2f}%")

## 2. AlexNet + BatchNorm

In [None]:
alexnet_bn = AlexNet_BN(nChannels, nClasses).to(device)
params_bn = count_parameters(alexnet_bn)
print(f"AlexNet + BN - Parámetros: {params_bn:,} ({params_bn/1e6:.2f}M)")

save_path_bn = '../checkpoints/alexnet_bn_best.pth'
print("\nEntrenando AlexNet + BatchNorm...")
history_bn, avg_time_bn = train_model(
    alexnet_bn, train_loader, val_loader, 
    epochs=60, optimizer_name='SGD', lr=0.001, 
    device=device, save_path=save_path_bn
)

alexnet_bn.load_state_dict(torch.load(save_path_bn))
test_acc_bn, test_f1_bn, conf_matrix_bn, _ = test_model(
    alexnet_bn, test_loader, device, class_names
)

print(f"AlexNet + BN - Test Acc: {test_acc_bn:.2f}%, Test F1: {test_f1_bn:.2f}%")
print(f"Mejora vs Base - Acc: {test_acc_bn - test_acc_base:.2f}%, F1: {test_f1_bn - test_f1_base:.2f}%")

## 3. AlexNet Reduced

In [None]:
alexnet_reduced = AlexNet_Reduced(nChannels, nClasses).to(device)
params_reduced = count_parameters(alexnet_reduced)
print(f"AlexNet Reduced - Parámetros: {params_reduced:,} ({params_reduced/1e6:.2f}M)")

save_path_reduced = '../checkpoints/alexnet_reduced_best.pth'
print("\nEntrenando AlexNet Reduced...")
history_reduced, avg_time_reduced = train_model(
    alexnet_reduced, train_loader, val_loader, 
    epochs=60, optimizer_name='SGD', lr=0.001, 
    device=device, save_path=save_path_reduced
)

alexnet_reduced.load_state_dict(torch.load(save_path_reduced))
test_acc_reduced, test_f1_reduced, conf_matrix_reduced, _ = test_model(
    alexnet_reduced, test_loader, device, class_names
)

print(f"AlexNet Reduced - Test Acc: {test_acc_reduced:.2f}%, Test F1: {test_f1_reduced:.2f}%")
print(f"Mejora vs Base - Acc: {test_acc_reduced - test_acc_base:.2f}%, F1: {test_f1_reduced - test_f1_base:.2f}%")

## 4. Visualización y Resultados

In [None]:
# Visualizar curvas
plot_training_curves(history_base, "AlexNet Base", '../results/alexnet_base_curves.png')
plot_training_curves(history_bn, "AlexNet + BatchNorm", '../results/alexnet_bn_curves.png')
plot_training_curves(history_reduced, "AlexNet Reduced", '../results/alexnet_reduced_curves.png')

# Matrices de confusión
plot_confusion_matrix(conf_matrix_base, class_names, "AlexNet Base", '../results/alexnet_base_confusion.png')
plot_confusion_matrix(conf_matrix_bn, class_names, "AlexNet + BN", '../results/alexnet_bn_confusion.png')
plot_confusion_matrix(conf_matrix_reduced, class_names, "AlexNet Reduced", '../results/alexnet_reduced_confusion.png')

# Resultados
results_alexnet = {
    'Modelo': ['AlexNet Base', 'AlexNet + BN', 'AlexNet Reduced'],
    'Params (M)': [params_base/1e6, params_bn/1e6, params_reduced/1e6],
    't/epoca (s)': [avg_time_base, avg_time_bn, avg_time_reduced],
    'Val Acc': [max(history_base['val_acc']), max(history_bn['val_acc']), max(history_reduced['val_acc'])],
    'Val F1': [max(history_base['val_f1']), max(history_bn['val_f1']), max(history_reduced['val_f1'])],
    'Test Acc': [test_acc_base, test_acc_bn, test_acc_reduced],
    'Test F1': [test_f1_base, test_f1_bn, test_f1_reduced]
}

df_alexnet = pd.DataFrame(results_alexnet)
print("Resultados AlexNet y Variantes:")
print(df_alexnet.round(2))

df_alexnet.to_csv('../results/alexnet_results.csv', index=False)

with open('../results/alexnet_histories.pkl', 'wb') as f:
    pickle.dump({
        'base': history_base,
        'bn': history_bn,
        'reduced': history_reduced
    }, f)