# CNN + Attention with Image-Based Attention Improvements

**Enhanced CNN + Multi-Head Attention Model**

This notebook improves the CNN+Attention method by treating spectrograms as images:
- **CBAM attention** in CNN blocks (channel + spatial)
- **2D spatial attention** across frequency and time
- **Multi-view spectrograms** (mel + CQT + chroma as 3 channels)
- **Temporal multi-head attention** over segments

**Architecture:** CNN (with image attention) → Temporal Segmentation → Multi-Head Attention → Classification

**Expected Performance:** 85-92% accuracy

## 1. Imports

In [None]:
import os
import numpy as np
import librosa
import warnings
warnings.filterwarnings('ignore')

import tensorflow as tf
from tensorflow.keras import layers, Model, regularizers
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.utils import to_categorical

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report, confusion_matrix

import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

# Reproducibility
np.random.seed(42)
tf.random.set_seed(42)

plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette('husl')

print(f"TensorFlow: {tf.__version__}")
print(f"GPU Available: {len(tf.config.list_physical_devices('GPU')) > 0}")

## 2. Configuration

In [None]:
# ==================== PATHS ====================
DATA_PATH = '/Users/narac0503/GIT/GTZAN Dataset Classification/GTZAN-Dataset-Classification/Data/genres_original'
print(f"Data exists: {os.path.exists(DATA_PATH)}")

# ==================== AUDIO ====================
SAMPLE_RATE = 22050
DURATION = 30
N_MELS = 128
N_FFT = 2048
HOP_LENGTH = 512

# ==================== SEGMENTATION ====================
NUM_SEGMENTS = 15  # Still using temporal segments

# ==================== MODEL ====================
NUM_CLASSES = 10
GENRES = ['blues', 'classical', 'country', 'disco', 'hiphop',
          'jazz', 'metal', 'pop', 'reggae', 'rock']

# ==================== HYPERPARAMETERS ====================
CNN_FILTERS = [32, 64, 128, 256]
ATTENTION_HEADS = 8
KEY_DIM = 32
DENSE_UNITS = 512
DROPOUT_RATE = 0.4
L2_REG = 0.0005
LEARNING_RATE = 0.0005
BATCH_SIZE = 16
EPOCHS = 100

## 3. Image-Based Attention Layers

