___
# NN Training Study
## __By:__ José Luis Almendarez González

### Abstract

### Methodology

### Results

### Analysis

### Limitations

### References
___

### Python libraries *

In [1]:
from torchvision.datasets import MNIST
from torchvision import transforms
from PIL import Image
import os, random
from torch.utils.data import ConcatDataset
import matplotlib.pyplot as plt
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import lib
import torchvision.models as models
import torch.nn as nn
import ipywidgets

 \* Many of these libraries are properly referenced in the "lib.py", an imported document used for presentation.

### Data Import, Split, Augmentation & Preparation
___

In [None]:
transform = transforms.ToTensor()

train_data = MNIST(root='./data', train=True, download=True, transform=transform)
test_data = MNIST(root='./data', train=False, download=True, transform=transform)

In [None]:
print(len(train_data))
print(len(test_data))

In [None]:
base_dir = './data/MNIST/use/original'

for split in ['train', 'test']:
    path = os.path.join(base_dir, split)
    os.makedirs(path, exist_ok=True)

In [None]:
def save_subset(dataset, split_dir, n_per_class=10, train_ratio=0.8):
    counts = {i:0 for i in range(10)}

    for img, label in dataset:
        if counts[label] < n_per_class:
            counts[label] += 1

            # Decidir si va a train o test
            target_split = 'train' if counts[label] <= n_per_class * train_ratio else 'test'
            save_dir = os.path.join(split_dir, target_split, str(label))  # subcarpeta por clase
            os.makedirs(save_dir, exist_ok=True)

            # Guardar imagen con el label en el nombre de archivo
            img_path = os.path.join(save_dir, f'{label}_{counts[label]}.jpg')
            img_pil = transforms.ToPILImage()(img)
            img_pil.save(img_path)

In [None]:
save_subset(train_data, base_dir, n_per_class=60, train_ratio=0.8)
save_subset(test_data, base_dir, n_per_class=60, train_ratio=0.8)

In [None]:
base_dir = './data/MNIST/use/original'

for split in ['train', 'test']:
    split_path = os.path.join(base_dir, split)
    num_images = sum(
        len([f for f in os.listdir(os.path.join(split_path, cls)) if f.endswith('.jpg')])
        for cls in os.listdir(split_path) if os.path.isdir(os.path.join(split_path, cls))
    )
    print(num_images)

In [None]:
augment_transform = transforms.Compose([
    transforms.RandomRotation(45),
    transforms.RandomHorizontalFlip(),
    transforms.RandomResizedCrop(28, scale=(0.8, 1.0)),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.RandomPerspective(distortion_scale=0.2, p=0.5),
])

In [None]:
def generate_augmented(input_dir, output_dir, target_count):
    os.makedirs(output_dir, exist_ok=True)

    classes = [cls for cls in os.listdir(input_dir) if os.path.isdir(os.path.join(input_dir, cls))]

    for cls in classes:
        input_class_dir = os.path.join(input_dir, cls)
        output_class_dir = os.path.join(output_dir, cls)
        os.makedirs(output_class_dir, exist_ok=True)

        images = [f for f in os.listdir(input_class_dir) if f.endswith('.jpg')]
        current_count = 0  # solo aumentadas
        counter = 0

        while current_count < target_count:
            img_name = random.choice(images)
            img = Image.open(os.path.join(input_class_dir, img_name))
            aug_img = augment_transform(img)
            counter += 1
            new_name = f"{img_name.split('.')[0]}_aug{counter}.jpg"
            aug_img.save(os.path.join(output_class_dir, new_name))
            current_count += 1

In [None]:
original_train = './data/MNIST/use/original/train'
original_test = './data/MNIST/use/original/test'

augmented_train = './data/MNIST/use/augmented/train'
augmented_test = './data/MNIST/use/augmented/test'

generate_augmented(original_train, augmented_train, target_count=432)
generate_augmented(original_test, augmented_test, target_count=108)


In [None]:
base_dir = './data/MNIST/use/augmented'

for split in ['train', 'test']:
    split_path = os.path.join(base_dir, split)
    num_images = sum(
        len([f for f in os.listdir(os.path.join(split_path, cls)) if f.endswith('.jpg')])
        for cls in os.listdir(split_path) if os.path.isdir(os.path.join(split_path, cls))
    )
    print(num_images)

