# Audio Spectrogram Transformer (AST) vs VGGish for Raga Classification

This comprehensive notebook compares two state-of-the-art audio classification models for Indian classical raga identification:

## Models Compared

### 🎵 **Audio Spectrogram Transformer (AST)**
- **Architecture**: Vision Transformer adapted for audio spectrograms
- **Input**: Mel spectrograms (128 mel bins × 1024 time frames)
- **Framework**: PyTorch + Hugging Face Transformers
- **Strengths**: Self-attention mechanisms, long-range dependencies

### 🎵 **VGGish**
- **Architecture**: CNN-based feature extractor (VGG-like)
- **Input**: Audio waveform → VGGish embeddings (128-dim)
- **Framework**: TensorFlow + TensorFlow Hub
- **Strengths**: Proven audio features, computational efficiency

## Notebook Objectives

1. 🔧 **Fair Comparison**: Same dataset, evaluation metrics, and preprocessing
2. 📊 **Comprehensive Analysis**: Training curves, confusion matrices, performance metrics
3. 🚀 **Production Ready**: Model saving, documentation, HuggingFace Hub deployment
4. 📝 **Reproducible**: Clear code structure with proper error handling

## Requirements

- GPU-enabled environment (Google Colab Pro recommended)
- Python 3.8+
- PyTorch with CUDA support
- TensorFlow 2.x
- Hugging Face transformers library

## 1. Environment Setup and Dependencies

Setting up the environment for both PyTorch (AST) and TensorFlow (VGGish) models.

In [None]:
# Install required packages
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install transformers datasets accelerate
!pip install tensorflow tensorflow-hub
!pip install librosa soundfile audiomentations
!pip install scikit-learn matplotlib seaborn plotly
!pip install huggingface_hub wandb tensorboard
!pip install ipywidgets tqdm pandas numpy

In [None]:
import os
import warnings
import json
import pickle
import random
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Any
import logging

# Suppress warnings
warnings.filterwarnings('ignore')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

# PyTorch imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import torchaudio
import torchaudio.transforms as T

# HuggingFace transformers
from transformers import (
    ASTForAudioClassification, ASTFeatureExtractor,
    TrainingArguments, Trainer, EarlyStoppingCallback,
    AutoConfig, AutoModel, AutoFeatureExtractor
)
from datasets import Dataset as HFDataset

# TensorFlow imports
import tensorflow as tf
import tensorflow_hub as hub
from tensorflow.keras import layers, Model
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau

# Audio processing
import librosa
import soundfile as sf
from audiomentations import Compose, AddGaussianNoise, TimeStretch, PitchShift

# Data science libraries
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report, confusion_matrix, f1_score
from sklearn.utils.class_weight import compute_class_weight

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Utilities
from tqdm.auto import tqdm
from huggingface_hub import HfApi, create_repo, upload_file

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

print("📦 All packages imported successfully!")

# GPU Configuration
def setup_device_config():
    """Configure PyTorch and TensorFlow for optimal GPU usage."""
    
    # PyTorch CUDA setup
    if torch.cuda.is_available():
        device = torch.device('cuda')
        print(f"🚀 PyTorch CUDA available: {torch.cuda.get_device_name(0)}")
        print(f"   CUDA Version: {torch.version.cuda}")
        print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
        
        # Set memory fraction
        torch.cuda.empty_cache()
    else:
        device = torch.device('cpu')
        print("⚠️  PyTorch: No CUDA available, using CPU")
    
    # TensorFlow GPU setup
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        try:
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
            print(f"✅ TensorFlow: Found {len(gpus)} GPU(s), memory growth enabled")
        except RuntimeError as e:
            print(f"❌ TensorFlow GPU setup error: {e}")
    else:
        print("⚠️  TensorFlow: No GPU found, using CPU")
    
    # Set random seeds for reproducibility
    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(42)
        torch.cuda.manual_seed_all(42)
    tf.random.set_seed(42)
    
    print("🎲 Random seeds set for reproducibility")
    
    return device

# Setup device
DEVICE = setup_device_config()

# Global configuration
CONFIG = {
    'batch_size': 16,
    'learning_rate_ast': 5e-5,
    'learning_rate_vggish': 1e-3,
    'epochs': 50,
    'patience': 10,
    'warmup_steps': 500,
    'weight_decay': 0.01,
    'gradient_accumulation_steps': 2,
    'max_grad_norm': 1.0,
    'label_smoothing': 0.1,
    
    # Audio configuration
    'sample_rate': 16000,
    'ast_max_length': 1024,  # AST expects 1024 time frames
    'ast_num_mel_bins': 128,
    'vggish_embedding_dim': 128,
    
    # Augmentation
    'augmentation_prob': 0.3,
    'time_stretch_range': [0.8, 1.2],
    'pitch_shift_range': [-2, 2],
    'noise_level': 0.005
}

print(f"⚙️  Configuration loaded: {CONFIG['batch_size']} batch size, {CONFIG['epochs']} epochs")

## 2. Data Loading and Preparation

Loading the raga dataset and preparing separate preprocessing pipelines for AST and VGGish models.

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Dataset paths
DATASET_PATH = "/content/drive/MyDrive/Raga_Dataset"  # Update this path
OUTPUT_PATH = "/content/drive/MyDrive/AST_VGGish_Output"
MODELS_PATH = "/content/drive/MyDrive/AST_VGGish_Models"

# Create output directories
os.makedirs(OUTPUT_PATH, exist_ok=True)
os.makedirs(MODELS_PATH, exist_ok=True)

print(f"📁 Dataset path: {DATASET_PATH}")
print(f"💾 Models path: {MODELS_PATH}")
print(f"📊 Output path: {OUTPUT_PATH}")

def discover_audio_dataset(dataset_path: str) -> pd.DataFrame:
    """
    Discover and catalog all audio files in the dataset.
    """
    audio_files = []
    supported_formats = ['.wav', '.mp3', '.flac', '.m4a', '.ogg']
    
    print("🔍 Discovering audio files...")
    
    for raga_folder in os.listdir(dataset_path):
        raga_path = os.path.join(dataset_path, raga_folder)
        
        if not os.path.isdir(raga_path):
            continue
            
        for audio_file in os.listdir(raga_path):
            file_path = os.path.join(raga_path, audio_file)
            file_ext = os.path.splitext(audio_file)[1].lower()
            
            if file_ext in supported_formats:
                try:
                    # Get basic audio info
                    info = sf.info(file_path)
                    
                    audio_files.append({
                        'file_path': file_path,
                        'filename': audio_file,
                        'raga': raga_folder,
                        'duration': info.duration,
                        'sample_rate': info.samplerate,
                        'channels': info.channels,
                        'format': file_ext[1:],
                        'file_size_mb': os.path.getsize(file_path) / (1024 * 1024)
                    })
                except Exception as e:
                    print(f"⚠️  Error reading {file_path}: {e}")
    
    df = pd.DataFrame(audio_files)
    
    print(f"✅ Found {len(df)} audio files")
    print(f"📊 Raga classes: {df['raga'].nunique()}")
    print(f"🎵 Total duration: {df['duration'].sum()/3600:.2f} hours")
    
    # Display class distribution
    print(f"\\n📈 Class distribution:")
    class_counts = df['raga'].value_counts()
    for raga, count in class_counts.items():
        percentage = (count / len(df)) * 100
        print(f"   {raga}: {count} files ({percentage:.1f}%)")
    
    return df

# Discover dataset
if os.path.exists(DATASET_PATH):
    metadata_df = discover_audio_dataset(DATASET_PATH)
    
    # Create label encoder
    label_encoder = LabelEncoder()
    metadata_df['label'] = label_encoder.fit_transform(metadata_df['raga'])
    
    # Save label encoder
    label_encoder_path = os.path.join(OUTPUT_PATH, 'label_encoder.pkl')
    with open(label_encoder_path, 'wb') as f:
        pickle.dump(label_encoder, f)
    
    print(f"\\n💾 Label encoder saved: {label_encoder_path}")
    print(f"🏷️  Classes: {list(label_encoder.classes_)}")
    
    NUM_CLASSES = len(label_encoder.classes_)
    print(f"🔢 Number of classes: {NUM_CLASSES}")
    
else:
    print(f"❌ Dataset not found at {DATASET_PATH}")
    print("Please update the DATASET_PATH variable with your correct Google Drive path")

In [None]:
# Split dataset
def create_data_splits(df: pd.DataFrame, test_size: float = 0.2, val_size: float = 0.2, random_state: int = 42):
    """Create stratified train/validation/test splits."""
    
    # First split: train+val vs test
    train_val, test = train_test_split(
        df, test_size=test_size, stratify=df['label'], random_state=random_state
    )
    
    # Second split: train vs val
    val_size_adjusted = val_size / (1 - test_size)
    train, val = train_test_split(
        train_val, test_size=val_size_adjusted, stratify=train_val['label'], random_state=random_state
    )
    
    print(f"📊 Dataset splits:")
    print(f"   Train: {len(train)} files ({len(train)/len(df)*100:.1f}%)")
    print(f"   Validation: {len(val)} files ({len(val)/len(df)*100:.1f}%)")
    print(f"   Test: {len(test)} files ({len(test)/len(df)*100:.1f}%)")
    
    return train, val, test

