# YAMNet Fine-tuning for Indian Classical Raga Classification

This comprehensive notebook demonstrates how to fine-tune Google's YAMNet model for Indian classical raga classification using TensorFlow Hub and HuggingFace integration.

## Overview

YAMNet is a pre-trained deep neural network that can classify audio into 521 different categories. We'll leverage its powerful audio feature extraction capabilities and add a custom classification head for raga identification.

## What You'll Learn

1. 🎵 Audio preprocessing for Indian classical music
2. 🤖 Transfer learning with pre-trained audio models
3. 📊 Advanced training techniques and data augmentation
4. 📈 Model evaluation and performance visualization
5. 🚀 Model deployment to HuggingFace Hub

## Prerequisites

- Google Colab or local environment with GPU support
- Google Drive with raga dataset
- HuggingFace account for model deployment

## 1. Environment Setup and GPU Configuration

Let's start by configuring TensorFlow for optimal GPU usage and checking available hardware resources.

In [None]:
import os
import warnings
warnings.filterwarnings('ignore')

# Configure TensorFlow GPU settings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

import tensorflow as tf

# Enable GPU memory growth to avoid allocation errors
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print(f"✅ Found {len(gpus)} GPU(s) - Memory growth enabled")
        
        # Get GPU details
        for i, gpu in enumerate(gpus):
            details = tf.config.experimental.get_device_details(gpu)
            print(f"   GPU {i}: {details.get('device_name', 'Unknown')}")
    except RuntimeError as e:
        print(f"❌ GPU configuration error: {e}")
else:
    print("⚠️  No GPU found - using CPU")

# Check TensorFlow version and backend
print(f"\n📊 TensorFlow version: {tf.__version__}")
print(f"🔧 Built with CUDA: {tf.test.is_built_with_cuda()}")
print(f"🚀 GPU Available: {tf.test.is_gpu_available()}")

# Set random seeds for reproducibility
tf.random.set_seed(42)
import numpy as np
np.random.seed(42)

## 2. Library Installation and Imports

Installing required packages and importing necessary modules for audio processing and model training.

In [None]:
# Install required packages
!pip install tensorflow-hub librosa soundfile huggingface_hub datasets transformers
!pip install scikit-learn matplotlib seaborn plotly tqdm ipywidgets
!pip install audiomentations resampy pydub

In [None]:
# Core libraries
import tensorflow as tf
import tensorflow_hub as hub
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Audio processing
import librosa
import soundfile as sf
from audiomentations import Compose, AddGaussianNoise, TimeStretch, PitchShift, Shift

# Machine Learning
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report, confusion_matrix, f1_score
from sklearn.utils.class_weight import compute_class_weight

# Utilities
import os
import glob
import json
import pickle
from pathlib import Path
from datetime import datetime
from tqdm.auto import tqdm
import logging
from typing import Tuple, List, Dict, Optional

# HuggingFace
from huggingface_hub import HfApi, create_repo, upload_file
from transformers import AutoConfig

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Configure matplotlib
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

print("✅ All libraries imported successfully!")

## 3. Google Drive Integration and Dataset Loading

Mounting Google Drive and loading the raga dataset from your specified folder.

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Configure dataset paths
DATASET_PATH = "/content/drive/MyDrive/Raga_Dataset"  # Update this path
MODELS_PATH = "/content/drive/MyDrive/YAMNet_Models"
OUTPUT_PATH = "/content/drive/MyDrive/YAMNet_Output"

# Create output directories
os.makedirs(MODELS_PATH, exist_ok=True)
os.makedirs(OUTPUT_PATH, exist_ok=True)

print(f"📁 Dataset path: {DATASET_PATH}")
print(f"💾 Models path: {MODELS_PATH}")
print(f"📊 Output path: {OUTPUT_PATH}")

# Check if dataset exists
if os.path.exists(DATASET_PATH):
    print(f"✅ Dataset found at {DATASET_PATH}")
    
    # List available raga folders
    raga_folders = [f for f in os.listdir(DATASET_PATH) 
                   if os.path.isdir(os.path.join(DATASET_PATH, f))]
    print(f"🎵 Found {len(raga_folders)} raga classes: {raga_folders}")
else:
    print(f"❌ Dataset not found at {DATASET_PATH}")
    print("Please update the DATASET_PATH variable with your correct path")

In [None]:
def discover_audio_files(dataset_path: str) -> pd.DataFrame:
    """
    Discover all audio files in the dataset and create a metadata DataFrame.
    """
    audio_files = []
    supported_formats = ['.wav', '.mp3', '.flac', '.m4a', '.ogg']
    
    print("🔍 Discovering audio files...")
    
    for raga_folder in tqdm(os.listdir(dataset_path)):
        raga_path = os.path.join(dataset_path, raga_folder)
        
        if not os.path.isdir(raga_path):
            continue
            
        for audio_file in os.listdir(raga_path):
            file_path = os.path.join(raga_path, audio_file)
            file_ext = os.path.splitext(audio_file)[1].lower()
            
            if file_ext in supported_formats:
                try:
                    # Get audio metadata
                    info = sf.info(file_path)
                    
                    audio_files.append({
                        'file_path': file_path,
                        'filename': audio_file,
                        'raga': raga_folder,
                        'duration': info.duration,
                        'sample_rate': info.samplerate,
                        'channels': info.channels,
                        'format': file_ext[1:],
                        'file_size_mb': os.path.getsize(file_path) / (1024 * 1024)
                    })
                except Exception as e:
                    print(f"⚠️  Error reading {file_path}: {e}")
    
    df = pd.DataFrame(audio_files)
    print(f"✅ Found {len(df)} audio files")
    return df

# Discover audio files
if 'raga_folders' in locals() and raga_folders:
    metadata_df = discover_audio_files(DATASET_PATH)
    
    # Display dataset statistics
    print(f"\n📊 Dataset Statistics:")
    print(f"Total files: {len(metadata_df)}")
    print(f"Total duration: {metadata_df['duration'].sum():.2f} seconds ({metadata_df['duration'].sum()/3600:.2f} hours)")
    print(f"Average duration: {metadata_df['duration'].mean():.2f} seconds")
    print(f"Total size: {metadata_df['file_size_mb'].sum():.2f} MB")
    
    # Class distribution
    print(f"\n🎵 Raga distribution:")
    raga_counts = metadata_df['raga'].value_counts()
    for raga, count in raga_counts.items():
        percentage = (count / len(metadata_df)) * 100
        print(f"   {raga}: {count} files ({percentage:.1f}%)")
    
    # Display first few rows
    print(f"\n📋 Sample data:")
    print(metadata_df.head())

## 4. Audio Preprocessing Pipeline

Creating functions to preprocess audio files for YAMNet compatibility.

In [None]:
# YAMNet configuration
YAMNET_SAMPLE_RATE = 16000
YAMNET_DURATION = 0.975  # YAMNet processes 0.975 seconds at a time
YAMNET_SAMPLES = int(YAMNET_SAMPLE_RATE * YAMNET_DURATION)

def load_and_preprocess_audio(file_path: str, 
                            target_sr: int = YAMNET_SAMPLE_RATE,
                            duration: Optional[float] = None,
                            offset: float = 0.0) -> np.ndarray:
    """
    Load and preprocess audio file for YAMNet.
    
    Args:
        file_path: Path to audio file
        target_sr: Target sample rate (16kHz for YAMNet)
        duration: Duration to load (None for full file)
        offset: Offset to start loading from
    
    Returns:
        Preprocessed audio array
    """
    try:
        # Load audio file
        audio, sr = librosa.load(file_path, sr=target_sr, duration=duration, offset=offset)
        
        # Normalize audio to [-1, 1]
        audio = librosa.util.normalize(audio)
        
        return audio
    
    except Exception as e:
        logger.error(f"Error loading {file_path}: {e}")
        return np.array([])

def remove_silence(audio: np.ndarray, 
                   sr: int = YAMNET_SAMPLE_RATE,
                   top_db: int = 20) -> np.ndarray:
    """
    Remove silence from audio using librosa's onset detection.
    """
    try:
        # Trim silence from beginning and end
        audio_trimmed, _ = librosa.effects.trim(audio, top_db=top_db)
        
        # Remove very short audio clips
        if len(audio_trimmed) < sr * 0.1:  # Less than 0.1 seconds
            return audio  # Return original if trimmed audio is too short
            
        return audio_trimmed
    
    except Exception as e:
        logger.warning(f"Error in silence removal: {e}")
        return audio

def segment_audio(audio: np.ndarray, 
                  sr: int = YAMNET_SAMPLE_RATE,
                  segment_duration: float = YAMNET_DURATION,
                  overlap: float = 0.5) -> List[np.ndarray]:
    """
    Segment audio into fixed-length chunks for YAMNet processing.
    
    Args:
        audio: Audio array
        sr: Sample rate
        segment_duration: Duration of each segment
        overlap: Overlap between segments (0.0-1.0)
    
    Returns:
        List of audio segments
    """
    segment_samples = int(sr * segment_duration)
    hop_samples = int(segment_samples * (1 - overlap))
    
    segments = []
    start = 0
    
    while start + segment_samples <= len(audio):
        segment = audio[start:start + segment_samples]
        segments.append(segment)
        start += hop_samples
    
    # Handle the last segment if there's remaining audio
    if start < len(audio) and len(audio) - start > segment_samples * 0.5:
        # Pad the last segment to required length
        last_segment = audio[start:]
        if len(last_segment) < segment_samples:
            padding = segment_samples - len(last_segment)
            last_segment = np.pad(last_segment, (0, padding), mode='constant')
        segments.append(last_segment)
    
    return segments

