In [None]:

validation_dir = "../../dataset/train-val_mosquito_sounds_classification_3cls_25_02_14/validation/"

checkpoint_dir = "./ensemble/best/_independentdata_seed42_adamw_cosine/"  # Best checkpoint directory

id2label = {
    0: 'Aedes_albopictus',
    1: 'Aedes_koreicus',
    2: 'Ochlerotatus_geniculatus'
}

cls_num = len(id2label)
label2id = {v: k for k, v in id2label.items()}
label2id


In [None]:
#from datasets import load_dataset
from sklearn.metrics import accuracy_score, confusion_matrix
import pandas as pd
import numpy as np
from tqdm import tqdm
import librosa
import torch
import os
from datasets import Dataset, DatasetDict
from safetensors.torch import load_file
import torch.nn as nn
from transformers import AutoProcessor, ASTModel
import torch.nn.functional as F


print("Does the checkpoint folder exist?", os.path.exists(checkpoint_dir))
print("Folder contents:", os.listdir(checkpoint_dir))



In [None]:
def load_audio_dataset_from_folders(validation_dir, everyNth=1):
    """
    Load audio data from folders and convert to a Hugging Face Dataset.

    Args:
        validation_dir (str): Path to the 'validation' folder.
        everyNth (int): Optional downsampling (take every nth file).

    Returns:
        Dataset: Dataset object with file paths and labels.
    """
    def get_audio_files_with_labels(directory):
        data = []
        for class_name in os.listdir(directory):  # Class folders
            class_path = os.path.join(directory, class_name)
            if os.path.isdir(class_path):
                for file_name in os.listdir(class_path):
                    if file_name.endswith(".wav"):  # Only WAV files
                        file_path = os.path.join(class_path, file_name)
                        data.append({"file_path": file_path, "label": class_name})
        return data

    validation_data = get_audio_files_with_labels(validation_dir)

    validation_dataset = Dataset.from_dict({
        "file_path": [d["file_path"] for idx, d in enumerate(validation_data) if idx % everyNth == 0],
        "label": [d["label"] for idx, d in enumerate(validation_data) if idx % everyNth == 0]
    })
    return validation_dataset


In [None]:
# Load data
validation_dataset = load_audio_dataset_from_folders(validation_dir)

print(validation_dataset)


In [None]:
# Load the base AST model
processor = AutoProcessor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
base_model = ASTModel.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")

