# Method 1 Improved: LSTM with Enhanced Features

**Key Improvements:**
- **GroupShuffleSplit by Song ID**: Prevents data leakage by keeping all segments of a song together
- **Enhanced Feature Engineering**: Delta MFCCs, spectral contrast, tonnetz, chroma CQT
- **Better Architecture**: Bidirectional LSTM with attention-like pooling

In [None]:
import numpy as np
import pandas as pd
import librosa
import os
import tensorflow as tf
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import (LSTM, Bidirectional, Dense, Dropout, 
                                     BatchNormalization, Input, GlobalAveragePooling1D)
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.utils import to_categorical
from sklearn.model_selection import GroupShuffleSplit
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

np.random.seed(42)
tf.random.set_seed(42)

print(f'TensorFlow: {tf.__version__}')

## 1. Configuration

In [None]:
DATA_PATH = '../data/gtzan/genres_original'
SAMPLE_RATE = 22050
DURATION = 3  # seconds per segment
SAMPLES_PER_SEGMENT = SAMPLE_RATE * DURATION

GENRES = ['blues', 'classical', 'country', 'disco', 'hiphop',
          'jazz', 'metal', 'pop', 'reggae', 'rock']
NUM_CLASSES = len(GENRES)

## 2. Enhanced Feature Extraction
We extract more features than the basic CSV, including delta MFCCs and tonal features.

In [None]:
def extract_enhanced_features(audio, sr):
    """
    Extract enhanced audio features including:
    - MFCCs (20) + Delta + Delta-Delta
    - Chroma (CQT-based)
    - Spectral Contrast
    - Tonnetz
    - Spectral features (centroid, bandwidth, rolloff, flatness)
    - Zero crossing rate
    - RMS energy
    """
    features = []
    
    # MFCCs (20 coefficients)
    mfcc = librosa.feature.mfcc(y=audio, sr=sr, n_mfcc=20)
    mfcc_delta = librosa.feature.delta(mfcc)
    mfcc_delta2 = librosa.feature.delta(mfcc, order=2)
    
    for i in range(20):
        features.extend([mfcc[i].mean(), mfcc[i].std()])
        features.extend([mfcc_delta[i].mean(), mfcc_delta[i].std()])
        features.extend([mfcc_delta2[i].mean(), mfcc_delta2[i].std()])
    
    # Chroma CQT (12 bins)
    chroma = librosa.feature.chroma_cqt(y=audio, sr=sr)
    for i in range(12):
        features.extend([chroma[i].mean(), chroma[i].std()])
    
    # Spectral Contrast (7 bands)
    contrast = librosa.feature.spectral_contrast(y=audio, sr=sr)
    for i in range(7):
        features.extend([contrast[i].mean(), contrast[i].std()])
    
    # Tonnetz (6 dimensions)
    tonnetz = librosa.feature.tonnetz(y=librosa.effects.harmonic(audio), sr=sr)
    for i in range(6):
        features.extend([tonnetz[i].mean(), tonnetz[i].std()])
    
    # Spectral features
    cent = librosa.feature.spectral_centroid(y=audio, sr=sr)
    bw = librosa.feature.spectral_bandwidth(y=audio, sr=sr)
    rolloff = librosa.feature.spectral_rolloff(y=audio, sr=sr)
    flatness = librosa.feature.spectral_flatness(y=audio)
    
    features.extend([cent.mean(), cent.std()])
    features.extend([bw.mean(), bw.std()])
    features.extend([rolloff.mean(), rolloff.std()])
    features.extend([flatness.mean(), flatness.std()])
    
    # Zero crossing rate
    zcr = librosa.feature.zero_crossing_rate(y=audio)
    features.extend([zcr.mean(), zcr.std()])
    
    # RMS energy
    rms = librosa.feature.rms(y=audio)
    features.extend([rms.mean(), rms.std()])
    
    # Tempo
    tempo, _ = librosa.beat.beat_track(y=audio, sr=sr)
    features.append(float(tempo))
    
    return np.array(features)

# Test feature extraction
test_audio = np.random.randn(SAMPLES_PER_SEGMENT)
test_features = extract_enhanced_features(test_audio, SAMPLE_RATE)
print(f"Feature vector size: {len(test_features)}")

## 3. Load Dataset with Enhanced Features

In [None]:
def load_dataset(data_path, segment_duration=3):
    """
    Load GTZAN dataset and extract enhanced features for each segment.
    Returns features, labels, and song_ids for grouping.
    """
    X, y, song_ids = [], [], []
    samples_per_segment = SAMPLE_RATE * segment_duration
    
    for genre_idx, genre in enumerate(GENRES):
        genre_path = os.path.join(data_path, genre)
        if not os.path.exists(genre_path):
            print(f"Warning: {genre_path} not found")
            continue
            
        files = sorted([f for f in os.listdir(genre_path) if f.endswith('.wav')])
        
        for filename in tqdm(files, desc=f"{genre}"):
            if 'jazz.00054' in filename:  # Skip corrupted file
                continue
                
            filepath = os.path.join(genre_path, filename)
            song_id = f"{genre}.{filename.split('.')[1]}"  # e.g., 'blues.00000'
            
            try:
                audio, sr = librosa.load(filepath, sr=SAMPLE_RATE, duration=30)
                
                # Split into segments
                num_segments = len(audio) // samples_per_segment
                
                for seg_idx in range(num_segments):
                    start = seg_idx * samples_per_segment
                    end = start + samples_per_segment
                    segment = audio[start:end]
                    
                    if len(segment) == samples_per_segment:
                        features = extract_enhanced_features(segment, sr)
                        X.append(features)
                        y.append(genre)
                        song_ids.append(song_id)
                        
            except Exception as e:
                print(f"Error: {filename}: {e}")
    
    return np.array(X), np.array(y), np.array(song_ids)

