In [None]:
%%time

# --- Install offline packages ---
try:
    import ace_tools_open
except ModuleNotFoundError:
    print('Installing ace tools...')
    !pip install -q /kaggle/input/offline-packages/itables-2.3.0-py3-none-any.whl
    !pip install -q /kaggle/input/offline-packages/ace_tools_open-0.1.0-py3-none-any.whl
    
try:
    import timm
except ModuleNotFoundError:
    print('Installing timm...')
    !pip install -q /kaggle/input/offline-packages/timm-1.0.15-py3-none-any.whl
    
# --- Core libraries ---
import os
import math
import random
import time
import numpy as np
import seaborn as sns
from collections import defaultdict
import plotly.express as px

# --- Data handling ---
import pandas as pd
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, roc_curve, accuracy_score, average_precision_score, precision_recall_curve

from dataclasses import dataclass, field
from sklearn.preprocessing import label_binarize

# --- PyTorch ---
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchaudio.functional import bandpass_biquad

# --- Audio processing ---
import torchaudio
import torchaudio.transforms as T

# --- Visualization ---
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import Audio, display
from tqdm.notebook import tqdm
import ace_tools_open as tools
import torchvision
from torchvision.ops.focal_loss import sigmoid_focal_loss
import cv2

# --- Parallel and Custom Tools ---
from joblib import Parallel, delayed
from torch.amp import GradScaler, autocast
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import Optional, List, Tuple
import timm
import tempfile
import gc
import itertools
from glob import glob
from mpl_toolkits.axes_grid1 import make_axes_locatable

In [None]:
n = os.listdir("/kaggle/input/eda-birdclef2025") + os.listdir("/kaggle/input/precomputing-spectrograms") + os.listdir("/kaggle/input/precomputing-spectrograms2") + os.listdir("/kaggle/input/precomputing-spectrograms3")
f = [g for g in n if g.endswith(".csv")]
f

In [None]:
@dataclass
class CFG:
    # General
    LOAD_DATA: bool = True
    seed: int = 69
    debug: bool = False
    device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
    BATCH_SIZE: int = 32
    num_workers: int = 4
    
    ## Data paths ##
    OUTPUT_DIR: str = '/kaggle/working/'
    temporary_dir: str = field(init=False)
    spectrogram_dir: str = "/kaggle/input/precomputing-spectrograms3" # "/kaggle/input/eda-birdclef2025" # "/kaggle/input/precomputing-spectrograms" #"/kaggle/working/precomputed_spectograms"
    spectrogram_csv_filename: str = "spec_metadata.csv"
    spectrograms_metadata_path: str = field(init=False) # "filename", "num_frames", "primary_label"
    
    # Base path to dataset
    data_path: str = '/kaggle/input/birdclef-2025/'
    # Key file paths
    metadata_path: str = field(init=False)
    taxonomy_path: str = field(init=False)
    sample_submission_path: str = field(init=False)
    location_path: str = field(init=False)
    # Audio data directories
    train_data_path: str = field(init=False)
    test_soundscapes_path: str = field(init=False)
    unlabeled_soundscapes_path: str = field(init=False)

    # Augmentation
    augment = False
    mixup = False
    aug_prob: float = 0.5
    mixup_alpha: float = 0.4

    # Model
    model_name: str = "efficientnet_b0" # 'efficientnet_b3_pruned', 'efficientnetv2_rw_m', 'efficientvit_l1', 'efficientvit_l2', 'efficientvit_m0'
    pretrained: bool = True
    input_directory: str = '/kaggle/input/offline-packages'
    input_model_filename: str = field(init=False)
    
    timewise_weights_path: str = '/kaggle/input/effnet28/efficientnet_b0_sed.pth'
    freqwise_weights_path: str = '/kaggle/input/effnet14/efficientnet_b0_sed.pth'
    
    output_model_filename: str = field(init=False)
    custom_weights_path: str = field(init=False)
    model_weights: str = field(init=False)
    num_classes: int = field(init=False)

    def __post_init__(self):
        self.metadata_path = os.path.join(self.data_path, 'train.csv')
        self.taxonomy_path = os.path.join(self.data_path, 'taxonomy.csv')
        self.sample_submission_path = os.path.join(self.data_path, 'sample_submission.csv')
        self.location_path = os.path.join(self.data_path, 'recording_location.txt')
        self.train_data_path = os.path.join(self.data_path, 'train_audio')
        self.test_soundscapes_path = os.path.join(self.data_path, 'test_soundscapes')
        self.unlabeled_soundscapes_path = os.path.join(self.data_path, 'train_soundscapes')
        self.spectrograms_metadata_path = os.path.join(self.spectrogram_dir, self.spectrogram_csv_filename)
        
        self.input_model_filename = f'{self.model_name}_pretrained.pth'
        self.output_model_filename = f'{self.model_name}_sed.pth'
        self.model_weights = os.path.join(self.input_directory, self.input_model_filename)
        self.custom_weights_path = f'/kaggle/input/effnet31/{self.output_model_filename}'
        self.num_classes = len(pd.read_csv(self.taxonomy_path))

        self.temporary_dir = tempfile.TemporaryDirectory().name
        if self.debug:
            self.EPOCHS = 2
            print("⚠️ Debug mode is ON. Training only for 2 epochs.")

cfg = CFG()

## Dataset and Dataloader

In [None]:
class PrecomputedSpectrogramDataset(Dataset):
    """
    PyTorch Dataset for loading precomputed log-mel spectrograms.

    Arguments:
    ----------
    metadata : pd.DataFrame
        DataFrame containing filenames, primary labels, and number of frames.
    spec_dir : str
        Directory where the spectrogram .npy files are saved.
    label_to_index : dict
        Mapping from class name to label index.
    augment : bool
        Whether to apply augmentation (placeholder for now).
    """
    def __init__(self, metadata, spec_dir, augment=False):
        self.metadata = metadata
        self.spectrogram_dir = spec_dir
        self.label_to_class, self.label_to_index, _, self.filename_to_secondary_label = get_mappings()
        self.augment = augment
        if self.augment:
            self.augmentor = SpectrogramAugmentor()

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        row = self.metadata.iloc[idx]
        path = os.path.join(self.spectrogram_dir, "precomputed_spectrograms", row["filename"])
        
        # Load spectrogram
        spec = np.load(path, mmap_mode="r")
        spec = torch.tensor(spec, dtype=torch.float16).unsqueeze(0)  # [1, M, T]
        
        # Get the class name
        primary_label = row["primary_label"]
        class_name = self.label_to_class.get(primary_label, "Unknown")
        
        # Apply augmentations if enabled
        if self.augment and class_name in {"Aves", "Insecta", "Amphibia", "Mammalia"}:
            spec = self.augmentor.apply_augmentations(spec, class_name)
        
        # Convert to 3-channel image-like format
        spec = spec.repeat(3, 1, 1)
        
        # Multi-hot vector for primary labels
        primary_label_tensor = torch.zeros(len(self.label_to_index), dtype=torch.float16)
        if primary_label in self.label_to_index:
            primary_label_tensor[self.label_to_index[primary_label]] = 1.0
        
        # Multi-hot vector for secondary labels
        secondary_labels = self.filename_to_secondary_label.get(f"{row['filename'].split('_')[0]}.ogg", [])
        secondary_label_tensor = torch.zeros(len(self.label_to_index), dtype=torch.float16)
        for sec_label in secondary_labels:
            if sec_label in self.label_to_index:
                secondary_label_tensor[self.label_to_index[sec_label]] = 0.5
    
        # Combine them (primary gets 1.0 weight, secondary gets 0.5 weight)
        combined_labels = primary_label_tensor + secondary_label_tensor
        combined_labels = torch.clamp(combined_labels, 0, 1)  # Ensure it's only 0 or 1
        
        return {
            "spectrogram": spec, 
            "labels": combined_labels, 
            "filename": row["filename"], 
            "class_name": class_name
        }
