# Imports

In [None]:
import os
import numpy as np
import pandas as pd
import librosa
import matplotlib.pyplot as plt
import tensorflow as tf
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dropout, Flatten, Dense, BatchNormalization
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau

# Define genre EQ profiles

In [None]:
# Predefine equalization profiles for each genre.
# Each profile is an array of 10 values (gains in decibels) corresponding to:
# [32, 64, 125, 250, 500, 1000, 2000, 4000, 8000, 16000 Hz]
genre_profiles = {
    'blues':     np.array([-1,  0,  2,  2,  1,  0, -1,  0,  1,  1]),
    'classical': np.array([ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0]),
    'country':   np.array([ 0,  1,  1,  2,  1,  0,  0,  0,  0,  0]),
    'disco':     np.array([ 2,  3,  2,  1,  0,  0,  1,  2,  3,  2]),
    'hiphop':    np.array([ 3,  4,  2,  0, -1, -1,  0,  1,  1,  0]),
    'jazz':      np.array([ 0,  1,  1,  2,  2,  1,  0,  0,  1,  0]),
    'metal':     np.array([ 2,  3,  0, -3, -4, -3,  0,  3,  3,  2]),
    'pop':       np.array([ 0,  1,  2,  2,  1,  0,  1,  1,  2,  2]),
    'reggae':    np.array([ 0,  1,  0,  0, -1, -1,  0,  0,  1,  0]),
    'rock':      np.array([ 1,  1,  0,  0,  1,  1,  0,  0,  1,  1])
}

# Ensure this order matches the model's output order.
genres = ['blues', 'classical', 'country', 'disco', 'hiphop', 'jazz', 'metal', 'pop', 'reggae', 'rock']

def weighted_eq_profile(predictions, genre_list=genres, profiles=genre_profiles):
    """
    Given a predictions vector (softmax output over genres), compute a weighted
    equalization profile based on the top three predicted genres.
    
    Parameters:
      predictions: numpy array of shape (10,) with probabilities for each genre.
      genre_list: list of genre names in the order corresponding to predictions.
      profiles: dictionary mapping genre names to their equalization profile (numpy array).
      
    Returns:
      weighted_profile: numpy array of shape (10,) representing the weighted EQ profile.
    """
    # Get indices of the top 3 predicted genres.
    top3_idx = predictions.argsort()[-3:][::-1]
    top3_probs = predictions[top3_idx]
    
    # Normalize the top 3 probabilities so they sum to 1.
    weight_sum = np.sum(top3_probs)
    if weight_sum == 0:
        norm_weights = np.ones_like(top3_probs) / 3
    else:
        norm_weights = top3_probs / weight_sum
    
    # Compute the weighted combination of the equalization profiles.
    weighted_profile = np.zeros_like(profiles[genre_list[0]], dtype=float)
    for i, idx in enumerate(top3_idx):
        genre = genre_list[idx]
        profile = profiles[genre]
        weighted_profile += norm_weights[i] * profile
    
    return weighted_profile

# Define or get CNN

In [None]:
model_save_path_keras = "./saved_models/CONEqNet.keras"
model_save_path_h5 = "./saved_models/CONEqNet.h5"

if os.path.exists(model_save_path_keras):
    model = tf.keras.models.load_model(model_save_path_keras)
    print(f"Loaded existing model from {model_save_path_keras}")
elif os.path.exists(model_save_path_h5):
    model = tf.keras.models.load_model(model_save_path_h5)
    print(f"Loaded existing model from {model_save_path_h5}")
else:
    # 600 time steps, 13 MFCC coefficients, 1 channel (grayscale-like image)
    print("No saved model found, creating a new one...")
    input_shape = (600, 13, 1)

    model = Sequential()

    # Low-Level Features (Input):
    model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=input_shape))
    model.add(BatchNormalization())
    model.add(MaxPooling2D(pool_size=(2, 2)))

    # Mid-Level Features (Convolutional Layers):
    model.add(Conv2D(64, kernel_size=(3, 3), activation='relu'))
    model.add(BatchNormalization())
    model.add(MaxPooling2D(pool_size=(2, 2)))

    # Upper-Level features (fully connected layers and output):
    model.add(Flatten())
    model.add(Dense(128, activation='relu'))
    model.add(Dropout(0.5))
    model.add(Dense(10, activation='softmax'))
    
    optimizer = Adam(learning_rate=0.001)
    
    model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
    print("New model successfully compiled.")
    model.summary()

# Training

### Prepare Data

In [None]:
# Define the paths (note: genres are inside a subfolder per genre)
data_dir = './Data'
audio_dir = os.path.normpath(os.path.join(data_dir, 'genres_original'))
csv_path = os.path.normpath(os.path.join(data_dir, 'features_30_sec.csv'))

df_features = pd.read_csv(csv_path)

