# Custom CNN Model Optimization

This notebook focuses on optimizing the custom CNN model for musical instrument classification using the flexible framework. We'll explore various optimization techniques and evaluate their impact on performance.

## Setup

Let's set up the environment by importing the necessary libraries and modules from our project structure.

In [None]:
import os
import sys
import yaml
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
import time

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns

# Add the project root to the path to enable importing from our package
project_root = Path(os.getcwd()).parent
if str(project_root) not in sys.path:
    sys.path.append(str(project_root))
    
# Import from our project modules
from src.data.dataset import InstrumentDataset, get_transforms
from src.data.preprocessing import create_train_val_split
from src.models.custom_cnn import MusicInstrumentCNN, create_custom_cnn
from src.training.trainer import train_model, evaluate_model
from src.training.scheduler import get_scheduler

# Set random seed for reproducibility
def set_seed(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# Function to load configuration
def load_config(config_path):
    with open(config_path, 'r') as file:
        config = yaml.safe_load(file)
    return config

## Load the Optimized Configuration

We'll use our new optimized configuration to guide the model training.

In [None]:
# Load configuration from YAML file
config_path = project_root / "config" / "optimized_custom_cnn.yaml"
config = load_config(config_path)

# Display the configuration
print("Model Configuration:")
for key, value in config.items():
    print(f"{key}: {value}")

## Data Preparation with Enhanced Augmentation

We'll implement stronger data augmentation strategies to improve model generalization.

In [None]:
import torchvision.transforms as transforms

def get_enhanced_transforms(img_size=224, augmentation_strength='strong'):
    """
    Get enhanced data transforms for training and validation
    
    Args:
        img_size (int): Image size for resizing
        augmentation_strength (str): Strength of augmentations ('light', 'medium', 'strong')
        
    Returns:
        train_transform: Transforms for training data
        val_transform: Transforms for validation data
    """
    # Basic validation transform (resize and normalize)
    val_transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    # Define augmentation parameters based on strength
    if augmentation_strength == 'light':
        color_jitter = transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05)
        erasing_prob = 0.1
        rotation_degrees = 10
    elif augmentation_strength == 'medium':
        color_jitter = transforms.ColorJitter(brightness=0.15, contrast=0.15, saturation=0.15, hue=0.1)
        erasing_prob = 0.2
        rotation_degrees = 15
    else:  # strong
        color_jitter = transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)
        erasing_prob = 0.3
        rotation_degrees = 20
    
    # Enhanced training transform
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(img_size, scale=(0.8, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(rotation_degrees),
        color_jitter,
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        transforms.RandomErasing(p=erasing_prob)
    ])
    
    return train_transform, val_transform

# Set paths
data_dir = project_root / "data" / "raw" / "30_Musical_Instruments"

# Create train/validation split
train_files, val_files, classes = create_train_val_split(
    data_dir, 
    val_split=0.2,
    seed=42
)

print(f"Number of classes: {len(classes)}")
print(f"Number of training samples: {len(train_files)}")
print(f"Number of validation samples: {len(val_files)}")

# Get enhanced data transforms
augmentation_strength = config.get('augmentation', {}).get('augmentation_strength', 'strong')
img_size = config.get('data', {}).get('img_size', 224)
train_transform, val_transform = get_enhanced_transforms(
    img_size=img_size,
    augmentation_strength=augmentation_strength
)

# Create datasets
train_dataset = InstrumentDataset(train_files, classes, transform=train_transform)
val_dataset = InstrumentDataset(val_files, classes, transform=val_transform)

# Create data loaders
batch_size = config.get('training', {}).get('batch_size', 32)
num_workers = config.get('data', {}).get('num_workers', 4)

train_loader = DataLoader(
    train_dataset, 
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True
)
val_loader = DataLoader(
    val_dataset, 
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True
)

## Visualize Augmented Images

Let's visualize some examples of our augmented training images to make sure our augmentation strategy is effective.