def collate_fn(batch, mixup=False, alpha=0.4):
    """
    Custom collate function to handle varying-size spectrograms and apply Mixup if specified.

    Parameters:
    -----------
    batch : list
        List of samples (dict) with spectrogram, label, and filename.
    mixup : bool
        Whether to apply Mixup augmentation to the batch.
    alpha : float
        Alpha parameter for the Beta distribution in Mixup.

    Returns:
    --------
    dict : 
        Dictionary with stacked tensors and filenames.
    """
    if len(batch) == 0:
        return {"spectrograms": None, "labels": None, "filenames": None}
    
    # Extract elements
    specs = [item["spectrogram"] for item in batch]
    labels = [item["labels"] for item in batch]
    filenames = [item["filename"] for item in batch]

    # Stack along the batch dimension
    specs = torch.stack(specs)
    labels = torch.stack(labels)

    # 🚀 Apply Mixup if specified and more than one element exists in the batch
    if mixup and len(batch) > 1:
        indices = torch.randperm(len(batch))
        mixed_specs = []
        mixed_labels = []

        for i in range(0, len(batch) - 1, 2):
            lam = np.random.beta(alpha, alpha)
            spec1, spec2 = specs[i], specs[indices[i]]
            label1, label2 = labels[i], labels[indices[i]]

            mixed_spec = lam * spec1 + (1 - lam) * spec2
            mixed_label = lam * label1 + (1 - lam) * label2

            mixed_specs.append(mixed_spec)
            mixed_labels.append(mixed_label)

        # If the batch size is odd, we append the last sample as it is
        if len(batch) % 2 != 0:
            mixed_specs.append(specs[-1])
            mixed_labels.append(labels[-1])

        specs = torch.stack(mixed_specs)
        labels = torch.stack(mixed_labels)

    return {"spectrograms": specs, "labels": labels, "filenames": filenames}

def get_mappings():
    """
    Creates label-to-class and label-to-index mappings.
    Also maps filenames to their associated secondary labels.
    
    Returns:
    --------
    - label_to_class : dict
        Maps primary labels to their respective class names.
    - label_to_index : dict
        Maps primary labels to unique index values.
    - filename_to_secondary_label : dict
        Maps filenames to lists of secondary labels.
    """
    # Load the datasets
    taxonomy_df = pd.read_csv(cfg.taxonomy_path)
    metadata_df = pd.read_csv(cfg.spectrograms_metadata_path)
    train_df = pd.read_csv(cfg.metadata_path)
    
    # Filter out unused labels
    used_labels = set(metadata_df["primary_label"].unique())
    taxonomy_df = taxonomy_df[taxonomy_df['primary_label'].isin(used_labels)]

    # Label to Class Mapping
    label_to_class = taxonomy_df.set_index('primary_label')['class_name'].to_dict()

    # Label to Index Mapping → Now guaranteed to match!
    label_to_index = {label: idx for idx, label in enumerate(sorted(label_to_class.keys()))}
    index_to_label = {idx: label for idx, label in enumerate(sorted(label_to_class.keys()))}

    # 🌟 New Logic to retrieve filename → secondary_labels
    filename_to_secondary_label = {}

    # Iterate over the DataFrame
    for _, row in train_df.iterrows():
        filename = row["filename"]
        secondary_labels = eval(row["secondary_labels"]) if isinstance(row["secondary_labels"], str) and row["secondary_labels"] != "['']" else []

        # Only add if there are secondary labels
        if len(secondary_labels) > 0:
            filename_to_secondary_label[filename] = secondary_labels

    # Display the label maps only once (no need to repeat every call)
    if not hasattr(get_mappings, "_displayed"):
        try:
            tools.display_dataframe_to_user(name="Label to Class", dataframe=pd.DataFrame(label_to_class.items(), columns=["Animal Label", "Class Name"]))
            tools.display_dataframe_to_user(name="Label Map", dataframe=pd.DataFrame(label_to_index.items(), columns=["Animal Label", "Index"]))
            tools.display_dataframe_to_user(name="Filename to Secondary Labels", dataframe=pd.DataFrame(list(filename_to_secondary_label.items()), columns=["Filename", "Secondary Labels"]))
            get_mappings._displayed = True
        except NameError as e:
            print(e)
    
    return label_to_class, label_to_index, index_to_label, filename_to_secondary_label

def create_dataloader(dataset, cfg, shuffle, collate_fn):
    """
    Creates a DataLoader for a given dataset.

    Parameters:
    -----------
    dataset : Dataset
        PyTorch dataset object.
    cfg : object
        Configuration object containing batch size and num_workers.
    shuffle : bool
        Whether to shuffle the dataset.
    collate_fn : callable
        Collate function for batching.

    Returns:
    --------
    DataLoader:
        A PyTorch DataLoader instance.
    """
    loader = DataLoader(
        dataset,
        batch_size=cfg.BATCH_SIZE,
        shuffle=shuffle,
        num_workers=cfg.num_workers,
        pin_memory=True,
        collate_fn=collate_fn
    )
    return loader



## SED models