In [3]:
train_dataset = ConcatDataset([
    ImageFolder('./data/MNIST/use/original/train'),
    ImageFolder('./data/MNIST/use/augmented/train')
])

test_dataset = ConcatDataset([
    ImageFolder('./data/MNIST/use/original/test'),
    ImageFolder('./data/MNIST/use/augmented/test'),
])

In [None]:
print(len(train_dataset))
print(len(test_dataset))

In [None]:
def plot_images(base_dir, titulo):

    fig, axes = plt.subplots(2, 5, figsize=(10, 5))
    axes = axes.flatten()

    clases = sorted([cls for cls in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, cls))], key=int)

    for idx, cls in enumerate(clases):
        class_dir = os.path.join(base_dir, cls)
        img_name = sorted([f for f in os.listdir(class_dir) if f.endswith('.jpg')])[0]
        img_path = os.path.join(class_dir, img_name)

        img = Image.open(img_path)
        img_tensor = transforms.ToTensor()(img)

        axes[idx].imshow(img_tensor.squeeze(), cmap='gray')
        axes[idx].set_title(f"Clase {cls}")
        axes[idx].axis('off')

    fig.suptitle(titulo, fontsize=16)
    plt.tight_layout()
    plt.show()

In [None]:
# Original train/test
plot_images('./data/MNIST/use/original/train', "Original Train")
plot_images('./data/MNIST/use/original/test', "Original Test")

# Augmented train/test
plot_images('./data/MNIST/use/augmented/train', "Augmented Train")
plot_images('./data/MNIST/use/augmented/test', "Augmented Test")

In [None]:
print(len(train_dataset)+len(test_dataset))

In [2]:
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Grayscale(num_output_channels=1),   # Primero grayscale
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])   # 1 canal
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

In [3]:
train_dataset = ConcatDataset([
    ImageFolder('./data/MNIST/use/original/train', transform=train_transform),
    ImageFolder('./data/MNIST/use/augmented/train', transform=train_transform)
])

test_dataset = ConcatDataset([
    ImageFolder('./data/MNIST/use/original/test', transform=test_transform),
    ImageFolder('./data/MNIST/use/augmented/test', transform=test_transform)
])

In [4]:
train_dataset, test_dataset

(<torch.utils.data.dataset.ConcatDataset at 0x10cbeb950>,
 <torch.utils.data.dataset.ConcatDataset at 0x10cc26720>)

### Build & Evaluate Model(s)
___

In [24]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, models
from pathlib import Path
import json
import random
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output
import time
import matplotlib.pyplot as plt
from tqdm import tqdm
import pandas as pd

# -------------------------
# Configuración base
# -------------------------
results_dir = Path("results")
results_dir.mkdir(exist_ok=True)

# -------------------------
# Modelo factory + init
# -------------------------
def get_model(model_name="alexnet", num_classes=10, weight_init="kaiming", dropout_rate=0.0, input_size=(1,224,224)):
    if model_name.lower() == "alexnet":
        model = models.alexnet(weights=None)
        model.features[0] = nn.Conv2d(1, 64, kernel_size=11, stride=4, padding=2)
        in_features = model.classifier[-1].in_features
        layers = list(model.classifier.children())[:-1]
        if dropout_rate > 0:
            layers.insert(-1, nn.Dropout(dropout_rate))
        layers.append(nn.Linear(in_features, num_classes))
        model.classifier = nn.Sequential(*layers)

    elif model_name.lower() == "vgg":
        model = models.vgg11(weights=None)
        model.features[0] = nn.Conv2d(1, 64, kernel_size=3, padding=1)
        in_features = model.classifier[-1].in_features
        layers = list(model.classifier.children())[:-1]
        if dropout_rate > 0:
            layers.insert(-1, nn.Dropout(dropout_rate))
        layers.append(nn.Linear(in_features, num_classes))
        model.classifier = nn.Sequential(*layers)

    elif model_name.lower() == "customcnn":
        class CustomCNN(nn.Module):
            def __init__(self, num_classes=10, dropout_rate=0.0, input_size=(1,224,224)):
                super().__init__()
                self.features = nn.Sequential(
                    nn.Conv2d(1, 32, 3, 1, 1), nn.ReLU(), nn.MaxPool2d(2),
                    nn.Conv2d(32, 64, 3, 1, 1), nn.ReLU(), nn.MaxPool2d(2),
                    nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(), nn.MaxPool2d(2)
                )
                with torch.no_grad():
                    x = torch.zeros(1, *input_size)
                    x = self.features(x)
                    flatten_size = x.view(1, -1).shape[1]
                self.classifier = nn.Sequential(
                    nn.Flatten(),
                    nn.Linear(flatten_size, 256), nn.ReLU(),
                    nn.Dropout(dropout_rate) if dropout_rate > 0 else nn.Identity(),
                    nn.Linear(256, 128), nn.ReLU(),
                    nn.Dropout(dropout_rate) if dropout_rate > 0 else nn.Identity(),
                    nn.Linear(128, num_classes)
                )
            def forward(self, x):
                return self.classifier(self.features(x))
        model = CustomCNN(num_classes=num_classes, dropout_rate=dropout_rate, input_size=input_size)

    else:
        raise ValueError(f"Modelo desconocido: {model_name}")

    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            if weight_init.lower() == "kaiming":
                nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
            elif weight_init.lower() == "xavier":
                nn.init.xavier_normal_(m.weight)
            elif weight_init.lower() == "orthogonal":
                nn.init.orthogonal_(m.weight)
            if hasattr(m, "bias") and m.bias is not None:
                nn.init.constant_(m.bias, 0)
    return model