In [None]:
def visualize_augmentations(dataset, num_samples=8, num_augmentations=5):
    """
    Visualize the effect of data augmentations on samples from the dataset
    
    Args:
        dataset: The dataset with augmentations applied
        num_samples: Number of different samples to visualize
        num_augmentations: Number of augmentations to apply to each sample
    """
    # Get random indices from the dataset
    indices = np.random.choice(len(dataset), num_samples, replace=False)
    
    plt.figure(figsize=(num_augmentations * 3, num_samples * 3))
    denormalize = transforms.Compose([
        transforms.Normalize(mean=[0, 0, 0], std=[1/0.229, 1/0.224, 1/0.225]),
        transforms.Normalize(mean=[-0.485, -0.456, -0.406], std=[1, 1, 1]),
    ])
    
    for i, idx in enumerate(indices):
        original_sample = dataset.get_original(idx)
        plt.subplot(num_samples, num_augmentations + 1, i * (num_augmentations + 1) + 1)
        plt.imshow(original_sample)
        plt.title(f"Original\n{classes[dataset.labels[idx]]}")
        plt.axis('off')
        
        for j in range(num_augmentations):
            # Apply augmentation
            img, _ = dataset[idx]
            # Denormalize for visualization
            img = denormalize(img)
            img = torch.clamp(img, 0, 1)
            img = img.permute(1, 2, 0).numpy()
            
            plt.subplot(num_samples, num_augmentations + 1, i * (num_augmentations + 1) + j + 2)
            plt.imshow(img)
            plt.title(f"Augmentation {j+1}")
            plt.axis('off')
    
    plt.tight_layout()
    plt.show()

# Add a method to the InstrumentDataset class to get original (non-augmented) images
def get_original(self, idx):
    """
    Get the original image without any transformations
    
    Args:
        idx (int): Index of the image
        
    Returns:
        PIL.Image: The original image
    """
    img_path = self.image_paths[idx]
    return Image.open(img_path).convert('RGB')

# Attach the method to the class
from PIL import Image
InstrumentDataset.get_original = get_original

# Visualize the augmentations
visualize_augmentations(train_dataset, num_samples=4, num_augmentations=4)

## Initialize the Optimized Model

Let's set up our custom CNN model with the optimized parameters.

In [None]:
# Create model
model = create_custom_cnn(
    num_classes=len(classes), 
    input_channels=3
)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
model = model.to(device)

# Print model architecture and parameter count
print(model)
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

## Enhanced Training Configuration

Now, we'll set up our training with optimized hyperparameters.

In [None]:
# Define loss function
criterion = nn.CrossEntropyLoss()

# Get optimizer parameters from config
optimizer_config = config.get('training', {}).get('optimizer', {})
optimizer_name = optimizer_config.get('name', 'adamw').lower()
lr = optimizer_config.get('learning_rate', 0.001)
weight_decay = optimizer_config.get('weight_decay', 0.001)
beta1 = optimizer_config.get('beta1', 0.9)
beta2 = optimizer_config.get('beta2', 0.999)

# Create optimizer
if optimizer_name == 'adam':
    optimizer = optim.Adam(
        model.parameters(), 
        lr=lr,
        weight_decay=weight_decay,
        betas=(beta1, beta2)
    )
elif optimizer_name == 'adamw':
    optimizer = optim.AdamW(
        model.parameters(), 
        lr=lr,
        weight_decay=weight_decay,
        betas=(beta1, beta2)
    )
elif optimizer_name == 'sgd':
    momentum = optimizer_config.get('momentum', 0.9)
    nesterov = optimizer_config.get('nesterov', True)
    optimizer = optim.SGD(
        model.parameters(), 
        lr=lr,
        momentum=momentum,
        weight_decay=weight_decay,
        nesterov=nesterov
    )
else:
    raise ValueError(f"Unsupported optimizer: {optimizer_name}")

# Get scheduler parameters from config
scheduler_config = config.get('training', {}).get('scheduler', {})
scheduler_name = scheduler_config.get('name', 'onecycle').lower()
num_epochs = config.get('training', {}).get('num_epochs', 75)

# Create scheduler
if scheduler_name == 'onecycle':
    max_lr = scheduler_config.get('max_lr', 0.01)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=max_lr,
        steps_per_epoch=len(train_loader),
        epochs=num_epochs,
        pct_start=scheduler_config.get('pct_start', 0.3),
        div_factor=25.0,  # initial_lr = max_lr / div_factor
        final_div_factor=10000.0  # final_lr = initial_lr / final_div_factor
    )