# Audio preprocessing utilities
class AudioAugmentation:
    """Audio augmentation pipeline."""
    
    def __init__(self, sample_rate: int = 16000):
        self.sample_rate = sample_rate
        self.augment = Compose([
            AddGaussianNoise(min_amplitude=0.001, max_amplitude=CONFIG['noise_level'], p=0.3),
            TimeStretch(min_rate=CONFIG['time_stretch_range'][0], 
                       max_rate=CONFIG['time_stretch_range'][1], p=0.3),
            PitchShift(min_semitones=CONFIG['pitch_shift_range'][0],
                      max_semitones=CONFIG['pitch_shift_range'][1], p=0.3),
        ])
    
    def __call__(self, audio: np.ndarray, apply_augmentation: bool = True) -> np.ndarray:
        if apply_augmentation and random.random() < CONFIG['augmentation_prob']:
            try:
                return self.augment(samples=audio, sample_rate=self.sample_rate)
            except Exception as e:
                logger.warning(f"Augmentation failed: {e}")
        return audio

def load_and_preprocess_audio(file_path: str, target_sr: int = CONFIG['sample_rate']) -> np.ndarray:
    """Load and preprocess audio file."""
    try:
        audio, sr = librosa.load(file_path, sr=target_sr)
        audio = librosa.util.normalize(audio)
        return audio
    except Exception as e:
        logger.error(f"Error loading {file_path}: {e}")
        return np.array([])

def prepare_ast_spectrogram(audio: np.ndarray, 
                           feature_extractor: ASTFeatureExtractor) -> Dict[str, torch.Tensor]:
    """Prepare mel spectrogram for AST model."""
    try:
        # AST feature extractor expects audio as list or numpy array
        inputs = feature_extractor(
            audio, 
            sampling_rate=CONFIG['sample_rate'],
            return_tensors="pt",
            max_length=CONFIG['ast_max_length'],
            truncation=True,
            padding=True
        )
        return inputs
    except Exception as e:
        logger.error(f"AST preprocessing error: {e}")
        return {"input_values": torch.zeros(1, CONFIG['ast_max_length'])}

def extract_vggish_embeddings(audio: np.ndarray, vggish_model) -> np.ndarray:
    """Extract VGGish embeddings from audio."""
    try:
        # VGGish expects audio in specific format
        # Reshape audio to match VGGish input requirements
        if len(audio.shape) == 1:
            audio = audio.reshape(1, -1)
        
        # Extract embeddings
        embeddings = vggish_model(audio)
        
        # Average embeddings if multiple time steps
        if len(embeddings.shape) > 1 and embeddings.shape[0] > 1:
            embeddings = tf.reduce_mean(embeddings, axis=0, keepdims=True)
        
        return embeddings.numpy()
    
    except Exception as e:
        logger.error(f"VGGish preprocessing error: {e}")
        return np.zeros((1, CONFIG['vggish_embedding_dim']))

# Create data splits
if 'metadata_df' in locals() and len(metadata_df) > 0:
    train_df, val_df, test_df = create_data_splits(metadata_df)
    
    # Initialize augmentation
    audio_augmenter = AudioAugmentation()
    
    print(f"✅ Data splits and preprocessing pipeline ready!")
    print(f"📊 Augmentation probability: {CONFIG['augmentation_prob']}")
    
else:
    print("❌ No metadata available. Please run the previous cell first.")

## 3. Audio Spectrogram Transformer (AST) Implementation

Loading and fine-tuning the AST model for raga classification using PyTorch and Hugging Face transformers.

In [None]:
# AST Model Configuration
AST_MODEL_NAME = "MIT/ast-finetuned-audioset-10-10-0.4593"  # Pre-trained AST model

class ASTDataset(Dataset):
    """PyTorch Dataset for AST model."""
    
    def __init__(self, dataframe: pd.DataFrame, feature_extractor: ASTFeatureExtractor, 
                 augment: bool = False, augmenter: Optional[AudioAugmentation] = None):
        self.df = dataframe.reset_index(drop=True)
        self.feature_extractor = feature_extractor
        self.augment = augment
        self.augmenter = augmenter
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        
        # Load audio
        audio = load_and_preprocess_audio(row['file_path'])
        
        if len(audio) == 0:
            # Return zero tensor if loading failed
            return {
                'input_values': torch.zeros(CONFIG['ast_max_length']),
                'labels': torch.tensor(row['label'], dtype=torch.long)
            }
        
        # Apply augmentation if enabled
        if self.augment and self.augmenter:
            audio = self.augmenter(audio, apply_augmentation=True)
        
        # Convert to AST input format
        inputs = prepare_ast_spectrogram(audio, self.feature_extractor)
        
        # Prepare output
        result = {
            'input_values': inputs['input_values'].squeeze(0),  # Remove batch dimension
            'labels': torch.tensor(row['label'], dtype=torch.long)
        }
        
        return result

def setup_ast_model(num_classes: int):
    """Setup AST model and feature extractor."""
    
    print(f"🔄 Loading AST model: {AST_MODEL_NAME}")
    
    try:
        # Load feature extractor
        feature_extractor = ASTFeatureExtractor.from_pretrained(AST_MODEL_NAME)
        
        # Load model with custom classification head
        model = ASTForAudioClassification.from_pretrained(
            AST_MODEL_NAME,
            num_labels=num_classes,
            ignore_mismatched_sizes=True
        )
        
        # Move to device
        model = model.to(DEVICE)
        
        print(f"✅ AST model loaded successfully!")
        print(f"   Parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")
        print(f"   Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.1f}M")
        
        return model, feature_extractor
        
    except Exception as e:
        print(f"❌ Failed to load AST model: {e}")
        return None, None

def create_ast_data_loaders(train_df: pd.DataFrame, val_df: pd.DataFrame, test_df: pd.DataFrame,
                           feature_extractor: ASTFeatureExtractor, batch_size: int = CONFIG['batch_size']):
    """Create PyTorch data loaders for AST training."""
    
    # Create datasets
    train_dataset = ASTDataset(train_df, feature_extractor, augment=True, augmenter=audio_augmenter)
    val_dataset = ASTDataset(val_df, feature_extractor, augment=False)
    test_dataset = ASTDataset(test_df, feature_extractor, augment=False)
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        shuffle=True, 
        num_workers=2,
        pin_memory=True if torch.cuda.is_available() else False
    )
    
    val_loader = DataLoader(
        val_dataset, 
        batch_size=batch_size, 
        shuffle=False, 
        num_workers=2,
        pin_memory=True if torch.cuda.is_available() else False
    )
    
    test_loader = DataLoader(
        test_dataset, 
        batch_size=batch_size, 
        shuffle=False, 
        num_workers=2,
        pin_memory=True if torch.cuda.is_available() else False
    )
    
    print(f"📊 AST Data loaders created:")
    print(f"   Train: {len(train_loader)} batches ({len(train_dataset)} samples)")
    print(f"   Validation: {len(val_loader)} batches ({len(val_dataset)} samples)")
    print(f"   Test: {len(test_loader)} batches ({len(test_dataset)} samples)")
    
    return train_loader, val_loader, test_loader

# Setup AST model and data loaders
if all(var in locals() for var in ['NUM_CLASSES', 'train_df', 'val_df', 'test_df']):
    ast_model, ast_feature_extractor = setup_ast_model(NUM_CLASSES)
    
    if ast_model is not None:
        # Create data loaders
        ast_train_loader, ast_val_loader, ast_test_loader = create_ast_data_loaders(
            train_df, val_df, test_df, ast_feature_extractor
        )
        
        # Test data loading
        print("\\n🧪 Testing AST data loading...")
        try:
            batch = next(iter(ast_train_loader))
            print(f"   Input shape: {batch['input_values'].shape}")
            print(f"   Labels shape: {batch['labels'].shape}")
            print(f"   Sample input range: [{batch['input_values'].min():.3f}, {batch['input_values'].max():.3f}]")
            print("✅ AST data loading test successful!")
        except Exception as e:
            print(f"❌ AST data loading test failed: {e}")
    
else:
    print("❌ Missing required variables for AST setup. Please run previous cells first.")