# Adaptation for SSL (just like before)
class AST_SSL(nn.Module):
    def __init__(self, base_model, output_dim):
        super(AST_SSL, self).__init__()
        self.encoder = base_model
        self.encoder_output_dim = base_model.config.hidden_size
        self.output_dim = output_dim
        self.projector = nn.Linear(self.encoder_output_dim, output_dim[-1])
        self.decoder = nn.Sequential(
            nn.Conv1d(output_dim[-1], output_dim[-1], kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(output_dim[-1], output_dim[-1], kernel_size=3, padding=1),
        )

    def forward(self, input_values, labels=None):
        outputs = self.encoder(input_values=input_values).last_hidden_state
        projected = self.projector(outputs)

        # Reshape for temporal decoder
        projected = projected.permute(0, 2, 1)  # (B, T, C) -> (B, C, T)
        reconstructed = self.decoder(projected).permute(0, 2, 1)  # Back to (B, T, C)

        if labels is not None:
            # Interpolate reconstructed output to match labels shape
            reconstructed = F.interpolate(
                reconstructed.permute(0, 2, 1),
                size=labels.shape[1], mode="linear", align_corners=True
            ).permute(0, 2, 1)

            loss_fn = nn.MSELoss()
            loss = loss_fn(reconstructed, labels)
            return loss, reconstructed

        return reconstructed

# Instantiate SSL base model
ssl_model = AST_SSL(base_model, output_dim=[768])

# Remove projector/decoder layers, keeping only the encoder
del ssl_model.projector
del ssl_model.decoder

# Add new classification layer
class AST_Classifier(nn.Module):
    def __init__(self, ssl_model, num_classes):
        super().__init__()
        self.encoder = ssl_model.encoder  # Keep AST encoder part
        self.classifier = nn.Linear(768, num_classes)  # New classification layer

    def forward(self, input_values, labels=None):
        outputs = self.encoder(input_values)
        x = outputs.last_hidden_state
        x = self.encoder.layernorm(x)
        x = x.mean(dim=1)  # Global average pooling over time (B, D)
        logits = self.classifier(x)  # Classification logits (B, num_classes)

        if labels is not None:
            loss = torch.nn.functional.cross_entropy(logits, labels)
            return {"loss": loss, "logits": logits}

        return {"logits": logits}
    
classifier_model = AST_Classifier(ssl_model, num_classes=cls_num)
print("New classifier model created:", classifier_model)


In [None]:
state_dict = load_file(checkpoint_dir + "model.safetensors")
classifier_model.load_state_dict(state_dict, strict=False)
print("✅ Model loaded from checkpoint!")


In [None]:
model = classifier_model

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model.to(device)



In [None]:
def predict(audio_filepath, model, processor, device):
    """
    Load an audio file, preprocess it, and predict its class using the loaded model.

    :param audio_filepath: Path to the audio file
    :param model: Loaded AST-based classifier model
    :param processor: Transformer processor for preprocessing
    :param device: 'cuda' or 'cpu'
    :return: (predicted_class_index, predicted_probability)
    """
    audio_waveform, _ = librosa.load(audio_filepath, sr=16000)

    inputs = processor(audio_waveform, sampling_rate=16000, return_tensors="pt", padding=True)
    input_values = inputs.input_values.to(device)

    model.eval()
    with torch.no_grad():
        outputs = model(input_values)

    logits = outputs["logits"]
    probabilities = torch.nn.functional.softmax(logits, dim=-1)
    predicted_class = torch.argmax(probabilities, dim=-1).item()

    return predicted_class, probabilities[0][predicted_class].item()


In [None]:
processor

In [None]:
# Compute predictions
true_labels = []
predicted_labels = []

for example in tqdm(validation_dataset, desc="Processing validation set"):
    audio_filepath = example["file_path"]
    true_label = example["label"]
    predicted_class, predicted_prob = predict(audio_filepath, model, processor, device)
    true_labels.append(true_label)
    predicted_labels.append(predicted_class)
    

In [None]:
true_labels = [label2id[x] for x in true_labels]

true_labels


In [None]:
# Metrics
from sklearn.metrics import (
    accuracy_score, balanced_accuracy_score, precision_score, recall_score,
    f1_score, confusion_matrix, roc_auc_score
)
import matplotlib.pyplot as plt
import seaborn as sns

accuracy = accuracy_score(true_labels, predicted_labels)
balanced_accuracy = balanced_accuracy_score(true_labels, predicted_labels)
precision = precision_score(true_labels, predicted_labels, average="weighted", zero_division=0)
recall = recall_score(true_labels, predicted_labels, average="weighted", zero_division=0)
f1 = f1_score(true_labels, predicted_labels, average="weighted", zero_division=0)

print(f"Accuracy: {accuracy*100:.4f}")
print(f"Balanced Accuracy: {balanced_accuracy*100:.4f}")
print(f"Precision: {precision*100:.4f}")
print(f"Recall: {recall*100:.4f}")
print(f"F1 Score: {f1*100:.4f}")

conf_matrix = confusion_matrix(true_labels, predicted_labels)
print("Confusion Matrix:")
print(conf_matrix)

plt.figure(figsize=(10, 8))
sns.heatmap(conf_matrix, annot=True, fmt="d", cmap="Blues", xticklabels=id2label.values(), yticklabels=id2label.values())
plt.xlabel("Predicted Labels")
plt.ylabel("True Labels")
plt.title("Confusion Matrix")
plt.show()


# Class-wise metrics


In [None]:
# Class-wise metrics
class_labels = list(id2label.keys())
precision_per_class = precision_score(true_labels, predicted_labels, average=None, zero_division=0)
recall_per_class = recall_score(true_labels, predicted_labels, average=None, zero_division=0)
f1_per_class = f1_score(true_labels, predicted_labels, average=None, zero_division=0)

true_labels_binarized = np.eye(len(class_labels))[true_labels]
predicted_labels_binarized = np.eye(len(class_labels))[predicted_labels]

roc_auc_per_class = []
for i in range(len(class_labels)):
    try:
        roc_auc = roc_auc_score(true_labels_binarized[:, i], predicted_labels_binarized[:, i])
    except ValueError:
        roc_auc = np.nan
    roc_auc_per_class.append(roc_auc)

print("\nClass-wise metrics:")
for i, label in enumerate(class_labels):
    print(f"Class {id2label[label]}:")
    print(f"  Precision: {precision_per_class[i]:.4f}")
    print(f"  Recall: {recall_per_class[i]:.4f}")
    print(f"  F1 Score: {f1_per_class[i]:.4f}")
    print(f"  ROC-AUC: {roc_auc_per_class[i]:.4f}" if not np.isnan(roc_auc_per_class[i]) else "  ROC-AUC: N/A")

plt.figure(figsize=(12, 6))
metrics_labels = ["Precision", "Recall", "F1 Score", "ROC-AUC"]
metrics_values = [precision_per_class, recall_per_class, f1_per_class, roc_auc_per_class]

for metric, values in zip(metrics_labels, metrics_values):
    plt.plot(class_labels, values, marker='o', label=metric)

plt.xticks(class_labels, [id2label[label] for label in class_labels], rotation=45)
plt.xlabel("Class")
plt.ylabel("Metric Value")
plt.title("Class-wise Metrics")
plt.legend(title="Metrics")
plt.grid()
plt.tight_layout()
plt.show()


# per file metrics


In [None]:
import torch
import librosa
import numpy as np

def predict_top(audio_filepath, model, processor, device):
    """
    Load an audio file, preprocess it, and return all class indices in descending probability order.

    Parameters
    ----------
    audio_filepath : str
        Path to the audio file (.wav).
    model : nn.Module
        Loaded AST-based classifier model.
    processor : AutoProcessor
        Hugging Face processor for preprocessing the waveform.
    device : str or torch.device
        Device to run inference on ("cpu" or "cuda").

    Returns
    -------
    sorted_predictions : list of int
        Class indices sorted by descending probability.
    sorted_probs : list of float
        Corresponding probabilities for each class index.
    """
    model.to(device)  # Make sure the model is on the correct device

    # Load and resample audio to 16 kHz
    audio_waveform, _ = librosa.load(audio_filepath, sr=16000)

    # Preprocess with transformer processor
    inputs = processor(audio_waveform, sampling_rate=16000, return_tensors="pt", padding=True)
    input_values = inputs.input_values.to(device)

    # Forward pass
    model.eval()  # Disable dropout for inference
    with torch.no_grad():
        outputs = model(input_values)

    logits = outputs["logits"]  # Raw scores
    probabilities = torch.nn.functional.softmax(logits, dim=-1).cpu().numpy().flatten()  # Convert to probabilities

    # Sort class indices by descending probability
    sorted_indices = np.argsort(probabilities)[::-1]
    sorted_predictions = [int(idx) for idx in sorted_indices]
    sorted_probs = [float(probabilities[idx]) for idx in sorted_indices]

    return sorted_predictions, sorted_probs



In [None]:
import pandas as pd
import os

# Collect per-file results
data = []

for example in validation_dataset:
    audio_filepath = example["file_path"]
    true_label = label2id[example["label"]]

    # Predict top classes for each file
    sorted_predictions, sorted_probs = predict_top(audio_filepath, model, processor, device)

    # Extract a test_id from the file name up to the second underscore
    basename = os.path.basename(audio_filepath)
    parts = basename.split("_")
    test_id = parts[0] + "_" + parts[1] if len(parts) > 2 else "N/A"

    data.append({
        "file_path": audio_filepath,
        "test_id": test_id,
        "true_label": true_label,
        "top1_prediction": sorted_predictions[0],
        "top2_prediction": sorted_predictions[1] if len(sorted_predictions) > 1 else None,
        "top3_prediction": sorted_predictions[2] if len(sorted_predictions) > 2 else None,
        "top4_prediction": sorted_predictions[3] if len(sorted_predictions) > 3 else None,
    })

# Convert to DataFrame
df_results = pd.DataFrame(data)
df_results.tail()


In [None]:
df_results.head()

In [None]:
# Aggregate results per test_id to get majority-vote predictions
grouped_results = df_results.groupby("test_id").agg(
    true_label=("true_label", lambda x: x.iloc[0]),  # Take the first true label for the group
    predicted_label=("top1_prediction", lambda x: x.mode()[0] if not x.mode().empty else "N/A")  # Most frequent top1 prediction
).reset_index()

# Compute overall metrics
accuracy = (grouped_results["true_label"] == grouped_results["predicted_label"]).mean()
precision = grouped_results.groupby("predicted_label").apply(
    lambda x: (x["true_label"] == x["predicted_label"]).sum() / len(x)
).mean()
recall = grouped_results.groupby("true_label").apply(
    lambda x: (x["true_label"] == x["predicted_label"]).sum() / len(x)
).mean()
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

metrics = {
    "Accuracy": accuracy,
    "Precision": precision,
    "Recall": recall,
    "F1 Score": f1
}
print(metrics)


In [None]:
grouped_results.head(10)

In [None]:
from sklearn.metrics import (
    precision_score, recall_score, f1_score, confusion_matrix, roc_auc_score
)
import matplotlib.pyplot as plt
import seaborn as sns

true_labels = grouped_results["true_label"].values
predicted_labels = grouped_results["predicted_label"].values

class_labels = list(id2label.keys())
precision_per_class = precision_score(true_labels, predicted_labels, average=None, zero_division=0)
recall_per_class = recall_score(true_labels, predicted_labels, average=None, zero_division=0)
f1_per_class = f1_score(true_labels, predicted_labels, average=None, zero_division=0)

true_labels_binarized = np.eye(len(class_labels))[true_labels]
predicted_labels_binarized = np.eye(len(class_labels))[predicted_labels]

roc_auc_per_class = []
for i in range(len(class_labels)):
    try:
        roc_auc = roc_auc_score(true_labels_binarized[:, i], predicted_labels_binarized[:, i])
    except ValueError:
        roc_auc = np.nan
    roc_auc_per_class.append(roc_auc)

print("\nClass-wise metrics:")
for i, label in enumerate(class_labels):
    print(f"Class {id2label[label]}:")
    print(f"  Precision: {precision_per_class[i]:.4f}")
    print(f"  Recall: {recall_per_class[i]:.4f}")
    print(f"  F1 Score: {f1_per_class[i]:.4f}")
    print(f"  ROC-AUC: {roc_auc_per_class[i]:.4f}" if not np.isnan(roc_auc_per_class[i]) else "  ROC-AUC: N/A")

# Confusion matrix
conf_matrix = confusion_matrix(true_labels, predicted_labels)
plt.figure(figsize=(10, 8))
sns.heatmap(conf_matrix, annot=True, fmt="d", cmap="Blues",
            xticklabels=id2label.values(), yticklabels=id2label.values())
plt.xlabel("Predicted Labels")
plt.ylabel("True Labels")
plt.title("Confusion Matrix")
plt.show()

# Plot per-class metrics
plt.figure(figsize=(12, 6))
metrics_labels = ["Precision", "Recall", "F1 Score", "ROC-AUC"]
metrics_values = [precision_per_class, recall_per_class, f1_per_class, roc_auc_per_class]

for metric, values in zip(metrics_labels, metrics_values):
    plt.plot(class_labels, values, marker='o', label=metric)

plt.xticks(class_labels, [id2label[label] for label in class_labels], rotation=45)
plt.xlabel("Class")
plt.ylabel("Metric Value")
plt.title("Class-wise Metrics")
plt.legend(title="Metrics")
plt.grid()
plt.tight_layout()
plt.show()