In [None]:
class EfficientNetTimeSED(nn.Module):
    """
    EfficientNet with an SED Head for BirdCLEF inference.
    """
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        
        # Backbone with no classifier head
        self.backbone = timm.create_model(
            cfg.model_name, 
            pretrained=False, 
            features_only=False
        )
        # Load pretrained weights
        print(f"[INFO] Loading custom trained weights from {cfg.model_weights}...")
        checkpoint = torch.load(cfg.model_weights, map_location=cfg.device, weights_only=True)
        self.backbone.load_state_dict(checkpoint)

        # Remove classifier and add custom head
        self.feature_dim = self.backbone.classifier.in_features
        self.backbone.classifier = nn.Identity()
        
        # Attention Block
        self.avg_pool = nn.AdaptiveAvgPool2d((1, None))
        self.conv_att = nn.Conv2d(self.feature_dim, self.feature_dim, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

        # Final Classification Layer
        self.classifier = nn.Sequential(
            nn.Conv2d(self.feature_dim, cfg.num_classes, kernel_size=1),
            nn.AdaptiveMaxPool2d((1, 1)),
            nn.Flatten()
        )

        self.attention_weights = None 
        

    def forward(self, x):
        """
        Forward pass for the model.
        Parameters:
        -----------
        x : Tensor of shape [B, 3, M, T] where M = Mel bands, T = Time.
        
        Returns:
        --------
        Tensor of shape [B, num_classes]
        """
        features = self.backbone.forward_features(x)  # shape: [B, C, M', T']
        
        attn = self.avg_pool(features)                # → [B, C, 1, T]
        attn = self.conv_att(attn)                    # pointwise conv
        attn = self.sigmoid(attn)                     # sigmoid scaling [B, C, 1, T]

        self.attention_weights = attn.detach()        # Store for visualization
        features = features * attn                    # Apply attention

        out = self.classifier(features)               # Final output
        return out

    def get_framewise_output(self, x):
        features = self.backbone.forward_features(x)  # [B, C, F, T]
        attn = self.sigmoid(self.conv_att(self.avg_pool(features)))  # [B, C, 1, T]
        weighted = features * attn  # [B, C, F, T]
        raw_output = self.classifier[0](weighted)  # Conv2d only: [B, num_classes, F, T]
        framewise_output = raw_output.mean(dim=2)  # Mean over F → [B, num_classes, T]
        return self.sigmoid(framewise_output)


def remap_attention_keys(state_dict):
    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith("att_block.1."):
            new_k = k.replace("att_block.1", "conv_att")
        elif k.startswith("att_block.2."):
            new_k = k.replace("att_block.2", "sigmoid")
        else:
            new_k = k
        new_state_dict[new_k] = v
    return new_state_dict


In [None]:
class EfficientNetFrequencySED(nn.Module):
    """
    EfficientNet with a custom SED head for frequency-wise attention.
    
    This model:
    - Uses a pretrained EfficientNet backbone
    - Applies a frequency-wise attention mechanism
    - Outputs class probabilities for multi-class classification
    
    Arguments:
    ----------
    cfg : object
        Configuration object (assumes it's an instance of CFG)
    """
    def __init__(self, cfg):
        super().__init__()
        
        # Store config and device
        self.cfg = cfg
        self.device = torch.device(cfg.device)

        # Create model with the correct architecture
        self.backbone = timm.create_model(cfg.model_name, pretrained=cfg.pretrained)

        # Remove classifier head, we will add our own
        self.feature_dim = self.backbone.classifier.in_features
        self.backbone.classifier = nn.Identity()  # Remove classifier

        # Frequency-wise attention block -> attention mechanism to emphasize important frequency regions.
        self.att_block = nn.Sequential(
            nn.AdaptiveAvgPool2d((None, 1)),          # Mean over frequency bands
            nn.Conv2d(self.feature_dim, self.feature_dim, kernel_size=1),
            nn.Sigmoid()
        )

        # Custom classifier head
        self.classifier = nn.Sequential(
            nn.Conv2d(self.feature_dim, cfg.num_classes, kernel_size=1),
            nn.AdaptiveMaxPool2d((1, 1)),
            nn.Flatten()
        )

        self.attention_weights = None 

    def forward(self, x):
        """
        Forward pass of the model.
        
        Parameters:
        -----------
        x : torch.Tensor
            Input tensor of shape [B, 3, M, T], where:
            - B = Batch size
            - M = Mel bands (frequency bins)
            - T = Time frames

        Returns:
        --------
        torch.Tensor:
            Output tensor of shape [B, num_classes]
        """
        x = x.to(self.device)
        features = self.backbone.forward_features(x)  # EfficientNet backbone [B, C, M', T']
        attn = self.att_block(features)  # Attention on frequency bands [B, C, T', 1]

        self.attention_weights = attn.detach()        # Store for visualization
        features = features * attn       # Apply attention
        
        out = self.classifier(features)  # Classify [B, num_classes]
        return out


## Model evaluation and interpretation

In [None]:
class ModelEvaluator:
    def __init__(self, model, dataloader, cfg, label_to_class, index_to_label):
        self.model = model
        self.dataloader = dataloader
        self.cfg = cfg
        self.label_to_class = label_to_class
        self.index_to_label = index_to_label
        self.num_classes = len(index_to_label)
        
    def run_inference(self):
        all_targets, all_preds = [], []
        self.model.eval()
        
        with torch.no_grad():
            for batch in tqdm(self.dataloader, desc="Evaluating"):
                inputs = batch["spectrograms"].to(self.cfg.device)
                targets = batch["labels"].to(self.cfg.device)
                
                # Only keep primary labels for AUC calculation
                primary_only_targets = (targets == 1).int()
                assert primary_only_targets.sum() == primary_only_targets.shape[0]
                
                try:
                    with autocast(device_type=self.cfg.device):
                        outputs = self.model(inputs)
                        outputs = torch.sigmoid(outputs).detach().cpu().numpy()
                except Exception as e:
                    print(f"[WARNING] Exception during inference: {e}")
                    outputs = self.model(inputs.float())
                    outputs = torch.sigmoid(outputs).cpu().numpy()
                
                all_targets.append(primary_only_targets.cpu().numpy())
                #all_targets.append(targets.cpu().numpy())
                all_preds.append(outputs)
        
        y_true = np.concatenate(all_targets, axis=0)
        y_pred = np.concatenate(all_preds, axis=0)
        return y_true, y_pred

    def calculate_per_class_auc(self, y_true, y_pred):
        results = []
        for idx in range(self.num_classes):
            true_class = y_true[:, idx]
            pred_class = y_pred[:, idx]

            if np.sum(true_class) > 0 and np.sum(true_class) < len(true_class):
                try:
                    auc = roc_auc_score(true_class, pred_class)
                except ValueError:
                    auc = np.nan
            else:
                auc = np.nan

            species = self.index_to_label[idx]
            class_name = self.label_to_class[species]

            results.append({
                "species": species,
                "class_name": class_name,
                "auc": auc,
                "positive_samples": int(np.sum(true_class)),
                "total_samples": len(true_class)
            })
        
        df_results = pd.DataFrame(results)
        return df_results

    def plot_results(self, df_results):
        plt.figure(figsize=(15, 6))
        sns.barplot(
            x="species", y="auc", hue="class_name", 
            data=df_results.sort_values(by="auc"),
            dodge=False
        )
        plt.xticks(rotation=90, fontsize=6)
        plt.title("Per-Class ROC AUC Scores")
        plt.xlabel("Species")
        plt.ylabel("AUC")
        plt.legend(title="Class", loc="lower right")
        plt.tight_layout()
        plt.show()

        plt.figure(figsize=(8, 5))
        sns.histplot(df_results["auc"].dropna(), bins=20, kde=True)
        plt.title("Distribution of Per-Class AUC Scores")
        plt.xlabel("AUC Score")
        plt.ylabel("Frequency")
        plt.show()


In [None]:
class ModelAnalyzer:
    def __init__(self, df_results, y_true, y_pred, do_print=False):
        self.df_results = df_results.copy()
        self.y_true = y_true  # shape (n_samples, n_classes), one-hot encoded
        self.y_pred = y_pred  # shape (n_samples, n_classes), probabilities
        self.print = do_print
        
    def summarize_best_worst(self, top_n=10):
        df_valid = self.df_results.dropna(subset=["auc"])
        best = df_valid.sort_values(by="auc", ascending=False).head(top_n)
        worst = df_valid.sort_values(by="auc", ascending=True).head(top_n)
        
        if self.print:
            print("\n=== BEST CLASSES ===")
            print(best[["species", "class_name", "auc", "positive_samples"]])
            print("\n=== WORST CLASSES ===")
            print(worst[["species", "class_name", "auc", "positive_samples"]])

    def compute_accuracy(self):
        y_true_labels = np.argmax(self.y_true, axis=1)
        y_pred_labels = np.argmax(self.y_pred, axis=1)
        overall_acc = accuracy_score(y_true_labels, y_pred_labels)
        if self.print:
            print(f"\n=== Overall Accuracy: {overall_acc:.4f} ===")

        # Per-class accuracy
        correct_per_class = []
        total_per_class = []

        for idx in range(self.y_true.shape[1]):
            true_class_idx = np.where(y_true_labels == idx)[0]
            if len(true_class_idx) == 0:
                acc = np.nan
            else:
                acc = np.mean(y_pred_labels[true_class_idx] == idx)
            correct_per_class.append(acc)
            total_per_class.append(len(true_class_idx))

        self.df_results["accuracy"] = correct_per_class
        
        if self.print:
            print("\n=== Per-Class Accuracy Added ===")
            print(self.df_results[["species", "class_name", "accuracy"]].sort_values(by="accuracy", ascending=True))
        return self.df_results

    def compute_auc(self):
        aucs = []
        for idx in range(self.y_true.shape[1]):
            true_class = self.y_true[:, idx]
            pred_class = self.y_pred[:, idx]
            try:
                auc = roc_auc_score(true_class, pred_class)
            except ValueError:
                auc = np.nan
            aucs.append(auc)
        self.df_results["auc"] = aucs
         
        if self.print:
            print("\n=== AUC Recomputed Without Positive/Negative Checks ===")
            print(self.df_results[["species", "class_name", "auc"]].sort_values(by="auc", ascending=True))
        return self.df_results

    def compute_precision_recall(self):
        precision_results = []
        for idx in range(self.y_true.shape[1]):
            true_class = self.y_true[:, idx]
            pred_class = self.y_pred[:, idx]
            try:
                ap = average_precision_score(true_class, pred_class)
            except ValueError:
                ap = np.nan
            precision_results.append(ap)
        self.df_results["average_precision"] = precision_results
        
        if self.print:
            print("\n=== Precision-Recall Added ===")
            print(self.df_results[["species", "class_name", "average_precision"]].sort_values(by="average_precision", ascending=True))
        return self.df_results

    def compute_confusion_stats(self, threshold=0.5):
        false_pos = []
        false_neg = []
        for idx in range(self.y_true.shape[1]):
            true_class = self.y_true[:, idx]
            pred_class = (self.y_pred[:, idx] >= self.df_results["best_threshold"][idx]).astype(int)
            fp = np.sum((pred_class == 1) & (true_class == 0))
            fn = np.sum((pred_class == 0) & (true_class == 1))
            false_pos.append(fp)
            false_neg.append(fn)
        self.df_results["false_positives"] = false_pos
        self.df_results["false_negatives"] = false_neg
        
        if self.print:
            print("\n=== Confusion Stats Added ===")
            print(self.df_results[["species", "class_name", "false_positives", "false_negatives"]].sort_values(by="false_negatives", ascending=False))
        return self.df_results

    def compute_best_thresholds(self):
        best_thresholds = []
        for idx in range(self.y_true.shape[1]):
            y_true_class = self.y_true[:, idx]
            y_pred_class = self.y_pred[:, idx]
            precision, recall, thresholds = precision_recall_curve(y_true_class, y_pred_class)
            f1_scores = 2 * (precision * recall) / (precision + recall + 1e-8)
            if len(thresholds) > 0:
                best_thresh = thresholds[np.argmax(f1_scores)]
            else:
                best_thresh = 0.5  # fallback default
            best_thresholds.append(best_thresh)
        self.df_results["best_threshold"] = best_thresholds
        
        if self.print:
            print("\n=== Best Thresholds Computed per Class ===")
            print(self.df_results[["species", "class_name", "best_threshold"]].sort_values(by="best_threshold", ascending=True))
        return self.df_results

    def get_species_info(self, primary_label):
        row = self.df_results[self.df_results["species"] == primary_label]
        if row.empty:
            print(f"[WARNING] Species {primary_label} not found in results.")
            return None
        info = row[[
            "species", "class_name", "auc", "positive_samples",
            "average_precision", "false_positives", "false_negatives", "accuracy", "total_samples"
        ]]
        tools.display_dataframe_to_user(name=f"\n=== Info for Species '{primary_label}' ===", dataframe=pd.DataFrame(info))
        return info

    def display_full_dataframe(self):
        tools.display_dataframe_to_user(name=f"\n=== Full Results DataFrame ===", dataframe=self.df_results)

    def plot_stats(self, df_with_confusion, n=10):
        plt.figure(figsize=(8, 6))
        sns.scatterplot(
            x="false_positives", y="false_negatives",
            hue="class_name", size="positive_samples",
            data=df_with_confusion, alpha=0.7, edgecolor='k'
        )
        plt.title("False Positives vs False Negatives per Class")
        plt.xlabel("False Positives")
        plt.ylabel("False Negatives")
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.tight_layout()
        plt.show()

        plt.figure(figsize=(8, 6))
        sns.scatterplot(
            x="positive_samples", y="auc",
            hue="class_name",
            data=df_with_confusion, alpha=0.7, edgecolor='k'
        )
        plt.title("AUC vs Positive Sample Count")
        plt.xlabel("Positive Samples")
        plt.ylabel("AUC")
        plt.xscale('log')
        plt.tight_layout()
        plt.show()

        top_fn = df_with_confusion.sort_values(by="false_negatives", ascending=False).head(n)
        if self.print:
            print(f"\n=== Top {n} Classes with Most False Negatives ===")
            print(top_fn[["species", "class_name", "false_negatives", "false_positives"]])

        for idx, row in top_fn.iterrows():
            y_true_class = self.y_true[:, idx]
            y_pred_class = self.y_pred[:, idx]
            precision, recall, thresholds = precision_recall_curve(y_true_class, y_pred_class)
            f1_scores = 2 * (precision * recall) / (precision + recall + 1e-8)
            best_thresh = thresholds[np.argmax(f1_scores)] if len(thresholds) > 0 else 0.5
            if self.print:
                print(f"Species: {row['species']} | Best Threshold: {best_thresh:.3f}")
                
    def plot_performance_metrics(self):
        df_acc = self.df_results.sort_values(by="accuracy", ascending=True)
        df_ap = self.df_results.sort_values(by="average_precision", ascending=True)
    
        # Use integer indices for x-axis to prevent label clutter
        df_acc["index"] = range(len(df_acc))
        df_ap["index"] = range(len(df_ap))

        # Generate a fixed color mapping
        unique_classes = self.df_results["class_name"].unique()
        palette = px.colors.qualitative.Safe  # or use D3, Set1, etc.
        color_map = {cls: palette[i % len(palette)] for i, cls in enumerate(sorted(unique_classes))}

    
        # Accuracy Plot
        fig1 = px.bar(
            df_acc,
            x="index",
            y="accuracy",
            color="class_name",
            color_discrete_map=color_map,
            hover_data=["species", "class_name", "accuracy", "positive_samples"],
            title="Accuracy by Species (Hover to See Species)",
            labels={"index": "Class Index", "accuracy": "Accuracy", "positive_samples": "# samples", "class_name": "Class name"},
            height=500
        )
        fig1.update_layout(
            xaxis_tickvals=[],  # Remove tick labels
            showlegend=True,
            legend_title="Class Name"
        )
        fig1.show()
    
        # Average Precision Plot
        fig2 = px.bar(
            df_ap,
            x="index",
            y="average_precision",
            color="class_name",
            color_discrete_map=color_map,
            hover_data=["species", "class_name", "average_precision", "positive_samples"],
            title="Average Precision by Species (Hover to See Species)",
            labels={"index": "Class Index", "average_precision": "Average Precision", "positive_samples": "# samples", "class_name": "Class name"},
            height=500
        )
        fig2.update_layout(
            xaxis_tickvals=[],
            showlegend=True,
            legend_title="Class Name"
        )
        fig2.show()



## Utils

In [None]:
def select_species_samples(
    loader,
    species: str,
    mode: str = "random",
    num_samples: int = 5,
    target_filename: Optional[str] = None
) -> List[Tuple[torch.Tensor, str]]:
    """
    Selects samples from the loader for a given species.

    Parameters:
    -----------
    loader : DataLoader
        The data loader to pull samples from.
    species : str
        The species ID to match in filenames.
    mode : str
        Selection mode: "random", "consecutive", or "specific".
    num_samples : int
        Number of samples to select.
    target_filename : str, optional
        Used only when mode == "specific".

    Returns:
    --------
    List of tuples: (spectrogram tensor, filename)
    """
    # Get species info from analyzer
    row = analyzer.get_species_info(species)
    max_n = row["positive_samples"].item()
    num_samples = min(max_n, num_samples)
    
    all_matches = []

    for batch in loader:
        filenames = batch["filenames"]
        inputs = batch["spectrograms"].to(cfg.device).float()

        for i, fname in enumerate(filenames):
            if fname.startswith(species):
                all_matches.append((inputs[i], fname))

        if len(all_matches) == max_n:
            break

    if not all_matches:
        print(f"[WARNING] No samples found for species '{species}'")
        return []

    if mode == "specific" and target_filename:
        for input_tensor, fname in all_matches:
            if target_filename.split('.')[0] == fname.split("_")[0]:
                return [(input_tensor, fname)]
        print(f"[WARNING] Specified file '{target_filename}' not found for species '{species}'")
        return []

    elif mode == "consecutive":
        return all_matches[:num_samples]

    elif mode == "random":
        return random.sample(all_matches, min(len(all_matches), num_samples))

    else:
        raise ValueError(f"[ERROR] Unknown mode: {mode}")


In [None]:
def filter_data(
    df, 
    collection=None, 
    rating=None, 
    primary_label=None, 
    common_name=None, 
    filename=None,
    random_sample=True
):
    # Step 1: Filter the dataframe
    filtered = df.copy()
    
    if collection:
        filtered = filtered[filtered['collection'] == collection]
    if rating is not None:
        filtered = filtered[filtered['rating'] == rating]
    if primary_label:
        filtered = filtered[filtered['primary_label'] == primary_label]
    if filename:
        filtered = filtered[filtered['filename'] == filename]
    if common_name:
        filtered = filtered[filtered['common_name'].str.contains(common_name, case=False, na=False)]

    # Step 2: Select sample
    if filtered.empty:
        print("No matches found with the given filters.")
        return

    return filtered  # return metadata for inspection


def play_audio(audio_dir, row):
    #sample = filtered.sample(1) if random_sample else filtered.iloc[[0]]
    #row = sample.iloc[0]
    
    # Step 3: Load and play
    path = os.path.join(audio_dir, row['filename'])
    print(f"▶️ Playing: {row['common_name']} [{row['primary_label']}]")
    print(f"Collection: {row['collection']}, Rating: {row['rating']}")
    print(f"File: {row['filename']}")
    display(Audio(path))

## GradCam and Attention

In [None]:
class GradCAM:
    def __init__(self, model, target_layer):
        """
        Initializes GradCAM for a specified model and layer.
        """
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        
        # Register hooks: Check if full backward is allowed
        try:
            self.hook = target_layer.register_full_backward_hook(self.save_gradients)
            if cfg.debug:
                print("[INFO] Full Backward Hook Registered Successfully.")
        except RuntimeError:
            # Fallback if full hook is not available
            self.hook = target_layer.register_backward_hook(self.save_gradients)
            print("[WARNING] Full Backward Hook not available, using regular backward hook.")

        target_layer.register_forward_hook(self.save_activations)

    def save_activations(self, module, input, output):
        """ Save activations from the forward pass. """
        self.activations = output

    def save_gradients(self, module, grad_input, grad_output):
        """ Save gradients from the backward pass. """
        self.gradients = grad_output[0]

    def __call__(self, inputs, class_idx=None):
        """
        Generate Grad-CAM heatmap for the given input.
        """
        self.model.zero_grad()
        outputs = self.model(inputs)

        if class_idx is None:
            class_idx = outputs.argmax(dim=1)

        one_hot = torch.zeros_like(outputs)
        for i, idx in enumerate(class_idx):
            one_hot[i, idx] = 1
        outputs.backward(gradient=one_hot, retain_graph=True)

        # Compute Grad-CAM
        pooled_gradients = torch.mean(self.gradients, dim=[0, 2, 3])
        activations = self.activations[0]
        
        for i in range(len(pooled_gradients)):
            activations[i, :, :] *= pooled_gradients[i]

        heatmap = torch.mean(activations, dim=0).cpu().detach().numpy()
        heatmap = np.maximum(heatmap, 0)  # ReLU
        heatmap /= np.max(heatmap)  # Normalize
        return heatmap

    def visualize(self, spec_path, heatmap, filename):
        """
        Visualizes the spectrogram with Grad-CAM overlay, side by side.
        """
        spec = np.load(spec_path)
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 4))

        # Original Spectrogram
        ax1.imshow(spec, aspect='auto', origin='lower', cmap='viridis')
        ax1.set_title(f"Original Spectrogram: {filename}")
        ax1.set_xlabel("Time Frames")
        ax1.set_ylabel("Mel Bands")

        # Grad-CAM Overlay
        ax2.imshow(np.flipud(spec), aspect='auto', origin='lower', cmap='viridis')
        ax2.imshow(cv2.resize(heatmap, (spec.shape[1], spec.shape[0])), alpha=0.4, cmap='jet')
        ax2.set_title(f"Grad-CAM Visualization: {filename}")
        ax2.set_xlabel("Time Frames")
        ax2.set_ylabel("Mel Bands")

        plt.tight_layout()
        plt.show()