In [None]:
# AST Training Configuration
def setup_ast_training(model, train_loader, val_loader, output_dir: str):
    """Setup AST training with Hugging Face Trainer."""
    
    # Create output directory for AST
    ast_output_dir = os.path.join(output_dir, "ast_model")
    os.makedirs(ast_output_dir, exist_ok=True)
    
    # Convert PyTorch DataLoaders to HuggingFace Dataset format
    def create_hf_dataset(loader):
        all_inputs = []
        all_labels = []
        
        print(f"Converting data loader to HuggingFace format...")
        for batch in tqdm(loader):
            all_inputs.extend(batch['input_values'].cpu().numpy())
            all_labels.extend(batch['labels'].cpu().numpy())
        
        return HFDataset.from_dict({
            'input_values': all_inputs,
            'labels': all_labels
        })
    
    # Convert datasets
    print("🔄 Converting train dataset...")
    hf_train_dataset = create_hf_dataset(train_loader)
    print("🔄 Converting validation dataset...")
    hf_val_dataset = create_hf_dataset(val_loader)
    
    # Training arguments
    training_args = TrainingArguments(
        output_dir=ast_output_dir,
        num_train_epochs=CONFIG['epochs'],
        per_device_train_batch_size=CONFIG['batch_size'],
        per_device_eval_batch_size=CONFIG['batch_size'],
        gradient_accumulation_steps=CONFIG['gradient_accumulation_steps'],
        learning_rate=CONFIG['learning_rate_ast'],
        weight_decay=CONFIG['weight_decay'],
        warmup_steps=CONFIG['warmup_steps'],
        max_grad_norm=CONFIG['max_grad_norm'],
        
        # Evaluation and logging
        evaluation_strategy="steps",
        eval_steps=100,
        save_strategy="steps",
        save_steps=100,
        logging_dir=os.path.join(ast_output_dir, "logs"),
        logging_steps=50,
        
        # Early stopping and checkpointing
        load_best_model_at_end=True,
        metric_for_best_model="eval_accuracy",
        greater_is_better=True,
        save_total_limit=3,
        
        # Performance optimization
        dataloader_pin_memory=True,
        dataloader_num_workers=2,
        fp16=torch.cuda.is_available(),  # Enable mixed precision if CUDA available
        
        # Other settings
        seed=42,
        report_to=["tensorboard"],
        remove_unused_columns=False,
    )
    
    # Metrics computation
    def compute_metrics(eval_pred):
        predictions, labels = eval_pred
        predictions = np.argmax(predictions, axis=1)
        
        # Calculate metrics
        accuracy = (predictions == labels).mean()
        f1_macro = f1_score(labels, predictions, average='macro')
        f1_weighted = f1_score(labels, predictions, average='weighted')
        
        return {
            'accuracy': accuracy,
            'f1_macro': f1_macro,
            'f1_weighted': f1_weighted
        }
    
    # Create trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=hf_train_dataset,
        eval_dataset=hf_val_dataset,
        compute_metrics=compute_metrics,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=CONFIG['patience'])]
    )
    
    print(f"✅ AST Trainer configured!")
    print(f"📁 Output directory: {ast_output_dir}")
    print(f"🎯 Training configuration:")
    print(f"   Learning rate: {CONFIG['learning_rate_ast']}")
    print(f"   Batch size: {CONFIG['batch_size']} (effective: {CONFIG['batch_size'] * CONFIG['gradient_accumulation_steps']})")
    print(f"   Epochs: {CONFIG['epochs']}")
    print(f"   Warmup steps: {CONFIG['warmup_steps']}")
    
    return trainer, ast_output_dir

def train_ast_model(trainer, output_dir: str):
    """Train the AST model."""
    
    print("🚀 Starting AST model training...")
    
    try:
        # Start training
        train_result = trainer.train()
        
        # Save the model
        trainer.save_model()
        trainer.save_state()
        
        # Save training metrics
        metrics_path = os.path.join(output_dir, "training_metrics.json")
        with open(metrics_path, 'w') as f:
            json.dump(train_result.metrics, f, indent=2)
        
        print(f"✅ AST training completed!")
        print(f"📊 Final training loss: {train_result.metrics.get('train_loss', 'N/A'):.4f}")
        print(f"💾 Model saved to: {output_dir}")
        
        return train_result
        
    except Exception as e:
        print(f"❌ AST training failed: {e}")
        import traceback
        traceback.print_exc()
        return None

# Setup and train AST model
if all(var in locals() for var in ['ast_model', 'ast_train_loader', 'ast_val_loader']):
    
    # Setup training
    ast_trainer, ast_output_dir = setup_ast_training(
        ast_model, ast_train_loader, ast_val_loader, MODELS_PATH
    )
    
    print("\\n" + "="*50)
    print("🎯 Ready to train AST model!")
    print("⚠️  This will take significant time. Monitor progress in TensorBoard:")
    print(f"   tensorboard --logdir {ast_output_dir}/logs")
    print("="*50)
    
    # Uncomment the following line to start training
    # ast_train_result = train_ast_model(ast_trainer, ast_output_dir)
    print("\\n💡 To start training, uncomment the line above and run the cell.")
    
else:
    print("❌ AST model setup incomplete. Please run previous cells first.")

In [None]:
# AST Model Evaluation and Inference
def evaluate_ast_model(trainer, test_loader, output_dir: str):
    """Evaluate the trained AST model on test data."""
    
    print("🧪 Evaluating AST model on test set...")
    
    # Convert test loader to HF dataset format
    test_inputs = []
    test_labels = []
    
    for batch in tqdm(test_loader, desc="Preparing test data"):
        test_inputs.extend(batch['input_values'].cpu().numpy())
        test_labels.extend(batch['labels'].cpu().numpy())
    
    hf_test_dataset = HFDataset.from_dict({
        'input_values': test_inputs,
        'labels': test_labels
    })
    
    # Evaluate
    eval_results = trainer.evaluate(eval_dataset=hf_test_dataset)
    
    # Generate detailed classification report
    predictions = trainer.predict(hf_test_dataset)
    y_pred = np.argmax(predictions.predictions, axis=1)
    y_true = predictions.label_ids
    
    # Classification report
    report = classification_report(
        y_true, y_pred, 
        target_names=CLASS_NAMES, 
        output_dict=True
    )
    
    # Confusion matrix
    cm = confusion_matrix(y_true, y_pred)
    
    # Save evaluation results
    eval_file = os.path.join(output_dir, "ast_evaluation_results.json")
    with open(eval_file, 'w') as f:
        json.dump({
            'eval_metrics': eval_results,
            'classification_report': report,
            'confusion_matrix': cm.tolist(),
            'class_names': CLASS_NAMES
        }, f, indent=2)
    
    # Visualize results
    plt.figure(figsize=(15, 5))
    
    # Plot 1: Confusion Matrix
    plt.subplot(1, 3, 1)
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES)
    plt.title('AST Model - Confusion Matrix')
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.xticks(rotation=45)
    plt.yticks(rotation=0)
    
    # Plot 2: Per-class F1 scores
    plt.subplot(1, 3, 2)
    f1_scores = [report[cls]['f1-score'] for cls in CLASS_NAMES]
    plt.bar(range(len(CLASS_NAMES)), f1_scores, color='skyblue')
    plt.title('AST Model - F1 Score per Class')
    plt.xlabel('Raga Classes')
    plt.ylabel('F1 Score')
    plt.xticks(range(len(CLASS_NAMES)), CLASS_NAMES, rotation=45)
    plt.ylim(0, 1)
    
    # Plot 3: Precision vs Recall
    plt.subplot(1, 3, 3)
    precisions = [report[cls]['precision'] for cls in CLASS_NAMES]
    recalls = [report[cls]['recall'] for cls in CLASS_NAMES]
    plt.scatter(recalls, precisions, color='red', alpha=0.7)
    for i, cls in enumerate(CLASS_NAMES):
        plt.annotate(cls, (recalls[i], precisions[i]), 
                    xytext=(5, 5), textcoords='offset points', fontsize=8)
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('AST Model - Precision vs Recall')
    plt.xlim(0, 1)
    plt.ylim(0, 1)
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'ast_evaluation_plots.png'), dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"\\n📊 AST Model Evaluation Results:")
    print(f"   Accuracy: {eval_results['eval_accuracy']:.4f}")
    print(f"   F1 (Macro): {eval_results['eval_f1_macro']:.4f}")
    print(f"   F1 (Weighted): {eval_results['eval_f1_weighted']:.4f}")
    print(f"💾 Results saved to: {eval_file}")
    
    return eval_results, report, cm

def predict_raga_ast(model, processor, audio_path: str, device):
    """Predict raga for a single audio file using AST model."""
    
    try:
        # Load and preprocess audio
        waveform, sr = librosa.load(audio_path, sr=CONFIG['sample_rate'])
        
        # Apply augmentation/preprocessing
        if len(waveform) > CONFIG['max_length'] * sr:
            waveform = waveform[:CONFIG['max_length'] * sr]
        else:
            waveform = np.pad(waveform, (0, max(0, CONFIG['max_length'] * sr - len(waveform))))
        
        # Process with AST processor
        inputs = processor(waveform, sampling_rate=sr, return_tensors="pt")
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        # Predict
        model.eval()
        with torch.no_grad():
            outputs = model(**inputs)
            probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
            predicted_class = torch.argmax(probabilities, dim=-1).item()
            confidence = probabilities[0][predicted_class].item()
        
        return {
            'predicted_raga': CLASS_NAMES[predicted_class],
            'confidence': confidence,
            'all_probabilities': {
                CLASS_NAMES[i]: prob.item() 
                for i, prob in enumerate(probabilities[0])
            }
        }
        
    except Exception as e:
        print(f"❌ Prediction failed for {audio_path}: {e}")
        return None

# Example evaluation (when model is trained)
print("\\n" + "="*50)
print("📊 AST Model Evaluation Ready")
print("="*50)
print("\\n💡 After training completes, run evaluation with:")
print("ast_eval_results, ast_report, ast_cm = evaluate_ast_model(ast_trainer, ast_test_loader, ast_output_dir)")
print("\\n🎯 For single audio prediction:")
print("prediction = predict_raga_ast(ast_model, ast_processor, 'path/to/audio.wav', device)")

## 5. VGGish Model Pipeline

