# Pretrained Models Comparison for Music Classification

This notebook provides an interactive interface for comparing different pretrained models for music genre classification.

## Models Compared
- **CNN Baseline**: Custom 4-layer CNN trained from scratch
- **ResNet-18**: Pretrained ResNet-18 with conservative 3-stage transfer learning
- **Vision Transformer**: Pretrained ViT-Base with 2-stage fine-tuning

## 1. Setup and Configuration

In [None]:
import sys
import warnings
warnings.filterwarnings('ignore')

# Import modular components
from models import MusicGenreCNN, OptimizedResNet18MusicClassifier, ViTMusicClassifier, is_transformers_available
from data import prepare_cv_data, print_data_summary
from training import get_cnn_config, get_resnet_config, get_vit_config, get_quick_test_config
from utils import set_seed, get_device_info

import torch

# Set random seeds
set_seed(42)

# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device_info = get_device_info()
print(f"Using device: {device_info['device_name']}")
print(f"CUDA available: {device_info['cuda_available']}")
print(f"Transformers available: {is_transformers_available()}")

## 2. Data Preparation

In [None]:
# Prepare cross-validation data
print("Preparing data for cross-validation...")

data_dir = "Data/images_original/"
n_folds = 3  # Use 3 folds for quick testing

all_images, all_labels, cv_splits, genre_to_idx, idx_to_genre = prepare_cv_data(
    data_dir, n_folds=n_folds
)

print_data_summary(all_images, all_labels, cv_splits, idx_to_genre)
print(f"\nNumber of classes: {len(genre_to_idx)}")

## 3. Model Configuration Comparison

In [None]:
# Compare different model configurations
print("Model Configurations:")
print("=" * 60)

configs = {
    'CNN Baseline': get_quick_test_config('cnn'),
    'ResNet-18': get_quick_test_config('resnet'),
}

if is_transformers_available():
    configs['ViT'] = get_quick_test_config('vit')

for name, config in configs.items():
    config.num_classes = len(genre_to_idx)
    print(f"\n{name}:")
    print(f"  Input size: {config.input_size}")
    print(f"  Batch size: {config.batch_size}")
    print(f"  Learning rate: {config.learning_rate}")
    print(f"  Epochs: {config.num_epochs}")
    print(f"  Model type: {config.model_type}")

## 4. Model Architecture Preview

In [None]:
# Create sample models to inspect architectures
print("Model Architecture Comparison:")
print("=" * 60)

num_classes = len(genre_to_idx)

# CNN Baseline
cnn_model = MusicGenreCNN(num_classes=num_classes)
cnn_info = cnn_model.get_model_info()
print(f"\nCNN Baseline:")
print(f"  Parameters: {cnn_info['num_parameters']:,}")
print(f"  Architecture: {cnn_info['architecture']}")

# ResNet-18
resnet_model = OptimizedResNet18MusicClassifier(num_classes=num_classes)
resnet_info = resnet_model.get_model_info()
print(f"\nResNet-18:")
print(f"  Parameters: {resnet_info['num_parameters']:,}")
print(f"  Trainable: {resnet_info['trainable_parameters']:,}")
print(f"  Architecture: {resnet_info['architecture']}")

# ViT (if available)
if is_transformers_available():
    try:
        vit_model = ViTMusicClassifier(num_classes=num_classes)
        vit_info = vit_model.get_model_info()
        print(f"\nVision Transformer:")
        print(f"  Parameters: {vit_info['num_parameters']:,}")
        print(f"  Trainable: {vit_info['trainable_parameters']:,}")
        print(f"  Architecture: {vit_info['architecture']}")
    except Exception as e:
        print(f"\nViT model creation failed: {e}")

# Clean up
del cnn_model, resnet_model
if 'vit_model' in locals():
    del vit_model

## 5. Run Quick Experiment

In [None]:
# Run the complete experiment
print("Running quick model comparison experiment...")
print("This may take several minutes to complete.")
print("\nFor full experiment, run: python main.py")
print("For quick test, run: python main.py --quick")

# You can uncomment the following line to run the experiment in the notebook
# from main import run_model_comparison
# results = run_model_comparison('quick')

## 6. Analysis and Visualization

After running the experiment, you can load and analyze the results:

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

# Load results (if experiment has been run)
try:
    results_df = pd.read_csv('pretrained_comparison_quick_results.csv')
    
    print("Experiment Results:")
    print("=" * 40)
    
    # Summary statistics
    summary = results_df.groupby('Model')['Test_Accuracy'].agg(['mean', 'std']).round(2)
    print(summary)
    
    # Create visualization
    plt.figure(figsize=(10, 6))
    models = results_df['Model'].unique()
    
    for i, model in enumerate(models):
        model_data = results_df[results_df['Model'] == model]
        plt.scatter([i] * len(model_data), model_data['Test_Accuracy'], alpha=0.7, s=100)
        plt.errorbar(i, model_data['Test_Accuracy'].mean(), 
                    yerr=model_data['Test_Accuracy'].std(), 
                    fmt='o', color='red', capsize=5, markersize=8)
    
    plt.xticks(range(len(models)), models)
    plt.ylabel('Test Accuracy (%)')
    plt.title('Model Comparison Results')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()
    
except FileNotFoundError:
    print("Results file not found. Please run the experiment first.")
    print("Run: python main.py --quick")

## 7. Next Steps

To run the complete experiment:

1. **Quick test**: `python main.py --quick`
2. **Full experiment**: `python main.py`
3. **Analyze results**: Load the generated CSV files for detailed analysis

The modular structure allows you to:
- Easily modify training configurations
- Add new model architectures
- Customize training strategies
- Extend analysis and visualization

In [None]:
print("✅ Notebook setup complete!")
print("\nModular structure:")
print("  📁 models/ - Model architectures (CNN, ResNet, ViT)")
print("  📁 data/ - Data processing and loading")
print("  📁 training/ - Training configurations and strategies")
print("  📁 analysis/ - Evaluation and statistical analysis")
print("  📁 utils/ - Common utilities")
print("\nReady for experimentation! 🚀")