def visualize_grad_cam_samples(loader, grad_cam, num_samples=5):
    """
    Visualizes a limited number of Grad-CAM samples from the loader.
    """
    count = 0
    for batch in tqdm(loader, desc="Visualizing Grad-CAM"):
        inputs = batch["spectrograms"].to(cfg.device).float()
        filenames = batch["filenames"]

        # Forward pass to get heatmaps
        for i, input_tensor in enumerate(inputs):
            if count >= num_samples:
                return  # Limit reached

            # Compute Grad-CAM
            heatmap = grad_cam(input_tensor.unsqueeze(0))

            # Load the original spectrogram
            spec_path = os.path.join(cfg.spectrogram_dir, "precomputed_spectrograms", filenames[i])

            # Visualize
            grad_cam.visualize(spec_path, heatmap, filenames[i])
            count += 1

def visualize_grad_cam_by_species(loader, grad_cam, species, num_samples=5):
    """
    Visualizes Grad-CAM for random samples matching the specified species.

    Parameters:
    ----------
    loader : DataLoader
        The DataLoader providing batches with keys: spectrograms, labels, filenames.
    grad_cam : GradCAM
        The initialized GradCAM object.
    species : str
        The target species (matches first part of filename, e.g., '1139490').
    num_samples : int
        Number of random Grad-CAM visualizations to generate.
    """
    matching_samples = []

    # First pass: Collect all matching samples
    for batch in tqdm(loader, desc=f"Searching for species {species}"):
        filenames = batch["filenames"]
        inputs = batch["spectrograms"]

        for i, fname in enumerate(filenames):
            if fname.startswith(species):
                matching_samples.append((inputs[i], fname))

    if len(matching_samples) == 0:
        print(f"[WARNING] No samples found for species '{species}'")
        return

    # Randomly select up to num_samples
    selected_samples = random.sample(matching_samples, min(num_samples, len(matching_samples)))

    for input_tensor, fname in selected_samples:
        input_tensor = input_tensor.unsqueeze(0).to(cfg.device).float()
        heatmap = grad_cam(input_tensor)

        spec_path = os.path.join(cfg.spectrogram_dir, "precomputed_spectograms", fname)
        grad_cam.visualize(spec_path, heatmap, fname)


