# Baseline ResNet18 for Musical Instrument Classification

This notebook demonstrates how to leverage our project's organized structure to train a ResNet18 model for musical instrument classification. The project structure offers:

- Modular code organization with reusable components
- Configuration-based model setup using YAML files
- Simplified training workflows
- Integrated GPU detection for acceleration
- Comprehensive evaluation metrics and visualizations

## Architecture Overview

Our model development strategy follows these key steps:

1. **Environment Setup**: Setting up our environment and project imports
2. **Configuration Loading**: Loading parameters from YAML configuration files
3. **Dataset Preparation**: Using our data utilities for consistent processing
4. **Model Creation**: Leveraging the baseline module for ResNet18 setup
5. **Training Execution**: Using the trainer module for efficient training
6. **Evaluation**: Assessing model performance with our metrics module
7. **Visualization**: Generating insightful visualizations of results

Let's begin by setting up our environment and importing the necessary modules from our project structure.

In [None]:
# Setup: Add project root to path to enable imports from src
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import torch
import yaml
from pathlib import Path
import time
import copy
from tqdm.notebook import tqdm

current_dir = os.getcwd()
print(f"Current directory: {current_dir}")

# Add project root to path to ensure imports work correctly
project_root = os.path.join(current_dir, "MIC-MA1")
sys.path.insert(0, project_root)
print(f"Project root added to path: {project_root}")

# Import our project modules
from scripts.colab_integration import setup_colab_environment, check_gpu
from src.data.preprocessing import get_preprocessing_transforms
from src.data.augmentation import AdvancedAugmentation
from src.data.dataloader import load_datasets
from src.models.baseline import get_resnet18_model, unfreeze_layers
from src.training.trainer import train_model, evaluate_model
from src.training.metrics import compute_metrics, get_confusion_matrix
from src.visualization.plotting import plot_training_history, plot_confusion_matrix, plot_sample_predictions
from src.models.model_utils import save_model

# Check if we're running in Colab and set up the environment
import importlib.util
IN_COLAB = importlib.util.find_spec("google.colab") is not None

if IN_COLAB:
    print("🚀 Running in Google Colab - setting up environment...")
    setup_colab_environment()  # This handles all the Colab-specific setup
else:
    print("💻 Running locally - using local environment")

# Check for TPU and GPU availability
try:
    import torch_xla
    import torch_xla.core.xla_model as xm
    device = xm.xla_device()
    print("Using TPU:", device)
except ImportError:
    device = check_gpu()  # Your utility function for GPU detection
    print("Using device:", device)

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

## 1. Configuration Loading

With our new project structure, we can load configurations directly from YAML files, which provides better consistency and reproducibility across experiments.

In [None]:
# Load configuration from YAML file
config_path = os.path.join(project_root, "config", "baseline_resnet18.yaml")
with open(config_path, "r") as file:
    config = yaml.safe_load(file)

# Display the configuration for verification
print("Configuration loaded from:", config_path)
print("\nModel configuration:")
print(f"- Architecture: {config['model']['name']}")
print(f"- Pretrained: {config['model'].get('pretrained', True)}")
print(f"- Feature extracting: {config['model'].get('feature_extracting', True)}")
print(f"- Num classes: {config['model'].get('num_classes', 30)}")

print("\nTraining configuration:")
print(f"- Batch size: {config['training']['batch_size']}")
print(f"- Num epochs: {config['training']['num_epochs']}")
print(f"- Optimizer: {config['training']['optimizer']['name']}")
print(f"- Learning rate: {config['training']['optimizer']['learning_rate']}")

# Set the data directory - this will be the location of our dataset
if IN_COLAB:
    data_dir = "../data/raw/30_Musical_Instruments"
    if not os.path.exists(data_dir):
        print("Please upload the dataset to Google Drive or adjust the path")
else:
    # Use the path from config or default to the project's data directory
    data_dir = os.path.join(project_root, config['data']['data_dir'])
    