def preprocess_audio_file(file_path: str, 
                         max_segments: int = 10,
                         remove_silence_flag: bool = True) -> Tuple[List[np.ndarray], str]:
    """
    Complete preprocessing pipeline for a single audio file.
    
    Returns:
        Tuple of (list of audio segments, error message if any)
    """
    try:
        # Load audio
        audio = load_and_preprocess_audio(file_path)
        
        if len(audio) == 0:
            return [], "Failed to load audio"
        
        # Remove silence if requested
        if remove_silence_flag:
            audio = remove_silence(audio)
        
        # Segment audio
        segments = segment_audio(audio)
        
        # Limit number of segments
        if len(segments) > max_segments:
            # Select segments from different parts of the audio
            indices = np.linspace(0, len(segments)-1, max_segments, dtype=int)
            segments = [segments[i] for i in indices]
        
        return segments, ""
    
    except Exception as e:
        return [], str(e)

# Test preprocessing with a sample file
if 'metadata_df' in locals() and len(metadata_df) > 0:
    sample_file = metadata_df.iloc[0]['file_path']
    print(f"🧪 Testing preprocessing with: {sample_file}")
    
    segments, error = preprocess_audio_file(sample_file, max_segments=3)
    
    if error:
        print(f"❌ Error: {error}")
    else:
        print(f"✅ Successfully preprocessed!")
        print(f"   Generated {len(segments)} segments")
        print(f"   Each segment shape: {segments[0].shape}")
        print(f"   Sample rate: {YAMNET_SAMPLE_RATE} Hz")
        print(f"   Duration per segment: {len(segments[0])/YAMNET_SAMPLE_RATE:.3f} seconds")

## 5. Dataset Preparation and Splitting

Organizing the dataset into train/validation/test splits with stratification.

In [None]:
def prepare_dataset(metadata_df: pd.DataFrame, 
                   test_size: float = 0.2, 
                   val_size: float = 0.2,
                   random_state: int = 42) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    """
    Split dataset into train/validation/test sets with stratification.
    """
    # Ensure we have enough samples per class
    min_samples_per_class = metadata_df['raga'].value_counts().min()
    print(f"📊 Minimum samples per class: {min_samples_per_class}")
    
    if min_samples_per_class < 3:
        print("⚠️  Warning: Some classes have very few samples. Consider collecting more data.")
    
    # Create label encoder
    le = LabelEncoder()
    metadata_df['raga_encoded'] = le.fit_transform(metadata_df['raga'])
    
    # First split: train+val vs test
    train_val, test = train_test_split(
        metadata_df, 
        test_size=test_size, 
        stratify=metadata_df['raga_encoded'],
        random_state=random_state
    )
    
    # Second split: train vs val
    val_size_adjusted = val_size / (1 - test_size)  # Adjust for the reduced dataset
    train, val = train_test_split(
        train_val,
        test_size=val_size_adjusted,
        stratify=train_val['raga_encoded'],
        random_state=random_state
    )
    
    print(f"📊 Dataset splits:")
    print(f"   Train: {len(train)} files ({len(train)/len(metadata_df)*100:.1f}%)")
    print(f"   Validation: {len(val)} files ({len(val)/len(metadata_df)*100:.1f}%)")
    print(f"   Test: {len(test)} files ({len(test)/len(metadata_df)*100:.1f}%)")
    
    # Check class distribution in each split
    print(f"\\n🎵 Class distribution:")
    for split_name, split_data in [('Train', train), ('Val', val), ('Test', test)]:
        print(f"   {split_name}:")
        for raga in le.classes_:
            count = len(split_data[split_data['raga'] == raga])
            print(f"      {raga}: {count} files")
    
    return train, val, test, le

def create_data_generators(train_df: pd.DataFrame, 
                          val_df: pd.DataFrame, 
                          test_df: pd.DataFrame,
                          batch_size: int = 32,
                          max_segments_per_file: int = 5) -> Tuple:
    """
    Create TensorFlow data generators for training.
    """
    
    def process_file_batch(file_paths, labels):
        """Process a batch of files and return segments with labels."""
        batch_segments = []
        batch_labels = []
        
        for file_path, label in zip(file_paths, labels):
            segments, error = preprocess_audio_file(
                file_path.numpy().decode('utf-8'), 
                max_segments=max_segments_per_file
            )
            
            if not error and segments:
                for segment in segments:
                    batch_segments.append(segment)
                    batch_labels.append(label)
        
        if batch_segments:
            return np.array(batch_segments), np.array(batch_labels)
        else:
            # Return empty arrays with correct shape
            return np.empty((0, YAMNET_SAMPLES)), np.empty((0,))
    
    def create_tf_dataset(df, shuffle=True):
        """Create TensorFlow dataset from DataFrame."""
        file_paths = df['file_path'].values
        labels = df['raga_encoded'].values
        
        dataset = tf.data.Dataset.from_tensor_slices((file_paths, labels))
        
        if shuffle:
            dataset = dataset.shuffle(buffer_size=len(df))
        
        dataset = dataset.batch(batch_size)
        dataset = dataset.map(
            lambda paths, lbls: tf.py_function(
                process_file_batch, 
                [paths, lbls], 
                [tf.float32, tf.int64]
            ),
            num_parallel_calls=tf.data.AUTOTUNE
        )
        
        # Filter out empty batches
        dataset = dataset.filter(lambda x, y: tf.shape(x)[0] > 0)
        dataset = dataset.prefetch(tf.data.AUTOTUNE)
        
        return dataset
    
    train_dataset = create_tf_dataset(train_df, shuffle=True)
    val_dataset = create_tf_dataset(val_df, shuffle=False)
    test_dataset = create_tf_dataset(test_df, shuffle=False)
    
    return train_dataset, val_dataset, test_dataset

# Prepare dataset splits
if 'metadata_df' in locals() and len(metadata_df) > 0:
    train_df, val_df, test_df, label_encoder = prepare_dataset(metadata_df)
    
    # Save label encoder
    label_encoder_path = os.path.join(OUTPUT_PATH, 'label_encoder.pkl')
    with open(label_encoder_path, 'wb') as f:
        pickle.dump(label_encoder, f)
    
    print(f"💾 Label encoder saved to: {label_encoder_path}")
    print(f"🏷️  Class labels: {list(label_encoder.classes_)}")
    
    # Store number of classes for model building
    NUM_CLASSES = len(label_encoder.classes_)
    print(f"🔢 Number of classes: {NUM_CLASSES}")
else:
    print("❌ No metadata available. Please run the previous cells first.")

## 6. YAMNet Model Loading and Feature Extraction

Loading the pre-trained YAMNet model from TensorFlow Hub and understanding its architecture.

In [None]:
# YAMNet model URL
YAMNET_MODEL_URL = "https://tfhub.dev/google/yamnet/1"

def load_yamnet_model():
    """
    Load YAMNet model from TensorFlow Hub.
    """
    print("🔄 Loading YAMNet model from TensorFlow Hub...")
    
    try:
        # Load the model
        yamnet_model = hub.load(YAMNET_MODEL_URL)
        print("✅ YAMNet model loaded successfully!")
        
        # Test the model with a sample input
        test_input = tf.random.normal([YAMNET_SAMPLES])
        scores, embeddings, spectrogram = yamnet_model(test_input)
        
        print(f"📊 Model output shapes:")
        print(f"   Scores (classifications): {scores.shape}")
        print(f"   Embeddings (features): {embeddings.shape}")
        print(f"   Spectrogram: {spectrogram.shape}")
        
        return yamnet_model
        
    except Exception as e:
        print(f"❌ Error loading YAMNet: {e}")
        return None

def extract_yamnet_embeddings(yamnet_model, audio_segments: List[np.ndarray]) -> np.ndarray:
    """
    Extract YAMNet embeddings from audio segments.
    
    Args:
        yamnet_model: Loaded YAMNet model
        audio_segments: List of audio segments
    
    Returns:
        Array of embeddings
    """
    embeddings_list = []
    
    for segment in audio_segments:
        # Ensure segment is the right length
        if len(segment) != YAMNET_SAMPLES:
            if len(segment) < YAMNET_SAMPLES:
                # Pad with zeros
                segment = np.pad(segment, (0, YAMNET_SAMPLES - len(segment)), mode='constant')
            else:
                # Truncate
                segment = segment[:YAMNET_SAMPLES]
        
        # Convert to tensor
        segment_tensor = tf.convert_to_tensor(segment, dtype=tf.float32)
        
        # Extract features
        _, embeddings, _ = yamnet_model(segment_tensor)
        
        # Average embeddings across time (YAMNet outputs multiple frames)
        avg_embedding = tf.reduce_mean(embeddings, axis=0)
        embeddings_list.append(avg_embedding.numpy())
    
    return np.array(embeddings_list)

# Load YAMNet model
yamnet_model = load_yamnet_model()

if yamnet_model is not None:
    # Test feature extraction with sample data
    if 'metadata_df' in locals() and len(metadata_df) > 0:
        print("\\n🧪 Testing feature extraction...")
        
        sample_file = metadata_df.iloc[0]['file_path']
        sample_segments, error = preprocess_audio_file(sample_file, max_segments=2)
        
        if not error and sample_segments:
            sample_embeddings = extract_yamnet_embeddings(yamnet_model, sample_segments)
            print(f"✅ Feature extraction test successful!")
            print(f"   Input: {len(sample_segments)} audio segments")
            print(f"   Output: {sample_embeddings.shape} embeddings")
            print(f"   Embedding dimension: {sample_embeddings.shape[1]}")
            
            # Store embedding dimension for model building
            EMBEDDING_DIM = sample_embeddings.shape[1]
        else:
            print(f"❌ Feature extraction test failed: {error}")