In [None]:
class AttentionVisualizer:
    def __init__(self, model):
        self.model = model

    def get_temporal_attention(self):
        """
        Returns the averaged temporal attention across channels.

        Assumes model.attention_weights shape: [B, C, 1, T]
        Returns:
            np.ndarray of shape [B, T] → temporal attention per time step
        """
        attn = self.model.attention_weights  # [B, C, 1, T]
        attn = attn.mean(dim=1)              # → [B, 1, T]
        attn = attn.squeeze(1)               # → [B, T]
        return attn.cpu().numpy()

    def get_frequency_attention(self):
        """
        Returns the averaged frequency attention across channels.

        Assumes model.attention_weights shape: [B, C, F, 1]
        Returns:
            np.ndarray of shape [B, T] → temporal attention per time step
        """
        attn = self.model.attention_weights  # [B, C, F. 1]
        attn = attn.mean(dim=1)              # → [B, F, 1]
        attn = attn.squeeze(2)               # → [B, F]
        return attn.cpu().numpy()


In [None]:
def visualize_multiple_gradcams_by_layer(selected_sample, model, gradcam_class, cols=4):
    """
    Visualizes Grad-CAM overlays from multiple layers on spectrograms.
    
    Parameters:
    -----------
    selected_sample : list of (tensor, str)
        List of (input_tensor, filename) tuples.
    model : torch.nn.Module
        The trained model.
    layers : list of torch.nn.Module
        List of layers to hook for Grad-CAM.
    layer_names : list of str
        Names for display corresponding to each layer.
    gradcam_class : type
        GradCAM class (used to instantiate per layer).
    """
    # Select layers
    layers_to_use = [
        model.backbone.blocks[0],
        model.backbone.blocks[1],
        model.backbone.blocks[2],
        model.backbone.blocks[3],
        model.backbone.blocks[4],
        model.backbone.blocks[5],
        model.backbone.blocks[6],
    ]
    layer_names = [f"Block {i}" for i in range(len(layers_to_use))]

    input_tensor, filename = selected_sample
    spec_path = os.path.join(cfg.spectrogram_dir, "precomputed_spectrograms", filename)
    spec = np.load(spec_path)
    

    total = len(layers_to_use) + 1  # +1 for original
    rows = (total + cols - 1) // cols

    fig, axes = plt.subplots(rows, cols, figsize=(6 * cols, 4.5 * rows))
    axes = axes.flatten()

    # Original
    axes[0].imshow(spec, aspect='auto', cmap='viridis', origin='lower')
    axes[0].set_title(f"Original Spectrogram\n{filename}")
    axes[0].set_xlabel("Time Frames")
    axes[0].set_ylabel("Mel Bands")

    for i, (layer, name) in enumerate(zip(layers_to_use, layer_names)):
        grad_cam = gradcam_class(model, layer)
        heatmap = grad_cam(input_tensor.unsqueeze(0))  # shape: [H, W]

        axes[i + 1].imshow(spec, aspect='auto', cmap='viridis', origin='lower')
        im = axes[i + 1].imshow(cv2.resize(heatmap, (spec.shape[1], spec.shape[0])), alpha=0.4, cmap='jet', origin='lower')
        axes[i + 1].set_title(f"Grad-CAM: {name}")
        axes[i + 1].set_xlabel("Time Frames")
        axes[i + 1].set_ylabel("Mel Bands")

        fig.colorbar(im, ax=axes[i+1], fraction=0.046, pad=0.04)

    plt.tight_layout()
    plt.show()