print(f"\nUsing data directory: {data_dir}")

## 2. Data Preprocessing and Loading

Using our project's data utilities for preprocessing and loading:

In [None]:
# Get preprocessing transforms with the appropriate image size from config
img_size = config['data']['img_size']

# Check if we should use data augmentation
if config['augmentation'].get('augmentation_strength'):
    print(f"Using advanced augmentation with strength: {config['augmentation']['augmentation_strength']}")
    transforms = AdvancedAugmentation.get_advanced_transforms(
        img_size=img_size,
        augmentation_strength=config['augmentation']['augmentation_strength']
    )
else:
    print("Using standard preprocessing (no advanced augmentation)")
    transforms = get_preprocessing_transforms(img_size=img_size)

# Load datasets using our utility function
data = load_datasets(
    data_dir=data_dir,
    transforms=transforms,
    batch_size=config['training']['batch_size'],
    num_workers=config['data']['num_workers'],
    pin_memory=config['data'].get('pin_memory', torch.cuda.is_available())
)

# Access the components
train_loader = data['dataloaders']['train']
valid_loader = data['dataloaders']['val']
test_loader = data['dataloaders']['test']

# Get class information
class_names = list(data['class_mappings']['idx_to_class'].values())
num_classes = data['num_classes']

print(f"\nDataset loaded successfully:")
print(f"- Number of classes: {num_classes}")
print(f"- Training samples: {len(data['datasets']['train'])}")
print(f"- Validation samples: {len(data['datasets']['val'])}")
print(f"- Test samples: {len(data['datasets']['test'])}")

# Display a few class names
print(f"\nSample classes: {class_names[:5]}...")

### Visualize Sample Images

Let's visualize a few images using our project's visualization utilities:

In [None]:
# Use our dataset visualization function from the project structure
try:
    # Access datasets from our data dictionary
    train_dataset = data['datasets']['train']
    valid_dataset = data['datasets']['val']
    
    # Use plot_sample_predictions to visualize samples (without predictions)
    from src.visualization.plotting import plot_sample_images
    
    print("Sample training images:")
    plot_sample_images(
        dataset=train_dataset,
        class_mapping=data['class_mappings']['idx_to_class'],
        num_images=5,
        title="Sample Training Images"
    )
    
    print("Sample validation images:")
    plot_sample_images(
        dataset=data['datasets']['val'],
        class_mapping=data['class_mappings']['idx_to_class'],
        num_images=5,
        title="Sample Validation Images"
    )
except Exception as e:
    print(f"Error visualizing images: {e}")

### Understanding the Data Transformations

The data transformations applied to our musical instrument images serve several critical purposes:

1. **Size Standardization** (`transforms.Resize((224, 224))`)
   - **Why needed**: Neural networks require consistent input dimensions. Images in our dataset may have different original sizes and aspect ratios.
   - **Why 224x224**: This is the standard input size for ResNet-18 and many other pre-trained CNN architectures. Since we're using transfer learning, matching the input size that the network was originally trained with is important.

2. **Data Augmentation** (applied only to training data)
   - **RandomHorizontalFlip**: Musical instruments generally maintain their identity when flipped horizontally
   - **RandomRotation(15)**: Adds robustness to slight orientation variations in the images
   - **ColorJitter**: Helps the model become invariant to lighting conditions and color variations

3. **Normalization** (`transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])`)
   - **Purpose**: Standardizes the pixel values to have similar ranges, which helps with training stability and convergence
   - **Values**: These specific values represent the mean and standard deviation of the RGB channels in the ImageNet dataset, which our pre-trained ResNet-18 was trained on

4. **Different Transforms for Training vs. Validation/Test**:
   - Training data receives data augmentation to artificially expand the dataset and improve generalization
   - Validation and test data only receive resizing and normalization to evaluate the model on clean, unmodified images

These transformations are essential for achieving good performance with transfer learning and help our model generalize better to unseen images of musical instruments.

## 3. Model Creation