else:
    print("❌ Cannot proceed without YAMNet model")

## 7. Custom Classification Head Architecture

Building a custom neural network head on top of YAMNet embeddings for raga classification.

In [None]:
def create_raga_classifier(embedding_dim: int, 
                          num_classes: int,
                          hidden_units: List[int] = [512, 256, 128],
                          dropout_rate: float = 0.3,
                          l2_reg: float = 0.01) -> tf.keras.Model:
    """
    Create a custom classification head for raga classification.
    
    Args:
        embedding_dim: YAMNet embedding dimension (1024)
        num_classes: Number of raga classes
        hidden_units: List of hidden layer sizes
        dropout_rate: Dropout rate for regularization
        l2_reg: L2 regularization strength
    
    Returns:
        Compiled Keras model
    """
    
    # Input layer for YAMNet embeddings
    inputs = tf.keras.Input(shape=(embedding_dim,), name='yamnet_embeddings')
    
    # Start with the embeddings
    x = inputs
    
    # Add hidden layers with batch normalization and dropout
    for i, units in enumerate(hidden_units):
        x = tf.keras.layers.Dense(
            units, 
            activation='relu',
            kernel_regularizer=tf.keras.regularizers.l2(l2_reg),
            name=f'dense_{i+1}'
        )(x)
        
        x = tf.keras.layers.BatchNormalization(name=f'batch_norm_{i+1}')(x)
        x = tf.keras.layers.Dropout(dropout_rate, name=f'dropout_{i+1}')(x)
    
    # Output layer
    outputs = tf.keras.layers.Dense(
        num_classes, 
        activation='softmax',
        kernel_regularizer=tf.keras.regularizers.l2(l2_reg),
        name='raga_predictions'
    )(x)
    
    # Create model
    model = tf.keras.Model(inputs=inputs, outputs=outputs, name='RagaClassifier')
    
    return model

def create_end_to_end_model(yamnet_model, 
                           classifier_model,
                           trainable_yamnet: bool = False) -> tf.keras.Model:
    """
    Create an end-to-end model combining YAMNet and the classifier.
    
    Args:
        yamnet_model: Pre-trained YAMNet model
        classifier_model: Custom classification head
        trainable_yamnet: Whether to fine-tune YAMNet weights
    
    Returns:
        End-to-end Keras model
    """
    
    # Audio input
    audio_input = tf.keras.Input(shape=(YAMNET_SAMPLES,), name='audio_input')
    
    # YAMNet feature extraction
    # We only need the embeddings (index 1) from YAMNet output
    yamnet_outputs = yamnet_model(audio_input)
    embeddings = yamnet_outputs[1]  # Shape: (num_frames, 1024)
    
    # Average embeddings across time frames
    avg_embeddings = tf.reduce_mean(embeddings, axis=0, keepdims=True)
    
    # Classification head
    predictions = classifier_model(avg_embeddings)
    
    # Create end-to-end model
    model = tf.keras.Model(inputs=audio_input, outputs=predictions, name='YAMNet_RagaClassifier')
    
    # Set YAMNet trainability
    # Note: YAMNet is a hub.KerasLayer, so we need to handle it differently
    if not trainable_yamnet:
        # Freeze YAMNet weights for transfer learning
        model.layers[1].trainable = False
        print("🔒 YAMNet weights frozen for transfer learning")
    else:
        model.layers[1].trainable = True
        print("🔓 YAMNet weights trainable for fine-tuning")
    
    return model

# Create models
if 'EMBEDDING_DIM' in locals() and 'NUM_CLASSES' in locals():
    print("🏗️  Building raga classification model...")
    
    # Create classifier head
    classifier = create_raga_classifier(
        embedding_dim=EMBEDDING_DIM,
        num_classes=NUM_CLASSES,
        hidden_units=[512, 256, 128],
        dropout_rate=0.3,
        l2_reg=0.01
    )
    
    print("✅ Classification head created!")
    print(classifier.summary())
    
    # Create end-to-end model
    model = create_end_to_end_model(
        yamnet_model=yamnet_model,
        classifier_model=classifier,
        trainable_yamnet=False  # Start with frozen YAMNet
    )
    
    print("\\n✅ End-to-end model created!")
    print(f"📊 Model summary:")
    print(model.summary())
    
    # Save model architecture
    model_config = {
        'embedding_dim': EMBEDDING_DIM,
        'num_classes': NUM_CLASSES,
        'hidden_units': [512, 256, 128],
        'dropout_rate': 0.3,
        'l2_reg': 0.01,
        'yamnet_url': YAMNET_MODEL_URL
    }
    
    config_path = os.path.join(OUTPUT_PATH, 'model_config.json')
    with open(config_path, 'w') as f:
        json.dump(model_config, f, indent=2)
    
    print(f"💾 Model configuration saved to: {config_path}")
    
else:
    print("❌ Missing required variables. Please run previous cells first.")

## 8. Data Augmentation and Generators

Implementing audio augmentation techniques and creating efficient data generators.

In [None]:
# Data augmentation configuration
AUGMENTATION_CONFIG = {
    'time_stretch_rate': [0.8, 1.2],
    'pitch_shift_semitones': [-2, 2],
    'noise_level': 0.005,
    'augmentation_probability': 0.3
}

def create_augmentation_pipeline():
    """
    Create audio augmentation pipeline using audiomentations.
    """
    return Compose([
        AddGaussianNoise(min_amplitude=0.001, max_amplitude=AUGMENTATION_CONFIG['noise_level'], p=0.3),
        TimeStretch(min_rate=AUGMENTATION_CONFIG['time_stretch_rate'][0], 
                   max_rate=AUGMENTATION_CONFIG['time_stretch_rate'][1], p=0.3),
        PitchShift(min_semitones=AUGMENTATION_CONFIG['pitch_shift_semitones'][0],
                  max_semitones=AUGMENTATION_CONFIG['pitch_shift_semitones'][1], p=0.3),
        Shift(min_fraction=-0.5, max_fraction=0.5, p=0.3)
    ])

def augment_audio_segment(audio: np.ndarray, augment_pipeline, apply_augmentation: bool = True) -> np.ndarray:
    """
    Apply augmentation to an audio segment.
    """
    if apply_augmentation and np.random.random() < AUGMENTATION_CONFIG['augmentation_probability']:
        try:
            augmented = augment_pipeline(samples=audio, sample_rate=YAMNET_SAMPLE_RATE)
            return augmented
        except Exception as e:
            logger.warning(f"Augmentation failed: {e}")
            return audio
    return audio

def create_efficient_data_generator(df: pd.DataFrame, 
                                  batch_size: int = 32,
                                  max_segments_per_file: int = 5,
                                  shuffle: bool = True,
                                  augment: bool = True) -> tf.data.Dataset:
    """
    Create an efficient TensorFlow data generator with augmentation.
    """
    augment_pipeline = create_augmentation_pipeline() if augment else None
    
    def generator():
        """Python generator function."""
        indices = np.arange(len(df))
        if shuffle:
            np.random.shuffle(indices)
        
        for idx in indices:
            row = df.iloc[idx]
            file_path = row['file_path']
            label = row['raga_encoded']
            
            # Load and preprocess audio
            segments, error = preprocess_audio_file(file_path, max_segments=max_segments_per_file)
            
            if not error and segments:
                for segment in segments:
                    # Apply augmentation if requested
                    if augment and augment_pipeline:
                        segment = augment_audio_segment(segment, augment_pipeline, apply_augmentation=True)
                    
                    # Ensure segment is correct length
                    if len(segment) != YAMNET_SAMPLES:
                        if len(segment) < YAMNET_SAMPLES:
                            segment = np.pad(segment, (0, YAMNET_SAMPLES - len(segment)), mode='constant')
                        else:
                            segment = segment[:YAMNET_SAMPLES]
                    
                    yield segment.astype(np.float32), label
    
    # Create TensorFlow dataset
    dataset = tf.data.Dataset.from_generator(
        generator,
        output_signature=(
            tf.TensorSpec(shape=(YAMNET_SAMPLES,), dtype=tf.float32),
            tf.TensorSpec(shape=(), dtype=tf.int64)
        )
    )
    
    # Batch and optimize
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    
    return dataset

def calculate_class_weights(train_df: pd.DataFrame) -> dict:
    """
    Calculate class weights for handling imbalanced datasets.
    """
    classes = train_df['raga_encoded'].values
    unique_classes = np.unique(classes)
    
    weights = compute_class_weight(
        class_weight='balanced',
        classes=unique_classes,
        y=classes
    )
    
    class_weight_dict = dict(zip(unique_classes, weights))
    
    print("⚖️  Class weights for balanced training:")
    for class_idx, weight in class_weight_dict.items():
        raga_name = label_encoder.inverse_transform([class_idx])[0]
        print(f"   {raga_name}: {weight:.3f}")
    
    return class_weight_dict