In [None]:
def init(model, conv_layer=-1):
    target_layer = model.backbone.blocks[conv_layer] 

    grad_cam = GradCAM(model, target_layer)
    attention_visualizer = AttentionVisualizer(model)

    return grad_cam, attention_visualizer
    

In [None]:
def visualize_gradcam_and_time_wise_attention(selected_samples, grad_cam, attention_visualizer, n=10):
    """
    Visualizes Grad-CAM and Attention Maps for random samples of a given species.
    """
    
    for input_tensor, filename in selected_samples[:n]:
        input_tensor = input_tensor.unsqueeze(0).to(cfg.device)
        spec_path = os.path.join(cfg.spectrogram_dir, "precomputed_spectrograms", filename)
        spec = np.load(spec_path)

        # Compute GradCAM
        heatmap = grad_cam(input_tensor)
        cam = cv2.resize(heatmap, (spec.shape[1], spec.shape[0]))

        # Trigger Attention Forward Pass
        with torch.no_grad():
            _ = attention_visualizer.model(input_tensor)

        # Get Attention Map → collapse over channels
        attn_raw = attention_visualizer.get_temporal_attention()[0]  # [T,]
        attn_raw = np.maximum(attn_raw, 0)
        attn_raw /= attn_raw.max() if attn_raw.max() > 0 else 1

        # Create vertical attention heatmap (each value stretched over spectrogram height)
        framewise_preds = model_time.get_framewise_output(input_tensor).squeeze(0).cpu().detach().numpy()  # [num_classes, T]
        attention_heatmap = np.tile(attn_raw, (spec.shape[0], 1))  # [Freq, Time]
        
        # Combined Plot
        fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(22, 5))

        # Original
        ax1.imshow(spec, aspect='auto', cmap='viridis', origin='lower')
        ax1.set_title(f"Original Spectrogram\nFile: {filename}")
        ax1.set_xlabel("Time Frames")
        ax1.set_ylabel("Mel Bands")

        # GradCAM
        ax2.imshow(spec, aspect='auto', cmap='viridis', origin='lower')
        ax2.imshow(cam, aspect='auto', alpha=0.5, cmap='jet', origin='lower')
        ax2.set_title("Grad-CAM Overlay\nWhich regions most contributed to THIS class decision?")
        ax2.set_xlabel("Time Frames")
        ax2.set_ylabel("Mel Bands")

        # Attention Overlay with vertical heat strips
        extent = [0, 5, 0, spec.shape[0]]  # [xmin, xmax, ymin, ymax]
        ax3.imshow(spec, aspect='auto', cmap='viridis', origin='lower')
        ax3.imshow(attention_heatmap, aspect='auto', alpha=0.4, cmap='hot', origin='lower', extent=extent)
        ax3.set_title("Temporal Attention Strips\nWhere is the model focusing internally?")
        ax3.set_xlabel("Time Frames")
        ax3.set_ylabel("Mel Bands")
        
        plt.tight_layout()
        plt.show()




In [None]:
def visualize_gradcam_and_frequency_wise_attention(selected_samples, grad_cam, attention_visualizer, n=10):
    """
    Visualizes Grad-CAM and frequency-wise Attention Maps for selected samples.

    Parameters:
    -----------
    selected_samples : list of (input_tensor, filename)
        Spectrogram tensors and their associated filenames.
    grad_cam : GradCAM object
        Initialized with the correct target layer.
    attention_visualizer : AttentionVisualizer object
        Must wrap a model that stores frequency attention weights in .get_temporal_attention().
    n : int
        Number of samples to visualize.
    """

    for input_tensor, filename in selected_samples[:n]:
        input_tensor = input_tensor.unsqueeze(0).to(cfg.device)
        spec_path = os.path.join(cfg.spectrogram_dir, "precomputed_spectrograms", filename)
        spec = np.load(spec_path)

        # === Grad-CAM ===
        heatmap = grad_cam(input_tensor)  # shape: [H, W]
        cam = cv2.resize(heatmap, (spec.shape[1], spec.shape[0]))

        # === Trigger forward pass to capture attention ===
        with torch.no_grad():
            _ = attention_visualizer.model(input_tensor)

        # === Extract frequency-wise attention weights ===
        attn_raw = attention_visualizer.get_frequency_attention()[0]  # shape: [Freq,]
        attn_raw = np.maximum(attn_raw, 0)
        attn_raw /= attn_raw.max() if attn_raw.max() > 0 else 1

        # Expand attention vertically across time (horizontal lines)
        attention_heatmap = np.tile(attn_raw[:, np.newaxis], (1, spec.shape[1]))  # [Freq, Time]

        # === Plot ===
        fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(22, 5))

        # Original Spectrogram
        ax1.imshow(spec, aspect='auto', cmap='viridis', origin='lower')
        ax1.set_title(f"Original Spectrogram\nFile: {filename}")
        ax1.set_xlabel("Time Frames")
        ax1.set_ylabel("Mel Bands")

        # GradCAM
        ax2.imshow(spec, aspect='auto', cmap='viridis', origin='lower')
        ax2.imshow(cam, aspect='auto', alpha=0.4, cmap='jet', origin='lower')
        ax2.set_title("Grad-CAM Overlay\nImportant Regions for Prediction")
        ax2.set_xlabel("Time Frames")
        ax2.set_ylabel("Mel Bands")

        # Frequency Attention Overlay
        ax3.imshow(spec, aspect='auto', cmap='viridis', origin='lower')
        ax3.imshow(attention_heatmap, aspect='auto', alpha=0.5, cmap='hot', origin='lower')
        ax3.set_title("Frequency-wise Attention\nWhere is the model focusing?")
        ax3.set_xlabel("Time Frames")
        ax3.set_ylabel("Mel Bands")

        plt.tight_layout()
        plt.show()


## Predictions against spectrogram

In [None]:
def plot_prediction_overview(spec, framewise_preds, output, filename, index_to_label=None, label_to_species=None, top_k=5):
    """
    Plots the spectrogram and all class framewise predictions (as in the example image).

    Parameters:
    -----------
    spec : np.ndarray
        Spectrogram [mel_bins, time]
    framewise_preds : np.ndarray
        Prediction matrix [num_classes, time]
    filename : str
        Name for title
    index_to_label : dict
        Optional class index → label
    label_to_species : dict
        Optional label → species name
    top_k : int
        Number of top predictions to print
    """
    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(7, 10), gridspec_kw={'height_ratios': [4, 3, 3]})

    # Spectrogram
    ax1.imshow(spec, aspect='auto', origin='lower', cmap='viridis')
    ax1.set_title(filename)
    ax1.set_ylabel("mel bin")
    ax1.set_xlabel("frame")

    # Framewise prediction heatmap (x-axis: 0 to 5 sec)
    n_frames = framewise_preds.shape[1]
    extent = [0, 5, 0, framewise_preds.shape[0]]  # [xmin, xmax, ymin, ymax]
    im = ax2.imshow(framewise_preds, aspect='auto', origin='lower', cmap='Blues', extent=extent)
    ax2.set_title("When the sound event is detected")
    ax2.set_ylabel("species")
    ax2.set_xlabel("frame / sec")

    # Add colorbar WITHOUT shrinking the plot
    divider2 = make_axes_locatable(ax2)
    cax2 = divider2.append_axes("right", size="2.5%", pad=0.05)
    plt.colorbar(im, cax=cax2)
    #fig.colorbar(im, ax=ax2)

    # Clip-level predictions
    im3 = ax3.imshow(output, aspect='auto', origin='lower', cmap='Greens', vmin=0.0, vmax=1)
    ax3.set_title(f"Confidence in the detection per species")
    ax3.set_ylabel("species")
    ax3.set_xlabel("prediction")
    #fig.colorbar(im3, ax=ax3)

    # Add colorbar WITHOUT shrinking the plot
    divider3 = make_axes_locatable(ax3)
    cax3 = divider3.append_axes("right", size="2.5%", pad=0.05)
    plt.colorbar(im3, cax=cax3)

    # Optional: print top_k predictions
    #mean_scores = framewise_preds.mean(axis=1) # (206, )
    #top_classes = framewise_preds.argsort()[::-1]#[:top_k] # (206, 10)
    print("\nTop Predictions:")
    top_classes = output.flatten().argsort()[::-1][:top_k]
    for i, idx in enumerate(top_classes):
        label = index_to_label.get(idx, f"Class {idx}") if index_to_label else f"Class {idx}"
        species = label_to_species.get(label, "") if label_to_species else ""
        print(f"{i+1}. {label:<7} | Class index: {idx:<3} | Score: {output.flatten()[idx]:.3f} | {species}")

    plt.tight_layout()
    plt.show()