### VGGish Overview
VGGish is a CNN-based audio feature extractor pre-trained on YouTube-8M dataset. It processes audio in log-mel spectrogram format and outputs 128-dimensional embeddings. We'll use TensorFlow Hub's VGGish model and add a classification head for raga classification.

**Key Features:**
- Pre-trained CNN feature extractor
- Fixed 128-dimensional embeddings
- Mel-spectrogram input preprocessing
- Suitable for transfer learning

In [None]:
# VGGish Configuration and Preprocessing
import tensorflow as tf
import tensorflow_hub as hub
from tensorflow.keras import layers, Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint

# VGGish specific configuration
VGGISH_CONFIG = {
    'model_url': 'https://tfhub.dev/google/vggish/1',
    'sample_rate': 16000,
    'stft_window_seconds': 0.025,
    'stft_hop_seconds': 0.010,
    'mel_bands': 64,
    'mel_min_hz': 125,
    'mel_max_hz': 7500,
    'log_offset': 0.01,
    'example_window_seconds': 0.96,
    'example_hop_seconds': 0.96,
    'embedding_size': 128
}

print("🔧 VGGish Configuration:")
for key, value in VGGISH_CONFIG.items():
    print(f"   {key}: {value}")

def vggish_preprocess_audio(waveform, sample_rate=VGGISH_CONFIG['sample_rate']):
    """
    Preprocess audio for VGGish model.
    Converts audio to log-mel spectrograms as expected by VGGish.
    """
    # Ensure correct sample rate
    if len(waveform.shape) > 1:
        waveform = waveform.mean(axis=0)  # Convert to mono if stereo
    
    # Convert to float32 and normalize
    waveform = waveform.astype(np.float32)
    waveform = waveform / np.max(np.abs(waveform) + 1e-8)  # Normalize
    
    # Parameters for STFT
    window_length_samples = int(round(sample_rate * VGGISH_CONFIG['stft_window_seconds']))
    hop_length_samples = int(round(sample_rate * VGGISH_CONFIG['stft_hop_seconds']))
    fft_length = 2 ** int(np.ceil(np.log(window_length_samples) / np.log(2.0)))
    
    # Compute STFT
    stft = librosa.stft(waveform,
                       hop_length=hop_length_samples,
                       win_length=window_length_samples,
                       n_fft=fft_length,
                       center=False)
    
    # Convert to magnitude spectrogram
    magnitude_spectrogram = np.abs(stft)
    
    # Convert to mel-scale
    mel_spectrogram = librosa.feature.melspectrogram(
        S=magnitude_spectrogram**2,
        sr=sample_rate,
        n_mels=VGGISH_CONFIG['mel_bands'],
        fmin=VGGISH_CONFIG['mel_min_hz'],
        fmax=VGGISH_CONFIG['mel_max_hz']
    )
    
    # Convert to log scale
    log_mel_spectrogram = np.log(mel_spectrogram + VGGISH_CONFIG['log_offset'])
    
    # Extract fixed-size windows
    window_length_frames = int(round(VGGISH_CONFIG['example_window_seconds'] / VGGISH_CONFIG['stft_hop_seconds']))
    hop_length_frames = int(round(VGGISH_CONFIG['example_hop_seconds'] / VGGISH_CONFIG['stft_hop_seconds']))
    
    # Extract examples (patches)
    examples = []
    for i in range(0, log_mel_spectrogram.shape[1] - window_length_frames + 1, hop_length_frames):
        examples.append(log_mel_spectrogram[:, i:i + window_length_frames])
    
    if len(examples) == 0:
        # If audio is too short, pad the spectrogram
        padded_spectrogram = np.pad(log_mel_spectrogram, 
                                  ((0, 0), (0, window_length_frames - log_mel_spectrogram.shape[1])), 
                                  mode='constant')
        examples = [padded_spectrogram]
    
    return np.array(examples)

# VGGish Dataset Class
class VGGishDataset:
    """Dataset class for VGGish model preprocessing."""
    
    def __init__(self, audio_files, labels, is_training=True):
        self.audio_files = audio_files
        self.labels = labels
        self.is_training = is_training
        
    def __len__(self):
        return len(self.audio_files)
    
    def preprocess_audio(self, audio_path):
        """Load and preprocess audio for VGGish."""
        try:
            # Load audio
            waveform, sr = librosa.load(audio_path, sr=VGGISH_CONFIG['sample_rate'])
            
            # Apply preprocessing for VGGish
            examples = vggish_preprocess_audio(waveform, sr)
            
            # During training, we can use data augmentation
            if self.is_training and len(examples) > 1:
                # Randomly select one example during training
                idx = np.random.randint(0, len(examples))
                return examples[idx]
            else:
                # During inference, average all examples
                return np.mean(examples, axis=0)
                
        except Exception as e:
            print(f"Error processing {audio_path}: {e}")
            # Return zero spectrogram if loading fails
            return np.zeros((VGGISH_CONFIG['mel_bands'], 
                           int(round(VGGISH_CONFIG['example_window_seconds'] / VGGISH_CONFIG['stft_hop_seconds']))))
    
    def __getitem__(self, idx):
        audio_path = self.audio_files[idx]
        label = self.labels[idx]
        
        # Preprocess audio
        spectrogram = self.preprocess_audio(audio_path)
        
        return {
            'spectrogram': spectrogram.astype(np.float32),
            'label': label
        }

def create_vggish_data_generators(train_files, train_labels, val_files, val_labels, test_files, test_labels, batch_size=32):
    """Create TensorFlow data generators for VGGish training."""
    
    def data_generator(files, labels, is_training=True):
        """Generator function for TensorFlow dataset."""
        dataset = VGGishDataset(files, labels, is_training)
        
        for i in range(len(dataset)):
            item = dataset[i]
            yield item['spectrogram'], item['label']
    
    # Define output signature
    output_signature = (
        tf.TensorSpec(shape=(VGGISH_CONFIG['mel_bands'], 
                           int(round(VGGISH_CONFIG['example_window_seconds'] / VGGISH_CONFIG['stft_hop_seconds']))), 
                     dtype=tf.float32),
        tf.TensorSpec(shape=(), dtype=tf.int32)
    )
    
    # Create datasets
    train_dataset = tf.data.Dataset.from_generator(
        lambda: data_generator(train_files, train_labels, True),
        output_signature=output_signature
    ).batch(batch_size).prefetch(tf.data.AUTOTUNE)
    
    val_dataset = tf.data.Dataset.from_generator(
        lambda: data_generator(val_files, val_labels, False),
        output_signature=output_signature
    ).batch(batch_size).prefetch(tf.data.AUTOTUNE)
    
    test_dataset = tf.data.Dataset.from_generator(
        lambda: data_generator(test_files, test_labels, False),
        output_signature=output_signature
    ).batch(batch_size).prefetch(tf.data.AUTOTUNE)
    
    print(f"✅ VGGish data generators created!")
    print(f"   Train batches: {len(list(train_dataset))}")
    print(f"   Validation batches: {len(list(val_dataset))}")
    print(f"   Test batches: {len(list(test_dataset))}")
    
    return train_dataset, val_dataset, test_dataset

# Create VGGish datasets
if 'train_files' in locals() and 'train_labels' in locals():
    print("\\n🔄 Creating VGGish data generators...")
    vggish_train_dataset, vggish_val_dataset, vggish_test_dataset = create_vggish_data_generators(
        train_files, train_labels,
        val_files, val_labels, 
        test_files, test_labels,
        batch_size=CONFIG['batch_size']
    )
    print("✅ VGGish datasets ready!")
else:
    print("❌ Data files not available. Please run data loading cells first.")

In [None]:
# VGGish Model Architecture
def create_vggish_model(num_classes: int):
    """
    Create VGGish model with classification head.
    
    Architecture:
    1. VGGish feature extractor (frozen/fine-tunable)
    2. Classification head with dropout and batch norm
    """
    
    # Load pre-trained VGGish model from TensorFlow Hub
    print("🔄 Loading VGGish model from TensorFlow Hub...")
    vggish_layer = hub.KerasLayer(
        VGGISH_CONFIG['model_url'],
        input_shape=(VGGISH_CONFIG['mel_bands'], 
                    int(round(VGGISH_CONFIG['example_window_seconds'] / VGGISH_CONFIG['stft_hop_seconds']))),
        trainable=False,  # Start with frozen features
        name='vggish'
    )
    
    # Build model
    inputs = layers.Input(shape=(VGGISH_CONFIG['mel_bands'], 
                                int(round(VGGISH_CONFIG['example_window_seconds'] / VGGISH_CONFIG['stft_hop_seconds']))),
                         name='mel_spectrogram')
    
    # VGGish feature extraction
    features = vggish_layer(inputs)  # Output: (batch_size, 128)
    
    # Classification head
    x = layers.BatchNormalization(name='bn1')(features)
    x = layers.Dropout(0.5, name='dropout1')(x)
    x = layers.Dense(256, activation='relu', name='dense1')(x)
    x = layers.BatchNormalization(name='bn2')(x)
    x = layers.Dropout(0.3, name='dropout2')(x)
    x = layers.Dense(128, activation='relu', name='dense2')(x)
    x = layers.BatchNormalization(name='bn3')(x)
    x = layers.Dropout(0.2, name='dropout3')(x)
    
    # Output layer
    outputs = layers.Dense(num_classes, activation='softmax', name='classification')(x)
    
    # Create model
    model = Model(inputs=inputs, outputs=outputs, name='VGGish_RagaClassifier')
    
    return model, vggish_layer

