# <div style="text-align: center; color: #1a5276;">Advanced Training</div>

## <font color='blue'>  Table of Contents </font>

1. [Introduction](#1)
2. [Setup](#2)
3. [Helper Functions](#3) 
4. [Data](#4) 
5. [Model](#5)
6. [Training](#6) <br>
    6.1. [Basic Training](#6.1) <br>
    6.2. [Including a progress bar](#6.2) <br>
    6.3. [Including a validation set and a custom metric](#6.3) <br>
    6.4. [Code improvement](#6.4) <br>
8. [References](#references)

<a name="1"></a>
## <font color='blue'> 1. Introduction </font>

https://gemini.google.com/app/c49f12a5fb1aef13?hl=es

<a name="2"></a>
## <font color='blue'> 2. Setup </font>

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

<a name="3"></a>
## <font color='blue'> 3. Data </font>

In [5]:
# binary data
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset

# Synthetic binary data
X, y = make_classification(n_samples=1000, n_features=20, n_classes=2, random_state=0)
X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.long)

# Split: 70% train, 15% val, 15% test
X_train, X_tmp, y_train, y_tmp = train_test_split(X, y, test_size=0.3, random_state=0)
X_val, X_test, y_val, y_test = train_test_split(X_tmp, y_tmp, test_size=0.5, random_state=0)

train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=32, shuffle=True)
val_loader = DataLoader(TensorDataset(X_val, y_val), batch_size=32)
test_loader = DataLoader(TensorDataset(X_test, y_test), batch_size=32)

## Example

1.1. Simple trainign


In [9]:
# Model (binary)
model = torch.nn.Sequential(
    torch.nn.Linear(20, 32),
    torch.nn.ReLU(),
    torch.nn.Linear(32, 2)
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

Sequential(
  (0): Linear(in_features=20, out_features=32, bias=True)
  (1): ReLU()
  (2): Linear(in_features=32, out_features=2, bias=True)
)

In [6]:
def train_one_epoch(model, dataloader, criterion, optimizer, device, metric_fn=None):
    model.train()
    running_loss, running_metric, total_samples = 0.0, 0.0, 0

    for inputs, labels in dataloader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        batch_size = labels.size(0)
        running_loss += loss.item() * batch_size
        if metric_fn:
            running_metric += metric_fn(outputs, labels).item() * batch_size
        total_samples += batch_size

    return running_loss / total_samples, running_metric / total_samples if metric_fn else 0

In [7]:
def evaluate(model, dataloader, criterion, device, metric_fn=None):
    model.eval()
    running_loss, running_metric, total_samples = 0.0, 0.0, 0

    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            batch_size = labels.size(0)
            running_loss += loss.item() * batch_size
            if metric_fn:
                running_metric += metric_fn(outputs, labels).item() * batch_size
            total_samples += batch_size

    return running_loss / total_samples, running_metric / total_samples if metric_fn else 0

In [8]:
def train_model(
    model, train_loader, val_loader, test_loader,
    criterion, optimizer, device,
    metric_fn=None, epochs=10,
    scheduler=None, early_stopping_patience=None
):
    best_val_loss = float('inf')
    best_model_state = None
    wait = 0
    history = {'train_loss': [], 'val_loss': [], 'train_metric': [], 'val_metric': []}

    for epoch in range(1, epochs + 1):
        train_loss, train_metric = train_one_epoch(model, train_loader, criterion, optimizer, device, metric_fn)
        val_loss, val_metric = evaluate(model, val_loader, criterion, device, metric_fn)

        if scheduler:
            scheduler.step(val_loss)

        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['train_metric'].append(train_metric)
        history['val_metric'].append(val_metric)

        print(f"Epoch {epoch:02d} | Train Loss: {train_loss:.4f}, Metric: {train_metric:.4f} | "
              f"Val Loss: {val_loss:.4f}, Metric: {val_metric:.4f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_state = model.state_dict()
            wait = 0
        else:
            wait += 1
            if early_stopping_patience and wait >= early_stopping_patience:
                print("Early stopping triggered.")
                break

    if best_model_state:
        model.load_state_dict(best_model_state)
        print("Loaded best model.")

    # Final test set evaluation (only once)
    test_loss, test_metric = evaluate(model, test_loader, criterion, device, metric_fn)
    print(f"\nFinal Test Loss: {test_loss:.4f}, Test Metric: {test_metric:.4f}")

    return model, history

In [10]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2)

def accuracy_fn(outputs, labels):
    preds = torch.argmax(outputs, dim=1)
    return (preds == labels).float().mean()

# Train with test set evaluation
model, history = train_model(
    model, train_loader, val_loader, test_loader,
    criterion, optimizer, device,
    metric_fn=accuracy_fn,
    epochs=20,
    scheduler=scheduler,
    early_stopping_patience=4
)

Epoch 01 | Train Loss: 0.4410, Metric: 0.8314 | Val Loss: 0.2145, Metric: 0.9333
Epoch 02 | Train Loss: 0.1712, Metric: 0.9329 | Val Loss: 0.1727, Metric: 0.9400
Epoch 03 | Train Loss: 0.1241, Metric: 0.9586 | Val Loss: 0.1787, Metric: 0.9400
Epoch 04 | Train Loss: 0.1004, Metric: 0.9686 | Val Loss: 0.1749, Metric: 0.9400
Epoch 05 | Train Loss: 0.0870, Metric: 0.9729 | Val Loss: 0.1803, Metric: 0.9400
Epoch 06 | Train Loss: 0.0720, Metric: 0.9829 | Val Loss: 0.1810, Metric: 0.9400
Early stopping triggered.
Loaded best model.

Final Test Loss: 0.1209, Test Metric: 0.9800


### Multiple metrics



### using torchmetrics to compute metrics