# Create data generators
if all(var in locals() for var in ['train_df', 'val_df', 'test_df']):
    print("🔄 Creating data generators...")
    
    # Training configuration
    BATCH_SIZE = 16  # Adjust based on your GPU memory
    MAX_SEGMENTS_PER_FILE = 8
    
    # Create datasets
    train_dataset = create_efficient_data_generator(
        train_df, 
        batch_size=BATCH_SIZE,
        max_segments_per_file=MAX_SEGMENTS_PER_FILE,
        shuffle=True,
        augment=True
    )
    
    val_dataset = create_efficient_data_generator(
        val_df,
        batch_size=BATCH_SIZE,
        max_segments_per_file=MAX_SEGMENTS_PER_FILE,
        shuffle=False,
        augment=False
    )
    
    test_dataset = create_efficient_data_generator(
        test_df,
        batch_size=BATCH_SIZE,
        max_segments_per_file=MAX_SEGMENTS_PER_FILE,
        shuffle=False,
        augment=False
    )
    
    # Calculate class weights
    class_weights = calculate_class_weights(train_df)
    
    print("✅ Data generators created successfully!")
    print(f"📊 Configuration:")
    print(f"   Batch size: {BATCH_SIZE}")
    print(f"   Max segments per file: {MAX_SEGMENTS_PER_FILE}")
    print(f"   Augmentation: Enabled for training")
    
    # Test the generators
    print("\\n🧪 Testing data generators...")
    for batch_audio, batch_labels in train_dataset.take(1):
        print(f"   Train batch: {batch_audio.shape} audio, {batch_labels.shape} labels")
        break
    
    for batch_audio, batch_labels in val_dataset.take(1):
        print(f"   Val batch: {batch_audio.shape} audio, {batch_labels.shape} labels")
        break

else:
    print("❌ Missing dataset splits. Please run previous cells first.")

## 9. Training Configuration and Callbacks

Setting up optimizer, loss function, learning rate scheduling, and training callbacks.

In [None]:
# Training configuration
TRAINING_CONFIG = {
    'initial_learning_rate': 0.001,
    'min_learning_rate': 1e-7,
    'epochs': 50,
    'patience': 10,
    'reduce_lr_patience': 5,
    'reduce_lr_factor': 0.5,
    'monitor_metric': 'val_accuracy'
}

def create_learning_rate_schedule():
    """
    Create learning rate schedule with exponential decay.
    """
    initial_learning_rate = TRAINING_CONFIG['initial_learning_rate']
    decay_steps = 1000
    decay_rate = 0.96
    
    lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate,
        decay_steps=decay_steps,
        decay_rate=decay_rate,
        staircase=True
    )
    
    return lr_schedule

def setup_callbacks(model_checkpoint_path: str, log_dir: str):
    """
    Setup training callbacks for monitoring and control.
    """
    callbacks = []
    
    # Model checkpoint - save best model
    checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath=model_checkpoint_path,
        monitor=TRAINING_CONFIG['monitor_metric'],
        save_best_only=True,
        save_weights_only=False,
        mode='max',
        verbose=1
    )
    callbacks.append(checkpoint_callback)
    
    # Early stopping
    early_stopping = tf.keras.callbacks.EarlyStopping(
        monitor=TRAINING_CONFIG['monitor_metric'],
        patience=TRAINING_CONFIG['patience'],
        restore_best_weights=True,
        mode='max',
        verbose=1
    )
    callbacks.append(early_stopping)
    
    # Reduce learning rate on plateau
    reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
        monitor=TRAINING_CONFIG['monitor_metric'],
        factor=TRAINING_CONFIG['reduce_lr_factor'],
        patience=TRAINING_CONFIG['reduce_lr_patience'],
        min_lr=TRAINING_CONFIG['min_learning_rate'],
        mode='max',
        verbose=1
    )
    callbacks.append(reduce_lr)
    
    # TensorBoard logging
    tensorboard = tf.keras.callbacks.TensorBoard(
        log_dir=log_dir,
        histogram_freq=1,
        write_graph=True,
        write_images=True,
        update_freq='epoch'
    )
    callbacks.append(tensorboard)
    
    # Custom callback for live plotting
    class LivePlotCallback(tf.keras.callbacks.Callback):
        def __init__(self):
            self.losses = []
            self.val_losses = []
            self.accuracies = []
            self.val_accuracies = []
            
        def on_epoch_end(self, epoch, logs=None):
            self.losses.append(logs.get('loss'))
            self.val_losses.append(logs.get('val_loss'))
            self.accuracies.append(logs.get('accuracy'))
            self.val_accuracies.append(logs.get('val_accuracy'))
            
            # Clear output and plot
            if epoch % 5 == 0 or epoch == 0:  # Plot every 5 epochs
                self.plot_training_progress()
        
        def plot_training_progress(self):
            plt.figure(figsize=(15, 5))
            
            # Loss plot
            plt.subplot(1, 2, 1)
            plt.plot(self.losses, label='Training Loss', color='blue')
            plt.plot(self.val_losses, label='Validation Loss', color='red')
            plt.title('Model Loss')
            plt.xlabel('Epoch')
            plt.ylabel('Loss')
            plt.legend()
            plt.grid(True)
            
            # Accuracy plot
            plt.subplot(1, 2, 2)
            plt.plot(self.accuracies, label='Training Accuracy', color='blue')
            plt.plot(self.val_accuracies, label='Validation Accuracy', color='red')
            plt.title('Model Accuracy')
            plt.xlabel('Epoch')
            plt.ylabel('Accuracy')
            plt.legend()
            plt.grid(True)
            
            plt.tight_layout()
            plt.show()
    
    live_plot = LivePlotCallback()
    callbacks.append(live_plot)
    
    return callbacks

def compile_model(model, learning_rate=None):
    """
    Compile the model with optimizer, loss, and metrics.
    """
    if learning_rate is None:
        learning_rate = create_learning_rate_schedule()
    
    # Optimizer
    optimizer = tf.keras.optimizers.Adam(
        learning_rate=learning_rate,
        beta_1=0.9,
        beta_2=0.999,
        epsilon=1e-7
    )
    
    # Loss function
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
    
    # Metrics
    metrics = [
        'accuracy',
        tf.keras.metrics.SparseCategoricalCrossentropy(name='crossentropy'),
        tf.keras.metrics.SparseTopKCategoricalAccuracy(k=3, name='top_3_accuracy')
    ]
    
    model.compile(
        optimizer=optimizer,
        loss=loss,
        metrics=metrics
    )
    
    print("✅ Model compiled successfully!")
    print(f"📊 Optimizer: Adam with learning rate schedule")
    print(f"📊 Loss: Sparse Categorical Crossentropy")
    print(f"📊 Metrics: Accuracy, Top-3 Accuracy, Crossentropy")
    
    return model

# Setup training environment
if 'model' in locals():
    print("🔧 Setting up training environment...")
    
    # Create directories
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    experiment_name = f"yamnet_raga_classifier_{timestamp}"
    
    experiment_dir = os.path.join(OUTPUT_PATH, experiment_name)
    os.makedirs(experiment_dir, exist_ok=True)
    
    model_checkpoint_path = os.path.join(experiment_dir, "best_model.h5")
    log_dir = os.path.join(experiment_dir, "logs")
    os.makedirs(log_dir, exist_ok=True)
    
    print(f"📁 Experiment directory: {experiment_dir}")
    print(f"💾 Model checkpoint: {model_checkpoint_path}")
    print(f"📊 TensorBoard logs: {log_dir}")
    
    # Compile model
    model = compile_model(model)
    
    # Setup callbacks
    callbacks = setup_callbacks(model_checkpoint_path, log_dir)
    
    print(f"✅ Training setup complete!")
    print(f"📋 Callbacks configured: {len(callbacks)} callbacks")
    print(f"🎯 Training configuration:")
    for key, value in TRAINING_CONFIG.items():
        print(f"   {key}: {value}")
        
else:
    print("❌ Model not available. Please run previous cells first.")

## 10. Model Training Loop

Executing the training process with proper validation monitoring and progress tracking.

In [None]:
# Calculate steps per epoch
def calculate_dataset_size(dataset, max_batches=100):
    """
    Estimate dataset size by sampling batches.
    """
    total_samples = 0
    for i, (batch_audio, batch_labels) in enumerate(dataset.take(max_batches)):
        total_samples += len(batch_labels)
        if i % 10 == 0:
            print(f"   Processed {i+1} batches, {total_samples} samples so far...")
    return total_samples

