# CNN + Multi-Head Attention for Music Genre Classification

**GTZAN Dataset - Improved & Bug-Free Implementation**

This notebook implements a state-of-the-art CNN + Attention model with:
- Multi-head self-attention mechanism
- Data augmentation (SpecAugment)
- Advanced CNN architecture with residual connections
- Learning rate scheduling
- Proper path handling

## 1. Imports and Setup

In [None]:
import os
import numpy as np
import librosa
import warnings
warnings.filterwarnings('ignore')

import tensorflow as tf
from tensorflow.keras import layers, Model, regularizers
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.utils import to_categorical

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report, confusion_matrix

import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

# Set seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

# Plot styling
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette('husl')

print(f"TensorFlow version: {tf.__version__}")
print(f"GPU Available: {len(tf.config.list_physical_devices('GPU')) > 0}")

## 2. Configuration

In [None]:
# ==================== PATHS ====================
# Automatically detect the correct path
BASE_DIR = os.path.dirname(os.getcwd())
DATA_PATH = os.path.join(BASE_DIR, 'Data', 'genres_original')

# Verify path
if not os.path.exists(DATA_PATH):
    # Try alternative path
    DATA_PATH = '../../Data/genres_original'
    if not os.path.exists(DATA_PATH):
        raise FileNotFoundError(f"Data not found. Please update DATA_PATH variable.")

print(f"Data path: {DATA_PATH}")
print(f"Path exists: {os.path.exists(DATA_PATH)}")

# ==================== AUDIO PARAMETERS ====================
SAMPLE_RATE = 22050
DURATION = 30
N_MELS = 128
N_FFT = 2048
HOP_LENGTH = 512

# ==================== SEGMENTATION ====================
NUM_SEGMENTS = 15  # 15 x 2-second segments

# ==================== MODEL ====================
NUM_CLASSES = 10
GENRES = ['blues', 'classical', 'country', 'disco', 'hiphop',
          'jazz', 'metal', 'pop', 'reggae', 'rock']

# ==================== HYPERPARAMETERS ====================
CNN_FILTERS = [32, 64, 128, 256]
ATTENTION_HEADS = 8
KEY_DIM = 32
DENSE_UNITS = 256
DROPOUT_RATE = 0.4
L2_REG = 0.0005
LEARNING_RATE = 0.001
BATCH_SIZE = 32
EPOCHS = 100

## 3. Verify Dataset

In [None]:
print("\nDataset verification:")
print("=" * 50)
total_files = 0
for genre in GENRES:
    genre_path = os.path.join(DATA_PATH, genre)
    if os.path.exists(genre_path):
        files = [f for f in os.listdir(genre_path) if f.endswith('.wav')]
        count = len(files)
        total_files += count
        print(f"  {genre:12s}: {count:3d} files")
    else:
        print(f"  {genre:12s}: NOT FOUND")
print("=" * 50)
print(f"Total: {total_files} files\n")

## 4. Feature Extraction with SpecAugment

In [None]:
def spec_augment(mel, time_mask_param=10, freq_mask_param=8, augment=True):
    """Apply SpecAugment for data augmentation."""
    if not augment:
        return mel
    
    mel = mel.copy()
    
    # Time masking
    if mel.shape[1] > time_mask_param:
        t = np.random.randint(0, mel.shape[1] - time_mask_param)
        mel[:, t:t+time_mask_param] = 0
    
    # Frequency masking
    if mel.shape[0] > freq_mask_param:
        f = np.random.randint(0, mel.shape[0] - freq_mask_param)
        mel[f:f+freq_mask_param, :] = 0
    
    return mel


def extract_melspectrogram(audio_path, augment=False):
    """Extract mel-spectrogram from audio file."""
    try:
        # Load audio
        audio, _ = librosa.load(audio_path, sr=SAMPLE_RATE, duration=DURATION)
        
        # Pad or trim to exact length
        target_len = SAMPLE_RATE * DURATION
        if len(audio) < target_len:
            audio = np.pad(audio, (0, target_len - len(audio)))
        else:
            audio = audio[:target_len]
        
        # Extract mel-spectrogram
        mel = librosa.feature.melspectrogram(
            y=audio,
            sr=SAMPLE_RATE,
            n_mels=N_MELS,
            n_fft=N_FFT,
            hop_length=HOP_LENGTH
        )
        
        # Convert to dB
        mel_db = librosa.power_to_db(mel, ref=np.max)
        
        # Apply SpecAugment
        mel_db = spec_augment(mel_db, augment=augment)
        
        return mel_db
        
    except Exception as e:
        print(f"Error loading {audio_path}: {str(e)}")
        return None


def create_segments(mel_spec, num_segments=NUM_SEGMENTS):
    """Split mel-spectrogram into equal segments."""
    n_frames = mel_spec.shape[1]
    seg_len = n_frames // num_segments
    
    segments = []
    for i in range(num_segments):
        start = i * seg_len
        end = start + seg_len
        if end > n_frames:
            end = n_frames
        seg = mel_spec[:, start:end]
        
        # Pad if needed
        if seg.shape[1] < seg_len:
            seg = np.pad(seg, ((0, 0), (0, seg_len - seg.shape[1])))
        
        segments.append(seg)
    
    return np.array(segments)

