In [1]:
import tensorflow as tf  # Fixed typo (was 'as pd')
import numpy as np
import pickle
import os
from tensorflow.keras.models import load_model
from utils.embedding_extraction import extract_embedding

# --- CONFIGURATION ---
MODEL_DIR = '..\models'
MODEL_PATH = os.path.join(MODEL_DIR, 'instrument_classifier.h5')
LE_PATH = os.path.join(MODEL_DIR, 'label_encoder.pkl')

print("Loading resources...")

# 1. Load the Custom Trained Model
if not os.path.exists(MODEL_PATH):
    raise FileNotFoundError(f"Model not found at {MODEL_PATH}. Run 02_Model_Training.ipynb first!")

custom_model = load_model(MODEL_PATH)
print("‚úÖ Custom Model loaded.")

# 2. Load the Label Encoder
if not os.path.exists(LE_PATH):
    raise FileNotFoundError(f"Label Encoder not found at {LE_PATH}. Run 02_Model_Training.ipynb first!")

with open(LE_PATH, 'rb') as f:
    label_encoder = pickle.load(f)
print("‚úÖ Label Encoder loaded.")

  from pkg_resources import parse_version


Loading resources...
‚úÖ Custom Model loaded.
‚úÖ Label Encoder loaded.


In [5]:
def predict_instrument(audio_path):
    """
    Predicts the instrument label for a given audio file.
    """
    # 1. Extract features using the shared utility
    # The model expects a batch, so we reshape (1024,) -> (1, 1024)
    embedding = extract_embedding(audio_path)

    if embedding is None:
        return "Error reading audio", 0.0

    # Reshape to (1, 1024) for the model
    features = embedding.reshape(1, -1)

    # 2. Predict
    prediction_scores = custom_model.predict(features, verbose=0)

    # 3. Process results
    predicted_index = np.argmax(prediction_scores)
    confidence = prediction_scores[0][predicted_index]

    # 4. Convert index back to Label string
    predicted_label = label_encoder.inverse_transform([predicted_index])[0]

    return predicted_label, confidence

# --- TEST USAGE ---
# Replace with a real path to a file you want to test
# flute
test_file = "C:\\Users\Hakim\Documents\ENSET\IA_Avance\music-instrument-classfication\data\\raw\\nsynth\\test\\audio\\flute_acoustic_002-098-025.wav"

# Bass
# test_file = "C:\\Users\Hakim\Documents\ENSET\IA_Avance\music-instrument-classfication\data\\raw\\nsynth\\test\\audio\\bass_synthetic_134-045-100.wav"

# Guitar
# test_file = "C:\\Users\Hakim\Documents\ENSET\IA_Avance\music-instrument-classfication\data\\raw\\nsynth\\test\\audio\\guitar_acoustic_010-106-050.wav"

# String
# test_file = "C:\\Users\Hakim\Documents\ENSET\IA_Avance\music-instrument-classfication\data\\raw\\nsynth\\test\\audio\\string_acoustic_014-046-100.wav"

# Violin
# test_file = "C:\\Users\Hakim\Downloads\G4_fla.wav"

# Saxophone
# test_file = "C:\\Users\Hakim\Downloads\TENORSA.WAV"

# Clarinet
# test_file = "C:\\Users\Hakim\Downloads\\natural2.wav"

print(f"\nTesting on: {test_file}")
if os.path.exists(test_file):
    label, conf = predict_instrument(test_file)
    print(f"üé∏ Prediction: {label}")
    print(f"üìä Confidence: {conf*100:.2f}%")
else:
    print("‚ö†Ô∏è File not found. Please update 'test_file' variable with a valid path.")


Testing on: C:\Users\Hakim\Documents\ENSET\IA_Avance\music-instrument-classfication\data\raw\nsynth\test\audio\flute_acoustic_002-098-025.wav
üé∏ Prediction: Flute
üìä Confidence: 96.78%