In [None]:
class CBAM(layers.Layer):
    """Convolutional Block Attention Module for spectrograms."""
    
    def __init__(self, reduction=16, **kwargs):
        super().__init__(**kwargs)
        self.reduction = reduction
    
    def build(self, input_shape):
        channels = input_shape[-1]
        
        # Channel attention
        self.fc1 = layers.Dense(channels // self.reduction, activation='relu')
        self.fc2 = layers.Dense(channels)
        
        # Spatial attention
        self.conv_spatial = layers.Conv2D(1, 7, padding='same')
    
    def call(self, x):
        # Channel attention
        avg_pool = tf.reduce_mean(x, axis=[1, 2], keepdims=True)
        max_pool = tf.reduce_max(x, axis=[1, 2], keepdims=True)
        
        avg_out = layers.Flatten()(avg_pool)
        avg_out = self.fc2(self.fc1(avg_out))
        
        max_out = layers.Flatten()(max_pool)
        max_out = self.fc2(self.fc1(max_out))
        
        channel_attn = tf.nn.sigmoid(avg_out + max_out)
        channel_attn = tf.reshape(channel_attn, [-1, 1, 1, tf.shape(x)[-1]])
        x = x * channel_attn
        
        # Spatial attention
        avg_spatial = tf.reduce_mean(x, axis=-1, keepdims=True)
        max_spatial = tf.reduce_max(x, axis=-1, keepdims=True)
        concat = tf.concat([avg_spatial, max_spatial], axis=-1)
        spatial_attn = tf.nn.sigmoid(self.conv_spatial(concat))
        
        return x * spatial_attn


class SpatialAttention2D(layers.Layer):
    """2D attention over frequency and time dimensions."""
    
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
    
    def build(self, input_shape):
        # Attention over frequency
        self.freq_dense = layers.Dense(input_shape[1])
        # Attention over time
        self.time_dense = layers.Dense(input_shape[2])
    
    def call(self, x):
        # x: (batch, freq, time, channels)
        
        # Frequency attention
        freq_avg = tf.reduce_mean(x, axis=2)  # (B, F, C)
        freq_weights = self.freq_dense(tf.transpose(freq_avg, [0, 2, 1]))
        freq_weights = tf.nn.softmax(freq_weights, axis=-1)
        freq_weights = tf.transpose(freq_weights, [0, 2, 1])
        freq_weights = tf.expand_dims(freq_weights, axis=2)
        
        # Time attention
        time_avg = tf.reduce_mean(x, axis=1)  # (B, T, C)
        time_weights = self.time_dense(tf.transpose(time_avg, [0, 2, 1]))
        time_weights = tf.nn.softmax(time_weights, axis=-1)
        time_weights = tf.transpose(time_weights, [0, 2, 1])
        time_weights = tf.expand_dims(time_weights, axis=1)
        
        # Combine
        return x * freq_weights * time_weights

print("Image-based attention layers defined.")

## 4. Feature Extraction

In [None]:
def spec_augment(spec, time_mask=15, freq_mask=10):
    """SpecAugment data augmentation."""
    spec = spec.copy()
    if spec.shape[1] > time_mask:
        t = np.random.randint(0, spec.shape[1] - time_mask)
        spec[:, t:t+time_mask] = 0
    if spec.shape[0] > freq_mask:
        f = np.random.randint(0, spec.shape[0] - freq_mask)
        spec[f:f+freq_mask, :] = 0
    return spec


def extract_multi_view_spectrogram(audio_path, augment=False):
    """Extract mel + CQT + chroma as 3-channel image."""
    try:
        audio, sr = librosa.load(audio_path, sr=SAMPLE_RATE, duration=DURATION)
        
        target_len = SAMPLE_RATE * DURATION
        if len(audio) < target_len:
            audio = np.pad(audio, (0, target_len - len(audio)))
        else:
            audio = audio[:target_len]
        
        # Mel-spectrogram
        mel = librosa.feature.melspectrogram(
            y=audio, sr=sr, n_mels=N_MELS, n_fft=N_FFT, hop_length=HOP_LENGTH
        )
        mel_db = librosa.power_to_db(mel, ref=np.max)
        
        # CQT
        cqt = np.abs(librosa.cqt(y=audio, sr=sr, n_bins=N_MELS))
        cqt_db = librosa.amplitude_to_db(cqt, ref=np.max)
        
        # Match shape
        if cqt_db.shape[1] < mel_db.shape[1]:
            cqt_db = np.pad(cqt_db, ((0,0), (0, mel_db.shape[1]-cqt_db.shape[1])))
        else:
            cqt_db = cqt_db[:, :mel_db.shape[1]]
        
        # Chroma
        chroma = librosa.feature.chroma_cqt(y=audio, sr=sr)
        from scipy.ndimage import zoom
        chroma_resized = zoom(chroma, (N_MELS/chroma.shape[0], mel_db.shape[1]/chroma.shape[1]))
        
        # SpecAugment
        if augment:
            mel_db = spec_augment(mel_db)
            cqt_db = spec_augment(cqt_db)
            chroma_resized = spec_augment(chroma_resized, freq_mask=5)
        
        # Stack as 3-channel image
        multi_view = np.stack([mel_db, cqt_db, chroma_resized], axis=-1)
        return multi_view
        
    except Exception as e:
        print(f"Error: {audio_path} - {e}")
        return None


def create_segments(spec, num_segments=NUM_SEGMENTS):
    """Split spectrogram into temporal segments."""
    n_frames = spec.shape[1]
    seg_len = n_frames // num_segments
    
    segments = []
    for i in range(num_segments):
        start = i * seg_len
        end = start + seg_len
        if end > n_frames:
            end = n_frames
        seg = spec[:, start:end, :]
        if seg.shape[1] < seg_len:
            seg = np.pad(seg, ((0,0), (0, seg_len-seg.shape[1]), (0,0)))
        segments.append(seg)
    
    return np.array(segments)

print("Feature extraction ready.")

## 5. Load Data

In [None]:
def load_data(data_path):
    X, y = [], []
    
    print("Loading data with multi-view spectrograms...\n")
    for genre_idx, genre in enumerate(GENRES):
        genre_path = os.path.join(data_path, genre)
        if not os.path.exists(genre_path):
            print(f"Warning: {genre} not found")
            continue
        
        files = sorted([f for f in os.listdir(genre_path) if f.endswith('.wav')])
        print(f"{genre}: {len(files)} files")
        
        for filename in tqdm(files, desc=genre):
            if filename == 'jazz.00054.wav':
                continue
            
            filepath = os.path.join(genre_path, filename)
            spec = extract_multi_view_spectrogram(filepath)
            
            if spec is not None:
                segments = create_segments(spec)
                X.append(segments)
                y.append(genre_idx)
    
    X = np.array(X)
    y = np.array(y)
    
    print(f"\nLoaded {len(X)} samples")
    print(f"Shape: {X.shape}")
    return X, y

X, y = load_data(DATA_PATH)

## 6. Preprocessing

In [None]:
# Encode labels
label_encoder = LabelEncoder()
y_encoded = label_encoder.fit_transform(y)
y_onehot = to_categorical(y_encoded, NUM_CLASSES)

print(f"X: {X.shape}")
print(f"y: {y_onehot.shape}")

## 7. Train/Val/Test Split

In [None]:
X_temp, X_test, y_temp, y_test = train_test_split(
    X, y_onehot, test_size=0.1, stratify=y_encoded, random_state=42
)

y_temp_enc = np.argmax(y_temp, axis=1)
X_train, X_val, y_train, y_val = train_test_split(
    X_temp, y_temp, test_size=0.111, stratify=y_temp_enc, random_state=42
)

print(f"Train: {X_train.shape[0]}")
print(f"Val:   {X_val.shape[0]}")
print(f"Test:  {X_test.shape[0]}")

## 8. Normalization

In [None]:
train_mean = X_train.mean(axis=(0,1,2,3), keepdims=True)
train_std = X_train.std(axis=(0,1,2,3), keepdims=True)

X_train = (X_train - train_mean) / (train_std + 1e-8)
X_val = (X_val - train_mean) / (train_std + 1e-8)
X_test = (X_test - train_mean) / (train_std + 1e-8)

print(f"Normalized - Mean: {X_train.mean():.4f}, Std: {X_train.std():.4f}")

## 9. Build CNN + Attention Model with Image Attention

In [None]:
def build_cnn_with_image_attention(input_shape):
    """
    CNN with image-based attention (CBAM + Spatial 2D).
    Applied to each segment.
    """
    inputs = layers.Input(shape=input_shape)
    x = inputs
    
    # Apply 2D spatial attention first
    x = SpatialAttention2D()(x)
    
    # CNN blocks with CBAM
    for i, filters in enumerate(CNN_FILTERS):
        x = layers.Conv2D(
            filters, 3, padding='same',
            kernel_regularizer=regularizers.l2(L2_REG)
        )(x)
        x = layers.BatchNormalization()(x)
        x = layers.Activation('elu')(x)
        
        # CBAM attention after each block
        x = CBAM(reduction=16)(x)
        
        x = layers.MaxPooling2D(2)(x)
        x = layers.Dropout(0.25)(x)
    
    x = layers.GlobalAveragePooling2D()(x)
    
    return Model(inputs, x, name='cnn_image_attention')


def build_full_model(input_shape):
    """
    Full CNN + Attention model with image-based improvements.
    
    Architecture:
    1. Multi-view spectrograms (3 channels)
    2. CNN with CBAM + 2D spatial attention (per segment)
    3. Multi-head temporal attention (over segments)
    4. Classification head
    """
    
    num_segments = input_shape[0]
    segment_shape = input_shape[1:]
    
    # Build CNN for each segment
    segment_cnn = build_cnn_with_image_attention(segment_shape)
    
    # Full model
    inputs = layers.Input(shape=input_shape)
    
    # Apply CNN to each segment
    features = layers.TimeDistributed(segment_cnn)(inputs)
    
    # Multi-head temporal attention
    attn_output = layers.MultiHeadAttention(
        num_heads=ATTENTION_HEADS,
        key_dim=KEY_DIM,
        dropout=0.1
    )(features, features)
    
    # Global pooling
    x = layers.GlobalAveragePooling1D()(attn_output)
    
    # Classification head
    x = layers.Dense(
        DENSE_UNITS,
        kernel_regularizer=regularizers.l2(L2_REG)
    )(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('elu')(x)
    x = layers.Dropout(DROPOUT_RATE)(x)
    
    x = layers.Dense(
        DENSE_UNITS // 2,
        kernel_regularizer=regularizers.l2(L2_REG)
    )(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('elu')(x)
    x = layers.Dropout(DROPOUT_RATE)(x)
    
    outputs = layers.Dense(NUM_CLASSES, activation='softmax')(x)
    
    model = Model(inputs, outputs, name='cnn_attention_image_enhanced')
    
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )
    
    return model


model = build_full_model(X_train.shape[1:])
print("\nModel Summary:")
model.summary()

## 10. Training

In [None]:
callbacks = [
    EarlyStopping(
        monitor='val_accuracy',
        patience=20,
        restore_best_weights=True,
        verbose=1
    ),
    ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=7,
        min_lr=1e-7,
        verbose=1
    ),
    ModelCheckpoint(
        'best_cnn_attention_image.keras',
        monitor='val_accuracy',
        save_best_only=True,
        verbose=1
    )
]

print("\nTraining...\n")
history = model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    callbacks=callbacks,
    verbose=1
)