# -------------------------
# Helper: scheduler + warmup
# -------------------------
def set_constant_with_warmup(optimizer, base_lr, epoch, warmup_epochs=5):
    if epoch <= warmup_epochs and warmup_epochs > 0:
        lr = base_lr * epoch / float(warmup_epochs)
    else:
        lr = base_lr
    for g in optimizer.param_groups:
        g['lr'] = lr
    return lr


def lr_range_test(model, train_loader, optimizer_class, lr_start=1e-6, lr_end=1, num_iters=100, device="cpu"):
    """Realiza un LR Range Test y devuelve las pérdidas por LR."""
    model.train()
    criterion = nn.CrossEntropyLoss()
    optimizer = optimizer_class(model.parameters(), lr=lr_start)

    lrs, losses = [], []
    mult = (lr_end / lr_start) ** (1/num_iters)
    lr = lr_start

    iterator = iter(train_loader)
    for i in range(num_iters):
        try:
            images, labels = next(iterator)
        except StopIteration:
            iterator = iter(train_loader)
            images, labels = next(iterator)

        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        lrs.append(lr)
        losses.append(loss.item())

        lr *= mult
        for g in optimizer.param_groups:
            g['lr'] = lr

    plt.figure(figsize=(6,4))
    plt.plot(lrs, losses)
    plt.xscale("log")
    plt.xlabel("Learning Rate (log scale)")
    plt.ylabel("Loss")
    plt.title("LR Range Test")
    plt.show()

    return lrs, losses

