# 03. Training Loop and Optimization

In this notebook, we implement the training loop, add optimization techniques (SpecAugment, Learning Rate Scheduling), and evaluate the model. We also integrate **MLflow** for experiment tracking.

## 1. Modular Code
At this stage, we have moved our `Dataset` and `Model` classes into the `src/` directory for better organization. We can import them directly.

In [None]:
import os
import sys
sys.path.append(os.path.abspath('../'))

from src.data.dataset import SpokenDigitDataset
from src.models.model import SimpleCNN, DeeperCNN
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
import mlflow
import mlflow.pytorch

print("Modules imported successfully!")

## 2. Setup Data and Splitting
We load all files (including .m4a and .ogg) and split them into Training and Validation sets.

In [None]:
dataset = SpokenDigitDataset('../data/processed')
files = dataset.file_list
train_files, test_files = train_test_split(files, test_size=0.2, random_state=42)

train_dataset = SpokenDigitDataset(file_list=train_files, train=True, time_mask_param=30, freq_mask_param=15) # Optimized Augmentation
test_dataset = SpokenDigitDataset(file_list=test_files, train=False)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

print(f"Training Samples: {len(train_dataset)}")
print(f"Validation Samples: {len(test_dataset)}")

## 3. Training Loop with Optimization and MLflow
We use `Adam` optimizer and `ReduceLROnPlateau` scheduler. We wrap the loop in `mlflow.start_run()`.

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.backends.mps.is_available():
    device = torch.device('mps')

model = DeeperCNN(num_classes=10).to(device) # Using DeeperCNN (Best Model)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3)

print(f"Training on {device}...")

# MLflow Setup
mlflow.set_experiment("Spoken Digit Recognition Notebook")

epochs = 50 # Optimal Epochs from hyperparameter tuning

with mlflow.start_run():
    mlflow.log_param("epochs", epochs)
    mlflow.log_param("model", "DeeperCNN")
    
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            
        avg_train_loss = running_loss / len(train_loader)
        
        # Validation
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        val_acc = 100 * correct / total
        
        print(f"Epoch {epoch+1}/{epochs} - Loss: {avg_train_loss:.4f} - Val Acc: {val_acc:.2f}%")
        
        # Log metrics
        mlflow.log_metric("train_loss", avg_train_loss, step=epoch)
        mlflow.log_metric("val_accuracy", val_acc, step=epoch)
        
        scheduler.step(val_acc)
        
    # Log Model
    # Create signature and input example
    example_input = images[:1].cpu().numpy()
    signature = mlflow.models.signature.infer_signature(example_input, model(images[:1]).detach().cpu().numpy())
    mlflow.pytorch.log_model(model, "model", signature=signature, input_example=example_input)
    print("Training Complete. Check MLflow UI!")

## 4. Final Results
To view the results, run `mlflow ui` in your terminal and open the link (usually http://127.0.0.1:5000).