def setup_vggish_training(model, vggish_layer):
    """Setup VGGish model for training with different strategies."""
    
    # Compile model
    model.compile(
        optimizer=Adam(learning_rate=CONFIG['learning_rate_vggish']),
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )
    
    print("🎯 VGGish Training Strategy:")
    print("   Phase 1: Frozen VGGish features + Train classification head")
    print("   Phase 2: Fine-tune VGGish features + classification head")
    
    # Training callbacks
    callbacks = [
        EarlyStopping(
            monitor='val_accuracy',
            patience=CONFIG['patience'],
            restore_best_weights=True,
            verbose=1
        ),
        ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.5,
            patience=CONFIG['patience']//2,
            min_lr=1e-7,
            verbose=1
        ),
        ModelCheckpoint(
            filepath=os.path.join(MODELS_PATH, 'vggish_model', 'best_model.h5'),
            monitor='val_accuracy',
            save_best_only=True,
            save_weights_only=False,
            verbose=1
        )
    ]
    
    return callbacks

def train_vggish_two_phase(model, vggish_layer, train_dataset, val_dataset, callbacks):
    """Train VGGish model in two phases: frozen then fine-tuning."""
    
    vggish_output_dir = os.path.join(MODELS_PATH, "vggish_model")
    os.makedirs(vggish_output_dir, exist_ok=True)
    
    training_history = {'phase1': None, 'phase2': None}
    
    print("\\n" + "="*60)
    print("🚀 PHASE 1: Training classification head (VGGish frozen)")
    print("="*60)
    
    # Phase 1: Train only classification head
    vggish_layer.trainable = False
    model.compile(
        optimizer=Adam(learning_rate=CONFIG['learning_rate_vggish']),
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )
    
    # Train Phase 1
    history1 = model.fit(
        train_dataset,
        validation_data=val_dataset,
        epochs=CONFIG['epochs'] // 2,  # Half epochs for phase 1
        callbacks=callbacks,
        verbose=1
    )
    training_history['phase1'] = history1.history
    
    print("\\n" + "="*60)
    print("🔥 PHASE 2: Fine-tuning entire model (VGGish unfrozen)")
    print("="*60)
    
    # Phase 2: Fine-tune entire model
    vggish_layer.trainable = True
    model.compile(
        optimizer=Adam(learning_rate=CONFIG['learning_rate_vggish'] * 0.1),  # Lower LR for fine-tuning
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )
    
    # Train Phase 2
    history2 = model.fit(
        train_dataset,
        validation_data=val_dataset,
        epochs=CONFIG['epochs'] // 2,  # Remaining epochs for phase 2
        callbacks=callbacks,
        verbose=1,
        initial_epoch=len(history1.history['loss'])  # Continue from phase 1
    )
    training_history['phase2'] = history2.history
    
    # Save complete training history
    history_path = os.path.join(vggish_output_dir, "training_history.json")
    with open(history_path, 'w') as f:
        json.dump(training_history, f, indent=2)
    
    # Save final model
    model.save(os.path.join(vggish_output_dir, "final_model.h5"))
    
    print(f"\\n✅ VGGish training completed!")
    print(f"💾 Model saved to: {vggish_output_dir}")
    
    return training_history

# Setup VGGish model
if 'CLASS_NAMES' in locals():
    print("\\n🔄 Setting up VGGish model...")
    
    # Create model
    vggish_model, vggish_feature_layer = create_vggish_model(len(CLASS_NAMES))
    
    # Setup training
    vggish_callbacks = setup_vggish_training(vggish_model, vggish_feature_layer)
    
    # Model summary
    print("\\n📋 VGGish Model Architecture:")
    vggish_model.summary()
    
    print(f"\\n✅ VGGish model ready!")
    print(f"📊 Total parameters: {vggish_model.count_params():,}")
    print(f"🎯 Output classes: {len(CLASS_NAMES)}")
    
    # Visualize model architecture
    tf.keras.utils.plot_model(
        vggish_model, 
        to_file=os.path.join(MODELS_PATH, 'vggish_architecture.png'),
        show_shapes=True, 
        show_layer_names=True,
        dpi=150
    )
    
    print("\\n💡 To start training:")
    print("vggish_history = train_vggish_two_phase(vggish_model, vggish_feature_layer, vggish_train_dataset, vggish_val_dataset, vggish_callbacks)")
    
else:
    print("❌ CLASS_NAMES not defined. Please run data loading cells first.")

In [None]:
# VGGish Model Evaluation
def evaluate_vggish_model(model, test_dataset, output_dir: str):
    """Evaluate trained VGGish model on test data."""
    
    print("🧪 Evaluating VGGish model on test set...")
    
    # Evaluate on test set
    test_results = model.evaluate(test_dataset, verbose=1)
    test_loss, test_accuracy = test_results
    
    # Get predictions for detailed analysis
    print("🔄 Generating predictions for detailed analysis...")
    y_true = []
    y_pred = []
    
    for batch_spectrograms, batch_labels in tqdm(test_dataset):
        # Get predictions
        predictions = model.predict(batch_spectrograms, verbose=0)
        predicted_classes = np.argmax(predictions, axis=1)
        
        y_true.extend(batch_labels.numpy())
        y_pred.extend(predicted_classes)
    
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    
    # Generate classification report
    report = classification_report(
        y_true, y_pred,
        target_names=CLASS_NAMES,
        output_dict=True
    )
    
    # Confusion matrix
    cm = confusion_matrix(y_true, y_pred)
    
    # Save evaluation results
    eval_file = os.path.join(output_dir, "vggish_evaluation_results.json")
    with open(eval_file, 'w') as f:
        json.dump({
            'test_loss': float(test_loss),
            'test_accuracy': float(test_accuracy),
            'classification_report': report,
            'confusion_matrix': cm.tolist(),
            'class_names': CLASS_NAMES
        }, f, indent=2)
    
    # Visualize results
    plt.figure(figsize=(15, 5))
    
    # Plot 1: Confusion Matrix
    plt.subplot(1, 3, 1)
    sns.heatmap(cm, annot=True, fmt='d', cmap='Oranges',
                xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES)
    plt.title('VGGish Model - Confusion Matrix')
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.xticks(rotation=45)
    plt.yticks(rotation=0)
    
    # Plot 2: Per-class F1 scores
    plt.subplot(1, 3, 2)
    f1_scores = [report[cls]['f1-score'] for cls in CLASS_NAMES]
    plt.bar(range(len(CLASS_NAMES)), f1_scores, color='orange', alpha=0.7)
    plt.title('VGGish Model - F1 Score per Class')
    plt.xlabel('Raga Classes')
    plt.ylabel('F1 Score')
    plt.xticks(range(len(CLASS_NAMES)), CLASS_NAMES, rotation=45)
    plt.ylim(0, 1)
    
    # Plot 3: Precision vs Recall
    plt.subplot(1, 3, 3)
    precisions = [report[cls]['precision'] for cls in CLASS_NAMES]
    recalls = [report[cls]['recall'] for cls in CLASS_NAMES]
    plt.scatter(recalls, precisions, color='darkorange', alpha=0.7)
    for i, cls in enumerate(CLASS_NAMES):
        plt.annotate(cls, (recalls[i], precisions[i]),
                    xytext=(5, 5), textcoords='offset points', fontsize=8)
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('VGGish Model - Precision vs Recall')
    plt.xlim(0, 1)
    plt.ylim(0, 1)
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'vggish_evaluation_plots.png'), dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"\\n📊 VGGish Model Evaluation Results:")
    print(f"   Test Loss: {test_loss:.4f}")
    print(f"   Test Accuracy: {test_accuracy:.4f}")
    print(f"   F1 (Macro): {report['macro avg']['f1-score']:.4f}")
    print(f"   F1 (Weighted): {report['weighted avg']['f1-score']:.4f}")
    print(f"💾 Results saved to: {eval_file}")
    
    return {
        'test_loss': test_loss,
        'test_accuracy': test_accuracy,
        'classification_report': report,
        'confusion_matrix': cm
    }

def predict_raga_vggish(model, audio_path: str):
    """Predict raga for a single audio file using VGGish model."""
    
    try:
        # Load and preprocess audio
        waveform, sr = librosa.load(audio_path, sr=VGGISH_CONFIG['sample_rate'])
        
        # Apply VGGish preprocessing
        spectrograms = vggish_preprocess_audio(waveform, sr)
        
        # Average predictions across all spectrograms
        all_predictions = []
        for spectrogram in spectrograms:
            # Expand dimensions for batch
            input_spectrogram = np.expand_dims(spectrogram, axis=0)
            
            # Predict
            prediction = model.predict(input_spectrogram, verbose=0)
            all_predictions.append(prediction[0])
        
        # Average predictions
        avg_prediction = np.mean(all_predictions, axis=0)
        predicted_class = np.argmax(avg_prediction)
        confidence = avg_prediction[predicted_class]
        
        return {
            'predicted_raga': CLASS_NAMES[predicted_class],
            'confidence': float(confidence),
            'all_probabilities': {
                CLASS_NAMES[i]: float(prob)
                for i, prob in enumerate(avg_prediction)
            }
        }
        
    except Exception as e:
        print(f"❌ VGGish prediction failed for {audio_path}: {e}")
        return None