# Calculate dataset sizes
if all(var in locals() for var in ['train_dataset', 'val_dataset', 'model', 'callbacks']):
    print("📊 Calculating dataset sizes...")
    
    # This might take a moment for large datasets
    train_size = calculate_dataset_size(train_dataset, max_batches=50)
    val_size = calculate_dataset_size(val_dataset, max_batches=20)
    
    steps_per_epoch = max(1, train_size // BATCH_SIZE)
    validation_steps = max(1, val_size // BATCH_SIZE)
    
    print(f"📊 Training dataset: ~{train_size} samples, {steps_per_epoch} steps per epoch")
    print(f"📊 Validation dataset: ~{val_size} samples, {validation_steps} validation steps")
    
    # Start training
    print(f"\\n🚀 Starting training for {TRAINING_CONFIG['epochs']} epochs...")
    print(f"💾 Best model will be saved to: {model_checkpoint_path}")
    print(f"📊 Monitor TensorBoard: tensorboard --logdir {log_dir}")
    
    try:
        # Train the model
        history = model.fit(
            train_dataset,
            epochs=TRAINING_CONFIG['epochs'],
            validation_data=val_dataset,
            callbacks=callbacks,
            class_weight=class_weights,
            steps_per_epoch=steps_per_epoch,
            validation_steps=validation_steps,
            verbose=1
        )
        
        print("✅ Training completed successfully!")
        
        # Save training history
        history_path = os.path.join(experiment_dir, "training_history.json")
        
        # Convert numpy arrays to lists for JSON serialization
        history_dict = {}
        for key, values in history.history.items():
            history_dict[key] = [float(v) for v in values]
        
        with open(history_path, 'w') as f:
            json.dump(history_dict, f, indent=2)
        
        print(f"💾 Training history saved to: {history_path}")
        
        # Load best model
        if os.path.exists(model_checkpoint_path):
            print(f"📥 Loading best model from: {model_checkpoint_path}")
            best_model = tf.keras.models.load_model(model_checkpoint_path)
            print("✅ Best model loaded successfully!")
        else:
            print("⚠️  Using final model (checkpoint not found)")
            best_model = model
        
        # Final training summary
        print(f"\\n📊 Training Summary:")
        print(f"   Total epochs: {len(history.history['loss'])}")
        print(f"   Best validation accuracy: {max(history.history['val_accuracy']):.4f}")
        print(f"   Final training accuracy: {history.history['accuracy'][-1]:.4f}")
        print(f"   Final validation accuracy: {history.history['val_accuracy'][-1]:.4f}")
        
    except Exception as e:
        print(f"❌ Training failed: {e}")
        import traceback
        traceback.print_exc()
        
else:
    print("❌ Missing required components. Please run all previous cells first.")
    missing = []
    if 'train_dataset' not in locals():
        missing.append('train_dataset')
    if 'val_dataset' not in locals():
        missing.append('val_dataset')
    if 'model' not in locals():
        missing.append('model')
    if 'callbacks' not in locals():
        missing.append('callbacks')
    print(f"Missing: {missing}")

## 11. Model Evaluation and Metrics

Evaluating the trained model on test data and calculating comprehensive metrics.

In [None]:
def evaluate_model(model, test_dataset, label_encoder, class_names=None):
    """
    Comprehensive model evaluation with metrics and analysis.
    """
    print("🔄 Evaluating model on test data...")
    
    # Collect predictions and true labels
    y_true = []
    y_pred = []
    y_pred_proba = []
    
    for batch_audio, batch_labels in tqdm(test_dataset, desc="Evaluating"):
        # Get predictions
        predictions = model.predict(batch_audio, verbose=0)
        
        # Store results
        y_true.extend(batch_labels.numpy())
        y_pred.extend(np.argmax(predictions, axis=1))
        y_pred_proba.extend(predictions)
    
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    y_pred_proba = np.array(y_pred_proba)
    
    # Get class names
    if class_names is None:
        class_names = label_encoder.classes_
    
    # Calculate metrics
    accuracy = np.mean(y_true == y_pred)
    
    # Classification report
    report = classification_report(
        y_true, y_pred,
        target_names=class_names,
        output_dict=True
    )
    
    # Confusion matrix
    cm = confusion_matrix(y_true, y_pred)
    
    # Per-class metrics
    per_class_accuracy = cm.diagonal() / cm.sum(axis=1)
    
    # Top-k accuracy
    top_3_accuracy = 0
    for i, true_label in enumerate(y_true):
        top_3_preds = np.argsort(y_pred_proba[i])[-3:]
        if true_label in top_3_preds:
            top_3_accuracy += 1
    top_3_accuracy /= len(y_true)
    
    print(f"✅ Evaluation completed!")
    print(f"📊 Overall Accuracy: {accuracy:.4f}")
    print(f"📊 Top-3 Accuracy: {top_3_accuracy:.4f}")
    print(f"📊 Macro F1-Score: {report['macro avg']['f1-score']:.4f}")
    print(f"📊 Weighted F1-Score: {report['weighted avg']['f1-score']:.4f}")
    
    # Per-class performance
    print(f"\\n🎵 Per-class Performance:")
    for i, class_name in enumerate(class_names):
        precision = report[class_name]['precision']
        recall = report[class_name]['recall']
        f1 = report[class_name]['f1-score']
        class_accuracy = per_class_accuracy[i]
        
        print(f"   {class_name}:")
        print(f"      Accuracy: {class_accuracy:.4f}")
        print(f"      Precision: {precision:.4f}")
        print(f"      Recall: {recall:.4f}")
        print(f"      F1-Score: {f1:.4f}")
    
    return {
        'accuracy': accuracy,
        'top_3_accuracy': top_3_accuracy,
        'classification_report': report,
        'confusion_matrix': cm,
        'per_class_accuracy': per_class_accuracy,
        'y_true': y_true,
        'y_pred': y_pred,
        'y_pred_proba': y_pred_proba,
        'class_names': class_names
    }

def analyze_misclassifications(evaluation_results, top_n=5):
    """
    Analyze the most common misclassifications.
    """
    y_true = evaluation_results['y_true']
    y_pred = evaluation_results['y_pred']
    class_names = evaluation_results['class_names']
    cm = evaluation_results['confusion_matrix']
    
    print(f"🔍 Analyzing misclassifications...")
    
    # Find most common misclassifications
    misclassifications = []
    for i in range(len(class_names)):
        for j in range(len(class_names)):
            if i != j and cm[i, j] > 0:
                misclassifications.append({
                    'true_class': class_names[i],
                    'predicted_class': class_names[j],
                    'count': cm[i, j],
                    'percentage': cm[i, j] / cm[i].sum() * 100
                })
    
    # Sort by count
    misclassifications.sort(key=lambda x: x['count'], reverse=True)
    
    print(f"\\n❌ Top {top_n} Misclassifications:")
    for i, misc in enumerate(misclassifications[:top_n]):
        print(f"   {i+1}. {misc['true_class']} → {misc['predicted_class']}: "
              f"{misc['count']} times ({misc['percentage']:.1f}%)")
    
    return misclassifications

def save_evaluation_results(evaluation_results, save_path):
    """
    Save evaluation results to file.
    """
    # Prepare data for saving (convert numpy arrays to lists)
    save_data = {
        'accuracy': float(evaluation_results['accuracy']),
        'top_3_accuracy': float(evaluation_results['top_3_accuracy']),
        'classification_report': evaluation_results['classification_report'],
        'confusion_matrix': evaluation_results['confusion_matrix'].tolist(),
        'per_class_accuracy': evaluation_results['per_class_accuracy'].tolist(),
        'class_names': evaluation_results['class_names'].tolist(),
        'evaluation_timestamp': datetime.now().isoformat()
    }
    
    with open(save_path, 'w') as f:
        json.dump(save_data, f, indent=2)
    
    print(f"💾 Evaluation results saved to: {save_path}")

# Evaluate the model
if all(var in locals() for var in ['best_model', 'test_dataset', 'label_encoder', 'experiment_dir']):
    evaluation_results = evaluate_model(best_model, test_dataset, label_encoder)
    
    # Analyze misclassifications
    misclassifications = analyze_misclassifications(evaluation_results)
    
    # Save results
    eval_results_path = os.path.join(experiment_dir, "evaluation_results.json")
    save_evaluation_results(evaluation_results, eval_results_path)
    
    print(f"\\n✅ Model evaluation completed!")
    
else:
    print("❌ Missing required components for evaluation.")
    missing = []
    if 'best_model' not in locals():
        missing.append('best_model')
    if 'test_dataset' not in locals():
        missing.append('test_dataset')
    if 'label_encoder' not in locals():
        missing.append('label_encoder')
    if 'experiment_dir' not in locals():
        missing.append('experiment_dir')
    print(f"Missing: {missing}")

## 12. Performance Visualization

Creating comprehensive visualizations for training history and model performance.

In [None]:
def plot_training_history(history, save_path=None):
    """
    Plot training history with loss and accuracy curves.
    """
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Training and validation loss
    axes[0, 0].plot(history['loss'], label='Training Loss', color='blue', linewidth=2)
    axes[0, 0].plot(history['val_loss'], label='Validation Loss', color='red', linewidth=2)
    axes[0, 0].set_title('Model Loss', fontsize=14, fontweight='bold')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # Training and validation accuracy
    axes[0, 1].plot(history['accuracy'], label='Training Accuracy', color='blue', linewidth=2)
    axes[0, 1].plot(history['val_accuracy'], label='Validation Accuracy', color='red', linewidth=2)
    axes[0, 1].set_title('Model Accuracy', fontsize=14, fontweight='bold')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # Learning rate (if available)
    if 'lr' in history:
        axes[1, 0].plot(history['lr'], label='Learning Rate', color='green', linewidth=2)
        axes[1, 0].set_title('Learning Rate', fontsize=14, fontweight='bold')
        axes[1, 0].set_xlabel('Epoch')
        axes[1, 0].set_ylabel('Learning Rate')
        axes[1, 0].set_yscale('log')
        axes[1, 0].legend()
        axes[1, 0].grid(True, alpha=0.3)
    else:
        axes[1, 0].text(0.5, 0.5, 'Learning Rate\\nNot Available', 
                       ha='center', va='center', transform=axes[1, 0].transAxes,
                       fontsize=12)
        axes[1, 0].set_title('Learning Rate', fontsize=14, fontweight='bold')
    
    # Top-3 accuracy (if available)
    if 'top_3_accuracy' in history and 'val_top_3_accuracy' in history:
        axes[1, 1].plot(history['top_3_accuracy'], label='Training Top-3', color='blue', linewidth=2)
        axes[1, 1].plot(history['val_top_3_accuracy'], label='Validation Top-3', color='red', linewidth=2)
        axes[1, 1].set_title('Top-3 Accuracy', fontsize=14, fontweight='bold')
        axes[1, 1].set_xlabel('Epoch')
        axes[1, 1].set_ylabel('Top-3 Accuracy')
        axes[1, 1].legend()
        axes[1, 1].grid(True, alpha=0.3)
    else:
        # Show final metrics
        final_train_acc = history['accuracy'][-1]
        final_val_acc = history['val_accuracy'][-1]
        axes[1, 1].bar(['Train', 'Validation'], [final_train_acc, final_val_acc], 
                      color=['blue', 'red'], alpha=0.7)
        axes[1, 1].set_title('Final Accuracy', fontsize=14, fontweight='bold')
        axes[1, 1].set_ylabel('Accuracy')
        axes[1, 1].set_ylim(0, 1)
        for i, v in enumerate([final_train_acc, final_val_acc]):
            axes[1, 1].text(i, v + 0.01, f'{v:.3f}', ha='center', va='bottom')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"📊 Training history plot saved to: {save_path}")
    
    plt.show()

def plot_confusion_matrix(cm, class_names, save_path=None):
    """
    Plot confusion matrix as a heatmap.
    """
    plt.figure(figsize=(12, 10))
    
    # Normalize confusion matrix
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    
    # Create heatmap
    sns.heatmap(cm_normalized, 
                annot=True, 
                fmt='.2f', 
                cmap='Blues',
                xticklabels=class_names,
                yticklabels=class_names,
                cbar_kws={'label': 'Normalized Count'})
    
    plt.title('Confusion Matrix (Normalized)', fontsize=16, fontweight='bold', pad=20)
    plt.xlabel('Predicted Label', fontsize=12)
    plt.ylabel('True Label', fontsize=12)
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"📊 Confusion matrix plot saved to: {save_path}")
    
    plt.show()