With our new project structure, we can use the baseline module to create our ResNet-18 model:

In [None]:
# Create the ResNet-18 model using our baseline module
model_config = config['model']
num_classes = data['num_classes']
pretrained = model_config.get('pretrained', True)
feature_extracting = model_config.get('feature_extracting', True)

try:
    # Create the model using our project's module
    model = get_resnet18_model(
        num_classes=num_classes,
        pretrained=pretrained,
        feature_extracting=feature_extracting
    )
    
    # Move model to the appropriate device
    model = model.to(device)
    
    # Display information about unfreezing layers if applicable
    if not feature_extracting and 'unfreeze_layers' in model_config:
        print(f"Unfreezing specified layers: {model_config['unfreeze_layers']}")
        model, unfrozen_params = unfreeze_layers(model, model_config['unfreeze_layers'])
        print(f"Number of parameters unfrozen: {len(unfrozen_params)}")
    
    # Calculate model statistics
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    # Print model architecture summary
    print(f"\nModel: ResNet-18")
    print(f"Pretrained: {pretrained}")
    print(f"Feature extracting (frozen backbone): {feature_extracting}")
    print(f"Number of classes: {num_classes}")
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    
    # Print layer structure to understand the architecture
    print("\nModel Architecture:")
    for name, child in model.named_children():
        print(f"Layer: {name}")
        if name == 'fc':
            print(f"  Output size: {child.out_features}")
    
    print(f"\nModel created successfully and moved to {device}")
except Exception as e:
    print(f"Error creating model: {e}")

### Why ResNet-18 is an Excellent Baseline Model Choice

ResNet-18 is a well-justified choice for our baseline model for several compelling reasons:

1. **Balanced Complexity**
   - With 11.7 million parameters, ResNet-18 offers substantial representational capacity while still being lightweight enough for academic projects
   - Trains efficiently even on limited GPU resources (like those provided for free by Google Colab)

2. **Strong Architecture Design**
   - Features residual connections that address the vanishing gradient problem
   - Deep enough to capture complex features in musical instruments (texture, shape, structure)
   - Proven performance on a wide range of image classification tasks

3. **Transfer Learning Benefits**
   - Pre-trained on over 1 million ImageNet images across 1,000 classes
   - Lower layers already capture universal visual features (edges, textures, patterns)
   - Requires less data to fine-tune for our specific task

4. **Practical Considerations**
   - Well-documented with extensive examples in PyTorch
   - Has excellent convergence properties during training
   - Robust to hyperparameter choices, making it more forgiving for initial experiments

5. **Benchmark Status**
   - Continues to be a standard benchmark in computer vision literature
   - Provides a meaningful baseline against which to compare custom architectures

## 4. Training Configuration and Execution

With our new project structure, we can use the training utilities from our src module.

In [None]:
# Set up optimizer and criterion based on config
optimizer_config = config['training']['optimizer']
optimizer_name = optimizer_config.get('name', 'adam').lower()
learning_rate = optimizer_config.get('learning_rate', 0.001)
weight_decay = optimizer_config.get('weight_decay', 0.0)

# Configure optimizer
if optimizer_name == 'adam':
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
elif optimizer_name == 'sgd':
    momentum = optimizer_config.get('momentum', 0.9)
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
else:
    raise ValueError(f"Unsupported optimizer: {optimizer_name}")

# Configure loss function
criterion = torch.nn.CrossEntropyLoss()

# Configure scheduler if specified
scheduler_config = config['training'].get('scheduler', {})
scheduler = None
if scheduler_config:
    scheduler_name = scheduler_config.get('name', '').lower()
    patience = scheduler_config.get('patience', 3)
    factor = scheduler_config.get('factor', 0.1)
    
    if scheduler_name == 'reducelronplateau':
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=factor, patience=patience, verbose=True
        )