elif scheduler_name == 'cosine':
    t_max = scheduler_config.get('t_max', num_epochs)
    eta_min = scheduler_config.get('eta_min', 0.000001)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=t_max,
        eta_min=eta_min
    )
elif scheduler_name == 'step':
    step_size = scheduler_config.get('step_size', 30)
    gamma = scheduler_config.get('gamma', 0.1)
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer,
        step_size=step_size,
        gamma=gamma
    )
elif scheduler_name == 'plateau':
    patience = scheduler_config.get('patience', 5)
    factor = scheduler_config.get('factor', 0.1)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=factor,
        patience=patience,
        min_lr=1e-6
    )
else:
    scheduler = None

# Print training configuration
print(f"Training for {num_epochs} epochs")
print(f"Optimizer: {optimizer_name}")
print(f"Learning rate: {lr}")
print(f"Weight decay: {weight_decay}")
print(f"Scheduler: {scheduler_name}")

## Implement Gradient Clipping

Gradient clipping helps stabilize training by preventing exploding gradients.

In [None]:
# Modified training function with gradient clipping
def train_model_with_clipping(model, dataloaders, criterion, optimizer, device, 
                            scheduler=None, num_epochs=10, gradient_clip_val=1.0, verbose=True):
    """
    Enhanced training function with gradient clipping
    
    Args:
        model (nn.Module): PyTorch model to train
        dataloaders (dict): Dictionary of PyTorch DataLoader objects for 'train' and 'val'
        criterion: Loss function
        optimizer: Optimizer to use
        device (torch.device): Device to train on (GPU or CPU)
        scheduler: Learning rate scheduler (optional)
        num_epochs (int): Number of epochs to train for
        gradient_clip_val (float): Max norm for gradient clipping
        verbose (bool): Whether to print progress
        
    Returns:
        model: Best model based on validation accuracy
        history (dict): Training and validation metrics
        training_stats (dict): Training statistics
    """
    since = time.time()
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': [], 'lr': []}
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    best_epoch = 0
    
    # Early stopping parameters
    patience = config.get('regularization', {}).get('early_stopping', {}).get('patience', 15)
    min_delta = config.get('regularization', {}).get('early_stopping', {}).get('min_delta', 0.001)
    counter = 0
    best_loss = float('inf')
    
    import copy
    
    for epoch in range(num_epochs):
        if verbose:
            print(f'Epoch {epoch+1}/{num_epochs}')
            print('-' * 10)
        
        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode
            
            running_loss = 0.0
            running_corrects = 0
            
            # Progress bar for the dataloader
            dataloader = dataloaders[phase]
            progress_bar = tqdm(dataloader, desc=f'{phase} Epoch {epoch+1}/{num_epochs}') if verbose else dataloader
            
            # Iterate over data (batch)
            for inputs, labels in progress_bar:
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                # Zero the parameter gradients
                optimizer.zero_grad()
                
                # Forward pass - track history only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)
                    
                    # Backward pass + optimize only in training phase
                    if phase == 'train':
                        loss.backward()
                        
                        # Gradient clipping
                        torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip_val)
                        
                        optimizer.step()
                        
                        # Step OneCycleLR per iteration
                        if isinstance(scheduler, torch.optim.lr_scheduler.OneCycleLR):
                            scheduler.step()
                
                # Statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                
                # Update progress bar if using tqdm
                if verbose:
                    progress_bar.set_postfix({
                        'loss': loss.item(), 
                        'accuracy': torch.sum(preds == labels.data).item() / inputs.size(0)
                    })
            
            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
            
            # Store the metrics
            if phase == 'train':
                history['train_loss'].append(epoch_loss)
                history['train_acc'].append(epoch_acc.item())
                history['lr'].append(optimizer.param_groups[0]['lr'])
            else:
                history['val_loss'].append(epoch_loss)
                history['val_acc'].append(epoch_acc.item())
            
            if verbose:
                print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
            
            # LR Scheduler step if it's a validation phase and not OneCycleLR
            if phase == 'val':
                if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                    scheduler.step(epoch_loss)
                elif scheduler and not isinstance(scheduler, torch.optim.lr_scheduler.OneCycleLR):
                    scheduler.step()
                
                # Early stopping check
                if epoch_loss < best_loss - min_delta:
                    best_loss = epoch_loss
                    counter = 0
                else:
                    counter += 1
                    if verbose and counter > 0:
                        print(f"Early stopping counter: {counter}/{patience}")
                    
                    if counter >= patience:
                        print(f"Early stopping at epoch {epoch+1}")
                        # Load the best model weights
                        model.load_state_dict(best_model_wts)
                        
                        # Calculate and print training time
                        time_elapsed = time.time() - since
                        if verbose:
                            print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
                            print(f'Best val Acc: {best_acc:.4f} at epoch {best_epoch}')
                        
                        # Store training statistics
                        training_stats = {
                            'training_time': f"{time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s",
                            'best_val_acc': best_acc.item(),
                            'best_epoch': best_epoch,
                            'stopped_early': True,
                            'stopped_epoch': epoch + 1
                        }
                        
                        return model, history, training_stats
            
            # Save the best model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_epoch = epoch + 1
                best_model_wts = copy.deepcopy(model.state_dict())
                if verbose:
                    print(f'New best model found! Val accuracy: {best_acc:.4f}')
        
        if verbose:
            print()
    
    # Calculate and print training time
    time_elapsed = time.time() - since
    if verbose:
        print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
        print(f'Best val Acc: {best_acc:.4f} at epoch {best_epoch}')
    
    # Store training statistics
    training_stats = {
        'training_time': f"{time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s",
        'best_val_acc': best_acc.item(),
        'best_epoch': best_epoch,
        'stopped_early': False
    }
    
    # Load the best model weights
    model.load_state_dict(best_model_wts)
    return model, history, training_stats

