In [None]:
validation_dir = "../dataset/train-val-random/validation/"

checkpoints_dir = "./ensemble/best/" # inside, there are many checkpoints from which an ensemble should be calculated

cls_num = 3  # Set the appropriate value here
#id2label={0: 'Aedes_koreicus', 1: 'Ochlerotatus_geniculatus', 2: 'Aedes_albopictus'}
id2label={0: 'Aedes_koreicus', 1: 'Aedes_albopictus', 2: 'Ochlerotatus_geniculatus'}

label2id = {v: k for k, v in id2label.items()}
label2id


In [None]:
#from datasets import load_dataset
from sklearn.metrics import accuracy_score, confusion_matrix
import pandas as pd
import numpy as np
from tqdm import tqdm
import librosa
import torch
import os
from datasets import Dataset, DatasetDict
from safetensors.torch import load_file
import torch.nn as nn
from transformers import AutoProcessor, ASTModel
import torch.nn.functional as F


In [None]:
# Lists all subdirectories in the specified directory
checkpoints = [os.path.join(checkpoints_dir, d) for d in os.listdir(checkpoints_dir) 
               if os.path.isdir(os.path.join(checkpoints_dir, d))]

print("Found checkpoints:", checkpoints)


In [None]:
def load_audio_dataset_from_folders(validation_dir, everyNth=1):
    """
    Load data from folders and convert to Dataset format.

    Args:
        validation_dir (str): Path to the 'validation' folder.

    Returns:
        DatasetDict: A DatasetDict with 'validation' data.
    """
    def get_audio_files_with_labels(directory):
        data = []
        for class_name in os.listdir(directory):  # Classes ('mosquito', 'not')
            class_path = os.path.join(directory, class_name)
            if os.path.isdir(class_path):
                for file_name in os.listdir(class_path):
                    if file_name.endswith(".wav"):  # Only WAV files
                        file_path = os.path.join(class_path, file_name)
                        data.append({"file_path": file_path, "label": class_name})
        return data

    # Load validation data
    validation_data = get_audio_files_with_labels(validation_dir)

    # Create Dataset

    validation_dataset = Dataset.from_dict({
        "file_path": [d["file_path"] for idx, d in enumerate(validation_data) if idx % everyNth == 0],
        "label": [d["label"] for idx, d in enumerate(validation_data) if idx % everyNth == 0]
    })

    #return DatasetDict({"validation": validation_dataset})
    return validation_dataset


In [None]:
# Load data
validation_dataset = load_audio_dataset_from_folders(validation_dir)

print(validation_dataset)


In [None]:

