# Music Genre Classification - Model Training

This notebook demonstrates the process of training a Convolutional Neural Network (CNN) for music genre classification using the GTZAN dataset.

## 1. Import Libraries

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import librosa
import librosa.display
import os
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.utils import to_categorical
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import warnings
warnings.filterwarnings('ignore')

## 2. Configuration

In [None]:
# Dataset path
DATA_PATH = '../data/raw/gtzan'

# Audio parameters
SAMPLE_RATE = 22050
DURATION = 30  # seconds
SAMPLES_PER_TRACK = SAMPLE_RATE * DURATION

# Spectrogram parameters
N_MELS = 128
N_FFT = 2048
HOP_LENGTH = 512

# Genre classes (GTZAN dataset)
GENRES = [
    'blues',
    'classical',
    'country',
    'disco',
    'hiphop',
    'jazz',
    'metal',
    'pop',
    'reggae',
    'rock'
]

## 3. Load and Process Data

In [None]:
def load_data(data_path):
    """
    Load audio files and generate spectrograms
    
    Args:
        data_path (str): Path to the dataset
        
    Returns:
        tuple: (spectrograms, labels)
    """
    spectrograms = []
    labels = []
    
    # Process each genre folder
    for genre_idx, genre in enumerate(GENRES):
        genre_path = os.path.join(data_path, genre)
        
        # Skip if folder doesn't exist
        if not os.path.exists(genre_path):
            print(f"Warning: {genre_path} does not exist. Skipping...")
            continue
        
        print(f"Processing {genre} files...")
        
        # Process each audio file in the genre folder
        for filename in os.listdir(genre_path):
            if not filename.endswith('.wav'):
                continue
                
            # Load audio file
            file_path = os.path.join(genre_path, filename)
            y, sr = librosa.load(file_path, sr=SAMPLE_RATE, mono=True)
            
            # Generate Mel spectrogram
            mel_spectrogram = librosa.feature.melspectrogram(
                y=y,
                sr=sr,
                n_fft=N_FFT,
                hop_length=HOP_LENGTH,
                n_mels=N_MELS
            )
            
            # Convert to dB scale
            mel_spectrogram_db = librosa.power_to_db(mel_spectrogram, ref=np.max)
            
            # Add to dataset
            spectrograms.append(mel_spectrogram_db)
            labels.append(genre_idx)
    
    return np.array(spectrograms), np.array(labels)

In [None]:
# Load data
# Note: This will take some time to process all audio files
# spectrograms, labels = load_data(DATA_PATH)

# For demonstration purposes, we'll create dummy data
# In a real implementation, you would use the load_data function above

# Create dummy spectrograms (10 samples per genre)
num_samples = 10 * len(GENRES)
dummy_spectrograms = np.random.rand(num_samples, N_MELS, 1292)  # 1292 is an example time dimension
dummy_labels = np.repeat(np.arange(len(GENRES)), 10)

spectrograms = dummy_spectrograms
labels = dummy_labels

print(f"Loaded {len(spectrograms)} spectrograms with shape {spectrograms[0].shape}")

## 4. Prepare Data for Training

In [None]:
# Add channel dimension for CNN
X = spectrograms[..., np.newaxis]

# Convert labels to one-hot encoding
y = to_categorical(labels, num_classes=len(GENRES))

# Split data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

print(f"Training set: {X_train.shape}")
print(f"Testing set: {X_test.shape}")

## 5. Build CNN Model

In [None]:
def build_model(input_shape):
    """
    Build a CNN model for music genre classification
    
    Args:
        input_shape (tuple): Shape of input data
        
    Returns:
        tf.keras.Model: Compiled model
    """
    model = models.Sequential([
        # First convolutional block
        layers.Conv2D(32, (3, 3), activation='relu', padding='same', input_shape=input_shape),
        layers.MaxPooling2D((2, 2)),
        layers.BatchNormalization(),
        
        # Second convolutional block
        layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
        layers.MaxPooling2D((2, 2)),
        layers.BatchNormalization(),
        
        # Third convolutional block
        layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
        layers.MaxPooling2D((2, 2)),
        layers.BatchNormalization(),
        
        # Flatten and dense layers
        layers.Flatten(),
        layers.Dense(256, activation='relu'),
        layers.Dropout(0.5),
        layers.Dense(len(GENRES), activation='softmax')
    ])
    
    # Compile model
    model.compile(
        optimizer='adam',
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )
    
    return model

In [None]:
# Build model
input_shape = X_train[0].shape
model = build_model(input_shape)

# Print model summary
model.summary()

## 6. Train Model

In [None]:
# Define callbacks
callbacks = [
    tf.keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True),
    tf.keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=2)
]

# Train model
history = model.fit(
    X_train, y_train,
    epochs=30,
    batch_size=32,
    validation_split=0.2,
    callbacks=callbacks
)

## 7. Evaluate Model

In [None]:
# Plot training history
plt.figure(figsize=(12, 4))

# Plot accuracy
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Training')
plt.plot(history.history['val_accuracy'], label='Validation')
plt.title('Model Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

# Plot loss
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Training')
plt.plot(history.history['val_loss'], label='Validation')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.tight_layout()
plt.show()

In [None]:
# Evaluate on test set
test_loss, test_accuracy = model.evaluate(X_test, y_test)
print(f"Test accuracy: {test_accuracy:.4f}")

## 8. Save Model

In [None]:
# Create model directory if it doesn't exist
model_dir = '../model'
os.makedirs(model_dir, exist_ok=True)

# Save model
model_path = os.path.join(model_dir, 'genre_classifier_model.h5')
model.save(model_path)
print(f"Model saved to {model_path}")

## 9. Test Prediction

In [None]:
def predict_genre(model, spectrogram):
    """
    Predict genre from spectrogram
    
    Args:
        model (tf.keras.Model): Trained model
        spectrogram (numpy.ndarray): Mel spectrogram
        
    Returns:
        tuple: (predicted_genre, confidence_scores)
    """
    # Add batch and channel dimensions
    spectrogram = spectrogram.reshape(1, *spectrogram.shape, 1)
    
    # Make prediction
    prediction = model.predict(spectrogram)[0]
    
    # Get predicted genre and confidence
    predicted_index = np.argmax(prediction)
    predicted_genre = GENRES[predicted_index]
    confidence = prediction[predicted_index]
    
    # Get confidence scores for all genres
    confidence_scores = {genre: float(score) for genre, score in zip(GENRES, prediction)}
    
    return predicted_genre, confidence_scores

In [None]:
# Test prediction on a sample from the test set
sample_index = np.random.randint(0, len(X_test))
sample_spectrogram = X_test[sample_index, :, :, 0]  # Remove channel dimension
true_genre = GENRES[np.argmax(y_test[sample_index])]

# Predict genre
predicted_genre, confidence_scores = predict_genre(model, sample_spectrogram)

print(f"True genre: {true_genre}")
print(f"Predicted genre: {predicted_genre}")
print("\nConfidence scores:")
for genre, score in sorted(confidence_scores.items(), key=lambda x: x[1], reverse=True):
    print(f"{genre}: {score:.4f}")

# Plot spectrogram
plt.figure(figsize=(10, 4))
librosa.display.specshow(
    sample_spectrogram,
    sr=SAMPLE_RATE,
    hop_length=HOP_LENGTH,
    x_axis='time',
    y_axis='mel'
)
plt.colorbar(format='%+2.0f dB')
plt.title(f"Mel Spectrogram - True: {true_genre}, Predicted: {predicted_genre}")
plt.tight_layout()
plt.show()