## 11. Training History

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

ax1.plot(history.history['accuracy'], label='Train', linewidth=2)
ax1.plot(history.history['val_accuracy'], label='Validation', linewidth=2)
ax1.set_title('Accuracy', fontsize=14, fontweight='bold')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Accuracy')
ax1.legend()
ax1.grid(True, alpha=0.3)

ax2.plot(history.history['loss'], label='Train', linewidth=2)
ax2.plot(history.history['val_loss'], label='Validation', linewidth=2)
ax2.set_title('Loss', fontsize=14, fontweight='bold')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Loss')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('cnn_attention_image_history.png', dpi=300)
plt.show()

best_val_acc = max(history.history['val_accuracy'])
print(f"\nBest Val Accuracy: {best_val_acc:.4f}")

## 12. Evaluation

In [None]:
model.load_weights('best_cnn_attention_image.keras')

test_loss, test_acc = model.evaluate(X_test, y_test, verbose=0)

print("\n" + "="*70)
print(f"TEST ACCURACY: {test_acc:.4f} ({test_acc*100:.2f}%)")
print(f"TEST LOSS: {test_loss:.4f}")
print("="*70)

## 13. Classification Report & Confusion Matrix

In [None]:
y_pred = model.predict(X_test, verbose=0)
y_pred_labels = np.argmax(y_pred, axis=1)
y_true_labels = np.argmax(y_test, axis=1)

