# Method 3: CNN + Multi-Head Attention (IMPROVED)

**Improvements over original version:**
- Enhanced CNN architecture with 4 convolutional blocks
- ELU activation for smoother gradients
- Progressive dropout (0.2 → 0.4) for better regularization
- Optimized attention heads (8 heads, d_model=256)
- Two-layer classification head (256 → 128 → 10)
- Learning rate scheduler (ReduceLROnPlateau)
- Increased early stopping patience (15 epochs)

## 1. Imports

In [None]:
import numpy as np
import librosa
import os
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Input, Conv2D, BatchNormalization, MaxPooling2D,
    GlobalAveragePooling2D, GlobalAveragePooling1D,
    Dense, Dropout, TimeDistributed, Layer
)
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.utils import to_categorical
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix
from tqdm import tqdm

# Plot styling
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette('husl')

# Reproducibility
np.random.seed(42)
tf.random.set_seed(42)

print(f'TensorFlow version: {tf.__version__}')
print(f'GPU Available: {len(tf.config.list_physical_devices("GPU")) > 0}')

## 2. Configuration

In [None]:
# Dataset path
DATA_PATH = '../../Data/genres_original'

# Audio parameters
SAMPLE_RATE = 22050
N_FFT = 2048
HOP_LENGTH = 512
N_MELS = 64

# Segmentation parameters
TARGET_LENGTH = 1291  # frames
SEGMENT_LENGTH = 87   # ~2 seconds
NUM_SEGMENTS = 15
OVERLAP = 0.75

# Model parameters
NUM_HEADS = 8
D_MODEL = 256
BATCH_SIZE = 32
EPOCHS = 100
LEARNING_RATE = 0.001

## 3. Data Loading Functions

In [None]:
def extract_melspectrogram(audio_path, sr=22050, n_fft=2048, hop_length=512, n_mels=64):
    """
    Extract mel-spectrogram from audio file.
    """
    y, sr = librosa.load(audio_path, sr=sr, duration=30)
    melspec = librosa.feature.melspectrogram(
        y=y, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels
    )
    melspec_db = librosa.power_to_db(melspec, ref=np.max)
    return melspec_db

def segment_spectrogram(melspec, segment_length=87, overlap=0.75, num_segments=15):
    """
    Segment mel-spectrogram into overlapping 2-second segments.
    
    Args:
        melspec: (64, 1291) mel-spectrogram
        segment_length: 87 frames (~2 seconds)
        overlap: 0.75 (75% overlap)
        num_segments: 15 segments
    
    Returns:
        (15, 87, 64) segmented array
    """
    melspec = melspec.T  # (1291, 64)
    hop = int(segment_length * (1 - overlap))
    segments = []
    
    for i in range(0, melspec.shape[0] - segment_length + 1, hop):
        segment = melspec[i:i+segment_length, :]
        segments.append(segment)
        if len(segments) >= num_segments:
            break
    
    # Pad to exactly num_segments
    while len(segments) < num_segments:
        segments.append(np.zeros((segment_length, 64)))
    
    return np.array(segments[:num_segments])  # (15, 87, 64)

def load_gtzan_segmented(data_path, target_length=1291):
    """
    Load GTZAN dataset and create segmented mel-spectrograms.
    """
    genres = ['blues', 'classical', 'country', 'disco', 'hiphop',
              'jazz', 'metal', 'pop', 'reggae', 'rock']
    
    features = []
    labels = []
    
    print("Loading GTZAN dataset and creating segmented spectrograms...")
    for genre_idx, genre in enumerate(genres):
        print(f"Processing {genre}...")
        genre_path = os.path.join(data_path, genre)
        files = [f for f in os.listdir(genre_path) if f.endswith('.wav')]
        
        for filename in tqdm(files, desc=f"{genre}"):
            if filename == 'jazz.00054.wav':
                print(f"Skipping corrupted file: {filename}")
                continue
            
            filepath = os.path.join(genre_path, filename)
            try:
                melspec = extract_melspectrogram(filepath)
                
                # Pad or truncate to target length
                if melspec.shape[1] < target_length:
                    pad_width = target_length - melspec.shape[1]
                    melspec = np.pad(melspec, ((0, 0), (0, pad_width)), mode='constant')
                else:
                    melspec = melspec[:, :target_length]
                
                # Segment into 15 overlapping segments
                segments = segment_spectrogram(melspec)
                features.append(segments)
                labels.append(genre_idx)
            except Exception as e:
                print(f"Error processing {filepath}: {e}")
    
    return np.array(features), np.array(labels), genres

## 4. Load Data

In [None]:
# Load dataset
X, y, genre_names = load_gtzan_segmented(DATA_PATH)

print(f"\n✓ Dataset loaded")
print(f"  Features shape: {X.shape}")  # (999, 15, 87, 64)
print(f"  Labels shape: {y.shape}")    # (999,)
print(f"  Genres: {genre_names}")

## 5. Train/Val/Test Split