df_features['filepath'] = df_features.apply(lambda row: os.path.join(audio_dir, row['label'], row['filename']), axis=1)
print("\nAdded filepath to dataframe features.")

genres = ['blues', 'classical', 'country', 'disco', 'hiphop', 'jazz', 'metal', 'pop', 'reggae', 'rock']
label_to_index = {genre: i for i, genre in enumerate(genres)}

def extract_mfcc(file_path, n_mfcc=13, hop_length=1103):
    y, sr = librosa.load(file_path, duration=30)
        
    # the resulting shape is (n_mfcc, n_frames)
    mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=n_mfcc, hop_length=hop_length)
    return mfcc

# extract MFCCs for each audio file
def prepare_data(df, target_frames=600, n_mfcc=13, hop_length=1103):
    X = []
    y = []
    for _, row in tqdm(df.iterrows(), total=len(df)):
        try:
            mfcc = extract_mfcc(row['filepath'], n_mfcc=n_mfcc, hop_length=hop_length)
            mfcc = mfcc.T
            # pad/truncate
            if mfcc.shape[0] < target_frames:
                pad_width = target_frames - mfcc.shape[0]
                mfcc = np.pad(mfcc, ((0, pad_width), (0, 0)), mode='constant')
            elif mfcc.shape[0] > target_frames:
                mfcc = mfcc[:target_frames, :]
                
            mfcc = mfcc[..., np.newaxis]
            X.append(mfcc)
            y.append(label_to_index[row['label']])
        except Exception as e:
            print("Error processing {}: {}".format(row['filepath'], e))
    return np.array(X), np.array(y)

 ### Make tensorflow use GPU

In [None]:
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print("GPU detected and enabled:", gpus)
    except RuntimeError as e:
        print("Error enabling GPU:", e)
else:
    print("No GPU detected. Running on CPU.")


### Train test split and data prep

In [None]:
train_df, val_df = train_test_split(df_features, test_size=0.2, stratify=df_features['label'], random_state=42)
print("Train set shape:", train_df.shape)
print("Validation set shape:", val_df.shape)

# Prepare training and validation data
X_train, y_train = prepare_data(train_df)
X_val, y_val = prepare_data(val_df)

# One-hot encode the labels
y_train_cat = to_categorical(y_train, num_classes=len(genres))
y_val_cat = to_categorical(y_val, num_classes=len(genres))

### Train

In [None]:
print("X_train shape:", X_train.shape)
print("y_train shape:", y_train_cat.shape)
print("X_val shape:", X_val.shape)
print("y_val shape:", y_val_cat.shape)

# Training details:
# - Batch size: 64
# - Epochs: 100 000 epochs
# - Learning rate: 0.001 via the Adam optimizer
# - Loss: Categorical Crossentropy

early_stopping = EarlyStopping(monitor='val_loss', patience=40, restore_best_weights=True)
reduce_lr = ReduceLROnPlateau(
    monitor='val_loss', 
    factor=0.97,
    patience=3,
    verbose=1,
    mode='min',
    min_delta=0.0001,
    cooldown=2,
    min_lr=1e-10
)

history = model.fit(
    X_train, y_train_cat,
    batch_size=256,
    epochs=100000,
    validation_data=(X_val, y_val_cat),
    callbacks=[early_stopping, reduce_lr]
)

# Print final training and validation accuracy
final_train_acc = history.history['accuracy'][-1]
final_val_acc = history.history['val_accuracy'][-1]
print("\nFinal Training Accuracy: {:.2f}%".format(final_train_acc * 100))
print("Final Validation Accuracy: {:.2f}%".format(final_val_acc * 100))

### Save model

In [None]:
# Define the save paths
model_save_path_keras = "./saved_models/CONEqNet.keras"
model_save_path_h5 = "./saved_models/CONEqNet.h5"

# Create the directory if it doesn't exist
os.makedirs(os.path.dirname(model_save_path_keras), exist_ok=True)

# Save the model in TensorFlow's recommended format (.keras)
model.save(model_save_path_keras)
print(f"Model saved successfully to {model_save_path_keras}")

# Save the model in HDF5 format (.h5) for compatibility with older Keras versions
model.save(model_save_path_h5)
print(f"Model saved successfully to {model_save_path_h5}")

# Classify a song

In [None]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import load_model
import librosa

def extract_mfcc_from_segment(segment, sr, n_mfcc=13, hop_length=1103):
    """
    Extract MFCC features from an audio segment.
    Returns an array of shape (n_mfcc, n_frames).
    """
    mfcc = librosa.feature.mfcc(y=segment, sr=sr, n_mfcc=n_mfcc, hop_length=hop_length)
    return mfcc