# Example evaluation setup
print("\\n" + "="*50)
print("📊 VGGish Model Evaluation Ready")
print("="*50)
print("\\n💡 After training completes, run evaluation with:")
print("vggish_eval_results = evaluate_vggish_model(vggish_model, vggish_test_dataset, os.path.join(MODELS_PATH, 'vggish_model'))")
print("\\n🎯 For single audio prediction:")
print("prediction = predict_raga_vggish(vggish_model, 'path/to/audio.wav')")

## 6. Model Comparison and Analysis

### Comparative Analysis
This section provides comprehensive comparison between AST and VGGish models across multiple dimensions:

**Comparison Metrics:**
- Performance: Accuracy, F1-scores, Precision, Recall
- Computational: Training time, inference speed, model size
- Architectural: Parameter count, memory usage
- Robustness: Cross-validation results, error analysis

In [None]:
# Comprehensive Model Comparison
import time
from datetime import datetime

def compare_models(ast_results=None, vggish_results=None, save_path=None):
    """
    Comprehensive comparison between AST and VGGish models.
    
    Args:
        ast_results: Dictionary containing AST evaluation results
        vggish_results: Dictionary containing VGGish evaluation results
        save_path: Path to save comparison results
    """
    
    if save_path is None:
        save_path = MODELS_PATH
    
    comparison_results = {
        'timestamp': datetime.now().isoformat(),
        'models': {}
    }
    
    print("\\n" + "="*60)
    print("📊 COMPREHENSIVE MODEL COMPARISON")
    print("="*60)
    
    # Model information
    models_info = {
        'AST': {
            'name': 'Audio Spectrogram Transformer',
            'type': 'Transformer-based',
            'pre_training': 'AudioSet',
            'input_type': 'Raw audio waveform',
            'architecture': 'Vision Transformer adapted for audio'
        },
        'VGGish': {
            'name': 'VGGish',
            'type': 'CNN-based',
            'pre_training': 'YouTube-8M',
            'input_type': 'Log-mel spectrogram',
            'architecture': 'VGG-like CNN with classification head'
        }
    }
    
    # Performance comparison
    print("\\n🎯 PERFORMANCE COMPARISON")
    print("-" * 40)
    
    performance_data = []
    model_names = []
    
    if ast_results:
        ast_acc = ast_results.get('eval_accuracy', 0)
        ast_f1 = ast_results.get('eval_f1_macro', 0)
        performance_data.append([ast_acc, ast_f1])
        model_names.append('AST')
        comparison_results['models']['AST'] = {
            'accuracy': float(ast_acc),
            'f1_macro': float(ast_f1),
            'info': models_info['AST']
        }
        print(f"AST Model:")
        print(f"  Accuracy: {ast_acc:.4f}")
        print(f"  F1 (Macro): {ast_f1:.4f}")
    
    if vggish_results:
        vggish_acc = vggish_results.get('test_accuracy', 0)
        vggish_f1 = vggish_results['classification_report']['macro avg']['f1-score']
        performance_data.append([vggish_acc, vggish_f1])
        model_names.append('VGGish')
        comparison_results['models']['VGGish'] = {
            'accuracy': float(vggish_acc),
            'f1_macro': float(vggish_f1),
            'info': models_info['VGGish']
        }
        print(f"VGGish Model:")
        print(f"  Accuracy: {vggish_acc:.4f}")
        print(f"  F1 (Macro): {vggish_f1:.4f}")
    
    # Visualization
    if len(performance_data) > 0:
        plt.figure(figsize=(20, 12))
        
        # Performance comparison
        plt.subplot(2, 4, 1)
        metrics = ['Accuracy', 'F1 Score']
        x = np.arange(len(metrics))
        width = 0.35
        
        for i, (model_name, data) in enumerate(zip(model_names, performance_data)):
            offset = (i - len(model_names)/2 + 0.5) * width
            plt.bar(x + offset, data, width, label=model_name, alpha=0.8)
        
        plt.xlabel('Metrics')
        plt.ylabel('Score')
        plt.title('Performance Comparison')
        plt.xticks(x, metrics)
        plt.legend()
        plt.ylim(0, 1)
        
        # Per-class F1 comparison (if both models available)
        if ast_results and vggish_results and 'classification_report' in ast_results:
            plt.subplot(2, 4, 2)
            
            ast_f1_per_class = [ast_results['classification_report'][cls]['f1-score'] for cls in CLASS_NAMES]
            vggish_f1_per_class = [vggish_results['classification_report'][cls]['f1-score'] for cls in CLASS_NAMES]
            
            x = np.arange(len(CLASS_NAMES))
            plt.bar(x - 0.2, ast_f1_per_class, 0.4, label='AST', alpha=0.8)
            plt.bar(x + 0.2, vggish_f1_per_class, 0.4, label='VGGish', alpha=0.8)
            
            plt.xlabel('Raga Classes')
            plt.ylabel('F1 Score')
            plt.title('Per-Class F1 Score Comparison')
            plt.xticks(x, CLASS_NAMES, rotation=45)
            plt.legend()
            plt.ylim(0, 1)
        
        # Confusion matrices side by side
        subplot_idx = 3
        if ast_results and 'confusion_matrix' in ast_results:
            plt.subplot(2, 4, subplot_idx)
            cm_ast = np.array(ast_results['confusion_matrix'])
            sns.heatmap(cm_ast, annot=True, fmt='d', cmap='Blues', 
                       xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES, cbar=False)
            plt.title('AST Confusion Matrix')
            plt.xlabel('Predicted')
            plt.ylabel('Actual')
            plt.xticks(rotation=45)
            subplot_idx += 1
        
        if vggish_results and 'confusion_matrix' in vggish_results:
            plt.subplot(2, 4, subplot_idx)
            cm_vggish = np.array(vggish_results['confusion_matrix'])
            sns.heatmap(cm_vggish, annot=True, fmt='d', cmap='Oranges',
                       xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES, cbar=False)
            plt.title('VGGish Confusion Matrix')
            plt.xlabel('Predicted')
            plt.ylabel('Actual')
            plt.xticks(rotation=45)
            subplot_idx += 1
        
        # Model architecture comparison
        plt.subplot(2, 4, (subplot_idx, subplot_idx + 1))
        architecture_comparison = {
            'Model': [],
            'Type': [],
            'Input': [],
            'Pre-training': []
        }
        
        for model, info in models_info.items():
            if (model == 'AST' and ast_results) or (model == 'VGGish' and vggish_results):
                architecture_comparison['Model'].append(info['name'])
                architecture_comparison['Type'].append(info['type'])
                architecture_comparison['Input'].append(info['input_type'])
                architecture_comparison['Pre-training'].append(info['pre_training'])
        
        # Create table
        table_data = list(zip(*[architecture_comparison[key] for key in architecture_comparison.keys()]))
        table = plt.table(cellText=table_data,
                         colLabels=list(architecture_comparison.keys()),
                         cellLoc='center',
                         loc='center')
        table.auto_set_font_size(False)
        table.set_fontsize(10)
        table.scale(1, 2)
        plt.axis('off')
        plt.title('Architecture Comparison', pad=20)
        
        plt.tight_layout()
        plt.savefig(os.path.join(save_path, 'model_comparison.png'), dpi=300, bbox_inches='tight')
        plt.show()
    
    # Save comprehensive comparison
    comparison_file = os.path.join(save_path, 'comprehensive_comparison.json')
    with open(comparison_file, 'w') as f:
        json.dump(comparison_results, f, indent=2)
    
    print(f"\\n💾 Comparison results saved to: {comparison_file}")
    
    return comparison_results

def benchmark_inference_speed(ast_model=None, vggish_model=None, sample_audio_path=None, num_runs=10):
    """Benchmark inference speed for both models."""
    
    if sample_audio_path is None:
        print("❌ No sample audio provided for benchmarking")
        return None
    
    print(f"\\n⏱️  INFERENCE SPEED BENCHMARK ({num_runs} runs)")
    print("-" * 50)
    
    results = {}
    
    # AST benchmark
    if ast_model is not None:
        print("🔄 Benchmarking AST model...")
        ast_times = []
        
        for _ in range(num_runs):
            start_time = time.time()
            # Simulate AST prediction (replace with actual prediction call)
            # prediction = predict_raga_ast(ast_model, ast_processor, sample_audio_path, device)
            time.sleep(0.1)  # Placeholder - replace with actual prediction
            end_time = time.time()
            ast_times.append(end_time - start_time)
        
        results['AST'] = {
            'mean_time': np.mean(ast_times),
            'std_time': np.std(ast_times),
            'min_time': np.min(ast_times),
            'max_time': np.max(ast_times)
        }
        
        print(f"AST Average: {results['AST']['mean_time']:.3f}s ± {results['AST']['std_time']:.3f}s")
    
    # VGGish benchmark
    if vggish_model is not None:
        print("🔄 Benchmarking VGGish model...")
        vggish_times = []
        
        for _ in range(num_runs):
            start_time = time.time()
            # Simulate VGGish prediction (replace with actual prediction call)
            # prediction = predict_raga_vggish(vggish_model, sample_audio_path)
            time.sleep(0.05)  # Placeholder - replace with actual prediction
            end_time = time.time()
            vggish_times.append(end_time - start_time)
        
        results['VGGish'] = {
            'mean_time': np.mean(vggish_times),
            'std_time': np.std(vggish_times),
            'min_time': np.min(vggish_times),
            'max_time': np.max(vggish_times)
        }
        
        print(f"VGGish Average: {results['VGGish']['mean_time']:.3f}s ± {results['VGGish']['std_time']:.3f}s")
    
    return results

