## 4. Training and Evaluation (Training Loop)
### This code trains the model using training data and evaluates its accuracy on the validation test.

## 4.1 Training Loop

In [None]:
import copy
import time
import torch

# Define path del mejor modelo en Drive
BEST_MODEL_PATH = os.path.join(PROJECT_PATH, 'best_simple_model.pth')

def train_model(model, criterion, optimizer, num_epochs=10):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    history = {'val_acc': []}

    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')

        model.train()
        for inputs, labels in dataloaders['train']:
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        model.eval()
        running_corrects = 0
        with torch.no_grad():
            for inputs, labels in dataloaders['val']:
                inputs = inputs.to(device)
                labels = labels.to(device)
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                running_corrects += torch.sum(preds == labels.data)

        epoch_acc = running_corrects.double() / dataset_sizes['val']
        history['val_acc'].append(epoch_acc.item())
        print(f'Validation Acc: {epoch_acc:.4f}')

        # Save best model
        if epoch_acc > best_acc:
            best_acc = epoch_acc
            best_model_wts = copy.deepcopy(model.state_dict())
            torch.save(model.state_dict(), BEST_MODEL_PATH)
            print(f">>> New best model saved at: {BEST_MODEL_PATH}")

    # Load best model weights
    model.load_state_dict(best_model_wts)
    print(f"\nBest Validation Accuracy: {best_acc:.4f}")
    return model
# Execute training
model_ft = train_model(model_ft, criterion, optimizer, num_epochs=10)
# Verify model saved
time.sleep(2)
if os.path.exists(BEST_MODEL_PATH):
    file_size = os.path.getsize(BEST_MODEL_PATH) / (1024*1024)
    print("Model checkpoint saved in Drive")
    print(f"Location: {BEST_MODEL_PATH}")
    print(f"Size: {file_size:.2f} MB")
else:
    print("ERROR: Model was NOT saved in Drive.")