def process_segment(segment, sr, target_frames=600, n_mfcc=13, hop_length=1103):
    """
    Process an audio segment to extract MFCC features and prepare them for prediction.
    Steps:
      - Extract MFCC features.
      - Transpose to shape (n_frames, n_mfcc).
      - Pad or truncate to exactly target_frames.
      - Add channel dimension, then batch dimension.
    """
    mfcc = extract_mfcc_from_segment(segment, sr, n_mfcc, hop_length)
    mfcc = mfcc.T  # now shape is (n_frames, n_mfcc)
    
    # Pad or truncate to ensure exactly target_frames are present
    if mfcc.shape[0] < target_frames:
        pad_width = target_frames - mfcc.shape[0]
        mfcc = np.pad(mfcc, ((0, pad_width), (0, 0)), mode='constant')
    elif mfcc.shape[0] > target_frames:
        mfcc = mfcc[:target_frames, :]
    
    # Add channel dimension -> (target_frames, n_mfcc, 1)
    mfcc = mfcc[..., np.newaxis]
    # Add batch dimension -> (1, target_frames, n_mfcc, 1)
    return np.expand_dims(mfcc, axis=0)

# --- Load the saved model ---
model_path = "./saved_models/CONEqNet.keras"
if os.path.exists(model_path):
    model = load_model(model_path)
    print(f"Loaded model from {model_path}")
else:
    raise FileNotFoundError("Model file not found. Please check the path.")

def predict_long_audio(file_path, segment_duration=30, overlap=15,
                       target_frames=600, n_mfcc=13, hop_length=1103):
    """
    Splits a long audio file into overlapping segments,
    processes each segment to extract MFCC features,
    and predicts the genre for each segment.
    Finally, averages the predictions to output an overall genre.
    """
    # Load the full audio file (with original sample rate)
    y, sr = librosa.load(file_path, sr=None)
    segment_samples = int(segment_duration * sr)
    overlap_samples = int(overlap * sr)
    step_samples = segment_samples - overlap_samples
    
    predictions = []
    segment_count = 0
    
    # Iterate through the audio with the given step size
    for start in range(0, len(y) - segment_samples + 1, step_samples):
        segment = y[start:start+segment_samples]
        x_input = process_segment(segment, sr, target_frames, n_mfcc, hop_length)
        pred = model.predict(x_input)
        predictions.append(pred)
        segment_count += 1
        print(f"Processed segment {segment_count}")
    
    # Convert predictions to numpy array and average over segments
    predictions = np.array(predictions)  # shape: (num_segments, 1, num_classes)
    avg_prediction = np.mean(predictions, axis=0)  # shape: (1, num_classes)
    overall_index = np.argmax(avg_prediction)
    
    return overall_index, avg_prediction, predictions

# List of genres (used to map prediction indices)
genres = ['blues', 'classical', 'country', 'disco', 'hiphop',
          'jazz', 'metal', 'pop', 'reggae', 'rock']

# Provide the path to your audio file (mp3, wav, etc.)
audio_file = "./music/rammstein-deuschland.mp3"  # Replace with your actual file path

# Predict the overall genre for the long audio file
overall_index, avg_prediction, segment_predictions = predict_long_audio(
    audio_file,
    segment_duration=30,  # duration of each segment in seconds
    overlap=15            # overlap in seconds between segments
)

overall_genre = genres[overall_index]
print(f"\nOverall Predicted Genre: {overall_genre}")

# Top 3 genres based on average prediction
top3 = np.argsort(avg_prediction[0])[::-1][:3]
print("Top 3 genres overall:", [genres[i] for i in top3])


# Equalize a song

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import librosa
import librosa.display
import soundfile as sf
import tensorflow as tf
import IPython.display as ipd
from scipy import signal
from tqdm import tqdm  # For progress bars

# Step 1: Load the audio file
def load_audio(file_path, sr=44100):
    """
    Load an audio file and return the audio time series and sampling rate.
    
    Parameters:
    file_path (str): Path to the audio file
    sr (int): Target sampling rate
    
    Returns:
    tuple: (audio_time_series, sampling_rate)
    """
    try:
        audio, sr = librosa.load(file_path, sr=sr)
        print(f"Successfully loaded audio: {len(audio)/sr:.2f} seconds at {sr} Hz")
        return audio, sr
    except Exception as e:
        print(f"Error loading audio file: {e}")
        # Return a short silent audio segment as fallback
        return np.zeros(sr), sr

# Step 2: Visualize the waveform and spectrogram
def visualize_audio(audio, sr, title="Original Audio"):
    """
    Visualize the waveform and spectrogram of an audio signal.
    
    Parameters:
    audio (numpy.ndarray): Audio time series
    sr (int): Sampling rate
    title (str): Title for the plot
    """
    # If audio is longer than 30 seconds, only visualize the first 30 seconds
    if len(audio) > 30 * sr:
        print(f"Audio is longer than 30 seconds, visualizing only the first 30 seconds")
        audio_to_viz = audio[:30 * sr]
    else:
        audio_to_viz = audio
    
    plt.figure(figsize=(14, 6))
    
    # Plot waveform
    plt.subplot(2, 1, 1)
    librosa.display.waveshow(audio_to_viz, sr=sr)
    plt.title(f"{title} - Waveform")
    plt.xlabel("Time (s)")
    plt.ylabel("Amplitude")
    
    # Plot spectrogram
    plt.subplot(2, 1, 2)
    D = librosa.amplitude_to_db(np.abs(librosa.stft(audio_to_viz)), ref=np.max)
    librosa.display.specshow(D, sr=sr, x_axis='time', y_axis='log')
    plt.colorbar(format='%+2.0f dB')
    plt.title(f"{title} - Spectrogram")
    plt.tight_layout()
    plt.show()