# Example usage instructions
print("\\n" + "="*60)
print("📋 MODEL COMPARISON USAGE")
print("="*60)
print("\\n💡 After training both models, compare them with:")
print("comparison_results = compare_models(ast_eval_results, vggish_eval_results)")
print("\\n⏱️  Benchmark inference speed with:")
print("speed_results = benchmark_inference_speed(ast_model, vggish_model, 'sample_audio.wav')")
print("\\n🎯 This will generate:")
print("   • Performance comparison charts")
print("   • Confusion matrix visualizations") 
print("   • Architecture comparison table")
print("   • Comprehensive JSON report")

## 7. Model Deployment and Sharing

### HuggingFace Hub Integration
Deploy and share trained models on HuggingFace Hub for easy access and reproducibility.

**Deployment Features:**
- Model cards with detailed documentation
- Inference API for real-time predictions
- Version control and model tracking
- Community sharing and collaboration

In [None]:
# HuggingFace Hub Deployment
from huggingface_hub import HfApi, create_repo, upload_folder
import tempfile
import shutil

def create_model_card(model_name: str, model_type: str, results: dict, dataset_info: dict):
    """Create a comprehensive model card for HuggingFace Hub."""
    
    accuracy = results.get('eval_accuracy' if 'eval_accuracy' in results else 'test_accuracy', 0)
    f1_macro = results.get('eval_f1_macro', results.get('classification_report', {}).get('macro avg', {}).get('f1-score', 0))
    
    model_card = f"""---
language:
- en
- hi
tags:
- audio-classification
- music
- indian-classical
- raga-classification
- {model_type.lower()}
datasets:
- custom-raga-dataset
metrics:
- accuracy
- f1
library_name: {'transformers' if model_type == 'AST' else 'tensorflow'}
pipeline_tag: audio-classification
---

# {model_name} for Indian Classical Raga Classification

## Model Description

This model is a fine-tuned {model_type} model for classifying Indian classical music ragas. The model was trained on a custom dataset of {dataset_info.get('total_samples', 'N/A')} audio samples across {len(CLASS_NAMES)} different ragas.

### Model Architecture
- **Base Model**: {model_type}
- **Task**: Multi-class audio classification
- **Classes**: {len(CLASS_NAMES)} Indian classical ragas
- **Input**: {'Raw audio waveform' if model_type == 'AST' else 'Log-mel spectrogram'}

### Raga Classes
{', '.join(CLASS_NAMES)}

## Training Data

The model was trained on a diverse dataset of Indian classical music recordings:
- **Total Samples**: {dataset_info.get('total_samples', 'N/A')}
- **Train/Val/Test Split**: {dataset_info.get('train_split', 'N/A')}/{dataset_info.get('val_split', 'N/A')}/{dataset_info.get('test_split', 'N/A')}
- **Audio Format**: {dataset_info.get('audio_format', 'WAV, 22050 Hz')}
- **Duration Range**: {dataset_info.get('duration_range', '30-180 seconds')}

## Performance

| Metric | Score |
|--------|--------|
| Accuracy | {accuracy:.4f} |
| F1 Score (Macro) | {f1_macro:.4f} |

### Per-Class Performance
"""

    if 'classification_report' in results:
        model_card += "\\n| Raga | Precision | Recall | F1-Score |\\n|------|-----------|--------|----------|\\n"
        for raga in CLASS_NAMES:
            if raga in results['classification_report']:
                precision = results['classification_report'][raga]['precision']
                recall = results['classification_report'][raga]['recall']
                f1 = results['classification_report'][raga]['f1-score']
                model_card += f"| {raga} | {precision:.3f} | {recall:.3f} | {f1:.3f} |\\n"

    model_card += f"""

## Usage

### Using Transformers (for AST)
```python
from transformers import ASTForAudioClassification, ASTProcessor
import torch
import librosa

# Load model and processor
model = ASTForAudioClassification.from_pretrained("your-username/{model_name.lower().replace(' ', '-')}")
processor = ASTProcessor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")

# Load and process audio
waveform, sr = librosa.load("audio_file.wav", sr=16000)
inputs = processor(waveform, sampling_rate=sr, return_tensors="pt")

# Predict
with torch.no_grad():
    outputs = model(**inputs)
    probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
    predicted_class = torch.argmax(probabilities, dim=-1)
```

### Using TensorFlow (for VGGish)
```python
import tensorflow as tf
import numpy as np

# Load model
model = tf.keras.models.load_model("path/to/vggish_model")

# Preprocess audio (implement vggish_preprocess_audio function)
spectrogram = vggish_preprocess_audio(audio_waveform)
prediction = model.predict(np.expand_dims(spectrogram, axis=0))
```

## Training Details

### Training Configuration
- **Epochs**: {CONFIG.get('epochs', 'N/A')}
- **Batch Size**: {CONFIG.get('batch_size', 'N/A')}
- **Learning Rate**: {CONFIG.get('learning_rate_' + model_type.lower(), 'N/A')}
- **Optimizer**: Adam
- **Loss Function**: Sparse Categorical Crossentropy

### Data Augmentation
- Audio time stretching
- Pitch shifting  
- Background noise addition
- Volume normalization

## Limitations and Biases

- Model trained primarily on specific recording conditions
- Performance may vary with different audio qualities
- Limited to {len(CLASS_NAMES)} specific ragas
- May not generalize to fusion or contemporary adaptations

## Citation

```bibtex
@misc{{{model_name.lower().replace(' ', '_')}_raga_classification,
  title={{{model_name} for Indian Classical Raga Classification}},
  author={{Your Name}},
  year={{2024}},
  howpublished={{\\url{{https://huggingface.co/your-username/{model_name.lower().replace(' ', '-')}}}}},
}}
```

## Acknowledgments

- MIT Audio Spectrogram Transformer (AST) team
- Google VGGish model creators
- Indian classical music community
- HuggingFace for hosting and infrastructure
"""

    return model_card

def deploy_ast_to_hub(model, processor, results: dict, repo_name: str, dataset_info: dict):
    """Deploy AST model to HuggingFace Hub."""
    
    print(f"🚀 Deploying AST model to HuggingFace Hub: {repo_name}")
    
    try:
        # Create repository
        api = HfApi()
        
        try:
            create_repo(repo_name, exist_ok=True)
            print(f"✅ Repository created/found: {repo_name}")
        except Exception as e:
            print(f"⚠️  Repository creation note: {e}")
        
        # Save model and processor temporarily
        with tempfile.TemporaryDirectory() as temp_dir:
            print(f"💾 Saving model to temporary directory: {temp_dir}")
            
            # Save model
            model.save_pretrained(temp_dir)
            processor.save_pretrained(temp_dir)
            
            # Create model card
            model_card = create_model_card("Audio Spectrogram Transformer", "AST", results, dataset_info)
            with open(f"{temp_dir}/README.md", "w", encoding="utf-8") as f:
                f.write(model_card)
            
            # Save configuration
            config = {
                "model_type": "AST",
                "num_classes": len(CLASS_NAMES),
                "class_names": CLASS_NAMES,
                "sample_rate": CONFIG['sample_rate'],
                "max_length": CONFIG['max_length']
            }
            
            with open(f"{temp_dir}/config.json", "w") as f:
                json.dump(config, f, indent=2)
            
            # Upload to Hub
            print("🔄 Uploading to HuggingFace Hub...")
            upload_folder(
                folder_path=temp_dir,
                repo_id=repo_name,
                repo_type="model"
            )
            
        print(f"✅ AST model deployed successfully!")
        print(f"🌐 Access your model at: https://huggingface.co/{repo_name}")
        
        return f"https://huggingface.co/{repo_name}"
        
    except Exception as e:
        print(f"❌ Deployment failed: {e}")
        return None