# Set up training parameters
num_epochs = config['training']['num_epochs']
print(f"\nTraining Configuration:")
print(f"- Optimizer: {optimizer_name}")
print(f"- Learning rate: {learning_rate}")
print(f"- Weight decay: {weight_decay}")
print(f"- Number of epochs: {num_epochs}")
if scheduler:
    print(f"- Scheduler: {scheduler_name}")
    print(f"  - Patience: {patience}")
    print(f"  - Factor: {factor}")

### Start Training

Now we'll use our project's training utilities to train the model:

In [None]:
# Prepare dataloaders dictionary
dataloaders = {
    'train': train_loader,
    'val': valid_loader
}

# Train the model using our training module
model, history, training_stats = train_model(
    model=model,
    dataloaders=dataloaders,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    num_epochs=num_epochs,
    device=device
)

# Display training results
print(f"\nTraining Results:")
print(f"- Best validation accuracy: {training_stats['best_val_acc']:.4f}")
print(f"- Best epoch: {training_stats['best_epoch']}")
print(f"- Training time: {training_stats['training_time']}")

### Visualize Training History

Let's visualize the training progress using our visualization utilities:

In [None]:
# Plot training history
plot_training_history(history)

## Initial Training Results Analysis

### Training Performance Summary

The initial training of our ResNet-18 model using transfer learning showed excellent results. Here's a comprehensive breakdown of the performance:

#### Basic Training Metrics

| Metric                | Value                      | Notes                                           |
|-----------------------|----------------------------|------------------------------------------------|
| **Training Duration** | 25m 54s                    | With GPU acceleration                           |
| **Best Validation Accuracy** | **97.33%**          | Achieved at Epoch 8                             |
| **Final Training Accuracy** | 95.45%               | After 20 epochs                                 |
| **Trainable Parameters** | 15,390                  | Only classifier layer trained initially         |

#### Epoch-by-Epoch Performance

The training progressed with significant improvements in the early epochs:

- **Rapid initial convergence**: Accuracy jumped from 53.16% to 83.35% between epochs 1 and 2
- **Key breakthrough point**: Epoch 4 showed validation accuracy of 96.67%
- **Stabilization**: After epoch 8, validation accuracy remained between 95.33% and 97.33%

#### Learning Dynamics

| Epoch | Training Loss | Training Acc | Validation Loss | Validation Acc | Notes                 |
|-------|---------------|--------------|-----------------|----------------|------------------------|
| 1     | 2.1005        | 53.16%       | 0.7260          | 88.67%         | Initial adaptation    |
| 2     | 0.9131        | 83.35%       | 0.3618          | 93.33%         | Large improvement     |
| 3     | 0.6229        | 88.32%       | 0.2478          | 94.67%         | Continued improvement |
| 4     | 0.4925        | 90.36%       | 0.1881          | 96.67%         | Strong performance    |
| 8     | 0.3128        | 93.16%       | 0.1122          | **97.33%**     | **Best model**        |
| 20    | 0.1704        | 95.45%       | 0.0958          | 95.33%         | Final model          |

#### Observations

1. **Validation outperforming training**: The validation accuracy consistently exceeded training accuracy in early epochs, suggesting the model was generalizing well without overfitting.

2. **Loss plateau**: After epoch 12, the training loss decreased more gradually (from 0.2262 to 0.1704), indicating we were approaching the limits of what could be achieved by training only the classifier layer.

3. **Validation fluctuations**: Small fluctuations in validation accuracy in later epochs (between 96.00% and 97.33%) may indicate that we've reached a performance plateau with the current architecture and training approach.

These results demonstrate that transfer learning using a pre-trained ResNet-18 model is highly effective for our musical instrument classification task. The next step is to fine-tune deeper layers to potentially improve performance further.

## Fine-tuning Phase

After our successful initial training phase where we only trained the classifier layer, we'll now fine-tune the deeper layers of the model to potentially improve performance further. We'll focus on unfreezing layer4 of ResNet-18, as it contains the high-level feature extractors most relevant to our task.