# -------------------------
# Main training function
# -------------------------
def train_model(model_name,
                train_dataset, test_dataset,
                batch_size=128,
                optimizer_name="SGD",
                nesterov=False,
                lr=0.01,
                weight_decay=5e-4,
                epochs=60,
                dropout_rate=0.0,
                weight_init="kaiming",
                label_smoothing=0.0,
                lr_schedule="constant+warmup",
                use_augment=False,
                protocol="Fixed Epochs",
                wall_clock_budget=None,
                seed=42,
                device=None,
                num_workers=2,
                pin_memory=True):

    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    if device is None:
        if torch.cuda.is_available():
            device = torch.device("cuda"); print("Usando dispositivo: CUDA")
        elif torch.backends.mps.is_available():
            device = torch.device("mps"); print("Usando dispositivo: MPS")
        else:
            device = torch.device("cpu"); print("Usando dispositivo: CPU")

    if use_augment:
        extra = transforms.Compose([
            transforms.RandomRotation(30),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
            transforms.RandomPerspective(distortion_scale=0.2, p=0.5)
        ])
        if hasattr(train_dataset, "transform") and train_dataset.transform is not None:
            train_dataset.transform = transforms.Compose([extra, train_dataset.transform])
        else:
            train_dataset.transform = extra

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                              num_workers=num_workers, pin_memory=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
                             num_workers=num_workers, pin_memory=False)

    model = get_model(model_name=model_name, num_classes=10, weight_init=weight_init,
                      dropout_rate=dropout_rate, input_size=(1,224,224))
    model.to(device)

    opt_lower = optimizer_name.lower()
    if opt_lower == "sgd":
        optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay, nesterov=False)
    elif opt_lower == "sgd+nesterov":
        optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay, nesterov=True)
    elif opt_lower == "adam":
        optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    elif opt_lower == "adamw":
        optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    else:
        raise ValueError("Optimizador desconocido")

    scheduler = None
    if lr_schedule == "step":
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30,50], gamma=0.1)
    elif lr_schedule == "cosine":
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    elif lr_schedule == "one_cycle":
        scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr, steps_per_epoch=len(train_loader), epochs=epochs)

    criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)

    logs = []
    train_losses, train_accs, test_accs = [], [], []
    epoch_times, epoch_lrs = [], []

    start_time = time.time()
    epoch = 1

    best_val_acc = -1.0
    best_model_state = None

    while True:
        if protocol == "Fixed Epochs" and epoch > epochs:
            break
        if protocol == "Fixed Wall-Clock Time" and wall_clock_budget is not None and (time.time() - start_time) > wall_clock_budget:
            break

        model.train()
        running_loss = 0.0
        running_correct = 0
        running_total = 0
        t0 = time.time()

        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}", disable=True):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)
            _, preds = outputs.max(1)
            running_correct += preds.eq(labels).sum().item()
            running_total += labels.size(0)

        epoch_time = time.time() - t0
        if lr_schedule == "constant+warmup":
            cur_lr = set_constant_with_warmup(optimizer, base_lr=lr, epoch=epoch, warmup_epochs=5)
        else:
            cur_lr = optimizer.param_groups[0]['lr']

        if scheduler is not None and lr_schedule != "one_cycle":
            scheduler.step()

        epoch_loss = running_loss / running_total if running_total > 0 else float('nan')
        epoch_acc = running_correct / running_total if running_total > 0 else 0.0

        # --- evaluation ---
        model.eval()
        test_correct = 0
        test_total = 0
        test_loss_accum = 0.0
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                test_loss_accum += loss.item() * images.size(0)
                _, preds = outputs.max(1)
                test_correct += preds.eq(labels).sum().item()
                test_total += labels.size(0)
        test_loss = test_loss_accum / test_total if test_total > 0 else float('nan')
        test_acc = test_correct / test_total if test_total > 0 else 0.0

        # --- checkpoint: actualizar mejor modelo ---
        if test_acc > best_val_acc:
            best_val_acc = test_acc
            best_model_state = model.state_dict()

        train_losses.append(epoch_loss)
        train_accs.append(epoch_acc)
        test_accs.append(test_acc)
        epoch_times.append(epoch_time)
        epoch_lrs.append(cur_lr)

        logs.append({
            "epoch": epoch,
            "train_loss": float(epoch_loss),
            "train_acc": float(epoch_acc),
            "test_loss": float(test_loss),
            "test_acc": float(test_acc),
            "epoch_time_s": float(epoch_time),
            "lr": float(cur_lr)
        })

        print(f"Epoch {epoch}/{epochs} | loss {epoch_loss:.4f} | train_acc {epoch_acc:.4f} | test_acc {test_acc:.4f} | lr {cur_lr:.5g} | time {epoch_time:.1f}s")

        epoch += 1

    if best_model_state is not None:
        model.load_state_dict(best_model_state)
        model.eval()
        test_correct_final = 0
        test_total_final = 0
        test_loss_accum_final = 0.0
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                test_loss_accum_final += loss.item() * images.size(0)
                _, preds = outputs.max(1)
                test_correct_final += preds.eq(labels).sum().item()
                test_total_final += labels.size(0)
    final_test_loss = test_loss_accum_final / test_total_final
    final_test_acc = test_correct_final / test_total_final
    print(f"\n🧪 Test final — Loss: {final_test_loss:.4f} | Acc: {final_test_acc:.4f}")

    # --- plots ---
    fig, axes = plt.subplots(1,2,figsize=(12,4))
    axes[0].plot([l["epoch"] for l in logs], [l["train_loss"] for l in logs], label="train_loss")
    axes[0].plot([l["epoch"] for l in logs], [l["test_loss"] for l in logs], label="test_loss")
    axes[0].set_title("Loss por época"); axes[0].legend()
    axes[1].plot([l["epoch"] for l in logs], [l["train_acc"] for l in logs], label="train_acc")
    axes[1].plot([l["epoch"] for l in logs], [l["test_acc"] for l in logs], label="test_acc")
    axes[1].set_title("Accuracy por época"); axes[1].legend()
    plt.show()

    model_dir = results_dir / model_name.lower()
    model_dir.mkdir(exist_ok=True)
    exp_idx = len(list(model_dir.glob("exp_*"))) + 1
    exp_path = model_dir / f"exp_{exp_idx}"
    exp_path.mkdir(exist_ok=True)

    metrics_df = pd.DataFrame(logs)
    metrics_df.to_json(exp_path / "metrics.json", index=False)

    metadata = {
        "model_name": model_name,
        "optimizer": optimizer_name,
        "nesterov_flag": True if optimizer_name.lower()=="sgd+nesterov" else False,
        "lr": lr,
        "weight_decay": weight_decay,
        "weight_init": weight_init,
        "dropout_rate": dropout_rate,
        "label_smoothing": label_smoothing,
        "lr_schedule": lr_schedule,
        "batch_size": batch_size,
        "use_augment": use_augment,
        "protocol": protocol,
        "wall_clock_budget_s": wall_clock_budget,
        "seed": seed,
        "device": str(device),
        "num_epochs_run": len(logs),
        "total_train_time_s": sum(epoch_times),
        "epoch_times_s": epoch_times,
        "epoch_lrs": epoch_lrs,
        "best_val_acc": float(best_val_acc),
        "tqdm_disabled_display": True
    }
    with open(exp_path / "metadata.json", "w") as f:
        json.dump(metadata, f, indent=4)

    # --- Save button ---
    def _save_model(btn):
        torch.save(best_model_state, exp_path / "model.pth")
        metadata["model_file"] = str((exp_path / "model.pth").resolve())
        with open(exp_path / "metadata.json", "w") as f:
            json.dump(metadata, f, indent=4)
        print(f"\n✓ Mejor modelo guardado en {exp_path / 'model.pth'} con test_acc = {best_val_acc:.4f}")

    save_btn = widgets.Button(description="💾 Guardar modelo y metadata", button_style="info")
    save_btn.on_click(_save_model)
    display(save_btn)

    print(f"\n✅ Métricas (metrics.json) y metadata.json creados en: {exp_path}")
    return exp_path