In [None]:
# Split: 70% train, 15% val, 15% test
X_train, X_temp, y_train, y_temp = train_test_split(
    X, y, test_size=0.3, random_state=42, stratify=y
)
X_val, X_test, y_val, y_test = train_test_split(
    X_temp, y_temp, test_size=0.5, random_state=42, stratify=y_temp
)

print(f"Train: {X_train.shape[0]} samples")
print(f"Val:   {X_val.shape[0]} samples")
print(f"Test:  {X_test.shape[0]} samples")

## 6. Normalize Data

In [None]:
# Reshape for normalization
X_train_flat = X_train.reshape(-1, 87*64)
X_val_flat = X_val.reshape(-1, 87*64)
X_test_flat = X_test.reshape(-1, 87*64)

# Normalize
scaler = StandardScaler()
X_train_norm = scaler.fit_transform(X_train_flat).reshape(X_train.shape)
X_val_norm = scaler.transform(X_val_flat).reshape(X_val.shape)
X_test_norm = scaler.transform(X_test_flat).reshape(X_test.shape)

# Add channel dimension
X_train_norm = X_train_norm[..., np.newaxis]  # (n, 15, 87, 64, 1)
X_val_norm = X_val_norm[..., np.newaxis]
X_test_norm = X_test_norm[..., np.newaxis]

# One-hot encode labels
y_train = to_categorical(y_train, 10)
y_val = to_categorical(y_val, 10)
y_test = to_categorical(y_test, 10)

print("✓ Data normalized and encoded")
print(f"  X_train shape: {X_train_norm.shape}")
print(f"  y_train shape: {y_train.shape}")

## 7. Multi-Head Attention Layer

In [None]:
class MultiHeadAttention(Layer):
    """
    Multi-head attention layer for temporal weighting.
    """
    def __init__(self, num_heads=8, d_model=128, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.d_model = d_model
        self.depth = d_model // num_heads
        
        self.wq = Dense(d_model)
        self.wk = Dense(d_model)
        self.wv = Dense(d_model)
        self.dense = Dense(d_model)
    
    def split_heads(self, x, batch_size):
        """Split the last dimension into (num_heads, depth)"""
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
        return tf.transpose(x, perm=[0, 2, 1, 3])
    
    def call(self, inputs):
        batch_size = tf.shape(inputs)[0]
        
        # Linear projections
        q = self.wq(inputs)
        k = self.wk(inputs)
        v = self.wv(inputs)
        
        # Split into multiple heads
        q = self.split_heads(q, batch_size)
        k = self.split_heads(k, batch_size)
        v = self.split_heads(v, batch_size)
        
        # Scaled dot-product attention
        matmul_qk = tf.matmul(q, k, transpose_b=True)
        dk = tf.cast(self.depth, tf.float32)
        scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
        
        # Attention weights
        attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
        
        # Apply attention to values
        output = tf.matmul(attention_weights, v)
        
        # Concatenate heads
        output = tf.transpose(output, perm=[0, 2, 1, 3])
        output = tf.reshape(output, (batch_size, -1, self.d_model))
        
        # Final linear projection
        output = self.dense(output)
        
        return output, attention_weights

## 8. Improved CNN Feature Extractor

In [None]:
def build_cnn_feature_extractor():
    """
    Build improved CNN to extract features from each 2-second segment.
    Input: (87, 64, 1)
    Output: (256,) feature vector
    
    Improvements:
    - 4 convolutional blocks (32→64→128→256 filters)
    - ELU activation for smoother gradients
    - Batch normalization after each conv
    - Progressive dropout (0.2 → 0.4)
    """
    inputs = Input(shape=(87, 64, 1))
    
    # Conv Block 1
    x = Conv2D(32, (3, 3), padding='same', activation='elu')(inputs)
    x = BatchNormalization()(x)
    x = MaxPooling2D((2, 2))(x)
    x = Dropout(0.2)(x)
    
    # Conv Block 2
    x = Conv2D(64, (3, 3), padding='same', activation='elu')(x)
    x = BatchNormalization()(x)
    x = MaxPooling2D((2, 2))(x)
    x = Dropout(0.2)(x)
    
    # Conv Block 3
    x = Conv2D(128, (3, 3), padding='same', activation='elu')(x)
    x = BatchNormalization()(x)
    x = MaxPooling2D((2, 2))(x)
    x = Dropout(0.3)(x)
    
    # Conv Block 4 (final)
    x = Conv2D(256, (3, 3), padding='same', activation='elu')(x)
    x = BatchNormalization()(x)
    x = GlobalAveragePooling2D()(x)
    x = Dropout(0.4)(x)
    
    model = Model(inputs, x, name='cnn_feature_extractor')
    return model

## 9. Build CNN + Attention Model

In [None]:
def build_cnn_attention_model():
    """
    Build improved CNN + Multi-Head Attention model.
    
    Architecture:
    - CNN feature extractor (TimeDistributed)
    - Multi-head attention (8 heads, d_model=256)
    - Global average pooling
    - Two-layer dense classifier (256→128→10)
    """
    segment_input = Input(shape=(15, 87, 64, 1), name='segment_input')
    
    # Apply CNN to each segment independently
    cnn_extractor = build_cnn_feature_extractor()
    cnn_features = TimeDistributed(cnn_extractor, name='time_distributed_cnn')(segment_input)
    
    # Multi-head attention (8 heads, d_model=256)
    attn_layer = MultiHeadAttention(num_heads=8, d_model=256, name='multi_head_attention')
    attn_output, attn_weights = attn_layer(cnn_features)
    
    # Global pooling over temporal dimension
    x = GlobalAveragePooling1D(name='global_avgpool')(attn_output)
    
    # Classification head
    x = Dense(256, activation='elu', name='dense_1')(x)
    x = Dropout(0.5, name='dropout_1')(x)
    x = Dense(128, activation='elu', name='dense_2')(x)
    x = Dropout(0.3, name='dropout_2')(x)
    outputs = Dense(10, activation='softmax', name='output')(x)
    
    # Build model
    model = Model(segment_input, outputs, name='cnn_attention_model_improved')
    
    # Compile
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )
    
    return model