def plot_species_predictions(model, samples, n=10):
    for input_tensor, filename in samples[:n]:
        input_tensor = input_tensor.unsqueeze(0).to(cfg.device)
        
        spec_path = os.path.join(cfg.spectrogram_dir, "precomputed_spectrograms", filename)
        spec_np = np.load(spec_path)  # [mel_bins, time]
    
        # Run model + get framewise predictions
        framewise = model.get_framewise_output(input_tensor).squeeze(0).cpu().detach().numpy()  # [num_classes, T]
        
        outputs = model(input_tensor.float())
        outputs = torch.sigmoid(outputs).cpu().detach().numpy().transpose(1, 0) # 206, 1
        plot_prediction_overview(
            spec=spec_np,
            framewise_preds=framewise,
            output=outputs,
            filename=filename,
            index_to_label=index_to_label,
            label_to_species=label_to_class,
            top_k=2,
        )


## Analysis and performance metrics

In [None]:
# === Load model ===
model_time = EfficientNetTimeSED(cfg).to(cfg.device)
print(f"[INFO] Loading weights from {cfg.timewise_weights_path}")
raw_state = torch.load(cfg.timewise_weights_path, map_location=cfg.device)
remapped_state = remap_attention_keys(raw_state)
model_time.load_state_dict(remapped_state, strict=False)
model_time.eval()

model_freq = EfficientNetFrequencySED(cfg).to(cfg.device)
print(f"[INFO] Loading weights from {cfg.freqwise_weights_path}")
raw_state = torch.load(cfg.freqwise_weights_path, map_location=cfg.device)
model_freq.load_state_dict(raw_state, strict=False)
model_freq.eval()

# === Load mappings ===
label_to_class, label_to_index, index_to_label, _ = get_mappings()
num_classes = len(label_to_index)

# === Prepare dataset and loader ===
metadata = pd.read_csv(cfg.spectrograms_metadata_path)
dataset = PrecomputedSpectrogramDataset(metadata, cfg.spectrogram_dir)
loader = create_dataloader(dataset, cfg, shuffle=False, collate_fn=lambda x: collate_fn(x, mixup=False))


In [None]:
time_evaluator = ModelEvaluator(model_time, loader, cfg, label_to_class, index_to_label)
y_true_t, y_pred_t = time_evaluator.run_inference() # (56168, 206), (56168, 206)
df_results_t = time_evaluator.calculate_per_class_auc(y_true_t, y_pred_t)

# === Display table ===
print(df_results_t.sort_values(by="auc", ascending=True))

In [None]:
evaluator = ModelEvaluator(model_freq, loader, cfg, label_to_class, index_to_label)
y_true, y_pred = evaluator.run_inference() # (56168, 206), (56168, 206)
df_results = evaluator.calculate_per_class_auc(y_true, y_pred)

# === Display table ===
print(df_results.sort_values(by="auc", ascending=True))

In [None]:
evaluator.plot_results(df_results)

Let's now analyze the predictions of our model.

In [None]:
analyzer = ModelAnalyzer(df_results, y_true, y_pred)
analyzer.compute_best_thresholds()

# Optional: recompute AUC without filtering
df_with_auc = analyzer.compute_auc()

analyzer.compute_accuracy()

# Summary of best and worst classes
analyzer.summarize_best_worst()

# Add precision-recall
df_with_pr = analyzer.compute_precision_recall()

# Add confusion stats
df_with_confusion = analyzer.compute_confusion_stats()

# Display everything
analyzer.display_full_dataframe()

In [None]:
analyzer.plot_stats(df_with_confusion, n=20)

As can be seen, despite the AUC being very high, there are several false positives and negatives, especially from the bird's category. Let's now analyze other metrics, like precision and accuracy.

In [None]:
analyzer.plot_performance_metrics()

Despite the AUC being high, several species have an extremely low precision and accuracy. Let's analyze the lowest one, `turvul`.

In [None]:
# Get details for a specific species
info = analyzer.get_species_info("turvul")

This bird has precision and accuracy extremely close to 0. Let's analyze way.

In [None]:
def print_stats(idx):

    true_class = y_true[:, idx]
    pred_class = y_pred[:, idx]
    print("Class id", index_to_label[idx])
    print("Max pred for class 1 samples:", np.max(pred_class[true_class == 1]))
    print("Min pred for class 1 samples:", np.min(pred_class[true_class == 1]))
    print("Mean pred for class 1 samples:", np.mean(pred_class[true_class == 1]))
    print("Max pred for non-class 1 samples:", np.max(pred_class[true_class == 0]))
    print("Mean pred for non-class 1 samples:", np.mean(pred_class[true_class == 0]))
    roc = roc_auc_score(true_class, pred_class)
    print("ROC:", roc)

print_stats(idx=182)

The model is extremely confused about this bird. Even the maximum prediction is quite low, with 0.15, while the highest non belonging to this class reach 0.89. The lowest prediction is as low as 0.0001, so it seems like it has no idea what to look for with this bird. Let's see its distribution and listen to some audios.

In [None]:
train_df = pd.read_csv(cfg.metadata_path)
sample = filter_data(train_df, primary_label="turvul") # 41778
sample.head(15)

In [None]:
play_audio(cfg.train_data_path, sample.iloc[1])
play_audio(cfg.train_data_path, sample.iloc[3])
play_audio(cfg.train_data_path, sample.iloc[4])
play_audio(cfg.train_data_path, sample.iloc[5])

As can be heard, they are extremely different audios. It seems to be quite impossible to recognize the bird, even for a human. Different sounds and a lot of birds in the background that are not flagged as secondary labels. The geographical locatin is different for almost every sample and the author is different too. That's why including metadata would have been probably a good idea.

In [None]:
turvul_samples_rufus = select_species_samples(
    loader,
    species = "turvul",
    mode = "specific",
    num_samples = 10,
    target_filename = 'turvul/XC748979.ogg'
)
turvul_samples_hiss = select_species_samples(
    loader,
    species = "turvul",
    mode = "specific",
    num_samples = 10,
    target_filename = 'turvul/XC904279.ogg'
)

In [None]:
# Visualize
visualize_multiple_gradcams_by_layer(turvul_samples_rufus[0], model_freq, GradCAM)

In the spectrogram we can distinguish two different sounds, the one at low frequency and the constant one at higher frequency. It has not figured out yet on which of the two it should focus. Let's inspect now the `hiss` call.

In [None]:
# Visualize
visualize_multiple_gradcams_by_layer(turvul_samples_hiss[0], model_freq, GradCAM)

Here the hiss is quite clear in the spectrogram, however there is only one such sample and it doesn't seem to be able to recognize it properly. The fifth gradcam focus in that area but the last one randomly focus in higher frequencies and last time frames. 

In [None]:
# Set up GradCAM and Attention Visualizer
grad_cam, attention_visualizer = init(model_freq, conv_layer=5)