# -------------------------
# Widget UI (igual que tu código original)
# -------------------------
model_widget = widgets.Dropdown(options=["alexnet","vgg","customcnn"], value="alexnet", description="Modelo")
optimizer_widget = widgets.Dropdown(options=["SGD","SGD+Nesterov","Adam","AdamW"], value="SGD", description="Optimizer")
nesterov_checkbox = widgets.Checkbox(value=True, description="Nesterov (when applicable)")
nesterov_checkbox.layout.display = "none"

def on_opt_change(change):
    if change["new"].lower() == "sgd+nesterov":
        nesterov_checkbox.layout.display = "block"
    else:
        nesterov_checkbox.layout.display = "none"
optimizer_widget.observe(on_opt_change, names="value")
on_opt_change({"new": optimizer_widget.value})

batch_widget = widgets.Dropdown(options=[32,128,512], value=128, description="Batch size")
lr_widget = widgets.FloatLogSlider(value=0.01, base=10, min=-4, max=-1, step=0.1, description="LR")
weight_decay_widget = widgets.FloatSlider(value=5e-4, min=0.0, max=0.01, step=1e-4, description="Weight Decay")
dropout_widget = widgets.Checkbox(value=False, description="Dropout")
dropout_rate_widget = widgets.FloatSlider(value=0.5, min=0.0, max=0.9, step=0.05, description="Dropout Rate")
weight_init_widget = widgets.Dropdown(options=["kaiming","xavier","orthogonal"], value="kaiming", description="Weight Init")
lr_schedule_widget = widgets.Dropdown(options=["constant+warmup","step","cosine","one_cycle"], value="constant+warmup", description="LR Schedule")
label_smooth_widget = widgets.FloatSlider(value=0.0, min=0.0, max=0.2, step=0.01, description="Label Smoothing")
augment_widget = widgets.Checkbox(value=False, description="Augment Extra")

