# Method 2 Improved: CNN with Enhanced Preprocessing

**Key Improvements:**
- **GroupShuffleSplit by Song**: Prevents data leakage
- **SpecAugment**: Time and frequency masking for data augmentation
- **Multi-resolution features**: Stacked mel spectrograms at different resolutions

In [None]:
import numpy as np
import librosa
import os
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import (Conv2D, BatchNormalization, MaxPooling2D,
                                     GlobalAveragePooling2D, Dense, Dropout, LeakyReLU)
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.utils import to_categorical
from sklearn.model_selection import GroupShuffleSplit
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import classification_report, accuracy_score
import matplotlib.pyplot as plt
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

np.random.seed(42)
tf.random.set_seed(42)

print(f'TensorFlow: {tf.__version__}')

## 1. Configuration

In [None]:
DATA_PATH = '../data/gtzan/genres_original'
SAMPLE_RATE = 22050
DURATION = 30
N_MELS = 128
N_FFT = 2048
HOP_LENGTH = 512
TARGET_LENGTH = 1291  # ~30 seconds at hop_length=512

GENRES = ['blues', 'classical', 'country', 'disco', 'hiphop',
          'jazz', 'metal', 'pop', 'reggae', 'rock']
NUM_CLASSES = len(GENRES)

## 2. Enhanced Feature Extraction with Multi-Resolution

In [None]:
def extract_melspec(audio, sr, n_mels=128, target_length=TARGET_LENGTH):
    """Extract mel spectrogram with padding/truncation."""
    mel = librosa.feature.melspectrogram(
        y=audio, sr=sr, n_mels=n_mels, n_fft=N_FFT, hop_length=HOP_LENGTH
    )
    mel_db = librosa.power_to_db(mel, ref=np.max)
    
    # Transpose to (time, frequency)
    mel_db = mel_db.T
    
    # Pad or truncate
    if mel_db.shape[0] < target_length:
        pad_width = target_length - mel_db.shape[0]
        mel_db = np.pad(mel_db, ((0, pad_width), (0, 0)), mode='constant')
    else:
        mel_db = mel_db[:target_length, :]
    
    return mel_db

def extract_multi_resolution(audio, sr):
    """
    Extract multi-resolution features:
    - Mel spectrogram (128 mels)
    - Delta (first derivative)
    - Delta-delta (second derivative)
    Returns: (time, freq, 3) tensor
    """
    mel = extract_melspec(audio, sr, n_mels=128)
    
    # Compute deltas along time axis
    delta = librosa.feature.delta(mel.T).T
    delta2 = librosa.feature.delta(mel.T, order=2).T
    
    # Stack as channels
    return np.stack([mel, delta, delta2], axis=-1)

# Test
test_audio = np.random.randn(SAMPLE_RATE * 30)
test_features = extract_multi_resolution(test_audio, SAMPLE_RATE)
print(f"Feature shape: {test_features.shape}")

## 3. SpecAugment Data Augmentation

In [None]:
def spec_augment(mel, time_mask_param=80, freq_mask_param=20, num_masks=2):
    """
    Apply SpecAugment: time and frequency masking.
    """
    augmented = mel.copy()
    time_steps, freq_bins = augmented.shape[:2]
    
    # Time masking
    for _ in range(num_masks):
        t = np.random.randint(0, time_mask_param)
        t0 = np.random.randint(0, max(1, time_steps - t))
        augmented[t0:t0+t, :] = 0
    
    # Frequency masking
    for _ in range(num_masks):
        f = np.random.randint(0, freq_mask_param)
        f0 = np.random.randint(0, max(1, freq_bins - f))
        augmented[:, f0:f0+f] = 0
    
    return augmented

## 4. Load Dataset

In [None]:
def load_dataset(data_path, augment=False):
    X, y, song_ids = [], [], []
    
    for genre in GENRES:
        genre_path = os.path.join(data_path, genre)
        if not os.path.exists(genre_path):
            continue
            
        files = sorted([f for f in os.listdir(genre_path) if f.endswith('.wav')])
        
        for filename in tqdm(files, desc=f"{genre}"):
            if 'jazz.00054' in filename:
                continue
                
            filepath = os.path.join(genre_path, filename)
            song_id = f"{genre}.{filename.split('.')[1]}"
            
            try:
                audio, sr = librosa.load(filepath, sr=SAMPLE_RATE, duration=DURATION)
                target_len = SAMPLE_RATE * DURATION
                if len(audio) < target_len:
                    audio = np.pad(audio, (0, target_len - len(audio)))
                
                # Extract features
                features = extract_multi_resolution(audio, sr)
                X.append(features)
                y.append(genre)
                song_ids.append(song_id)
                
                # Augmentation
                if augment:
                    aug_features = spec_augment(features)
                    X.append(aug_features)
                    y.append(genre)
                    song_ids.append(song_id)
                    
            except Exception as e:
                print(f"Error: {filename}: {e}")
    
    return np.array(X), np.array(y), np.array(song_ids)