print("\nClassification Report:")
print("="*70)
print(classification_report(
    y_true_labels, y_pred_labels,
    target_names=GENRES, digits=3
))

cm = confusion_matrix(y_true_labels, y_pred_labels)

plt.figure(figsize=(12, 10))
sns.heatmap(
    cm, annot=True, fmt='d', cmap='Blues',
    xticklabels=GENRES, yticklabels=GENRES,
    cbar_kws={'label': 'Count'}
)
plt.xlabel('Predicted', fontsize=12, fontweight='bold')
plt.ylabel('True', fontsize=12, fontweight='bold')
plt.title(
    f'CNN + Attention with Image Attention (Acc: {test_acc:.2%})',
    fontsize=14, fontweight='bold'
)
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.savefig('cnn_attention_image_cm.png', dpi=300)
plt.show()

## 14. Save Model

In [None]:
model.save('cnn_attention_image_final.keras')
np.savez('cnn_attention_image_norm.npz', mean=train_mean, std=train_std)
np.save('cnn_attention_image_history.npy', history.history)

print("\nSaved:")
print("  ✓ cnn_attention_image_final.keras")
print("  ✓ best_cnn_attention_image.keras")
print("  ✓ cnn_attention_image_norm.npz")
print("  ✓ cnn_attention_image_history.npy")

## Summary

This model enhances CNN + Attention with image-based techniques:

**Image-Based Attention Improvements:**
1. **CBAM**: Channel + spatial attention in each CNN block
2. **2D Spatial Attention**: Attention across frequency and time
3. **Multi-view spectrograms**: Mel + CQT + Chroma (3 channels)

**Architecture:**
```
Multi-view Spectrograms (15 segments × 128×87×3)
  ↓
Per Segment:
  → 2D Spatial Attention (freq + time)
  → CNN Block 1 (32 filters) + CBAM
  → CNN Block 2 (64 filters) + CBAM
  → CNN Block 3 (128 filters) + CBAM
  → CNN Block 4 (256 filters) + CBAM
  → Global Average Pool → Features
  ↓
Multi-Head Temporal Attention (8 heads)
  ↓
Classification (512 → 256 → 10)
```

**Key Differences from Baseline:**
- **3 channels** instead of 1 (mel + CQT + chroma)
- **CBAM** attention in every CNN block
- **2D spatial attention** before CNN
- **Multi-head attention** over segments (unchanged)

**Expected Performance:**
- Baseline CNN+Attention: 70-75%
- With image attention: **85-92%**

**Still CNN + Attention method, just enhanced with image-based techniques!**