def deploy_vggish_to_hub(model, results: dict, repo_name: str, dataset_info: dict):
    """Deploy VGGish model to HuggingFace Hub."""
    
    print(f"🚀 Deploying VGGish model to HuggingFace Hub: {repo_name}")
    
    try:
        # Create repository
        api = HfApi()
        
        try:
            create_repo(repo_name, exist_ok=True)
            print(f"✅ Repository created/found: {repo_name}")
        except Exception as e:
            print(f"⚠️  Repository creation note: {e}")
        
        # Save model temporarily
        with tempfile.TemporaryDirectory() as temp_dir:
            print(f"💾 Saving model to temporary directory: {temp_dir}")
            
            # Save TensorFlow model
            model.save(f"{temp_dir}/vggish_model.h5")
            
            # Create model card
            model_card = create_model_card("VGGish", "VGGish", results, dataset_info)
            with open(f"{temp_dir}/README.md", "w", encoding="utf-8") as f:
                f.write(model_card)
            
            # Save configuration
            config = {
                "model_type": "VGGish",
                "num_classes": len(CLASS_NAMES),
                "class_names": CLASS_NAMES,
                "sample_rate": VGGISH_CONFIG['sample_rate'],
                "mel_bands": VGGISH_CONFIG['mel_bands'],
                "embedding_size": VGGISH_CONFIG['embedding_size']
            }
            
            with open(f"{temp_dir}/config.json", "w") as f:
                json.dump(config, f, indent=2)
            
            # Create inference script
            inference_script = '''
import tensorflow as tf
import numpy as np
import librosa
import json

def load_model_and_config(model_path):
    model = tf.keras.models.load_model(f"{model_path}/vggish_model.h5")
    with open(f"{model_path}/config.json", "r") as f:
        config = json.load(f)
    return model, config

def predict(model, config, audio_path):
    # Load and preprocess audio
    waveform, sr = librosa.load(audio_path, sr=config["sample_rate"])
    
    # Apply VGGish preprocessing (implement vggish_preprocess_audio)
    # spectrograms = vggish_preprocess_audio(waveform, sr)
    
    # Predict
    # predictions = model.predict(spectrograms)
    # return predictions
    
    return {"message": "Implement vggish_preprocess_audio function"}
'''
            
            with open(f"{temp_dir}/inference.py", "w") as f:
                f.write(inference_script)
            
            # Upload to Hub
            print("🔄 Uploading to HuggingFace Hub...")
            upload_folder(
                folder_path=temp_dir,
                repo_id=repo_name,
                repo_type="model"
            )
            
        print(f"✅ VGGish model deployed successfully!")
        print(f"🌐 Access your model at: https://huggingface.co/{repo_name}")
        
        return f"https://huggingface.co/{repo_name}"
        
    except Exception as e:
        print(f"❌ Deployment failed: {e}")
        return None

# Dataset information for model cards
dataset_info = {
    'total_samples': 'TBD',  # Will be filled after data loading
    'train_split': '70%',
    'val_split': '15%', 
    'test_split': '15%',
    'audio_format': 'WAV, 22050 Hz',
    'duration_range': '30-180 seconds'
}

print("\\n" + "="*60)
print("🚀 MODEL DEPLOYMENT READY")
print("="*60)
print("\\n📋 Deployment Instructions:")
print("\\n🎯 For AST model:")
print("ast_hub_url = deploy_ast_to_hub(")
print("    model=trained_ast_model,")
print("    processor=ast_processor,") 
print("    results=ast_eval_results,")
print("    repo_name='your-username/ast-raga-classifier',")
print("    dataset_info=dataset_info")
print(")")
print("\\n🎯 For VGGish model:")
print("vggish_hub_url = deploy_vggish_to_hub(")
print("    model=trained_vggish_model,")
print("    results=vggish_eval_results,")
print("    repo_name='your-username/vggish-raga-classifier',")
print("    dataset_info=dataset_info")
print(")")
print("\\n⚠️  Remember to:")
print("   • Login to HuggingFace: huggingface-cli login")
print("   • Replace 'your-username' with your actual username")
print("   • Ensure you have write permissions to the repositories")

## 8. Complete Execution Workflow

### Step-by-Step Execution Guide

This section provides a complete workflow to execute the entire notebook from start to finish.

**Execution Order:**
1. Environment setup and data loading
2. AST model training and evaluation  
3. VGGish model training and evaluation
4. Model comparison and analysis
5. Deployment to HuggingFace Hub

In [None]:
# Complete Execution Workflow
def execute_complete_pipeline():
    """
    Complete execution pipeline for AST and VGGish raga classification.
    This function demonstrates the complete workflow but should be executed step by step.
    """
    
    print("\\n" + "="*80)
    print("🎵 COMPLETE RAGA CLASSIFICATION PIPELINE")
    print("="*80)
    
    workflow_steps = [
        "1️⃣  Environment Setup and GPU Configuration",
        "2️⃣  Data Loading and Preprocessing", 
        "3️⃣  AST Model Setup and Training",
        "4️⃣  AST Model Evaluation",
        "5️⃣  VGGish Model Setup and Training", 
        "6️⃣  VGGish Model Evaluation",
        "7️⃣  Model Comparison and Analysis",
        "8️⃣  Deployment to HuggingFace Hub"
    ]
    
    print("\\n📋 Execution Steps:")
    for step in workflow_steps:
        print(f"   {step}")
    
    print("\\n⚠️  IMPORTANT NOTES:")
    print("   • Execute cells sequentially, not all at once")
    print("   • Monitor GPU usage and memory")
    print("   • Training may take several hours")
    print("   • Save checkpoints regularly")
    print("   • Verify data paths before training")
    
    execution_checklist = {
        "Environment": "✅ Python packages installed, GPU configured",
        "Data": "✅ Audio files loaded, classes defined, train/val/test split",
        "AST": "✅ Model loaded, datasets created, training ready",
        "VGGish": "✅ Model created, preprocessing setup, training ready", 
        "Training": "⚠️  Execute training cells manually",
        "Evaluation": "⚠️  Run after training completion",
        "Comparison": "⚠️  Run after both models trained",
        "Deployment": "⚠️  Configure HuggingFace credentials first"
    }
    
    print("\\n✅ Pre-Execution Checklist:")
    for item, status in execution_checklist.items():
        print(f"   {item}: {status}")
    
    return workflow_steps

# Training Execution Template
def training_execution_template():
    """Template for executing the training pipeline."""
    
    training_code = '''
# STEP 1: Ensure all setup is complete
print("Verifying setup...")
assert 'CLASS_NAMES' in locals(), "❌ Class names not defined"
assert 'train_files' in locals(), "❌ Training files not loaded"
assert 'device' in locals(), "❌ Device not configured"
print("✅ Setup verified")

# STEP 2: Train AST Model
print("\\n🚀 Starting AST training...")
# Uncomment to execute:
# ast_train_result = train_ast_model(ast_trainer, ast_output_dir)
# ast_eval_results, ast_report, ast_cm = evaluate_ast_model(ast_trainer, ast_test_loader, ast_output_dir)

# STEP 3: Train VGGish Model  
print("\\n🚀 Starting VGGish training...")
# Uncomment to execute:
# vggish_history = train_vggish_two_phase(vggish_model, vggish_feature_layer, vggish_train_dataset, vggish_val_dataset, vggish_callbacks)
# vggish_eval_results = evaluate_vggish_model(vggish_model, vggish_test_dataset, os.path.join(MODELS_PATH, 'vggish_model'))

# STEP 4: Compare Models
print("\\n📊 Comparing models...")
# Uncomment to execute:
# comparison_results = compare_models(ast_eval_results, vggish_eval_results)

# STEP 5: Deploy to HuggingFace
print("\\n🚀 Deploying to HuggingFace...")
# Uncomment to execute:
# ast_hub_url = deploy_ast_to_hub(ast_model, ast_processor, ast_eval_results, 'your-username/ast-raga-classifier', dataset_info)
# vggish_hub_url = deploy_vggish_to_hub(vggish_model, vggish_eval_results, 'your-username/vggish-raga-classifier', dataset_info)

print("✅ Pipeline template ready - uncomment sections to execute")
'''
    
    print("\\n📝 Training Execution Template:")
    print(training_code)
    
    return training_code

# Execute workflow overview
workflow_steps = execute_complete_pipeline()
training_template = training_execution_template()

# Final Summary
print("\\n" + "="*80)
print("📊 NOTEBOOK SUMMARY")
print("="*80)

summary = {
    "🎯 Objective": "Compare AST and VGGish models for Indian classical raga classification",
    "🏗️ Models": "Audio Spectrogram Transformer (AST) + VGGish CNN",
    "📁 Architecture": "Modular design with clear separation between models",
    "🔧 Features": [
        "Complete data preprocessing pipelines",
        "Two-phase VGGish training (frozen → fine-tuning)",
        "Comprehensive evaluation metrics",
        "Visual model comparison",
        "HuggingFace Hub deployment",
        "Reproducible experimental setup"
    ],
    "📈 Outputs": [
        "Trained models with performance metrics",
        "Confusion matrices and classification reports", 
        "Model comparison visualizations",
        "Deployable models on HuggingFace Hub",
        "Complete documentation and model cards"
    ]
}

for key, value in summary.items():
    if isinstance(value, list):
        print(f"\\n{key}:")
        for item in value:
            print(f"   • {item}")
    else:
        print(f"\\n{key}: {value}")

print("\\n" + "="*80)
print("🎉 NOTEBOOK SETUP COMPLETE - Ready for Execution!")
print("="*80)
print("\\n💡 Next Steps:")
print("   1. Verify your data paths and GPU setup")
print("   2. Execute cells sequentially from the beginning") 
print("   3. Monitor training progress and save checkpoints")
print("   4. Compare results and deploy best performing model")
print("   5. Share your research and findings!")

# Save notebook execution log
execution_log = {
    "notebook_created": datetime.now().isoformat(),
    "models": ["AST", "VGGish"],
    "classes": CLASS_NAMES if 'CLASS_NAMES' in locals() else [],
    "workflow_steps": workflow_steps,
    "status": "Ready for execution"
}

log_file = os.path.join(BASE_PATH, "execution_log.json")
with open(log_file, 'w') as f:
    json.dump(execution_log, f, indent=2)

print(f"\\n📝 Execution log saved: {log_file}")