# Build model
model = build_cnn_attention_model()

# Display model architecture
model.summary()

## 10. Training

In [None]:
# Callbacks
early_stop = EarlyStopping(
    monitor='val_loss',
    patience=15,
    restore_best_weights=True,
    verbose=1
)

reduce_lr = ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.5,
    patience=5,
    min_lr=1e-6,
    verbose=1
)

# Train model
history = model.fit(
    X_train_norm, y_train,
    validation_data=(X_val_norm, y_val),
    batch_size=32,
    epochs=100,
    callbacks=[early_stop, reduce_lr],
    verbose=1
)

## 11. Training History

In [None]:
# Plot training history
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss
axes[0].plot(history.history['loss'], label='Train Loss', linewidth=2)
axes[0].plot(history.history['val_loss'], label='Val Loss', linewidth=2)
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_title('Training & Validation Loss', fontsize=14, fontweight='bold')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Accuracy
axes[1].plot(history.history['accuracy'], label='Train Accuracy', linewidth=2)
axes[1].plot(history.history['val_accuracy'], label='Val Accuracy', linewidth=2)
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('Accuracy', fontsize=12)
axes[1].set_title('Training & Validation Accuracy', fontsize=14, fontweight='bold')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('cnn_attention_improved_training_history.png', dpi=300, bbox_inches='tight')
plt.show()

## 12. Evaluation

In [None]:
# Evaluate on test set
test_loss, test_acc = model.evaluate(X_test_norm, y_test, verbose=0)
print(f"\n{'='*50}")
print(f"Test Accuracy: {test_acc*100:.2f}%")
print(f"Test Loss: {test_loss:.4f}")
print(f"{'='*50}\n")

# Predictions
y_pred = model.predict(X_test_norm, verbose=0)
y_pred_labels = np.argmax(y_pred, axis=1)
y_true_labels = np.argmax(y_test, axis=1)

# Classification report
print("\nClassification Report:")
print(classification_report(y_true_labels, y_pred_labels, target_names=genre_names))

## 13. Confusion Matrix

In [None]:
# Confusion matrix
cm = confusion_matrix(y_true_labels, y_pred_labels)

plt.figure(figsize=(12, 10))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=genre_names, yticklabels=genre_names,
            cbar_kws={'label': 'Count'})
plt.xlabel('Predicted Genre', fontsize=12, fontweight='bold')
plt.ylabel('True Genre', fontsize=12, fontweight='bold')
plt.title('CNN + Attention (Improved) - Confusion Matrix', fontsize=14, fontweight='bold')
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.savefig('cnn_attention_improved_confusion_matrix.png', dpi=300, bbox_inches='tight')
plt.show()

## 14. Save Model

In [None]:
# Save model
model.save('cnn_attention_improved.keras')
print("✓ Model saved: cnn_attention_improved.keras")

# Save training history
np.save('cnn_attention_improved_history.npy', history.history)
print("✓ Training history saved: cnn_attention_improved_history.npy")

## Summary

This improved CNN + Multi-Head Attention model includes:

**Architecture Improvements:**
- 4-block CNN feature extractor (32→64→128→256 filters)
- ELU activation functions for smoother gradients
- Batch normalization after each convolution
- Progressive dropout (0.2 → 0.4) for regularization
- 8-head attention mechanism with d_model=256
- Two-layer classification head (256→128→10)

**Training Improvements:**
- Learning rate scheduler (ReduceLROnPlateau)
- Increased early stopping patience (15 epochs)
- Adam optimizer with initial LR=0.001

**Expected Performance:**
- Target accuracy: 70-85% (significant improvement over baseline 64%)
- Better generalization through enhanced regularization
- More stable training with learning rate scheduling