In [None]:
# Unfreeze layer4 for fine-tuning
print("Starting fine-tuning phase...")
model, unfrozen_params = unfreeze_layers(model, ['layer4'])
print(f"Number of parameters unfrozen in layer4: {len(unfrozen_params):,}")

# Configure fine-tuning hyperparameters
fine_tune_config = {
    'learning_rate': 1e-4,  # Lower learning rate for fine-tuning
    'num_epochs': 10,
    'weight_decay': 1e-4    # Add regularization to prevent overfitting
}

# Create new optimizer for fine-tuning
fine_tune_optimizer = torch.optim.Adam(
    model.parameters(),
    lr=fine_tune_config['learning_rate'],
    weight_decay=fine_tune_config['weight_decay']
)

# Configure learning rate scheduler for fine-tuning
fine_tune_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    fine_tune_optimizer,
    mode='min',
    factor=0.1,
    patience=3,
    verbose=True
)

print("\nFine-tuning Configuration:")
print(f"- Learning rate: {fine_tune_config['learning_rate']}")
print(f"- Number of epochs: {fine_tune_config['num_epochs']}")
print(f"- Weight decay: {fine_tune_config['weight_decay']}")
print(f"- Scheduler: ReduceLROnPlateau")

# Start fine-tuning
print("\nStarting fine-tuning training...")
model_ft, history_ft, training_stats_ft = train_model(
    model=model,
    dataloaders=dataloaders,
    criterion=criterion,
    optimizer=fine_tune_optimizer,
    scheduler=fine_tune_scheduler,
    num_epochs=fine_tune_config['num_epochs'],
    device=device
)

print("\nFine-tuning Results:")
print(f"- Best validation accuracy: {training_stats_ft['best_val_acc']:.4f}")
print(f"- Best epoch: {training_stats_ft['best_epoch']}")
print(f"- Training time: {training_stats_ft['training_time']}")

In [None]:
# Plot fine-tuning history
plot_training_history(history_ft)

## Fine-tuning Results Analysis

### Performance Summary

The fine-tuning phase of our ResNet-18 model shows remarkable improvements over the initial training phase, achieving perfect validation accuracy. Here's a detailed breakdown:

#### Key Training Metrics

| Metric | Value | Notes |
|--------|-------|-------|
| **Training Duration** | 5m 22s | Significantly faster than initial training |
| **Best Validation Accuracy** | **100%** | Perfect accuracy achieved at Epoch 3 |
| **Final Training Accuracy** | 99.96% | Nearly perfect after 10 epochs |
| **Trainable Parameters** | 8,409,118 | ~545× more than initial training |

#### Fine-tuning Approach

We unfroze the `layer4` component of ResNet-18, which includes:
- Convolutional layers
- Batch normalization layers
- Downsample layers
- Final fully-connected layer (carried over from initial training)

This strategic approach allowed the model to adapt its feature extractors specifically for musical instrument recognition.

#### Epoch-by-Epoch Analysis

| Epoch | Training Loss | Training Acc | Validation Loss | Validation Acc | Notes |
|-------|---------------|--------------|-----------------|----------------|-------|
| 1     | 0.2136        | 94.16%       | 0.0529          | 98.67%         | Strong start from pre-trained state |
| 2     | 0.0777        | 98.21%       | 0.0561          | 98.67%         | Major improvement in training accuracy |
| 3     | 0.0493        | 99.00%       | 0.0288          | **100%**       | **Perfect validation accuracy achieved** |
| 10    | 0.0065        | 99.96%       | 0.0083          | 100%           | Continued refinement of training performance |

#### Key Observations

1. **Rapid Convergence**: The model achieved 100% validation accuracy in just 3 epochs, demonstrating the effectiveness of our staged training approach (first the classifier, then fine-tuning deeper layers).

2. **Dramatic Loss Reduction**: The training loss decreased from 0.2136 to 0.0065 (97% reduction), showing that the model's internal representations became highly optimized for the task.

