In [None]:
import IPython.display as ipd
import numpy as np
import pandas as pd
import librosa
import tensorflow as tf

In [None]:
MODEL_PATH = './best_lstm_model.keras'
CSV_PATH   = '../augmented_normal_data.csv'

In [None]:
model = tf.keras.models.load_model(MODEL_PATH, compile=False)
F = model.input_shape[-1]
print("Loaded model:", MODEL_PATH, "| feature dim F =", F)

Loaded model: ./Models/best_lstm_model.keras | feature dim F = 128


In [None]:
# === Class order from CSV ===
df_meta = pd.read_csv(CSV_PATH)
print("Loaded CSV:", CSV_PATH, "| shape:", df_meta.shape)

US8K_ID2NAME = [
    'air_conditioner','car_horn','children_playing','dog_bark','drilling',
    'engine_idling','gun_shot','jackhammer','siren','street_music'
]

if 'label' in df_meta.columns:
    # CSV uses numeric ids 0..9
    class_names = US8K_ID2NAME[:]
    X_trainspace = df_meta.drop(columns=['label']).values.astype(np.float32)
else:
    # CSV is one-hot in the last 10 columns
    class_names = [str(c) for c in df_meta.columns[-10:]]
    X_trainspace = df_meta.iloc[:, :-10].values.astype(np.float32)

assert X_trainspace.shape[1] == F, f"CSV has {X_trainspace.shape[1]} features; model expects {F}"
print("Class order:", class_names)

Loaded CSV: ./augmented_normal_data.csv | shape: (43660, 129)
Class order: ['air_conditioner', 'car_horn', 'children_playing', 'dog_bark', 'drilling', 'engine_idling', 'gun_shot', 'jackhammer', 'siren', 'street_music']


In [None]:
# === Build CSV-space normalization ===
FEAT_MEAN = X_trainspace.mean(axis=0).astype(np.float32)
FEAT_STD  = X_trainspace.std(axis=0).astype(np.float32)
FEAT_STD[FEAT_STD == 0] = 1.0
NORMALIZE_WITH_CSV = True 

In [None]:
# === Audio -> vector (linear mel mean), chosen recipe ===
SR   = 22050
FMIN = 20
FMAX = 11025      
NFFT = 2048
HOP  = 512

def mel_vector_linear(y, sr, n_mels):
    m = librosa.feature.melspectrogram(
        y=y, sr=sr, n_mels=n_mels, fmin=FMIN, fmax=FMAX,
        n_fft=NFFT, hop_length=HOP, center=True, power=2.0
    )
    return m.mean(axis=1).astype(np.float32)  # (F,)

def extract_vector(file_path, target_dim):
    y, _ = librosa.load(file_path, sr=SR, mono=True)
    vec = mel_vector_linear(y, SR, target_dim)  # (F,)
    if NORMALIZE_WITH_CSV:
        vec = (vec - FEAT_MEAN) / FEAT_STD
    return vec


In [None]:
# === Prediction with TTA over 1s crops ===
def predict(file_path, class_names, tta_windows=7, crop_sec=1.0):
    y, _ = librosa.load(file_path, sr=SR, mono=True)
    total_sec = len(y) / SR
    C = len(class_names)
    proba_accum = np.zeros((C,), dtype=np.float32)

    if total_sec < crop_sec or tta_windows <= 1:
        v = extract_vector(file_path, F)
        x = v.reshape(1, 1, F)
        proba_accum += model.predict(x, verbose=0)[0]
    else:
        step = max((total_sec - crop_sec) / (tta_windows - 1), 1e-6)
        frame = int(crop_sec * SR)
        for w in range(tta_windows):
            start = int((w * step) * SR)
            end   = min(start + frame, len(y))
            seg   = y[start:end]
            if len(seg) < frame:
                seg = np.pad(seg, (0, frame - len(seg)))
            v = mel_vector_linear(seg, SR, F)
            if NORMALIZE_WITH_CSV:
                v = (v - FEAT_MEAN) / FEAT_STD
            x = v.reshape(1, 1, F)
            proba_accum += model.predict(x, verbose=0)[0]

    proba = proba_accum / max(tta_windows, 1)
    pred_idx = int(np.argmax(proba))
    top3_idx = np.argsort(proba)[::-1][:3]
    top3 = [(class_names[i], float(proba[i])) for i in top3_idx]
    return class_names[pred_idx], float(proba[pred_idx]), top3

In [None]:
# === Quick tests ===
test_files = [
    '../sounds/22601-8-0-51.wav',  # siren
    '../sounds/9223-2-0-17.wav',   # children_playing
    '../sounds/344-3-4-0.wav',     # dog_bark
]

for audio_path in test_files:
    print("\n▶︎", audio_path)
    display(ipd.Audio(filename=audio_path))
    try:
        name, conf, top3 = predict(audio_path, class_names, tta_windows=7, crop_sec=1.0)
        print(f"Predicted: {name} (p={conf:.3f})")
        print("Top-3:", ", ".join([f"{n} {p:.2f}" for n,p in top3]))
    except Exception as e:
        print("Error:", e)


▶︎ ./sounds/22601-8-0-51.wav


Predicted: siren (p=0.750)
Top-3: siren 0.75, street_music 0.13, dog_bark 0.12

▶︎ ./sounds/9223-2-0-17.wav


Predicted: children_playing (p=0.256)
Top-3: children_playing 0.26, street_music 0.26, siren 0.24

▶︎ ./sounds/344-3-4-0.wav


Predicted: dog_bark (p=0.142)
Top-3: dog_bark 0.14, siren 0.00, children_playing 0.00