def plot_per_class_metrics(evaluation_results, save_path=None):
    """
    Plot per-class performance metrics.
    """
    class_names = evaluation_results['class_names']
    report = evaluation_results['classification_report']
    
    # Extract metrics for each class
    metrics = ['precision', 'recall', 'f1-score']
    metric_values = {metric: [] for metric in metrics}
    
    for class_name in class_names:
        for metric in metrics:
            metric_values[metric].append(report[class_name][metric])
    
    # Create plot
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    x = np.arange(len(class_names))
    width = 0.25
    
    colors = ['skyblue', 'lightcoral', 'lightgreen']
    
    for i, (metric, color) in enumerate(zip(metrics, colors)):
        axes[i].bar(x, metric_values[metric], color=color, alpha=0.8, width=0.6)
        axes[i].set_title(f'{metric.capitalize()} per Class', fontsize=14, fontweight='bold')
        axes[i].set_xlabel('Raga Class')
        axes[i].set_ylabel(metric.capitalize())
        axes[i].set_xticks(x)
        axes[i].set_xticklabels(class_names, rotation=45, ha='right')
        axes[i].grid(True, alpha=0.3)
        axes[i].set_ylim(0, 1)
        
        # Add value labels on bars
        for j, v in enumerate(metric_values[metric]):
            axes[i].text(j, v + 0.01, f'{v:.3f}', ha='center', va='bottom')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"📊 Per-class metrics plot saved to: {save_path}")
    
    plt.show()

def create_interactive_plots(evaluation_results, history):
    """
    Create interactive plots using Plotly.
    """
    # Training history
    fig_history = make_subplots(
        rows=2, cols=2,
        subplot_titles=('Loss', 'Accuracy', 'Learning Rate', 'Final Metrics'),
        specs=[[{"secondary_y": False}, {"secondary_y": False}],
               [{"secondary_y": False}, {"type": "bar"}]]
    )
    
    epochs = list(range(1, len(history['loss']) + 1))
    
    # Loss plot
    fig_history.add_trace(
        go.Scatter(x=epochs, y=history['loss'], name='Training Loss', line=dict(color='blue')),
        row=1, col=1
    )
    fig_history.add_trace(
        go.Scatter(x=epochs, y=history['val_loss'], name='Validation Loss', line=dict(color='red')),
        row=1, col=1
    )
    
    # Accuracy plot
    fig_history.add_trace(
        go.Scatter(x=epochs, y=history['accuracy'], name='Training Accuracy', line=dict(color='blue')),
        row=1, col=2
    )
    fig_history.add_trace(
        go.Scatter(x=epochs, y=history['val_accuracy'], name='Validation Accuracy', line=dict(color='red')),
        row=1, col=2
    )
    
    # Learning rate (if available)
    if 'lr' in history:
        fig_history.add_trace(
            go.Scatter(x=epochs, y=history['lr'], name='Learning Rate', line=dict(color='green')),
            row=2, col=1
        )
    
    # Final metrics bar chart
    final_metrics = ['Train Acc', 'Val Acc', 'Test Acc']
    final_values = [
        history['accuracy'][-1], 
        history['val_accuracy'][-1],
        evaluation_results['accuracy']
    ]
    
    fig_history.add_trace(
        go.Bar(x=final_metrics, y=final_values, name='Final Metrics'),
        row=2, col=2
    )
    
    fig_history.update_layout(height=800, title_text="Training History Dashboard")
    fig_history.show()
    
    # Confusion matrix heatmap
    cm = evaluation_results['confusion_matrix']
    class_names = evaluation_results['class_names']
    
    fig_cm = go.Figure(data=go.Heatmap(
        z=cm,
        x=class_names,
        y=class_names,
        colorscale='Blues',
        text=cm,
        texttemplate="%{text}",
        textfont={"size": 12},
    ))
    
    fig_cm.update_layout(
        title="Confusion Matrix",
        xaxis_title="Predicted Label",
        yaxis_title="True Label",
        width=600,
        height=600
    )
    fig_cm.show()

# Create visualizations
if all(var in locals() for var in ['history', 'evaluation_results', 'experiment_dir']):
    print("📊 Creating performance visualizations...")
    
    # Training history plot
    history_plot_path = os.path.join(experiment_dir, "training_history.png")
    plot_training_history(history.history, history_plot_path)
    
    # Confusion matrix plot
    cm_plot_path = os.path.join(experiment_dir, "confusion_matrix.png")
    plot_confusion_matrix(
        evaluation_results['confusion_matrix'], 
        evaluation_results['class_names'], 
        cm_plot_path
    )
    
    # Per-class metrics plot
    metrics_plot_path = os.path.join(experiment_dir, "per_class_metrics.png")
    plot_per_class_metrics(evaluation_results, metrics_plot_path)
    
    # Interactive plots
    print("\\n📊 Creating interactive plots...")
    create_interactive_plots(evaluation_results, history.history)
    
    print("✅ All visualizations created successfully!")
    
else:
    print("❌ Missing data for visualization.")
    missing = []
    if 'history' not in locals():
        missing.append('history')
    if 'evaluation_results' not in locals():
        missing.append('evaluation_results')
    if 'experiment_dir' not in locals():
        missing.append('experiment_dir')
    print(f"Missing: {missing}")

## 13. Model Export and Saving

Saving the trained model in TensorFlow SavedModel format and creating deployment files.

In [None]:
def save_model_for_deployment(model, save_dir, label_encoder, model_config):
    """
    Save model and all necessary files for deployment.
    """
    print(f"💾 Saving model for deployment to: {save_dir}")
    os.makedirs(save_dir, exist_ok=True)
    
    # Save the model in SavedModel format
    model_path = os.path.join(save_dir, "saved_model")
    model.save(model_path, save_format='tf')
    print(f"✅ Model saved to: {model_path}")
    
    # Save label encoder
    label_encoder_path = os.path.join(save_dir, "label_encoder.pkl")
    with open(label_encoder_path, 'wb') as f:
        pickle.dump(label_encoder, f)
    print(f"✅ Label encoder saved to: {label_encoder_path}")
    
    # Save model configuration
    config_path = os.path.join(save_dir, "config.json")
    full_config = {
        **model_config,
        'class_names': label_encoder.classes_.tolist(),
        'num_classes': len(label_encoder.classes_),
        'yamnet_sample_rate': YAMNET_SAMPLE_RATE,
        'yamnet_duration': YAMNET_DURATION,
        'yamnet_samples': YAMNET_SAMPLES,
        'model_version': '1.0.0',
        'creation_date': datetime.now().isoformat(),
        'framework': 'tensorflow',
        'base_model': 'yamnet'
    }
    
    with open(config_path, 'w') as f:
        json.dump(full_config, f, indent=2)
    print(f"✅ Configuration saved to: {config_path}")
    
    # Create inference script
    inference_script = f'''
import tensorflow as tf
import numpy as np
import pickle
import json
import librosa

class RagaClassifier:
    def __init__(self, model_path):
        self.model = tf.saved_model.load(model_path)
        
        # Load configuration
        with open(f"{{model_path}}/../config.json", 'r') as f:
            self.config = json.load(f)
        
        # Load label encoder
        with open(f"{{model_path}}/../label_encoder.pkl", 'rb') as f:
            self.label_encoder = pickle.load(f)
        
        self.sample_rate = self.config['yamnet_sample_rate']
        self.duration = self.config['yamnet_duration']
        self.samples = self.config['yamnet_samples']
        
    def preprocess_audio(self, audio_path):
        """Preprocess audio file for prediction."""
        # Load audio
        audio, sr = librosa.load(audio_path, sr=self.sample_rate)
        
        # Normalize
        audio = librosa.util.normalize(audio)
        
        # Ensure correct length
        if len(audio) < self.samples:
            audio = np.pad(audio, (0, self.samples - len(audio)), mode='constant')
        else:
            audio = audio[:self.samples]
        
        return audio.astype(np.float32)
    
    def predict(self, audio_path):
        """Predict raga from audio file."""
        # Preprocess audio
        audio = self.preprocess_audio(audio_path)
        
        # Add batch dimension
        audio_batch = np.expand_dims(audio, axis=0)
        
        # Get prediction
        predictions = self.model(audio_batch)
        probabilities = tf.nn.softmax(predictions).numpy()[0]
        
        # Get class predictions
        predicted_class = np.argmax(probabilities)
        predicted_raga = self.label_encoder.inverse_transform([predicted_class])[0]
        confidence = probabilities[predicted_class]
        
        # Get top-3 predictions
        top_3_indices = np.argsort(probabilities)[-3:][::-1]
        top_3_predictions = []
        
        for idx in top_3_indices:
            raga = self.label_encoder.inverse_transform([idx])[0]
            prob = probabilities[idx]
            top_3_predictions.append({{'raga': raga, 'confidence': float(prob)}})
        
        return {{
            'predicted_raga': predicted_raga,
            'confidence': float(confidence),
            'top_3_predictions': top_3_predictions,
            'all_probabilities': {{
                self.label_encoder.inverse_transform([i])[0]: float(prob) 
                for i, prob in enumerate(probabilities)
            }}
        }}

# Example usage:
# classifier = RagaClassifier("path/to/saved_model")
# result = classifier.predict("path/to/audio.wav")
# print(f"Predicted raga: {{result['predicted_raga']}} ({{result['confidence']:.3f}})")
'''
    
    inference_script_path = os.path.join(save_dir, "inference.py")
    with open(inference_script_path, 'w') as f:
        f.write(inference_script)
    print(f"✅ Inference script saved to: {inference_script_path}")
    
    # Create requirements.txt
    requirements = '''tensorflow>=2.8.0
numpy>=1.21.0
librosa>=0.9.0
scikit-learn>=1.0.0
'''
    
    requirements_path = os.path.join(save_dir, "requirements.txt")
    with open(requirements_path, 'w') as f:
        f.write(requirements)
    print(f"✅ Requirements file saved to: {requirements_path}")
    
    # Create README
    readme_content = f'''# YAMNet Raga Classifier

This model classifies Indian classical music ragas using a fine-tuned YAMNet architecture.

## Model Information
- Base Model: YAMNet (Google)
- Classes: {len(label_encoder.classes_)} ragas
- Input: 16kHz audio, {YAMNET_DURATION}s duration
- Framework: TensorFlow {tf.__version__}

## Classes
{chr(10).join([f"- {cls}" for cls in label_encoder.classes_])}

## Usage

```python
from inference import RagaClassifier

# Load the classifier
classifier = RagaClassifier("saved_model")

# Predict raga
result = classifier.predict("audio_file.wav")
print(f"Predicted: {{result['predicted_raga']}} ({{result['confidence']:.3f}})")
```

## Files
- `saved_model/`: TensorFlow SavedModel
- `config.json`: Model configuration
- `label_encoder.pkl`: Label encoder for class mapping
- `inference.py`: Inference script
- `requirements.txt`: Python dependencies

## Training Details
- Training Date: {datetime.now().strftime("%Y-%m-%d")}
- Model Version: 1.0.0
- Base Architecture: YAMNet + Custom Classification Head
'''
    
    readme_path = os.path.join(save_dir, "README.md")
    with open(readme_path, 'w') as f:
        f.write(readme_content)
    print(f"✅ README saved to: {readme_path}")
    
    return save_dir