protocol_widget = widgets.Dropdown(options=["Fixed Epochs","Fixed Wall-Clock Time"], value="Fixed Epochs", description="Protocol")
epochs_box = widgets.BoundedIntText(value=60, min=1, max=100, description="Epochs")
time_box = widgets.BoundedIntText(value=90, min=1, max=120, description="Time (min)")
time_box.layout.display = "none"
seed_box = widgets.BoundedIntText(value=42, min=0, max=9999, description="Seed")

def on_protocol_change(change):
    if change["new"] == "Fixed Epochs":
        epochs_box.layout.display = "block"
        time_box.layout.display = "none"
    else:
        epochs_box.layout.display = "none"
        time_box.layout.display = "block"
protocol_widget.observe(on_protocol_change, names="value")
on_protocol_change({"new": protocol_widget.value})

train_button = widgets.Button(description="🚀 Entrenar", button_style="success")
lrtest_button = widgets.Button(description="🔎 LR Range Test", button_style="warning")
output_widget = widgets.Output(layout={'border': '1px solid black'})

def on_lrtest_clicked(btn):
    with output_widget:
        clear_output(wait=True)
        print("🔎 Ejecutando LR Range Test...\n")
        device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
        model = get_model(model_name=model_widget.value, num_classes=10,
                          weight_init=weight_init_widget.value,
                          dropout_rate=dropout_rate_widget.value if dropout_widget.value else 0.0)
        model.to(device)
        optimizer_class = {
            "SGD": lambda params, lr: optim.SGD(params, lr=lr, momentum=0.9),
            "SGD+Nesterov": lambda params, lr: optim.SGD(params, lr=lr, momentum=0.9, nesterov=True),
            "Adam": lambda params, lr: optim.Adam(params, lr=lr),
            "AdamW": lambda params, lr: optim.AdamW(params, lr=lr)
        }[optimizer_widget.value]

        train_loader = DataLoader(train_dataset, batch_size=int(batch_widget.value), shuffle=True)
        lr_range_test(model, train_loader, optimizer_class, device=device)

def on_train_clicked(btn):
    with output_widget:
        clear_output(wait=True)
        print("🚀 Iniciando entrenamiento...\n")
        wall_budget = None
        if protocol_widget.value == "Fixed Wall-Clock Time":
            wall_budget = time_box.value * 60

        try:
            exp_path = train_model(
                model_name = model_widget.value,
                train_dataset = train_dataset,  # debe existir
                test_dataset  = test_dataset,   # debe existir
                batch_size = int(batch_widget.value),
                optimizer_name = optimizer_widget.value,
                nesterov = bool(nesterov_checkbox.value),
                lr = float(lr_widget.value),
                weight_decay = float(weight_decay_widget.value),
                epochs = int(epochs_box.value),
                dropout_rate = float(dropout_rate_widget.value) if dropout_widget.value else 0.0,
                weight_init = weight_init_widget.value,
                label_smoothing = float(label_smooth_widget.value),
                lr_schedule = lr_schedule_widget.value,
                use_augment = bool(augment_widget.value),
                protocol = protocol_widget.value,
                wall_clock_budget = wall_budget,
                seed = int(seed_box.value),
                device = None,
                num_workers = 2,
                pin_memory = True
            )
            print(f"\n🏁 Experimento completado. Archivos en: {exp_path}")
        except Exception as e:
            print(f"❌ Error durante entrenamiento: {e}")

train_button.on_click(on_train_clicked)
lrtest_button.on_click(on_lrtest_clicked)

ui = widgets.VBox([
    widgets.HBox([model_widget, optimizer_widget, nesterov_checkbox]),
    widgets.HBox([batch_widget, lr_widget, weight_decay_widget]),
    widgets.HBox([dropout_widget, dropout_rate_widget, weight_init_widget]),
    widgets.HBox([lr_schedule_widget, label_smooth_widget, augment_widget]),
    widgets.HBox([protocol_widget, epochs_box, time_box, seed_box]),
    widgets.HBox([train_button, lrtest_button]),
    output_widget
])
display(ui)


VBox(children=(HBox(children=(Dropdown(description='Modelo', options=('alexnet', 'vgg', 'customcnn'), value='a…

___