# **BirdCLEF 2025 Inference Notebook**
This notebook runs inference on BirdCLEF 2025 test soundscapes and generates a submission file. It supports both single model inference and ensemble inference with multiple models. You can find the pre-processing and training processes in the following notebooks:

- [Transforming Audio-to-Mel Spec. | BirdCLEF'25](https://www.kaggle.com/code/kadircandrisolu/transforming-audio-to-mel-spec-birdclef-25)  
- [EfficientNet B0 Pytorch [Train] | BirdCLEF'25](https://www.kaggle.com/code/kadircandrisolu/efficientnet-b0-pytorch-train-birdclef-25)

**Features**
- Audio Preprocessing
- Test-Time Augmentation (TTA)

In [65]:
import os
import gc
import warnings
import logging
import time
import math
import cv2
from pathlib import Path

import numpy as np
import pandas as pd
import librosa
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from tqdm.auto import tqdm

warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.ERROR)

In [None]:
def set_seed(seed=42):
    """Set seeds for reproducibility."""
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)

In [None]:
class CFG:
    seed = 42
 
    test_soundscapes = '/kaggle/input/birdclef-2025/test_soundscapes'
    submission_csv = '/kaggle/input/birdclef-2025/sample_submission.csv'
    taxonomy_csv = '/kaggle/input/birdclef-2025/taxonomy.csv'
    model_path = '/kaggle/input/birdclef25-efficientnet-pseudlabled/pytorch/primary-only-thr-0.9/3'  
    
    # Audio parameters
    FS = 32000  
    WINDOW_SIZE = 5  
    TARGET_DURATION = 5.0
    
    # Mel spectrogram parameters
    N_FFT = 1024
    HOP_LENGTH = 512
    N_MELS = 128
    FMIN = 50
    FMAX = 14000
    TARGET_SHAPE = (256, 256)
    
    model_name = 'efficientnet_b0'
    in_channels = 1
    device = 'cpu'
    pretrained = False
    
    # Inference parameters
    batch_size = 16
    use_tta = False  
    tta_count = 3
    threshold = 0.5
    
    # Features for compatibility with training notebook
    dropout_rate = 0.2
    drop_path_rate = 0.2
    
    # Add parameters to ensure compatibility with training model
    projection_dim = 0  # Set to 0 to match training model default
    
    use_specific_folds = False  # If False, use all found models
    folds = [0, 1]  # Used only if use_specific_folds is True
    
    debug = False
    debug_count = 3
    
    # Debug option for state dict loading
    debug_state_dict = False  # Set to True to print missing keys

cfg = CFG()
set_seed(cfg.seed)

In [67]:
print(f"Using device: {cfg.device}")
print(f"Loading taxonomy data...")
taxonomy_df = pd.read_csv(cfg.taxonomy_csv)
species_ids = taxonomy_df['primary_label'].tolist()
num_classes = len(species_ids)
print(f"Number of classes: {num_classes}")

Using device: cpu
Loading taxonomy data...
Number of classes: 206