## Train the Optimized Model

In [None]:
# Create dataloaders dictionary
dataloaders = {
    'train': train_loader,
    'val': val_loader
}

# Gradient clipping value
gradient_clip_val = config.get('regularization', {}).get('gradient_clipping', {}).get('max_norm', 1.0)

# Train the model
print(f"Starting model training for {num_epochs} epochs...")
optimized_model, history, training_stats = train_model_with_clipping(
    model=model,
    dataloaders=dataloaders,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    device=device,
    num_epochs=num_epochs,
    gradient_clip_val=gradient_clip_val,
    verbose=True
)

# Print training summary
print("\nTraining summary:")
print(f"Best validation accuracy: {training_stats['best_val_acc']:.4f} at epoch {training_stats['best_epoch']}")
print(f"Training time: {training_stats['training_time']}")

## Visualize Training Metrics

Let's plot the learning curves to understand the training dynamics.

In [None]:
def plot_training_history(history):
    """Plot the training and validation metrics"""
    epochs = range(1, len(history['train_loss']) + 1)
    
    # Create a 2x2 grid of plots
    plt.figure(figsize=(16, 10))
    
    # Plot training & validation loss
    plt.subplot(2, 2, 1)
    plt.plot(epochs, history['train_loss'], 'b-', label='Training Loss')
    plt.plot(epochs, history['val_loss'], 'r-', label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    
    # Plot training & validation accuracy
    plt.subplot(2, 2, 2)
    plt.plot(epochs, history['train_acc'], 'b-', label='Training Accuracy')
    plt.plot(epochs, history['val_acc'], 'r-', label='Validation Accuracy')
    plt.title('Training and Validation Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)
    
    # Plot learning rate
    plt.subplot(2, 2, 3)
    plt.plot(epochs, history['lr'], 'g-')
    plt.title('Learning Rate')
    plt.xlabel('Epochs')
    plt.ylabel('Learning Rate')
    plt.grid(True)
    
    # Plot validation accuracy vs. learning rate
    plt.subplot(2, 2, 4)
    plt.scatter(history['lr'], history['val_acc'], alpha=0.7)
    plt.title('Validation Accuracy vs. Learning Rate')
    plt.xlabel('Learning Rate')
    plt.ylabel('Validation Accuracy')
    plt.xscale('log')
    plt.grid(True)
    
    plt.tight_layout()
    plt.show()

# Plot the training history
plot_training_history(history)

## Evaluate on Test Set

Let's evaluate our optimized model on the test set to see if our optimizations improved performance.

In [None]:
# Create test dataset and loader
test_files, _, _ = create_train_val_split(
    data_dir, 
    val_split=0.0,  # Don't create a validation set
    test_split=0.2,  # Use 20% of data for testing
    seed=42
)

print(f"Number of test samples: {len(test_files)}")

test_dataset = InstrumentDataset(test_files, classes, transform=val_transform)  # Use validation transform for testing
test_loader = DataLoader(
    test_dataset, 
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True
)

# Evaluate the model on the test set
print("Evaluating model on test set...")
test_accuracy, all_preds, all_labels = evaluate_model(
    model=optimized_model,
    test_loader=test_loader,
    device=device,
    verbose=True
)

print("Test Results:")
print(f"- Accuracy: {test_accuracy/100:.4f}")

## Confusion Matrix and Detailed Analysis

Let's create a confusion matrix and classification report to better understand our model's strengths and weaknesses.

In [None]:
# Create confusion matrix
cm = confusion_matrix(all_labels, all_preds)

# Plot confusion matrix
plt.figure(figsize=(20, 16))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=classes, yticklabels=classes)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.xticks(rotation=90)
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()

# Generate classification report
report = classification_report(all_labels, all_preds, target_names=classes)
print("Classification Report:")
print(report)

## Save the Optimized Model

Let's save our optimized model for future use.

In [None]:
# Create directory to save the model
save_dir = project_root / "experiments" / "optimized_custom_cnn"
save_dir.mkdir(parents=True, exist_ok=True)

# Save model weights
model_save_path = save_dir / "optimized_model.pth"
torch.save(optimized_model.state_dict(), model_save_path)

# Save training history and statistics
import json
history_save_path = save_dir / "training_history.json"
with open(history_save_path, 'w') as f:
    json.dump({
        'train_loss': [float(x) for x in history['train_loss']],
        'train_acc': [float(x) for x in history['train_acc']],
        'val_loss': [float(x) for x in history['val_loss']],
        'val_acc': [float(x) for x in history['val_acc']],
        'lr': [float(x) for x in history['lr']],
    }, f)

stats_save_path = save_dir / "training_stats.json"
with open(stats_save_path, 'w') as f:
    # Convert any tensor values to float
    stats_dict = {}
    for k, v in training_stats.items():
        if isinstance(v, torch.Tensor):
            stats_dict[k] = v.item()
        else:
            stats_dict[k] = v
    json.dump(stats_dict, f)

print(f"Model saved to {model_save_path}")
print(f"Training history saved to {history_save_path}")
print(f"Training statistics saved to {stats_save_path}")

## Compare with Baseline Results

Let's compare our optimized model with the previous baseline custom model.

In [None]:
# Performance comparison
comparison_data = {
    'Model': ['ResNet-18 (Transfer Learning)', 'Original Custom CNN', 'Optimized Custom CNN'],
    'Test Accuracy': ['100.00%', '80.67%', f'{test_accuracy:.2f}%'],
    'Training Time': ['11m 20s', '28m 47s', training_stats['training_time']],
    'Best Epoch': [8, 47, training_stats['best_epoch']],
    'Parameters': ['11.7 million', '8.6 million', f'{trainable_params/1e6:.1f} million']
}

comparison_df = pd.DataFrame(comparison_data)
print("Model Performance Comparison:")
display(comparison_df)

# Create a summary of optimization changes
optimization_summary = """
## Key Optimizations Applied:

1. **Enhanced Data Augmentation**:
   - Increased augmentation strength to improve model generalization
   - Added random erasing and more advanced transformations

2. **Optimizer Improvements**:
   - Switched from Adam to AdamW for better weight decay handling
   - Increased weight decay for better regularization

3. **Learning Rate Scheduling**:
   - Implemented OneCycleLR policy for faster convergence
   - Used warmup period to stabilize early training

4. **Regularization Techniques**:
   - Applied gradient clipping to prevent exploding gradients
   - Implemented early stopping to prevent overfitting
   - Adjusted dropout rates for better feature learning

5. **Training Process Improvements**:
   - Increased number of epochs to allow for more learning
   - Maintained batch size for stable training
   - Added extensive monitoring of training metrics
"""

print(optimization_summary)