def test_saved_model(save_dir, test_audio_path=None):
    """
    Test the saved model to ensure it works correctly.
    """
    print(f"🧪 Testing saved model...")
    
    try:
        # Load the saved model
        model_path = os.path.join(save_dir, "saved_model")
        loaded_model = tf.saved_model.load(model_path)
        
        # Test with dummy input
        dummy_input = tf.random.normal([1, YAMNET_SAMPLES])
        output = loaded_model(dummy_input)
        
        print(f"✅ Model loading test passed!")
        print(f"   Input shape: {dummy_input.shape}")
        print(f"   Output shape: {output.shape}")
        
        # Test inference script if test audio is provided
        if test_audio_path and os.path.exists(test_audio_path):
            print(f"🧪 Testing inference script with: {test_audio_path}")
            
            # This would require the inference script to be imported
            # For now, just verify the files exist
            required_files = ["config.json", "label_encoder.pkl", "inference.py", "README.md"]
            for file in required_files:
                file_path = os.path.join(save_dir, file)
                if os.path.exists(file_path):
                    print(f"   ✅ {file} exists")
                else:
                    print(f"   ❌ {file} missing")
        
        return True
        
    except Exception as e:
        print(f"❌ Model testing failed: {e}")
        return False

# Save the model for deployment
if all(var in locals() for var in ['best_model', 'label_encoder', 'model_config', 'experiment_dir']):
    deployment_dir = os.path.join(experiment_dir, "deployment")
    
    saved_model_dir = save_model_for_deployment(
        best_model, 
        deployment_dir, 
        label_encoder, 
        model_config
    )
    
    # Test the saved model
    test_success = test_saved_model(deployment_dir)
    
    if test_success:
        print(f"\\n🎉 Model successfully saved for deployment!")
        print(f"📁 Deployment directory: {deployment_dir}")
        print(f"📦 Model size: {sum(os.path.getsize(os.path.join(dirpath, filename)) for dirpath, dirnames, filenames in os.walk(deployment_dir) for filename in filenames) / (1024*1024):.1f} MB")
        
        # Create a zip file for easy sharing
        import shutil
        zip_path = os.path.join(experiment_dir, "raga_classifier_model")
        shutil.make_archive(zip_path, 'zip', deployment_dir)
        print(f"📦 Model package created: {zip_path}.zip")
    else:
        print(f"❌ Model saving failed validation")
        
else:
    print("❌ Missing required components for model saving.")
    missing = []
    if 'best_model' not in locals():
        missing.append('best_model')
    if 'label_encoder' not in locals():
        missing.append('label_encoder')
    if 'model_config' not in locals():
        missing.append('model_config')
    if 'experiment_dir' not in locals():
        missing.append('experiment_dir')
    print(f"Missing: {missing}")

## 14. HuggingFace Hub Integration

Uploading the trained model to HuggingFace Hub with proper documentation and metadata.

In [None]:
# HuggingFace Hub configuration
HF_USERNAME = "your-username"  # Replace with your HuggingFace username
HF_MODEL_NAME = "yamnet-raga-classifier"
HF_TOKEN = None  # Will be set during login

def create_model_card(evaluation_results, model_config, training_config):
    """
    Create a comprehensive model card for HuggingFace Hub.
    """
    
    model_card = f'''---
language: en
tags:
- audio-classification
- music
- indian-classical
- raga
- yamnet
- tensorflow
license: mit
datasets:
- custom
metrics:
- accuracy
- f1
model-index:
- name: {HF_MODEL_NAME}
  results:
  - task:
      type: audio-classification
      name: Audio Classification
    dataset:
      type: custom
      name: Indian Classical Raga Dataset
    metrics:
    - type: accuracy
      value: {evaluation_results['accuracy']:.4f}
    - type: f1
      value: {evaluation_results['classification_report']['macro avg']['f1-score']:.4f}
---

# YAMNet Fine-tuned for Indian Classical Raga Classification

This model is a fine-tuned version of [YAMNet](https://tfhub.dev/google/yamnet/1) for classifying Indian classical music ragas.

## Model Description

This model uses Google's YAMNet as a feature extractor and adds a custom classification head for identifying {model_config['num_classes']} different Indian classical ragas.

### Model Architecture
- **Base Model**: YAMNet (Google)
- **Input**: 16kHz audio, {YAMNET_DURATION}s duration ({YAMNET_SAMPLES} samples)
- **Output**: Softmax probabilities over {model_config['num_classes']} raga classes
- **Framework**: TensorFlow {tf.__version__}

### Supported Ragas
{chr(10).join([f"- {cls}" for cls in evaluation_results['class_names']])}

## Training Details

### Training Data
- **Dataset**: Custom Indian Classical Raga Dataset
- **Classes**: {model_config['num_classes']} ragas
- **Audio Format**: 16kHz WAV files
- **Augmentation**: Time stretching, pitch shifting, noise addition

### Training Configuration
- **Optimizer**: Adam
- **Learning Rate**: {training_config['initial_learning_rate']}
- **Batch Size**: Variable (adaptive based on audio segments)
- **Epochs**: {training_config['epochs']}
- **Early Stopping**: Patience of {training_config['patience']} epochs

### Training Results
- **Best Validation Accuracy**: {max([v for v in evaluation_results.get('val_accuracies', [evaluation_results['accuracy']])]):.4f}
- **Test Accuracy**: {evaluation_results['accuracy']:.4f}
- **Test F1-Score (Macro)**: {evaluation_results['classification_report']['macro avg']['f1-score']:.4f}
- **Test F1-Score (Weighted)**: {evaluation_results['classification_report']['weighted avg']['f1-score']:.4f}

## Usage

### Using the Model

```python
import tensorflow as tf
import numpy as np
import librosa

# Load the model
model = tf.saved_model.load("path/to/model")

# Preprocess audio
def preprocess_audio(audio_path):
    audio, sr = librosa.load(audio_path, sr=16000)
    audio = librosa.util.normalize(audio)
    
    # Ensure correct length
    target_samples = {YAMNET_SAMPLES}
    if len(audio) < target_samples:
        audio = np.pad(audio, (0, target_samples - len(audio)), mode='constant')
    else:
        audio = audio[:target_samples]
    
    return audio.astype(np.float32)

# Make prediction
audio = preprocess_audio("your_audio.wav")
audio_batch = np.expand_dims(audio, axis=0)
predictions = model(audio_batch)
probabilities = tf.nn.softmax(predictions).numpy()[0]

# Get predicted class
predicted_class = np.argmax(probabilities)
confidence = probabilities[predicted_class]

print(f"Predicted raga index: {{predicted_class}} (confidence: {{confidence:.3f}})")
```

### Class Mapping
The model outputs integer class indices. Here's the mapping to raga names:

```python
class_names = {evaluation_results['class_names']}
predicted_raga = class_names[predicted_class]
```

## Performance

### Overall Metrics
- **Accuracy**: {evaluation_results['accuracy']:.4f}
- **Macro F1-Score**: {evaluation_results['classification_report']['macro avg']['f1-score']:.4f}
- **Weighted F1-Score**: {evaluation_results['classification_report']['weighted avg']['f1-score']:.4f}

### Per-Class Performance
| Raga | Precision | Recall | F1-Score |
|------|-----------|--------|----------|'''

    # Add per-class metrics to the table
    for class_name in evaluation_results['class_names']:
        metrics = evaluation_results['classification_report'][class_name]
        model_card += f'''
| {class_name} | {metrics['precision']:.3f} | {metrics['recall']:.3f} | {metrics['f1-score']:.3f} |'''

    model_card += f'''

## Limitations

- The model is trained on a specific dataset and may not generalize to all styles of Indian classical music
- Performance may vary with different recording qualities, instruments, or vocal styles
- The model requires audio to be exactly {YAMNET_DURATION}s long and sampled at 16kHz

## Bias and Fairness

This model may exhibit bias based on:
- The composition of the training dataset
- Recording quality and conditions
- Specific musical instruments or vocal styles present in training data
- Regional variations in raga interpretation

## Citation

If you use this model, please cite:

```bibtex
@misc{{yamnet-raga-classifier,
  title={{YAMNet Fine-tuned for Indian Classical Raga Classification}},
  author={{Your Name}},
  year={{2024}},
  url={{https://huggingface.co/{HF_USERNAME}/{HF_MODEL_NAME}}}
}}
```

## License

This model is released under the MIT License.

## Contact

For questions or issues, please contact [your-email@example.com](mailto:your-email@example.com).
'''

    return model_card