In [None]:
class BirdCLEFModel(nn.Module):
    def __init__(self, cfg, num_classes):
        super().__init__()
        self.cfg = cfg
        
        # Support for different model architectures
        self.backbone = timm.create_model(
            cfg.model_name,
            pretrained=cfg.pretrained,
            in_chans=cfg.in_channels,
            drop_rate=cfg.dropout_rate,
            drop_path_rate=cfg.drop_path_rate if hasattr(cfg, 'drop_path_rate') else 0.2
        )
        
        # Extract feature dimension based on model type
        if 'efficientnet' in cfg.model_name:
            backbone_out = self.backbone.classifier.in_features
            self.backbone.classifier = nn.Identity()
        elif 'convnext' in cfg.model_name:
            backbone_out = self.backbone.head.fc.in_features
            self.backbone.head.fc = nn.Identity()
        elif 'resnet' in cfg.model_name:
            backbone_out = self.backbone.fc.in_features
            self.backbone.fc = nn.Identity()
        else:
            backbone_out = self.backbone.get_classifier().in_features
            self.backbone.reset_classifier(0, '')
        
        self.pooling = nn.AdaptiveAvgPool2d(1)
        self.feat_dim = backbone_out
        
        # Add an additional projection layer for better feature representation
        if hasattr(cfg, 'projection_dim') and cfg.projection_dim > 0:
            self.projection = nn.Sequential(
                nn.Linear(backbone_out, cfg.projection_dim),
                nn.BatchNorm1d(cfg.projection_dim),
                nn.ReLU(inplace=True),
                nn.Dropout(0.3),
                nn.Linear(cfg.projection_dim, num_classes)
            )
            self.classifier = self.projection
        else:
            self.classifier = nn.Linear(backbone_out, num_classes)
        
        # Mixup and CutMix support
        self.mixup_enabled = False
        self.cutmix_enabled = False
        
    def forward(self, x, targets=None):
        features = self.backbone(x)
        
        if isinstance(features, dict):
            features = features['features']
            
        if len(features.shape) == 4:
            features = self.pooling(features)
            features = features.view(features.size(0), -1)
        
        logits = self.classifier(features)
        return logits
        
    def mixup_data(self, x, targets):
        """Applies mixup to the data batch"""
        batch_size = x.size(0)
        lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
        indices = torch.randperm(batch_size).to(x.device, non_blocking=True)
        mixed_x = lam * x + (1 - lam) * x[indices]
        
        return mixed_x, targets, targets[indices], lam
        
    def cutmix_data(self, x, targets):
        """Applies cutmix to the data batch"""
        batch_size = x.size(0)
        lam = np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
        
        # Get random indices for mixing
        indices = torch.randperm(batch_size).to(x.device)
        
        # Get random box coordinates
        W, H = x.size(2), x.size(3)
        cut_ratio = np.sqrt(1. - lam)
        cut_w = np.int_(W * cut_ratio)
        cut_h = np.int_(H * cut_ratio)
        
        cx = np.random.randint(W)
        cy = np.random.randint(H)
        
        bbx1 = np.clip(cx - cut_w // 2, 0, W)
        bby1 = np.clip(cy - cut_h // 2, 0, H)
        bbx2 = np.clip(cx + cut_w // 2, 0, W)
        bby2 = np.clip(cy + cut_h // 2, 0, H)
        
        # Apply cutmix
        x_mixed = x.clone()
        x_mixed[:, :, bbx1:bbx2, bby1:bby2] = x[indices, :, bbx1:bbx2, bby1:bby2]
        
        # Adjust lambda to actual area ratio
        lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (W * H))
        
        return x_mixed, targets, targets[indices], lam

In [70]:
def audio2melspec(audio_data, cfg):
    """Convert audio data to mel spectrogram"""
    if np.isnan(audio_data).any():
        mean_signal = np.nanmean(audio_data)
        audio_data = np.nan_to_num(audio_data, nan=mean_signal)

    mel_spec = librosa.feature.melspectrogram(
        y=audio_data,
        sr=cfg.FS,
        n_fft=cfg.N_FFT,
        hop_length=cfg.HOP_LENGTH,
        n_mels=cfg.N_MELS,
        fmin=cfg.FMIN,
        fmax=cfg.FMAX,
        power=2.0
    )

    mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
    mel_spec_norm = (mel_spec_db - mel_spec_db.min()) / (mel_spec_db.max() - mel_spec_db.min() + 1e-8)
    
    return mel_spec_norm

def process_audio_segment(audio_data, cfg):
    """Process audio segment to get mel spectrogram"""
    if len(audio_data) < cfg.FS * cfg.WINDOW_SIZE:
        audio_data = np.pad(audio_data, 
                          (0, cfg.FS * cfg.WINDOW_SIZE - len(audio_data)), 
                          mode='constant')
    
    mel_spec = audio2melspec(audio_data, cfg)
    
    # Resize if needed
    if mel_spec.shape != cfg.TARGET_SHAPE:
        mel_spec = cv2.resize(mel_spec, cfg.TARGET_SHAPE, interpolation=cv2.INTER_LINEAR)
        
    return mel_spec.astype(np.float32)

In [None]:
def find_model_files(cfg):
    """
    Find all .pth model files in the specified model directory
    """
    model_files = []
    
    model_dir = Path(cfg.model_path)
    
    for path in model_dir.glob('**/*.pth'):
        model_files.append(str(path))
    
    return model_files

def load_models(cfg, num_classes):
    """
    Load all found model files and prepare them for ensemble
    """
    models = []
    
    model_files = find_model_files(cfg)
    
    if not model_files:
        print(f"Warning: No model files found under {cfg.model_path}!")
        return models
    
    print(f"Found a total of {len(model_files)} model files.")
    
    if cfg.use_specific_folds:
        filtered_files = []
        for fold in cfg.folds:
            fold_files = [f for f in model_files if f"fold{fold}" in f]
            filtered_files.extend(fold_files)
        model_files = filtered_files
        print(f"Using {len(model_files)} model files for the specified folds ({cfg.folds}).")
    
    for model_path in model_files:
        try:
            print(f"Loading model: {model_path}")
            checkpoint = torch.load(model_path, map_location=cfg.device)
            
            model = BirdCLEFModel(cfg, num_classes)
            
            # Handle different state dict configurations
            state_dict = None
            if 'model_state_dict' in checkpoint:
                state_dict = checkpoint['model_state_dict']
            elif 'state_dict' in checkpoint:
                state_dict = checkpoint['state_dict']
            else:
                # Assume the checkpoint itself is the state dict
                state_dict = checkpoint
            
            # Handle DataParallel wrapped state dict (remove 'module.' prefix)
            if any(k.startswith('module.') for k in state_dict.keys()):
                state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
            
            # Check for missing keys before loading
            model_dict = model.state_dict()
            missing_keys = [k for k in model_dict.keys() if k not in state_dict]
            unexpected_keys = [k for k in state_dict.keys() if k not in model_dict]
            
            if missing_keys and cfg.debug_state_dict:
                print(f"Missing keys: {missing_keys}")
            if unexpected_keys and cfg.debug_state_dict:
                print(f"Unexpected keys: {unexpected_keys}")
            
            # Try to partially load if there are missing or unexpected keys
            if missing_keys or unexpected_keys:
                print(f"Warning: {len(missing_keys)} missing keys, {len(unexpected_keys)} unexpected keys")
                # Filter state_dict to only include keys that are in the model
                filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_dict}
                model_dict.update(filtered_state_dict)
                model.load_state_dict(model_dict, strict=False)
                print("Loaded model with partial state dict")
            else:
                model.load_state_dict(state_dict)
                
            model = model.to(cfg.device)
            model.eval()
            
            models.append(model)
        except Exception as e:
            print(f"Error loading model {model_path}: {e}")
            import traceback
            traceback.print_exc()
    
    return models

def predict_on_spectrogram(audio_path, models, cfg, species_ids):
    """Process a single audio file and predict species presence for each 5-second segment"""
    predictions = []
    row_ids = []
    soundscape_id = Path(audio_path).stem
    
    try:
        print(f"Processing {soundscape_id}")
        audio_data, _ = librosa.load(audio_path, sr=cfg.FS)
        
        total_segments = int(len(audio_data) / (cfg.FS * cfg.WINDOW_SIZE))
        
        for segment_idx in range(total_segments):
            start_sample = segment_idx * cfg.FS * cfg.WINDOW_SIZE
            end_sample = start_sample + cfg.FS * cfg.WINDOW_SIZE
            segment_audio = audio_data[start_sample:end_sample]
            
            end_time_sec = (segment_idx + 1) * cfg.WINDOW_SIZE
            row_id = f"{soundscape_id}_{end_time_sec}"
            row_ids.append(row_id)

            if cfg.use_tta:
                all_preds = []
                
                for tta_idx in range(cfg.tta_count):
                    mel_spec = process_audio_segment(segment_audio, cfg)
                    mel_spec = apply_tta(mel_spec, tta_idx)

                    mel_spec = torch.tensor(mel_spec, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
                    mel_spec = mel_spec.to(cfg.device)

                    if len(models) == 1:
                        with torch.no_grad():
                            outputs = models[0](mel_spec)
                            probs = torch.sigmoid(outputs).cpu().numpy().squeeze()
                            all_preds.append(probs)
                    else:
                        segment_preds = []
                        for model in models:
                            with torch.no_grad():
                                outputs = model(mel_spec)
                                probs = torch.sigmoid(outputs).cpu().numpy().squeeze()
                                segment_preds.append(probs)
                        
                        avg_preds = np.mean(segment_preds, axis=0)
                        all_preds.append(avg_preds)

                final_preds = np.mean(all_preds, axis=0)
            else:
                mel_spec = process_audio_segment(segment_audio, cfg)
                
                mel_spec = torch.tensor(mel_spec, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
                mel_spec = mel_spec.to(cfg.device)
                
                if len(models) == 1:
                    with torch.no_grad():
                        outputs = models[0](mel_spec)
                        final_preds = torch.sigmoid(outputs).cpu().numpy().squeeze()
                else:
                    segment_preds = []
                    for model in models:
                        with torch.no_grad():
                            outputs = model(mel_spec)
                            probs = torch.sigmoid(outputs).cpu().numpy().squeeze()
                            segment_preds.append(probs)

                    final_preds = np.mean(segment_preds, axis=0)
                    
            predictions.append(final_preds)
            
    except Exception as e:
        print(f"Error processing {audio_path}: {e}")
    
    return row_ids, predictions

In [None]:
def apply_tta(spec, tta_idx):
    """Apply test-time augmentation"""
    if tta_idx == 0:
        # Original spectrogram
        return spec
    elif tta_idx == 1:
        # Time shift (horizontal flip)
        return np.flip(spec, axis=1)
    elif tta_idx == 2:
        # Frequency shift (vertical flip)
        return np.flip(spec, axis=0)
    else:
        return spec

def run_inference(cfg, models, species_ids):
    """Run inference on all test soundscapes"""
    test_files = list(Path(cfg.test_soundscapes).glob('*.ogg'))
    
    if cfg.debug:
        print(f"Debug mode enabled, using only {cfg.debug_count} files")
        test_files = test_files[:cfg.debug_count]
    
    print(f"Found {len(test_files)} test soundscapes")

    all_row_ids = []
    all_predictions = []

    for audio_path in tqdm(test_files):
        row_ids, predictions = predict_on_spectrogram(str(audio_path), models, cfg, species_ids)
        all_row_ids.extend(row_ids)
        all_predictions.extend(predictions)
    
    return all_row_ids, all_predictions

def create_submission(row_ids, predictions, species_ids, cfg):
    """Create submission dataframe"""
    print("Creating submission dataframe...")

    submission_dict = {'row_id': row_ids}
    
    for i, species in enumerate(species_ids):
        submission_dict[species] = [pred[i] for pred in predictions]

    submission_df = pd.DataFrame(submission_dict)

    submission_df.set_index('row_id', inplace=True)

    sample_sub = pd.read_csv(cfg.submission_csv, index_col='row_id')

    missing_cols = set(sample_sub.columns) - set(submission_df.columns)
    if missing_cols:
        print(f"Warning: Missing {len(missing_cols)} species columns in submission")
        for col in missing_cols:
            submission_df[col] = 0.0

    submission_df = submission_df[sample_sub.columns]

    submission_df = submission_df.reset_index()
    
    return submission_df

In [None]:
def main():
    start_time = time.time()
    print("Starting BirdCLEF-2025 inference...")
    print(f"TTA enabled: {cfg.use_tta} (variations: {cfg.tta_count if cfg.use_tta else 0})")
    print(f"Using model architecture: {cfg.model_name}")

    models = load_models(cfg, num_classes)
    
    if not models:
        print("No models found! Please check model paths.")
        return
    
    print(f"Model usage: {'Single model' if len(models) == 1 else f'Ensemble of {len(models)} models'}")

    row_ids, predictions = run_inference(cfg, models, species_ids)

    submission_df = create_submission(row_ids, predictions, species_ids, cfg)

    submission_path = 'submission.csv'
    submission_df.to_csv(submission_path, index=False)
    print(f"Submission saved to {submission_path}")
    
    end_time = time.time()
    print(f"Inference completed in {(end_time - start_time)/60:.2f} minutes")

In [None]:
def debug_model_compatibility():
    """
    Function to debug model compatibility issues
    """
    print("\n=== Debugging Model Compatibility ===")
    
    print(f"Creating a test model with the current configuration...")
    test_model = BirdCLEFModel(cfg, num_classes)
    test_state_dict = test_model.state_dict()
    
    print(f"Model has {len(test_state_dict)} parameters")
    
    # Display some parameter shapes to help debug
    print("\nSample parameter shapes from current model:")
    for i, (name, param) in enumerate(test_state_dict.items()):
        if i < 5 or i > len(test_state_dict) - 5:  # First 5 and last 5 params
            print(f"{name}: {param.shape}")
        if i == 5 and len(test_state_dict) > 10:
            print("...")
            
    # Try to load sample model
    model_files = find_model_files(cfg)
    if model_files:
        sample_file = model_files[0]
        print(f"\nAttempting to load a sample model: {sample_file}")
        
        try:
            checkpoint = torch.load(sample_file, map_location=cfg.device)
            
            # Check what's in the checkpoint
            if isinstance(checkpoint, dict):
                print(f"Checkpoint keys: {list(checkpoint.keys())}")
                
                # Look at state dict keys
                if 'model_state_dict' in checkpoint:
                    state_dict = checkpoint['model_state_dict']
                elif 'state_dict' in checkpoint:
                    state_dict = checkpoint['state_dict']
                else:
                    state_dict = None
                    print("No state dict found in checkpoint")
                
                if state_dict:
                    print(f"State dict has {len(state_dict)} parameters")
                    
                    # Count mismatches
                    missing = [k for k in test_state_dict.keys() if k not in state_dict]
                    unexpected = [k for k in state_dict.keys() if k not in test_state_dict]
                    
                    print(f"Missing keys: {len(missing)}")
                    print(f"Unexpected keys: {len(unexpected)}")
                    
                    if missing:
                        print("\nSample missing keys:")
                        for k in missing[:5]:
                            print(f"  {k}")
                    
                    if unexpected:
                        print("\nSample unexpected keys:")
                        for k in unexpected[:5]:
                            print(f"  {k}")
                        
                    # Try to map keys
                    if unexpected and missing:
                        print("\nAttempting to map keys...")
                        successful_maps = 0
                        for unexp_key in unexpected[:10]:
                            for miss_key in missing:
                                # Simple heuristic: check if the key ends with the same parameter name
                                if unexp_key.split('.')[-1] == miss_key.split('.')[-1]:
                                    print(f"Possible mapping: {unexp_key} -> {miss_key}")
                                    successful_maps += 1
                                    break
                        
                        if successful_maps == 0:
                            print("No obvious mappings found. Model architectures may be different.")
            else:
                print("Checkpoint is not a dictionary - it may be a direct state dict")
                
        except Exception as e:
            print(f"Error loading sample model: {e}")
            import traceback
            traceback.print_exc()
    
    print("\n=== End of Compatibility Debugging ===")

In [None]:
if __name__ == "__main__":
    # Debug model compatibility if needed
    if cfg.debug_state_dict:
        debug_model_compatibility()
    
    main()

Starting BirdCLEF-2025 inference...
TTA enabled: False (variations: 0)
Found a total of 1 model files.
Loading model: /kaggle/input/birdclef25-efficientnet-pseudlabled/pytorch/primary-only-thr-0.9/3/model_20250501_131702_efficientnet_b0_fold0.pth
Error loading model /kaggle/input/birdclef25-efficientnet-pseudlabled/pytorch/primary-only-thr-0.9/3/model_20250501_131702_efficientnet_b0_fold0.pth: BirdCLEFModel.__init__() takes 2 positional arguments but 3 were given
No models found! Please check model paths.
