# 03. Training Loop and Optimization

In this notebook, we implement the training loop, add optimization techniques (SpecAugment, Learning Rate Scheduling), and evaluate the model.

## 1. Modular Code
At this stage, we have moved our `Dataset` and `Model` classes into the `src/` directory for better organization. We can import them directly.

In [None]:
import os
import sys
sys.path.append(os.path.abspath('../'))

from src.data.dataset import SpokenDigitDataset
from src.models.model import SimpleCNN
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

print("Modules imported successfully!")

## 2. Setup Data and Splitting
We load all files (including .m4a and .ogg) and split them into Training and Validation sets.

In [None]:
dataset = SpokenDigitDataset('../data/processed')
files = dataset.file_list
train_files, test_files = train_test_split(files, test_size=0.2, random_state=42)

train_dataset = SpokenDigitDataset(file_list=train_files, train=True) # Augmentation Enabled
test_dataset = SpokenDigitDataset(file_list=test_files, train=False)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

print(f"Training Samples: {len(train_dataset)}")
print(f"Validation Samples: {len(test_dataset)}")

## 3. Training Loop with Optimization
We use `Adam` optimizer and `ReduceLROnPlateau` scheduler.

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.backends.mps.is_available():
    device = torch.device('mps')

model = SimpleCNN(num_classes=10).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3)

print(f"Training on {device}...")

## 4. Final Results
We train for 50 epochs to ensure high accuracy (>90%). The `train_model.py` script encapsulates the full loop.