"""
# Load the base AST model
processor = AutoProcessor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
base_model = ASTModel.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")

"""
# Convert for SSL
class AST_SSL(nn.Module):
    def __init__(self, base_model, output_dim):
        super(AST_SSL, self).__init__()
        self.encoder = base_model
        self.encoder_output_dim = base_model.config.hidden_size

        # Input and output dimension check
        self.output_dim = output_dim

        # Projector - Linear transformation on the AST hidden representation
        self.projector = nn.Linear(self.encoder_output_dim, output_dim[-1])

        # Convolutional decoder for temporal reconstruction
        self.decoder = nn.Sequential(
            nn.Conv1d(output_dim[-1], output_dim[-1], kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(output_dim[-1], output_dim[-1], kernel_size=3, padding=1),
        )

    def forward(self, input_values, labels=None):
        outputs = self.encoder(input_values=input_values).last_hidden_state
        projected = self.projector(outputs)

        # Reshape for temporal decoder
        projected = projected.permute(0, 2, 1)  # (B, T, C) -> (B, C, T)
        reconstructed = self.decoder(projected).permute(0, 2, 1)  # Back to (B, T, C)

        if labels is not None:
            # **Interpolation to the shape of labels**
            reconstructed = F.interpolate(reconstructed.permute(0, 2, 1), 
                                          size=labels.shape[1], mode="linear", align_corners=True)
            reconstructed = reconstructed.permute(0, 2, 1)
            
            loss_fn = nn.MSELoss()
            loss = loss_fn(reconstructed, labels)
            return loss, reconstructed

        return reconstructed


class AST_Classifier(nn.Module):
    def __init__(self, ssl_model, num_classes):
        super().__init__()
        self.encoder = ssl_model.encoder  # We keep the AST encoder part
        #self.layernorm = ssl_model.encoder.layernorm  # LayerNorm remains
        self.classifier = nn.Linear(768, num_classes)  # New classification layer

    def forward(self, input_values, labels=None):
        # The encoder's output is a ModelOutput, from which we need to select last_hidden_state.
        outputs = self.encoder(input_values)
        x = outputs.last_hidden_state  # This is already a Tensor
        x = self.encoder.layernorm(x)
        x = x.mean(dim=1)  # Global pooling: temporal averaging (B, D)
        logits = self.classifier(x)  # Classification logits, shape: (B, num_classes)
        
        if labels is not None:
            # If labels are also provided, calculate cross-entropy loss
            loss = torch.nn.functional.cross_entropy(logits, labels)
            return {"loss": loss, "logits": logits}
        
        return {"logits": logits}



In [None]:
import torch
import librosa
import numpy as np

def predict(audio_filepath, model, processor, device):
    """
    Loads the audio file, preprocesses it, and predicts the category with the loaded model.
    
    :param audio_filepath: The path to the audio file
    :param model: The loaded AST-based classifier model
    :param processor: The Transformer processor for preprocessing
    :return: (predicted_class, probability) - The index of the predicted class and its probability
    """

    # Load and normalize the audio file
    audio_waveform, sample_rate = librosa.load(audio_filepath, sr=16000)  # sampled at 16kHz

    # Transformer preprocessing
    inputs = processor(audio_waveform, sampling_rate=16000, return_tensors="pt", padding=True)
    input_values = inputs.input_values
    input_values = inputs.input_values.to(device)  # Send to the same device as the model


    # Model prediction
    model.eval()  # Turn off dropout layers (if any)
    with torch.no_grad():
        outputs = model(input_values)
    
    logits = outputs["logits"]  # Extract the logits
    probabilities = torch.nn.functional.softmax(logits, dim=-1)  # Softmax normalization
    predicted_class = torch.argmax(probabilities, dim=-1).item()  # Index of the most probable class

    return predicted_class, probabilities[0][predicted_class].item()


In [None]:

def evaluate_checkpoints(checkpoints_dir, validation_dataset, label2id, cls_num, predict):
    """ 
    Loads all checkpoints from the specified directory, 
    evaluates them on the validation data, and collects the results.
    """

    # List all available checkpoint directories
    checkpoints = [os.path.join(checkpoints_dir, d) for d in os.listdir(checkpoints_dir) 
                   if os.path.isdir(os.path.join(checkpoints_dir, d))]

    print(f"Found {len(checkpoints)} checkpoints!")

    # Check if a GPU is available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Load the base AST model
    processor = AutoProcessor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
    base_model = ASTModel.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")

    # Initialize the DataFrame to store results
    all_predictions = pd.DataFrame(columns=["file_path", "true_label"] + [f"pred_{i}" for i in range(len(checkpoints))])

    # Iterate through all checkpoints
    for idx, checkpoint_dir in enumerate(checkpoints):
        print(f"🔄 Evaluating with checkpoint {checkpoint_dir}...")

        # Create SSL base model
        ssl_model = AST_SSL(base_model, output_dim=[768])
        del ssl_model.projector
        del ssl_model.decoder
        # Initialize classification model
        classifier_model = AST_Classifier(ssl_model, num_classes=cls_num)
        classifier_model.to(device)

        # Load model from checkpoint
        checkpoint_path = os.path.join(checkpoint_dir, "model.safetensors")
        state_dict = load_file(checkpoint_path)
        classifier_model.load_state_dict(state_dict, strict=False)
        classifier_model.eval()  # Switch to evaluation mode
        print("✅ Model loaded!")

        # Calculate predictions
        predictions = []
        true_labels = []

        for example in tqdm(validation_dataset, desc=f"Checkpoint {idx+1}/{len(checkpoints)}"):
            audio_filepath = example["file_path"]
            true_label = example["label"]
            predicted_class, _ = predict(audio_filepath, classifier_model, processor, device)
            
            true_labels.append(label2id[true_label])
            predictions.append(predicted_class)

        # If this is the first checkpoint, create a base DataFrame
        if idx == 0:
            all_predictions["file_path"] = [example["file_path"] for example in validation_dataset]
            all_predictions["true_label"] = true_labels

        # Add a new column with the predictions of the current checkpoint
        all_predictions[f"pred_{idx}"] = predictions
        
        all_predictions.to_csv("ensemble_results.csv")

    return all_predictions


In [None]:
checkpoints_dir = "./ensemble/best/"
results_df = evaluate_checkpoints(checkpoints_dir, validation_dataset, label2id, cls_num, predict)



In [None]:
results_df

In [None]:
# unfortunately, the labels got mixed up due to the use of "set".
# they need to be corrected afterwards

import pandas as pd
import numpy as np
from collections import Counter

def remap_predictions_simple(df):
    true_labels = df["true_label"].values
    pred_cols = [col for col in df.columns if col.startswith("pred_")]
    
    for col in pred_cols:
        pred_labels = df[col].values
        
        # We count which label is the most frequent pair between true_label and pred_label
        mapping = {}
        for true, pred in zip(true_labels, pred_labels):
            if pred not in mapping:
                mapping[pred] = Counter()
            mapping[pred][true] += 1
        
        # For each pred value, we assign the true_label value that is the most frequent
        best_mapping = {pred: max(counts, key=counts.get) for pred, counts in mapping.items()}
        
        # Apply the mapping to the column
        df[col] = df[col].map(best_mapping)

    return df

# Run
updated_df = remap_predictions_simple(results_df)

# Save CSV
updated_df.to_csv("ensemble_results.csv", index=False)

updated_df


In [None]:
# evaluation


import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import mode
from sklearn.metrics import (
    accuracy_score, balanced_accuracy_score, precision_score, recall_score,
    f1_score, confusion_matrix, roc_auc_score
)

# Calculate the ensemble prediction (majority vote)
pred_cols = [col for col in updated_df.columns if col.startswith("pred_")]
ensemble_predictions = mode(updated_df[pred_cols].values, axis=1)[0].flatten()

# True labels
true_labels = updated_df["true_label"].values

# Calculate metrics for individual models and the ensemble
metrics = {"Model": [], "Accuracy": [], "Balanced Accuracy": [], "Precision": [], "Recall": [], "F1 Score": []}

# Evaluate individual models
for col in pred_cols:
    predicted_labels = updated_df[col].values
    accuracy = accuracy_score(true_labels, predicted_labels)
    balanced_accuracy = balanced_accuracy_score(true_labels, predicted_labels)
    precision = precision_score(true_labels, predicted_labels, average="weighted", zero_division=0)
    recall = recall_score(true_labels, predicted_labels, average="weighted", zero_division=0)
    f1 = f1_score(true_labels, predicted_labels, average="weighted", zero_division=0)
    
    metrics["Model"].append(col)
    metrics["Accuracy"].append(accuracy)
    metrics["Balanced Accuracy"].append(balanced_accuracy)
    metrics["Precision"].append(precision)
    metrics["Recall"].append(recall)
    metrics["F1 Score"].append(f1)

# Evaluate ensemble model
accuracy = accuracy_score(true_labels, ensemble_predictions)
balanced_accuracy = balanced_accuracy_score(true_labels, ensemble_predictions)
precision = precision_score(true_labels, ensemble_predictions, average="weighted", zero_division=0)
recall = recall_score(true_labels, ensemble_predictions, average="weighted", zero_division=0)
f1 = f1_score(true_labels, ensemble_predictions, average="weighted", zero_division=0)

metrics["Model"].append("Ensemble")
metrics["Accuracy"].append(accuracy)
metrics["Balanced Accuracy"].append(balanced_accuracy)
metrics["Precision"].append(precision)
metrics["Recall"].append(recall)
metrics["F1 Score"].append(f1)

# Print metrics
metrics_df = pd.DataFrame(metrics)
print(metrics_df)

# Calculate confusion matrix for the ensemble model
conf_matrix = confusion_matrix(true_labels, ensemble_predictions)

# Visualize confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(conf_matrix, annot=True, fmt="d", cmap="Blues")
plt.xlabel("Predicted Labels")
plt.ylabel("True Labels")
plt.title("Ensemble Confusion Matrix")
plt.show()



In [None]:
#ERRORROR

In [None]:
#print(predicted_prob)

In [None]:
from sklearn.metrics import (
    accuracy_score, balanced_accuracy_score, precision_score, recall_score,
    f1_score, confusion_matrix, roc_auc_score
)
import matplotlib.pyplot as plt
import seaborn as sns

# Calculate accuracy
accuracy = accuracy_score(true_labels, predicted_labels)
balanced_accuracy = balanced_accuracy_score(true_labels, predicted_labels)
precision = precision_score(true_labels, predicted_labels, average="weighted", zero_division=0)
recall = recall_score(true_labels, predicted_labels, average="weighted", zero_division=0)
f1 = f1_score(true_labels, predicted_labels, average="weighted", zero_division=0)

# Print metrics
print(f"Accuracy: {accuracy:.2f}")
print(f"Balanced Accuracy: {balanced_accuracy:.2f}")
print(f"Precision: {precision:.2f}")
print(f"Recall: {recall:.2f}")
print(f"F1 Score: {f1:.2f}")

# Calculate confusion matrix
conf_matrix = confusion_matrix(true_labels, predicted_labels)
print("Confusion Matrix:")
print(conf_matrix)

# Visualize confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(conf_matrix, annot=True, fmt="d", cmap="Blues", xticklabels=id2label.values(), yticklabels=id2label.values())
plt.xlabel("Predicted Labels")
plt.ylabel("True Labels")
plt.title("Confusion Matrix")
plt.show()



In [None]:
# Calculate metrics for individual classes
class_labels = list(id2label.keys())  # Class indices
precision_per_class = precision_score(true_labels, predicted_labels, average=None, zero_division=0)
recall_per_class = recall_score(true_labels, predicted_labels, average=None, zero_division=0)
f1_per_class = f1_score(true_labels, predicted_labels, average=None, zero_division=0)

# Binarize true labels for ROC-AUC calculation (one vs. all)
true_labels_binarized = np.eye(len(class_labels))[true_labels]  # One-hot encoding
predicted_labels_binarized = np.eye(len(class_labels))[predicted_labels]

roc_auc_per_class = []
for i in range(len(class_labels)):
    try:
        roc_auc = roc_auc_score(true_labels_binarized[:, i], predicted_labels_binarized[:, i])
    except ValueError:  # If there is no positive sample
        roc_auc = np.nan
    roc_auc_per_class.append(roc_auc)

# Print class-wise metrics
print("\nClass-wise metrics:")
for i, label in enumerate(class_labels):
    print(f"Class {id2label[label]}:")
    print(f"  Precision: {precision_per_class[i]:.2f}")
    print(f"  Recall: {recall_per_class[i]:.2f}")
    print(f"  F1 Score: {f1_per_class[i]:.2f}")
    print(f"  ROC-AUC: {roc_auc_per_class[i]:.2f}" if not np.isnan(roc_auc_per_class[i]) else "  ROC-AUC: N/A")

# Visualization (class-wise metrics)
metrics_df = {
    "Precision": precision_per_class,
    "Recall": recall_per_class,
    "F1 Score": f1_per_class,
    "ROC-AUC": roc_auc_per_class,
}
metrics_df = {id2label[label]: values for label, values in zip(class_labels, zip(*metrics_df.values()))}

# Correct visualization of class-wise metrics

plt.figure(figsize=(12, 6))
metrics_labels = ["Precision", "Recall", "F1 Score", "ROC-AUC"]
metrics_values = [precision_per_class, recall_per_class, f1_per_class, roc_auc_per_class]

for metric, values in zip(metrics_labels, metrics_values):
    plt.plot(class_labels, values, marker='o', label=metric)

plt.xticks(class_labels, [id2label[label] for label in class_labels], rotation=45)
plt.xlabel("Class")
plt.ylabel("Metric Value")
plt.title("Class-wise Metrics")
plt.legend(title="Metrics")
plt.grid()
plt.tight_layout()
plt.show()