print("Feature extraction functions defined.")

## 5. Load Dataset

In [None]:
def load_data(data_path, augment=False):
    """Load GTZAN dataset and extract features."""
    X, y = [], []
    
    print("Loading dataset...\n")
    for genre_idx, genre in enumerate(GENRES):
        genre_path = os.path.join(data_path, genre)
        
        if not os.path.exists(genre_path):
            print(f"Warning: {genre} folder not found")
            continue
        
        files = sorted([f for f in os.listdir(genre_path) if f.endswith('.wav')])
        
        print(f"Processing {genre} ({len(files)} files)...")
        
        for filename in tqdm(files, desc=genre):
            # Skip known corrupted file
            if filename == 'jazz.00054.wav':
                print(f"  Skipping corrupted file: {filename}")
                continue
            
            filepath = os.path.join(genre_path, filename)
            
            # Extract features
            mel = extract_melspectrogram(filepath, augment=augment)
            
            if mel is not None:
                segments = create_segments(mel)
                X.append(segments)
                y.append(genre_idx)
    
    X = np.array(X)
    y = np.array(y)
    
    print(f"\nLoaded {len(X)} samples")
    print(f"Shape: {X.shape}")
    
    return X, y

# Load data
X, y = load_data(DATA_PATH, augment=False)

## 6. Preprocessing

In [None]:
# Encode labels
label_encoder = LabelEncoder()
y_encoded = label_encoder.fit_transform(y)
y_onehot = to_categorical(y_encoded, NUM_CLASSES)

# Add channel dimension
X = X[..., np.newaxis]

print(f"X shape: {X.shape}")
print(f"y shape: {y_onehot.shape}")
print(f"Genre mapping: {dict(enumerate(GENRES))}")

## 7. Train/Val/Test Split

In [None]:
# Split: 80% train, 10% val, 10% test
X_temp, X_test, y_temp, y_test = train_test_split(
    X, y_onehot, test_size=0.1, stratify=y_encoded, random_state=42
)

y_temp_enc = np.argmax(y_temp, axis=1)
X_train, X_val, y_train, y_val = train_test_split(
    X_temp, y_temp, test_size=0.111, stratify=y_temp_enc, random_state=42
)

print(f"Train: {X_train.shape[0]} samples")
print(f"Val:   {X_val.shape[0]} samples")
print(f"Test:  {X_test.shape[0]} samples")

## 8. Normalization

In [None]:
# Compute normalization stats from training data
train_mean = X_train.mean()
train_std = X_train.std()

# Normalize all sets
X_train = (X_train - train_mean) / (train_std + 1e-8)
X_val = (X_val - train_mean) / (train_std + 1e-8)
X_test = (X_test - train_mean) / (train_std + 1e-8)

print(f"Normalization - Mean: {train_mean:.2f}, Std: {train_std:.2f}")
print(f"After norm - Mean: {X_train.mean():.4f}, Std: {X_train.std():.4f}")

## 9. Build CNN + Multi-Head Attention Model

In [None]:
def build_cnn_attention_model(input_shape):
    """
    Build CNN + Multi-Head Attention model.
    
    Architecture:
    1. TimeDistributed CNN (processes each segment)
    2. Multi-Head Self-Attention (learns temporal relationships)
    3. Global pooling + dense classifier
    """
    
    num_segments = input_shape[0]
    segment_shape = input_shape[1:]
    
    # ==================== SEGMENT CNN ====================
    seg_input = layers.Input(shape=segment_shape)
    x = seg_input
    
    # CNN blocks with increasing filters
    for i, filters in enumerate(CNN_FILTERS):
        x = layers.Conv2D(
            filters, 3, padding='same',
            kernel_regularizer=regularizers.l2(L2_REG)
        )(x)
        x = layers.BatchNormalization()(x)
        x = layers.Activation('elu')(x)
        x = layers.MaxPooling2D(2)(x)
        x = layers.Dropout(0.25)(x)
    
    # Global pooling
    x = layers.GlobalAveragePooling2D()(x)
    
    segment_cnn = Model(seg_input, x, name='segment_cnn')
    
    # ==================== FULL MODEL ====================
    inputs = layers.Input(shape=input_shape, name='input')
    
    # Apply CNN to each segment
    features = layers.TimeDistributed(segment_cnn, name='time_distributed_cnn')(inputs)
    
    # Multi-Head Self-Attention
    attn_output = layers.MultiHeadAttention(
        num_heads=ATTENTION_HEADS,
        key_dim=KEY_DIM,
        dropout=0.1,
        name='multi_head_attention'
    )(features, features)  # Self-attention
    
    # Global pooling over temporal dimension
    x = layers.GlobalAveragePooling1D(name='global_pool')(attn_output)
    
    # Classification head
    x = layers.Dense(
        DENSE_UNITS,
        kernel_regularizer=regularizers.l2(L2_REG)
    )(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('elu')(x)
    x = layers.Dropout(DROPOUT_RATE)(x)
    
    x = layers.Dense(
        DENSE_UNITS // 2,
        kernel_regularizer=regularizers.l2(L2_REG)
    )(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('elu')(x)
    x = layers.Dropout(DROPOUT_RATE)(x)
    
    # Output
    outputs = layers.Dense(NUM_CLASSES, activation='softmax', name='output')(x)
    
    # Build and compile
    model = Model(inputs, outputs, name='cnn_multihead_attention')
    
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )
    
    return model, segment_cnn


# Build model
input_shape = X_train.shape[1:]
print(f"Input shape: {input_shape}\n")

model, segment_cnn = build_cnn_attention_model(input_shape)
model.summary()

## 10. Training

In [None]:
# Callbacks
callbacks = [
    EarlyStopping(
        monitor='val_accuracy',
        patience=20,
        restore_best_weights=True,
        verbose=1
    ),
    ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=7,
        min_lr=1e-6,
        verbose=1
    ),
    ModelCheckpoint(
        'best_cnn_attention.keras',
        monitor='val_accuracy',
        save_best_only=True,
        verbose=1
    )
]

