In [None]:
import os
import numpy as np
import librosa
import joblib

sample_audio = "data/genres/rock/rock.00055.wav"

## TODO: import _agg_stats and extract_features function from other files

def _agg_stats(x):
    """
    Compute mean and standard deviation for a feature array.
    Handles NaN or Inf values by converting them to 0.
    """
    x = np.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)
    return np.mean(x), np.std(x)

def extract_features(path):
    """
    Extracts a set of audio features from a given .wav file.
    Returns a feature vector of fixed length (~75 values).
    """
    y, sr = librosa.load(path, sr=SR, mono=True)
    S = np.abs(librosa.stft(y, n_fft=N_FFT, hop_length=HOP))
    
    # --- MFCC ---
    mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=MFCC_N, hop_length=HOP)
    mfcc_stats = [val for i in range(MFCC_N) for val in _agg_stats(mfcc[i])]
    
    # --- Chroma ---
    chroma = librosa.feature.chroma_stft(S=S, sr=sr, hop_length=HOP)
    chroma_stats = [val for i in range(chroma.shape[0]) for val in _agg_stats(chroma[i])]
    
    # --- Spectral Contrast ---
    contrast = librosa.feature.spectral_contrast(S=S, sr=sr, hop_length=HOP)
    contrast_stats = [val for i in range(contrast.shape[0]) for val in _agg_stats(contrast[i])]
    
    # --- Centroid, Bandwidth, Rolloff, ZCR, RMS ---
    centroid_stats = _agg_stats(librosa.feature.spectral_centroid(S=S, sr=sr)[0])
    bandwidth_stats = _agg_stats(librosa.feature.spectral_bandwidth(S=S, sr=sr)[0])
    rolloff_stats = _agg_stats(librosa.feature.spectral_rolloff(S=S, sr=sr, roll_percent=ROLLOFF_PERCENT)[0])
    zcr_stats = _agg_stats(librosa.feature.zero_crossing_rate(y, hop_length=HOP)[0])
    rms_stats = _agg_stats(librosa.feature.rms(S=S)[0])
    
    # --- Tempo (BPM) ---
    tempo, _ = librosa.beat.beat_track(y=y, sr=sr, hop_length=HOP)
    tempo = 0.0 if np.isnan(tempo) else float(tempo)
    
    features = np.array(
        mfcc_stats + chroma_stats + contrast_stats +
        list(centroid_stats) + list(bandwidth_stats) +
        list(rolloff_stats) + list(zcr_stats) + list(rms_stats) +
        [tempo],
        dtype=np.float32
    )
    return features


def predict_genre(audio_path, model):
    """
    Given an audio file path and a trained model,
    extract features and predict the genre label.
    """
    feat = extract_features(audio_path).reshape(1, -1)
    pred = model.predict(feat)[0]
    return pred

# Load best trained model from 'models/' directory
model_files = [f for f in os.listdir("models") if f.startswith("best_model")]
assert len(model_files) > 0, "No trained model found in 'models' folders."
model_path = os.path.join("models", model_files[0])
best_model = joblib.load(model_path)

print("Loaded best model: ", model_path)

predicted_genre = predict_genre(sample_audio, best_model)