def upload_to_huggingface(model_dir, model_card_content, evaluation_results):
    """
    Upload model to HuggingFace Hub with proper documentation.
    """
    global HF_TOKEN
    
    try:
        # Login to HuggingFace (this will prompt for token if not provided)
        if HF_TOKEN is None:
            from huggingface_hub import notebook_login
            notebook_login()
            # Get token after login
            from huggingface_hub import HfFolder
            HF_TOKEN = HfFolder.get_token()
        
        # Initialize HF API
        api = HfApi(token=HF_TOKEN)
        
        # Create repository
        repo_id = f"{HF_USERNAME}/{HF_MODEL_NAME}"
        
        try:
            create_repo(repo_id, token=HF_TOKEN, exist_ok=True)
            print(f"✅ Repository created/verified: {repo_id}")
        except Exception as e:
            print(f"⚠️  Repository creation warning: {e}")
        
        # Upload model files
        model_files = []
        
        # Find all files in the model directory
        for root, dirs, files in os.walk(model_dir):
            for file in files:
                file_path = os.path.join(root, file)
                relative_path = os.path.relpath(file_path, model_dir)
                model_files.append((file_path, relative_path))
        
        print(f"📤 Uploading {len(model_files)} files...")
        
        # Upload each file
        for file_path, relative_path in tqdm(model_files, desc="Uploading files"):
            try:
                upload_file(
                    path_or_fileobj=file_path,
                    path_in_repo=relative_path,
                    repo_id=repo_id,
                    token=HF_TOKEN
                )
            except Exception as e:
                print(f"⚠️  Failed to upload {relative_path}: {e}")
        
        # Create and upload model card
        model_card_path = os.path.join(model_dir, "README.md")
        with open(model_card_path, 'w', encoding='utf-8') as f:
            f.write(model_card_content)
        
        upload_file(
            path_or_fileobj=model_card_path,
            path_in_repo="README.md",
            repo_id=repo_id,
            token=HF_TOKEN
        )
        
        print(f"✅ Model uploaded successfully!")
        print(f"🔗 Model URL: https://huggingface.co/{repo_id}")
        
        return repo_id
        
    except Exception as e:
        print(f"❌ Upload failed: {e}")
        import traceback
        traceback.print_exc()
        return None

def create_usage_example(repo_id, class_names):
    """
    Create a usage example notebook for the uploaded model.
    """
    
    example_code = f'''# YAMNet Raga Classifier - Usage Example

## Installation

```bash
pip install tensorflow huggingface_hub librosa numpy
```

## Quick Start

```python
import tensorflow as tf
import numpy as np
import librosa
from huggingface_hub import snapshot_download

# Download model from HuggingFace Hub
model_path = snapshot_download(repo_id="{repo_id}")

# Load the model
model = tf.saved_model.load(model_path + "/saved_model")

# Class names
class_names = {class_names}

def predict_raga(audio_path):
    # Load and preprocess audio
    audio, sr = librosa.load(audio_path, sr=16000)
    audio = librosa.util.normalize(audio)
    
    # Ensure correct length
    target_samples = {YAMNET_SAMPLES}
    if len(audio) < target_samples:
        audio = np.pad(audio, (0, target_samples - len(audio)), mode='constant')
    else:
        audio = audio[:target_samples]
    
    # Add batch dimension
    audio_batch = np.expand_dims(audio.astype(np.float32), axis=0)
    
    # Get prediction
    predictions = model(audio_batch)
    probabilities = tf.nn.softmax(predictions).numpy()[0]
    
    # Get top predictions
    top_indices = np.argsort(probabilities)[-3:][::-1]
    
    results = []
    for idx in top_indices:
        results.append({{
            'raga': class_names[idx],
            'confidence': float(probabilities[idx])
        }})
    
    return results

# Example usage
# results = predict_raga("path/to/your/audio.wav")
# print(f"Top prediction: {{results[0]['raga']}} ({{results[0]['confidence']:.3f}})")
```

## Batch Processing

```python
import os
from tqdm import tqdm

def process_audio_folder(folder_path):
    results = {{}}
    
    for filename in tqdm(os.listdir(folder_path)):
        if filename.endswith(('.wav', '.mp3', '.flac')):
            file_path = os.path.join(folder_path, filename)
            try:
                predictions = predict_raga(file_path)
                results[filename] = predictions[0]  # Top prediction
            except Exception as e:
                print(f"Error processing {{filename}}: {{e}}")
    
    return results

# Process a folder of audio files
# results = process_audio_folder("path/to/audio/folder")
```
'''
    
    return example_code

# Upload to HuggingFace Hub
if all(var in locals() for var in ['evaluation_results', 'model_config', 'TRAINING_CONFIG', 'deployment_dir']):
    print("🤗 Preparing to upload to HuggingFace Hub...")
    
    # Update HuggingFace configuration
    print("📝 Please update the HuggingFace configuration:")
    print(f"   HF_USERNAME: {HF_USERNAME}")
    print(f"   HF_MODEL_NAME: {HF_MODEL_NAME}")
    print("\\nIf these are correct, proceed. Otherwise, update the variables above.")
    
    # Create model card
    print("📄 Creating model card...")
    model_card = create_model_card(evaluation_results, model_config, TRAINING_CONFIG)
    
    # Save model card locally
    model_card_path = os.path.join(deployment_dir, "MODEL_CARD.md")
    with open(model_card_path, 'w', encoding='utf-8') as f:
        f.write(model_card)
    print(f"💾 Model card saved locally: {model_card_path}")
    
    # Create usage example
    usage_example = create_usage_example(
        f"{HF_USERNAME}/{HF_MODEL_NAME}", 
        evaluation_results['class_names'].tolist()
    )
    
    usage_example_path = os.path.join(deployment_dir, "USAGE_EXAMPLE.md")
    with open(usage_example_path, 'w', encoding='utf-8') as f:
        f.write(usage_example)
    print(f"💾 Usage example saved: {usage_example_path}")
    
    # Ask for confirmation before uploading
    print("\\n🚀 Ready to upload to HuggingFace Hub!")
    print("⚠️  This will make your model publicly available.")
    print("\\nTo proceed with upload, uncomment and run the following line:")
    print("# repo_id = upload_to_huggingface(deployment_dir, model_card, evaluation_results)")
    
    # Uncomment the line below to actually upload
    # repo_id = upload_to_huggingface(deployment_dir, model_card, evaluation_results)
    
    print("\\n📋 Summary of what's ready for upload:")
    print(f"   Model files: {len([f for f in os.listdir(deployment_dir) if os.path.isfile(os.path.join(deployment_dir, f))])} files")
    print(f"   Model card: ✅")
    print(f"   Usage example: ✅")
    print(f"   Configuration: ✅")
    
else:
    print("❌ Missing required components for HuggingFace upload.")
    missing = []
    if 'evaluation_results' not in locals():
        missing.append('evaluation_results')
    if 'model_config' not in locals():
        missing.append('model_config')
    if 'TRAINING_CONFIG' not in locals():
        missing.append('TRAINING_CONFIG')
    if 'deployment_dir' not in locals():
        missing.append('deployment_dir')
    print(f"Missing: {missing}")

print("\\n🎉 Notebook execution complete!")
print("\\n📊 Final Summary:")
print("✅ Environment setup and GPU configuration")
print("✅ Library installation and imports")
print("✅ Google Drive integration and dataset loading")
print("✅ Audio preprocessing pipeline")
print("✅ Dataset preparation and splitting")
print("✅ YAMNet model loading and feature extraction")
print("✅ Custom classification head architecture")
print("✅ Data augmentation and generators")
print("✅ Training configuration and callbacks")
print("✅ Model training loop")
print("✅ Model evaluation and metrics")
print("✅ Performance visualization")
print("✅ Model export and saving")
print("✅ HuggingFace Hub integration preparation")

print("\\n🚀 Your YAMNet raga classifier is ready for deployment!")
print("📁 Check the experiment directory for all outputs and saved models.")