print("Loading dataset with enhanced features...")
X, y, song_ids = load_dataset(DATA_PATH)
print(f"\nDataset: {X.shape[0]} samples, {X.shape[1]} features")
print(f"Unique songs: {len(np.unique(song_ids))}")

## 4. GroupShuffleSplit by Song ID
This ensures ALL segments from a song stay in the same split, preventing data leakage.

In [None]:
# Encode labels
label_encoder = LabelEncoder()
y_encoded = label_encoder.fit_transform(y)
y_cat = to_categorical(y_encoded, NUM_CLASSES)

# Split: 70% Train, 15% Val, 15% Test (by song)
splitter_outer = GroupShuffleSplit(test_size=0.30, n_splits=1, random_state=42)
train_idx, temp_idx = next(splitter_outer.split(X, y, song_ids))

X_train, X_temp = X[train_idx], X[temp_idx]
y_train, y_temp = y_cat[train_idx], y_cat[temp_idx]
songs_temp = song_ids[temp_idx]

splitter_inner = GroupShuffleSplit(test_size=0.50, n_splits=1, random_state=42)
val_idx, test_idx = next(splitter_inner.split(X_temp, y_temp, songs_temp))

X_val, X_test = X_temp[val_idx], X_temp[test_idx]
y_val, y_test = y_temp[val_idx], y_temp[test_idx]
songs_test = songs_temp[test_idx]

print(f"Train: {X_train.shape}")
print(f"Val: {X_val.shape}")
print(f"Test: {X_test.shape}")

# Verify no overlap
train_songs = set(song_ids[train_idx])
val_songs = set(songs_temp[val_idx])
test_songs = set(songs_test)
print(f"\nSong overlap check:")
print(f"  Train-Val overlap: {len(train_songs & val_songs)}")
print(f"  Train-Test overlap: {len(train_songs & test_songs)}")
print(f"  Val-Test overlap: {len(val_songs & test_songs)}")

## 5. Standardization

In [None]:
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_val_scaled = scaler.transform(X_val)
X_test_scaled = scaler.transform(X_test)

# Reshape for LSTM: (samples, timesteps, features)
# Treat feature dimension as sequence
n_features = X_train_scaled.shape[1]
X_train_lstm = X_train_scaled.reshape(-1, n_features, 1)
X_val_lstm = X_val_scaled.reshape(-1, n_features, 1)
X_test_lstm = X_test_scaled.reshape(-1, n_features, 1)

print(f"LSTM input shape: {X_train_lstm.shape}")

## 6. Build Improved LSTM Model

In [None]:
def build_lstm_model(input_shape, num_classes):
    model = Sequential([
        Bidirectional(LSTM(128, return_sequences=True), input_shape=input_shape),
        Dropout(0.3),
        Bidirectional(LSTM(64, return_sequences=True)),
        Dropout(0.3),
        GlobalAveragePooling1D(),
        Dense(128, activation='relu'),
        BatchNormalization(),
        Dropout(0.4),
        Dense(64, activation='relu'),
        Dropout(0.3),
        Dense(num_classes, activation='softmax')
    ])
    
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )
    return model

model = build_lstm_model((n_features, 1), NUM_CLASSES)
model.summary()

## 7. Training

In [None]:
callbacks = [
    EarlyStopping(monitor='val_loss', patience=15, restore_best_weights=True),
    ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, min_lr=1e-6)
]

history = model.fit(
    X_train_lstm, y_train,
    validation_data=(X_val_lstm, y_val),
    epochs=100,
    batch_size=64,
    callbacks=callbacks,
    verbose=1
)

## 8. Evaluation

In [None]:
# Plot training history
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
axes[0].plot(history.history['accuracy'], label='Train')
axes[0].plot(history.history['val_accuracy'], label='Val')
axes[0].set_title('Accuracy')
axes[0].legend()

axes[1].plot(history.history['loss'], label='Train')
axes[1].plot(history.history['val_loss'], label='Val')
axes[1].set_title('Loss')
axes[1].legend()
plt.tight_layout()
plt.show()

# Segment-level accuracy
test_loss, test_acc = model.evaluate(X_test_lstm, y_test, verbose=0)
print(f"\nSegment Test Accuracy: {test_acc*100:.2f}%")

# Song-level accuracy (majority voting)
y_pred_prob = model.predict(X_test_lstm)
y_pred_classes = np.argmax(y_pred_prob, axis=1)
y_true_classes = np.argmax(y_test, axis=1)

results_df = pd.DataFrame({
    'song_id': songs_test,
    'true_label': y_true_classes,
    'pred_label': y_pred_classes
})

song_results = results_df.groupby('song_id').agg(lambda x: x.mode()[0])
song_acc = accuracy_score(song_results['true_label'], song_results['pred_label'])
print(f"Song-Level Accuracy: {song_acc*100:.2f}%")

print("\nClassification Report:")
print(classification_report(y_true_classes, y_pred_classes, target_names=label_encoder.classes_))

## 9. Save Model

In [None]:
model.save('../models/lstm_enhanced.keras')
print("Model saved to ../models/lstm_enhanced.keras")