# Step 3: Create an equalization profile
def create_eq_profile(num_bands=10, min_freq=20, max_freq=22050, sr=44100):
    """
    Create an equalization profile with specified number of bands.
    
    Parameters:
    num_bands (int): Number of frequency bands
    min_freq (int): Minimum frequency (Hz)
    max_freq (int): Maximum frequency (Hz)
    sr (int): Sampling rate
    
    Returns:
    tuple: (center_frequencies, gains)
    """
    # Ensure max_freq doesn't exceed Nyquist frequency
    max_freq = min(max_freq, sr // 2 - 1)
    
    # Create logarithmically spaced center frequencies
    center_freqs = np.logspace(np.log10(min_freq), np.log10(max_freq), num_bands)
    
    # Initialize gains to 0 dB (no change)
    gains = np.zeros(num_bands)
    
    return center_freqs, gains

# Step 4: Display and modify the equalization profile
def set_eq_gains(center_freqs, gains):
    """
    Set custom gain values for specific frequency bands.
    
    Parameters:
    center_freqs (numpy.ndarray): Center frequencies
    gains (numpy.ndarray): Gain values in dB
    
    Returns:
    numpy.ndarray: Updated gain values
    """
    # Display the current center frequencies and gains
    for i, (freq, gain) in enumerate(zip(center_freqs, gains)):
        print(f"Band {i+1}: {freq:.1f} Hz = {gain:.1f} dB")
    
    # Example: Update gains for specific bands (modify as needed)
    return gains

# Step 5: Apply equalization using FFT-based approach
def apply_eq_fft(audio, sr, center_freqs, gains):
    """
    Apply equalization using FFT-based method.
    
    Parameters:
    audio (numpy.ndarray): Audio time series
    sr (int): Sampling rate
    center_freqs (numpy.ndarray): Center frequencies
    gains (numpy.ndarray): Gain values in dB
    
    Returns:
    numpy.ndarray: Equalized audio
    """
    n_fft = 2048  # Next power of 2 for efficient FFT
    stft = librosa.stft(audio, n_fft=n_fft)
    magnitude, phase = librosa.magphase(stft)
    freq_bins = librosa.fft_frequencies(sr=sr, n_fft=n_fft)
    
    # Create an equalization curve (default no change)
    eq_curve = np.ones_like(freq_bins)
    
    for center_freq, gain_db in zip(center_freqs, gains):
        gain_linear = 10 ** (gain_db / 20)
        lower_freq = center_freq / np.sqrt(2)
        upper_freq = center_freq * np.sqrt(2)
        band_indices = np.where((freq_bins >= lower_freq) & (freq_bins <= upper_freq))
        eq_curve[band_indices] *= gain_linear
    
    eq_magnitude = magnitude * eq_curve[:, np.newaxis]
    eq_stft = eq_magnitude * phase
    eq_audio = librosa.istft(eq_stft, length=len(audio))
    return eq_audio

# Step 6: Process audio in segments
def process_audio_in_segments(audio, sr, center_freqs, gains, segment_duration=30, overlap=3):
    """
    Process audio in segments with overlap to avoid artifacts at boundaries.
    
    Parameters:
    audio (numpy.ndarray): Audio time series
    sr (int): Sampling rate
    center_freqs (numpy.ndarray): Center frequencies for EQ
    gains (numpy.ndarray): Gain values to apply (same for all segments)
    segment_duration (int): Duration of each segment in seconds
    overlap (int): Overlap between segments in seconds
    
    Returns:
    numpy.ndarray: Processed audio
    """
    segment_length = int(segment_duration * sr)
    overlap_length = int(overlap * sr)
    hop_length = segment_length - overlap_length
    total_samples = len(audio)
    num_segments = int(np.ceil((total_samples - overlap_length) / hop_length)) + 1
    
    processed_audio = np.zeros_like(audio)
    normalization = np.zeros_like(audio)
    
    fade_len = overlap_length
    fade_in = np.linspace(0, 1, fade_len)
    fade_out = np.linspace(1, 0, fade_len)
    
    print(f"Processing audio in {num_segments} segments...")
    for i in tqdm(range(num_segments)):
        start_idx = i * hop_length
        end_idx = min(start_idx + segment_length, total_samples)
        if start_idx >= total_samples:
            break
        segment = audio[start_idx:end_idx]
        processed_segment = apply_eq_fft(segment, sr, center_freqs, gains)
        
        window = np.ones(len(segment))
        if i > 0:
            window[:fade_len] = fade_in
        if i < num_segments - 1 and len(segment) >= fade_len:
            window[-fade_len:] = fade_out
        
        processed_segment *= window
        processed_audio[start_idx:end_idx] += processed_segment
        normalization[start_idx:end_idx] += window
    
    mask = normalization > 0
    processed_audio[mask] /= normalization[mask]
    return processed_audio

# Step 7: Play and save audio
def play_audio(audio, sr, title="Audio"):
    """
    Play audio in the Jupyter notebook.
    
    Parameters:
    audio (numpy.ndarray): Audio time series
    sr (int): Sampling rate
    title (str): Title for the audio player
    """
    print(f"Playing {title}")
    return ipd.Audio(audio, rate=sr)

def save_audio(audio, sr, file_path):
    """
    Save audio to file.
    
    Parameters:
    audio (numpy.ndarray): Audio time series
    sr (int): Sampling rate
    file_path (str): Output file path
    """
    max_amplitude = np.max(np.abs(audio))
    normalized_audio = audio / max_amplitude * 0.95 if max_amplitude > 0 else audio
    sf.write(file_path, normalized_audio, sr)
    print(f"Saved equalized audio to {file_path}")

# Step 8: Define some preset equalization profiles
def apply_eq_preset(preset_name):
    """
    Return gain values for common EQ presets.
    
    Parameters:
    preset_name (str): Name of the preset ('bass_boost', 'vocal_enhance', etc.)
    
    Returns:
    numpy.ndarray: Gain values for 10-band EQ
    """
    presets = {
        'flat': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        'bass_boost': [7, 5, 3, 1, 0, 0, 0, 0, 0, 0],
        'bass_cut': [-7, -5, -3, -1, 0, 0, 0, 0, 0, 0],
        'treble_boost': [0, 0, 0, 0, 0, 0, 1, 3, 5, 7],
        'treble_cut': [0, 0, 0, 0, 0, 0, -1, -3, -5, -7],
        'vocal_enhance': [-2, -2, 0, 2, 4, 3, 2, 1, 0, -1],
        'v_shape': [4, 2, 0, -2, -3, -3, -2, 0, 2, 4],
        'loudness': [5, 4, 2, 0, -1, -1, 0, 2, 4, 5],
        'crazy': [20, 15, 10, 5, 2, 1, 0, 0, 0, 0]
    }
    
    if preset_name in presets:
        return np.array(presets[preset_name])
    else:
        print(f"Preset '{preset_name}' not found. Available presets: {list(presets.keys())}")
        return np.zeros(10)
try:
    # Step 1: Load or generate audio
    file_path = "./music/wtc-cream.mp3"  # Change to your file path
    try:
        audio, sr = load_audio(file_path)
    except:
        print(f"Couldn't load {file_path}")
    
    # Step 2: Visualize original audio (first 30 seconds)
    visualize_audio(audio, sr, "Original Audio")
    
    # Step 3: Create equalization profile
    center_freqs, gains = create_eq_profile(num_bands=10, sr=sr)
    
    # Step 4: Apply an EQ preset (change preset name as desired)
    gains = apply_eq_preset('crazy')
    print("Using EQ preset 'bass_boost' with gains:", gains)
    
    # Step 5: Process audio in segments (30s segments with 3s overlap)
    processed_audio = process_audio_in_segments(audio, sr, center_freqs, gains, segment_duration=30, overlap=3)
    
    # Step 6: Visualize the equalized audio
    visualize_audio(processed_audio, sr, "Equalized Audio")
    
    # Step 7: Play original and equalized audio
    print("Playing Original Audio:")
    ipd.display(play_audio(audio, sr, "Original Audio"))
    print("Playing Equalized Audio:")
    ipd.display(play_audio(processed_audio, sr, "Equalized Audio"))
    
    # Step 8: Save the equalized audio to file
    output_path = "./music/wtc-cream_equalized.mp3"
    save_audio(processed_audio, sr, output_path)

except Exception as e:
    print("An error occurred in the processing pipeline:", e)

# EQ 2

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import librosa
import librosa.display
import soundfile as sf
import tensorflow as tf
import IPython.display as ipd
from scipy import signal
from tqdm import tqdm  # For progress bars

# Step 1: Load the audio file
def load_audio(file_path, sr=44100):
    """
    Load an audio file and return the audio time series and sampling rate.
    
    Parameters:
    file_path (str): Path to the audio file
    sr (int): Target sampling rate
    
    Returns:
    tuple: (audio_time_series, sampling_rate)
    """
    try:
        audio, sr = librosa.load(file_path, sr=sr)
        print(f"Successfully loaded audio: {len(audio)/sr:.2f} seconds at {sr} Hz")
        return audio, sr
    except Exception as e:
        print(f"Error loading audio file: {e}")
        # Return a short silent audio segment as fallback
        return np.zeros(sr), sr
    
def extract_mfcc_from_segment(segment, sr, n_mfcc=13, hop_length=1103):
    """
    Extract MFCC features from an audio segment.
    Returns an array of shape (n_mfcc, n_frames).
    """
    mfcc = librosa.feature.mfcc(y=segment, sr=sr, n_mfcc=n_mfcc, hop_length=hop_length)
    return mfcc

# Step 2: Visualize the waveform and spectrogram
def visualize_audio(audio, sr, title="Original Audio"):
    """
    Visualize the waveform and spectrogram of an audio signal.
    
    Parameters:
    audio (numpy.ndarray): Audio time series
    sr (int): Sampling rate
    title (str): Title for the plot
    """
    # If audio is longer than 30 seconds, only visualize the first 30 seconds
    if len(audio) > 30 * sr:
        print(f"Audio is longer than 30 seconds, visualizing only the first 30 seconds")
        audio_to_viz = audio[:30 * sr]
    else:
        audio_to_viz = audio
    
    plt.figure(figsize=(14, 6))
    
    # Plot waveform
    plt.subplot(2, 1, 1)
    librosa.display.waveshow(audio_to_viz, sr=sr)
    plt.title(f"{title} - Waveform")
    plt.xlabel("Time (s)")
    plt.ylabel("Amplitude")
    
    # Plot spectrogram
    plt.subplot(2, 1, 2)
    D = librosa.amplitude_to_db(np.abs(librosa.stft(audio_to_viz)), ref=np.max)
    librosa.display.specshow(D, sr=sr, x_axis='time', y_axis='log')
    plt.colorbar(format='%+2.0f dB')
    plt.title(f"{title} - Spectrogram")
    plt.tight_layout()
    plt.show()

# Step 3: Create an equalization profile
def create_eq_profile(num_bands=10, min_freq=20, max_freq=22050, sr=44100):
    """
    Create an equalization profile with specified number of bands.
    
    Parameters:
    num_bands (int): Number of frequency bands
    min_freq (int): Minimum frequency (Hz)
    max_freq (int): Maximum frequency (Hz)
    sr (int): Sampling rate
    
    Returns:
    tuple: (center_frequencies, gains)
    """
    # Ensure max_freq doesn't exceed Nyquist frequency
    max_freq = min(max_freq, sr // 2 - 1)
    
    # Create logarithmically spaced center frequencies
    center_freqs = np.logspace(np.log10(min_freq), np.log10(max_freq), num_bands)
    
    # Initialize gains to 0 dB (no change)
    gains = np.zeros(num_bands)
    
    return center_freqs, gains

# Step 4: Display and modify the equalization profile
def set_eq_gains(center_freqs, gains):
    """
    Set custom gain values for specific frequency bands.
    
    Parameters:
    center_freqs (numpy.ndarray): Center frequencies
    gains (numpy.ndarray): Gain values in dB
    
    Returns:
    numpy.ndarray: Updated gain values
    """
    # Display the current center frequencies and gains
    for i, (freq, gain) in enumerate(zip(center_freqs, gains)):
        print(f"Band {i+1}: {freq:.1f} Hz = {gain:.1f} dB")
    
    # Example: Update gains for specific bands (modify as needed)
    return gains

# Step 5: Apply equalization using FFT-based approach
def apply_eq_fft(audio, sr, center_freqs, gains):
    """
    Apply equalization using FFT-based method.
    
    Parameters:
    audio (numpy.ndarray): Audio time series
    sr (int): Sampling rate
    center_freqs (numpy.ndarray): Center frequencies
    gains (numpy.ndarray): Gain values in dB
    
    Returns:
    numpy.ndarray: Equalized audio
    """
    n_fft = 2048  # Next power of 2 for efficient FFT
    stft = librosa.stft(audio, n_fft=n_fft)
    magnitude, phase = librosa.magphase(stft)
    freq_bins = librosa.fft_frequencies(sr=sr, n_fft=n_fft)
    
    # Create an equalization curve (default no change)
    eq_curve = np.ones_like(freq_bins)
    
    for center_freq, gain_db in zip(center_freqs, gains):
        gain_linear = 10 ** (gain_db / 20)
        lower_freq = center_freq / np.sqrt(2)
        upper_freq = center_freq * np.sqrt(2)
        band_indices = np.where((freq_bins >= lower_freq) & (freq_bins <= upper_freq))
        eq_curve[band_indices] *= gain_linear
    
    eq_magnitude = magnitude * eq_curve[:, np.newaxis]
    eq_stft = eq_magnitude * phase
    eq_audio = librosa.istft(eq_stft, length=len(audio))
    return eq_audio

# Step 6: Process audio in segments
def process_audio_in_segments(audio, sr, center_freqs, segment_duration=30, overlap=3, intensity=1.0):
    """
    Process audio in segments with overlap to avoid artifacts at boundaries.
    
    Parameters:
    audio (numpy.ndarray): Audio time series
    sr (int): Sampling rate
    center_freqs (numpy.ndarray): Center frequencies for EQ
    gains (numpy.ndarray): Gain values to apply (same for all segments)
    segment_duration (int): Duration of each segment in seconds
    overlap (int): Overlap between segments in seconds
    
    Returns:
    numpy.ndarray: Processed audio
    """
    segment_length = int(segment_duration * sr)
    overlap_length = int(overlap * sr)
    hop_length = segment_length - overlap_length
    total_samples = len(audio)
    num_segments = int(np.ceil((total_samples - overlap_length) / hop_length)) + 1
    
    processed_audio = np.zeros_like(audio)
    normalization = np.zeros_like(audio)
    
    fade_len = overlap_length
    fade_in = np.linspace(0, 1, fade_len)
    fade_out = np.linspace(1, 0, fade_len)

    genres = ['blues', 'classical', 'country', 'disco', 'hiphop',
                 'jazz', 'metal', 'pop', 'reggae', 'rock']
    
    print(f"Processing audio in {num_segments} segments...")
    for i in tqdm(range(num_segments)):
        start_idx = i * hop_length
        end_idx = min(start_idx + segment_length, total_samples)
        if start_idx >= total_samples:
            break

        segment = audio[start_idx:end_idx]
        x_input = process_segment(segment, sr, target_frames=600, n_mfcc=13, hop_length=1103)
        pred = model.predict(x_input)
        genre_index = np.argmax(pred)
        genre_preset = genres[genre_index]
        print(f"Predicted genre for segment {i+1}: {genre_preset}")

        gains = apply_eq_preset(genre_preset, intensity=intensity)
        processed_segment = apply_eq_fft(segment, sr, center_freqs, gains)
        
        window = np.ones(len(segment))
        if i > 0:
            window[:fade_len] = fade_in
        if i < num_segments - 1 and len(segment) >= fade_len:
            window[-fade_len:] = fade_out
        
        processed_segment *= window
        processed_audio[start_idx:end_idx] += processed_segment
        normalization[start_idx:end_idx] += window
    
    mask = normalization > 0
    processed_audio[mask] /= normalization[mask]
    return processed_audio

def process_segment(segment, sr, target_frames=600, n_mfcc=13, hop_length=1103):
    """
    Process an audio segment to extract MFCC features and prepare them for prediction.
    Steps:
      - Extract MFCC features.
      - Transpose to shape (n_frames, n_mfcc).
      - Pad or truncate to exactly target_frames.
      - Add channel dimension, then batch dimension.
    """
    mfcc = extract_mfcc_from_segment(segment, sr, n_mfcc, hop_length)
    mfcc = mfcc.T  # now shape is (n_frames, n_mfcc)
    
    # Pad or truncate to ensure exactly target_frames are present
    if mfcc.shape[0] < target_frames:
        pad_width = target_frames - mfcc.shape[0]
        mfcc = np.pad(mfcc, ((0, pad_width), (0, 0)), mode='constant')
    elif mfcc.shape[0] > target_frames:
        mfcc = mfcc[:target_frames, :]
    
    # Add channel dimension -> (target_frames, n_mfcc, 1)
    mfcc = mfcc[..., np.newaxis]
    # Add batch dimension -> (1, target_frames, n_mfcc, 1)
    return np.expand_dims(mfcc, axis=0)

# Step 7: Play and save audio
def play_audio(audio, sr, title="Audio"):
    """
    Play audio in the Jupyter notebook.
    
    Parameters:
    audio (numpy.ndarray): Audio time series
    sr (int): Sampling rate
    title (str): Title for the audio player
    """
    print(f"Playing {title}")
    return ipd.Audio(audio, rate=sr)

def save_audio(audio, sr, file_path):
    """
    Save audio to file.
    
    Parameters:
    audio (numpy.ndarray): Audio time series
    sr (int): Sampling rate
    file_path (str): Output file path
    """
    max_amplitude = np.max(np.abs(audio))
    normalized_audio = audio / max_amplitude * 0.95 if max_amplitude > 0 else audio
    sf.write(file_path, normalized_audio, sr)
    print(f"Saved equalized audio to {file_path}")

# Step 8: Define some preset equalization profiles
def apply_eq_preset(preset_name, intensity=1.0):
    """
    Return gain values for common EQ presets.
    
    Parameters:
    preset_name (str): Name of the preset ('bass_boost', 'vocal_enhance', etc.)
    
    Returns:
    numpy.ndarray: Gain values for 10-band EQ
    """
    genre_presets = {
    'blues':     [-1,  0,  2,  2,  1,  0, -1,  0,  1,  1],
    'classical': [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
    'country':   [ 0,  1,  1,  2,  1,  0,  0,  0,  0,  0],
    'disco':     [ 2,  3,  2,  1,  0,  0,  1,  2,  3,  2],
    'hiphop':    [ 3,  4,  2,  0, -1, -1,  0,  1,  1,  0],
    'jazz':      [ 0,  1,  1,  2,  2,  1,  0,  0,  1,  0],
    'metal':     [ 2,  3,  0, -3, -4, -3,  0,  3,  3,  2],
    'pop':       [ 0,  1,  2,  2,  1,  0,  1,  1,  2,  2],
    'reggae':    [ 0,  1,  0,  0, -1, -1,  0,  0,  1,  0],
    'rock':      [ 1,  1,  0,  0,  1,  1,  0,  0,  1,  1]
    }
    
    if preset_name in genre_presets:
        return np.array(genre_presets[preset_name]) * intensity
    else:
        print(f"Preset '{preset_name}' not found. Available presets: {list(genre_presets.keys())}")
        return np.zeros(10)
    
# --- Load the saved model ---
model_path = "./saved_models/CONEqNet.keras"
if os.path.exists(model_path):
    model = tf.keras.models.load_model(model_path)
    print(f"Loaded model from {model_path}")
else:
    raise FileNotFoundError("Model file not found. Please check the path.")

try:
    # Step 1: Load or generate audio
    dir_path = "./music/"
    input_file_name = "rammstein-deuschland.mp3"
    output_file_name = input_file_name.replace(".mp3", "-equalized-ai.mp3")
    file_path = os.path.join(dir_path, input_file_name)
    output_path = os.path.join(dir_path, output_file_name)

    intensity = 1.0

    try:
        audio, sr = load_audio(file_path)
    except:
        print(f"Couldn't load {file_path}")
    
    # Step 2: Visualize original audio (first 30 seconds)
    visualize_audio(audio, sr, "Original Audio")
    
    # Step 3: Create equalization profile
    center_freqs, gains = create_eq_profile(num_bands=10, sr=sr)
    
    # Step 4: Process audio in segments (30s segments with 3s overlap)
    processed_audio = process_audio_in_segments(audio, sr, center_freqs, segment_duration=30, overlap=3, intensity=intensity)
    
    # Step 5: Visualize the equalized audio
    visualize_audio(processed_audio, sr, "Equalized Audio")
    
    # Step 6: Play original and equalized audio
    print("Playing Original Audio:")
    ipd.display(play_audio(audio, sr, "Original Audio"))
    print("Playing Equalized Audio:")
    ipd.display(play_audio(processed_audio, sr, "Equalized Audio"))
    
    # Step 7: Save the equalized audio to file
    save_audio(processed_audio, sr, output_path)

except Exception as e:
    print("An error occurred in the processing pipeline:", e)

# Pred 2

In [None]:
import torch
import torchaudio
import numpy as np
import librosa

# Load the pre-trained PANNs CNN14 model from Torch Hub.
# This model is trained on AudioSet (527 classes)
model = torch.hub.load('qiuqiangkong/audioset_tagging_cnn', 'cnn14', pretrained=True, force_reload=True)
model.eval()  # Set model to evaluation mode

# Define a function to preprocess audio and run prediction
def predict_genre(audio_path, target_sr=32000):
    """
    Load an audio file, resample as needed, and run the pre-trained CNN14 model
    to obtain clip-level predictions.
    
    Parameters:
    audio_path (str): Path to the audio file (can be mp3, wav, etc.)
    target_sr (int): Sampling rate expected by the model (default 32000 Hz)
    
    Returns:
    np.ndarray: The predicted output (clipwise_output) from CNN14.
    """
    # Load audio using torchaudio (returns tensor shape [channels, samples])
    waveform, sr = torchaudio.load(audio_path)
    
    # If the sampling rate is different from target_sr, resample the waveform.
    if sr != target_sr:
        resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr)
        waveform = resampler(waveform)
    
    # Ensure mono: average if more than one channel.
    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)
    
    # Add batch dimension: model expects (batch, channels, samples)
    waveform = waveform.unsqueeze(0)
    
    # Run prediction
    with torch.no_grad():
        output = model(waveform)
    
    # The model returns a dict; 'clipwise_output' is a tensor of shape [batch, num_classes]
    predictions = output['clipwise_output'].squeeze(0).cpu().numpy()
    return predictions

# Example usage:
audio_path = "your_audio_file.mp3"  # Replace with your actual audio file path
preds = predict_genre(audio_path)
print("Model predictions:", preds)

# Mapping to a smaller set of music genres:
# The output 'preds' is a probability distribution over 527 classes.
# You may need to create a mapping or filter out predictions corresponding to music genres.
# For example:
# genre_mapping = {'rock': [...], 'pop': [...], ...}
# and then aggregate the scores from the relevant indices.
