In [None]:
CONFIG = '../../configuration/multiclass_test/config.json'
MODEL_PATH = "../../logs/multiclass_test/run_8/best.pth"

In [None]:
import json
import random

from IPython.display import Audio, display
import matplotlib.pyplot as plt
import sys
import os

sys.path.append(os.path.abspath('../../'))

with open(CONFIG) as f:
    cfg = json.load(f)

class Config:
    def __init__(self, dictionary):
        for k, v in dictionary.items():
            if isinstance(v, dict):
                setattr(self, k, Config(v))
            else:
                setattr(self, k, v)

cfg = Config(cfg)
cfg.data.root = os.path.join('..', '..', 'data')

In [None]:
from dataset.dataset import SpeechCommandsDataset, get_loader

test_dataset = SpeechCommandsDataset(
    root_dir=cfg.data.root,
    cfg=cfg,
    mode='testing'
)

In [None]:
import torch
from modeling.model import build_model

model = build_model(cfg)

state_dict = torch.load(MODEL_PATH, map_location=torch.device('cpu'))
model.load_state_dict(state_dict)

model.eval()

print("Model loaded successfully!")

In [None]:
def evaluate_and_display_wrong_predictions(model, test_dataset, n=5):
    model.eval() 
    
    indices = list(range(len(test_dataset)))
    random.shuffle(indices)
    
    wrong_predictions = []
    
    for idx in indices:
        data, true_label = test_dataset[idx]
        
        data_tensor = torch.tensor(data, dtype=torch.float32).unsqueeze(0)
        
        with torch.no_grad():
            output = model(data_tensor)
            predicted_label = torch.argmax(output, dim=1).item()
        
        if predicted_label != true_label:
            wrong_predictions.append((idx, true_label, predicted_label))
        
        if len(wrong_predictions) >= n:
            break
    
    for idx, true_label, predicted_label in wrong_predictions:
        print(f"Sample Index: {idx}")
        print(f"True Label: {cfg.data.target_commands[true_label]}, Predicted Label: {cfg.data.target_commands[predicted_label]}")
        
        waveform, _ = test_dataset[idx]
        display(Audio(waveform, rate=cfg.data.sample_rate))
        print("-" * 50)

In [None]:
evaluate_and_display_wrong_predictions(model, test_dataset, n=10)