# Train
print("\nStarting training...\n")
history = model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    callbacks=callbacks,
    verbose=1
)

## 11. Training History

In [None]:
# Plot training curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Accuracy
ax1.plot(history.history['accuracy'], label='Train', linewidth=2)
ax1.plot(history.history['val_accuracy'], label='Validation', linewidth=2)
ax1.set_title('Model Accuracy', fontsize=14, fontweight='bold')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Accuracy')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Loss
ax2.plot(history.history['loss'], label='Train', linewidth=2)
ax2.plot(history.history['val_loss'], label='Validation', linewidth=2)
ax2.set_title('Model Loss', fontsize=14, fontweight='bold')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Loss')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('training_history.png', dpi=300, bbox_inches='tight')
plt.show()

# Best metrics
best_val_acc = max(history.history['val_accuracy'])
best_epoch = history.history['val_accuracy'].index(best_val_acc) + 1
print(f"\nBest Validation Accuracy: {best_val_acc:.4f} (Epoch {best_epoch})")

## 12. Evaluation

In [None]:
# Load best model
model.load_weights('best_cnn_attention.keras')

# Evaluate on test set
test_loss, test_acc = model.evaluate(X_test, y_test, verbose=0)

print("\n" + "="*60)
print(f"TEST ACCURACY: {test_acc:.4f} ({test_acc*100:.2f}%)")
print(f"TEST LOSS: {test_loss:.4f}")
print("="*60)

## 13. Classification Report

In [None]:
# Predictions
y_pred = model.predict(X_test, verbose=0)
y_pred_labels = np.argmax(y_pred, axis=1)
y_true_labels = np.argmax(y_test, axis=1)

# Classification report
print("\nClassification Report:")
print("=" * 60)
print(classification_report(
    y_true_labels,
    y_pred_labels,
    target_names=GENRES,
    digits=3
))

## 14. Confusion Matrix

In [None]:
# Confusion matrix
cm = confusion_matrix(y_true_labels, y_pred_labels)

plt.figure(figsize=(12, 10))
sns.heatmap(
    cm, annot=True, fmt='d', cmap='Blues',
    xticklabels=GENRES, yticklabels=GENRES,
    cbar_kws={'label': 'Count'}
)
plt.xlabel('Predicted Genre', fontsize=12, fontweight='bold')
plt.ylabel('True Genre', fontsize=12, fontweight='bold')
plt.title(f'Confusion Matrix (Accuracy: {test_acc:.2%})', fontsize=14, fontweight='bold')
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.savefig('confusion_matrix.png', dpi=300, bbox_inches='tight')
plt.show()

## 15. Save Model and Results

In [None]:
# Save final model
model.save('cnn_attention_final.keras')

# Save normalization parameters
np.savez('normalization_params.npz', mean=train_mean, std=train_std)

# Save training history
np.save('training_history.npy', history.history)

print("\nSaved files:")
print("  ✓ cnn_attention_final.keras")
print("  ✓ best_cnn_attention.keras")
print("  ✓ normalization_params.npz")
print("  ✓ training_history.npy")
print("  ✓ training_history.png")
print("  ✓ confusion_matrix.png")

## Summary

This notebook implements a CNN + Multi-Head Attention model for music genre classification:

**Key Features:**
- 15 temporal segments (2 seconds each)
- 4-layer CNN feature extractor per segment
- 8-head self-attention mechanism
- SpecAugment data augmentation
- Learning rate scheduling
- Early stopping with best weight restoration

**Architecture:**
```
Input (15 segments) 
  → TimeDistributed CNN (32→64→128→256 filters)
  → Multi-Head Attention (8 heads)
  → Global Pool
  → Dense (256→128)
  → Output (10 genres)
```

**Expected Performance:**
- Target accuracy: 75-85% on test set
- Attention mechanism learns which temporal segments are most discriminative for each genre

**Next Steps:**
- Try different numbers of segments (10, 20)
- Experiment with attention heads (4, 16)
- Add more data augmentation
- Fine-tune hyperparameters