visualize_gradcam_and_frequency_wise_attention(turvul_samples_hiss, grad_cam, attention_visualizer, n=5)

It seems that in the second to last convolutional layer the network managed to focus on the hiss, but the same understanding is not retained by the frequency attention model.

In [None]:
# Set up GradCAM and Attention Visualizer
grad_cam, attention_visualizer = init(model_time, conv_layer=5)


visualize_gradcam_and_time_wise_attention(turvul_samples_hiss, grad_cam, attention_visualizer, n=5)

The temporal attention here spans basically the whole 5 seconds, so it's not as useful.

In [None]:
plot_species_predictions(model_time, turvul_samples_hiss, n=5)

The bluest lines corresponds to when the sound event is detected, which more or less correspond to when the hiss is actually present in the spectrogram.

Let's now analyze the best predicted bird:

In [None]:
# Get details for a specific species
info = analyzer.get_species_info("compau")

In [None]:
print_stats(idx=97)

The model is much more confident about the preditions on these bird, as the mean prediction of 0.8 shows, but some confusion still persists. There are wrong predictions with .9 of confidence and correct predictions as low as 0.00685, which is not ideal. Let's see why.

In [None]:
filtered_compau = filter_data(train_df, primary_label="compau") # 41778
filtered_compau.head(15)

In [None]:
play_audio(cfg.train_data_path, filtered_compau.iloc[1])
play_audio(cfg.train_data_path, filtered_compau.iloc[3])
play_audio(cfg.train_data_path, filtered_compau.iloc[10])
play_audio(cfg.train_data_path, filtered_compau.iloc[14])

In [None]:
compau_samples_call = select_species_samples(
    loader,
    species = "compau",
    mode = "specific",
    num_samples = 10,
    target_filename = 'compau/XC112631.ogg'
)
compau_samples_flight_call = select_species_samples(
    loader,
    species = "compau",
    mode = "specific",
    num_samples = 10,
    target_filename = 'compau/XC123378.ogg'
)
compau_samples_song = select_species_samples(
    loader,
    species = "compau",
    mode = "specific",
    num_samples = 10,
    target_filename = 'compau/XC121382.ogg'
)

In [None]:
# Visualize
visualize_multiple_gradcams_by_layer(compau_samples_call[0], model_freq, GradCAM)

In [None]:
# Visualize
visualize_multiple_gradcams_by_layer(compau_samples_flight_call[0], model_freq, GradCAM)

In [None]:
# Visualize
visualize_multiple_gradcams_by_layer(compau_samples_song[0], model_freq, GradCAM)

In [None]:
# Set up GradCAM and Attention Visualizer
grad_cam, attention_visualizer = init(model_freq, conv_layer=6)

visualize_gradcam_and_frequency_wise_attention(compau_samples_call, grad_cam, attention_visualizer, n=5)

The frequency attention model doesn't seem to have a clear idea of what to look for. The frequency bands where the call lies holds a high value in the attenton heatmap, but not as high as the highest frequency apparently. The model is overfitting the training dataset.

In [None]:
# Set up GradCAM and Attention Visualizer
grad_cam, attention_visualizer = init(model_time, conv_layer=5)

visualize_gradcam_and_time_wise_attention(compau_samples_call, grad_cam, attention_visualizer, n=5)

The temporal model seems to have a better idea of where the sound even happen, even if the confusion persists. The two sounds location in time is mainly recognized by the model.

In [None]:
plot_species_predictions(model_time, compau_samples_call, n=5)

The confidence in the prediction for `compau` is quite high. In the second graph, it can be clearly seen that the species with index 97 has the better predicted sound event locations, which coincides with the `compau` predictions, of course.

## Chunks Tracker

In [None]:
class ChunkStatsTracker:
    """
    Tracks per-species and per-collection chunk statistics from precomputed files.
    """

    def __init__(self, spectrogram_dir):
        self.spectrogram_dir = spectrogram_dir
        self.class_chunk_counts = defaultdict(int)
        self.collection_counts = defaultdict(int)
        self.chunk_metadata = []

    def scan_precomputed_chunks(self):
        """
        Walks through precomputed_spectograms/ and collects species and collection stats.
        """
        print(f"[INFO] Scanning directory: {self.spectrogram_dir}")

        for species_id in os.listdir(self.spectrogram_dir):
            species_dir = os.path.join(self.spectrogram_dir, species_id)
            if not os.path.isdir(species_dir):
                continue

            for file in os.listdir(species_dir):
                if file.endswith(".npy"):
                    collection = file.split("_")[0]  # assumes e.g., CSA_chunk_0.npy → CSA
                    self.class_chunk_counts[species_id] += 1
                    self.collection_counts[collection] += 1
                    self.chunk_metadata.append({
                        "species": species_id,
                        "collection": collection,
                        "filename": file
                    })

        print(f"[INFO] Completed scan: {len(self.chunk_metadata)} chunks found.")

    def get_summary_dataframe(self):
        return pd.DataFrame(self.chunk_metadata)

    def plot_species_distribution(self):
        df = self.get_summary_dataframe()
        count_df = df.groupby("species").size().reset_index(name="chunk_count")
        count_df = count_df.sort_values("chunk_count", ascending=False)

        plt.figure(figsize=(16, 6))
        sns.barplot(x="species", y="chunk_count", data=count_df)
        plt.xticks(rotation=90)
        plt.title("Number of Chunks per Species")
        plt.xlabel("Species ID")
        plt.ylabel("Chunk Count")
        plt.tight_layout()
        plt.show()

    def plot_collection_distribution(self):
        df = self.get_summary_dataframe()
        count_df = df.groupby("collection").size().reset_index(name="chunk_count")
        plt.figure(figsize=(6, 4))
        sns.barplot(x="collection", y="chunk_count", data=count_df)
        plt.title("Number of Chunks per Collection")
        plt.xlabel("Collection")
        plt.ylabel("Chunk Count")
        plt.tight_layout()
        plt.show()

    def save_to_csv(self, path):
        df = self.get_summary_dataframe()
        os.makedirs(os.path.dirname(path), exist_ok=True)
        df.to_csv(path, index=False)
        print(f"[INFO] Chunk summary saved to {path}")


In [None]:
spect_dir = os.path.join("/kaggle/input/eda-birdclef2025", "precomputed_spectograms")
tracker = ChunkStatsTracker(spect_dir)
tracker.scan_precomputed_chunks()
tracker.plot_species_distribution()
tracker.plot_collection_distribution()
tracker.save_to_csv(os.path.join(cfg.OUTPUT_DIR, "chunk_summary_0.csv"))


In [None]:
spect_dir = os.path.join("/kaggle/input/precomputing-spectrograms", "precomputed_spectrograms")
# spect_dir = os.path.join("/kaggle/input/eda-birdclef2025", "precomputed_spectograms")
tracker = ChunkStatsTracker(spect_dir)
tracker.scan_precomputed_chunks()
tracker.plot_species_distribution()
tracker.plot_collection_distribution()
tracker.save_to_csv(os.path.join(cfg.OUTPUT_DIR, "chunk_summary_1.csv"))


In [None]:
spect_dir = os.path.join("/kaggle/input/precomputing-spectrograms2", "precomputed_spectrograms")
# spect_dir = os.path.join("/kaggle/input/eda-birdclef2025", "precomputed_spectograms")
tracker = ChunkStatsTracker(spect_dir)
tracker.scan_precomputed_chunks()
tracker.plot_species_distribution()
tracker.plot_collection_distribution()
tracker.save_to_csv(os.path.join(cfg.OUTPUT_DIR, "chunk_summary_2.csv"))


In [None]:
spect_dir = os.path.join("/kaggle/input/precomputing-spectrograms3", "precomputed_spectrograms")
# spect_dir = os.path.join("/kaggle/input/eda-birdclef2025", "precomputed_spectograms")
tracker = ChunkStatsTracker(spect_dir)
tracker.scan_precomputed_chunks()
tracker.plot_species_distribution()
tracker.plot_collection_distribution()
tracker.save_to_csv(os.path.join(cfg.OUTPUT_DIR, "chunk_summary_3.csv"))