3. **Near-Perfect Training Accuracy**: By the final epoch, the model correctly classified 99.96% of training samples, indicating excellent fit without apparent overfitting (given the perfect validation accuracy).

4. **Validation Loss Trend**: The validation loss consistently decreased throughout fine-tuning, from 0.0529 to 0.0083, suggesting the model was becoming more confident in its predictions.

5. **Parameter Efficiency**: Despite unfreezing over 8.4 million parameters, the model trained efficiently in just over 5 minutes with GPU acceleration.

### Conclusions

The fine-tuning phase has dramatically improved model performance, resulting in a classifier that achieves perfect validation accuracy on our musical instrument dataset. This demonstrates:

1. The effectiveness of transfer learning with a pre-trained ResNet-18 model
2. The importance of a staged approach (training classifier first, then fine-tuning deeper layers)
3. The power of GPU acceleration for deep learning tasks

These results suggest that our model is now ready for thorough evaluation on the test set to verify its generalization capabilities.

In [None]:
# Evaluate the model on test data
test_metrics = evaluate_model(
    model=model,
    dataloader=test_loader,
    criterion=criterion,
    device=device
)

# Display test results
print(f"Test Results:")
print(f"- Loss: {test_metrics['loss']:.4f}")
print(f"- Accuracy: {test_metrics['accuracy']:.4f}")

# Get detailed metrics
class_names = list(data['class_mappings']['idx_to_class'].values())
y_true, y_pred, _ = compute_metrics(model, test_loader, device, return_predictions=True)

# Plot confusion matrix
cm = get_confusion_matrix(y_true, y_pred)
plot_confusion_matrix(
    cm, 
    class_names=class_names, 
    normalize=True,
    title="Normalized Confusion Matrix (ResNet-18)"
)

## Test Set Evaluation Results

### Perfect Classification Performance

Our fine-tuned ResNet-18 model has achieved exceptional results on the unseen test data, demonstrating outstanding generalization capability:

| Metric | Value | Notes |
|--------|-------|-------|
| **Test Accuracy** | **100.00%** | Perfect classification of all test samples |
| **Evaluation Time** | 35 seconds | For the complete test dataset |

### Analysis of Results

The model has flawlessly classified all test images across all 30 musical instrument categories. This remarkable performance validates several aspects of our approach:

#### 1. Transfer Learning Effectiveness

The perfect test accuracy confirms that our transfer learning approach with ResNet-18 was highly effective. By leveraging pre-trained weights and carefully fine-tuning deeper layers, we created a model that:
- Captured essential features distinguishing different musical instruments
- Generalized extremely well to unseen examples
- Avoided overfitting despite reaching 100% validation accuracy

#### 2. Training Strategy Validation

Our two-phase training strategy proved extremely effective:
1. **Initial phase**: Training only the classifier layer to adapt the model to our dataset
2. **Fine-tuning phase**: Carefully unfreezing deeper convolutional layers (layer4) to optimize feature extraction

This approach allowed the model to first adapt its classification head to our specific classes before fine-tuning the feature extractors, resulting in optimal performance.

#### 3. Dataset Considerations

The perfect test accuracy also suggests:
- High quality and consistency of the dataset
- Well-balanced class distribution
- Good separation between musical instrument categories
- Effective data preprocessing and augmentation strategies

## 6. Save the Trained Model

Our project structure includes utilities for saving models:

In [None]:
# Create a timestamp for the saved model
import datetime
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

# Define the save path
save_dir = os.path.join(project_root, "experiments", f"resnet18_{timestamp}")
os.makedirs(save_dir, exist_ok=True)

# Save the model
model_path = save_model(
    model=model,
    save_dir=save_dir,
    model_name="resnet18_transfer_learning",
    training_history=history,
    metrics={
        "accuracy": test_metrics['accuracy'],
        "loss": test_metrics['loss'],
        "best_val_accuracy": training_stats['best_val_acc']
    },
    class_mapping=data['class_mappings']['idx_to_class']
)

print(f"Model saved to {model_path}")