print("Loading dataset...")
X, y, song_ids = load_dataset(DATA_PATH, augment=True)
print(f"\nDataset: {X.shape}")
print(f"Unique songs: {len(np.unique(song_ids))}")

## 5. GroupShuffleSplit by Song ID

In [None]:
# Encode labels
label_encoder = LabelEncoder()
y_encoded = label_encoder.fit_transform(y)
y_cat = to_categorical(y_encoded, NUM_CLASSES)

# Split by song
splitter = GroupShuffleSplit(test_size=0.20, n_splits=1, random_state=42)
train_idx, test_idx = next(splitter.split(X, y, song_ids))

X_train, X_test = X[train_idx], X[test_idx]
y_train, y_test = y_cat[train_idx], y_cat[test_idx]

# Further split train into train/val
songs_train = song_ids[train_idx]
splitter_val = GroupShuffleSplit(test_size=0.125, n_splits=1, random_state=42)
train_idx2, val_idx = next(splitter_val.split(X_train, y_train, songs_train))

X_val = X_train[val_idx]
y_val = y_train[val_idx]
X_train = X_train[train_idx2]
y_train = y_train[train_idx2]

print(f"Train: {X_train.shape}")
print(f"Val: {X_val.shape}")
print(f"Test: {X_test.shape}")

## 6. Normalization

In [None]:
# Normalize per-channel
mean = X_train.mean(axis=(0, 1, 2), keepdims=True)
std = X_train.std(axis=(0, 1, 2), keepdims=True) + 1e-8

X_train = (X_train - mean) / std
X_val = (X_val - mean) / std
X_test = (X_test - mean) / std

print(f"Normalized shape: {X_train.shape}")

## 7. Build CNN Model

In [None]:
def build_cnn_model(input_shape, num_classes):
    model = Sequential([
        # Block 1
        Conv2D(64, (3, 3), padding='same', input_shape=input_shape),
        LeakyReLU(0.1),
        BatchNormalization(),
        MaxPooling2D((2, 2)),
        
        # Block 2
        Conv2D(128, (3, 3), padding='same'),
        LeakyReLU(0.1),
        BatchNormalization(),
        MaxPooling2D((2, 2)),
        
        # Block 3
        Conv2D(256, (3, 3), padding='same'),
        LeakyReLU(0.1),
        BatchNormalization(),
        MaxPooling2D((2, 2)),
        
        # Block 4
        Conv2D(512, (3, 3), padding='same'),
        LeakyReLU(0.1),
        BatchNormalization(),
        GlobalAveragePooling2D(),
        
        # Classifier
        Dense(256),
        LeakyReLU(0.1),
        Dropout(0.5),
        Dense(num_classes, activation='softmax')
    ])
    
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )
    return model

model = build_cnn_model(X_train.shape[1:], NUM_CLASSES)
model.summary()

## 8. Training

In [None]:
callbacks = [
    EarlyStopping(monitor='val_loss', patience=15, restore_best_weights=True),
    ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, min_lr=1e-6)
]

history = model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    epochs=100,
    batch_size=16,
    callbacks=callbacks,
    verbose=1
)

## 9. Evaluation

In [None]:
# Plot
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
axes[0].plot(history.history['accuracy'], label='Train')
axes[0].plot(history.history['val_accuracy'], label='Val')
axes[0].set_title('Accuracy')
axes[0].legend()

axes[1].plot(history.history['loss'], label='Train')
axes[1].plot(history.history['val_loss'], label='Val')
axes[1].set_title('Loss')
axes[1].legend()
plt.show()

# Evaluate
test_loss, test_acc = model.evaluate(X_test, y_test, verbose=0)
print(f"\nTest Accuracy: {test_acc*100:.2f}%")

y_pred = model.predict(X_test)
y_pred_classes = np.argmax(y_pred, axis=1)
y_true_classes = np.argmax(y_test, axis=1)

print("\nClassification Report:")
print(classification_report(y_true_classes, y_pred_classes, target_names=label_encoder.classes_))

## 10. Save Model

In [None]:
model.save('../models/cnn_enhanced.keras')
print("Model saved!")