# **BirdCLEF 2025 Training Notebook**

This is a baseline training pipeline for BirdCLEF 2025 using EfficientNetB0 with PyTorch and Timm(for pretrained EffNet). You can check inference and preprocessing notebooks in the following links: 

- [EfficientNet B0 Pytorch [Inference] | BirdCLEF'25](https://www.kaggle.com/code/kadircandrisolu/efficientnet-b0-pytorch-inference-birdclef-25)

  
- [Transforming Audio-to-Mel Spec. | BirdCLEF'25](https://www.kaggle.com/code/kadircandrisolu/transforming-audio-to-mel-spec-birdclef-25)  

Note that by default this notebook is in Debug Mode, so it will only train the model with 2 epochs, but the [weight](https://www.kaggle.com/datasets/kadircandrisolu/birdclef25-effnetb0-starter-weight) I used in the inference notebook was obtained after 10 epochs of training.

**Features**
* Implement with Pytorch and Timm
* Flexible audio processing with both pre-computed and on-the-fly mel spectrograms
* Stratified 5-fold cross-validation with ensemble capability
* Mixup training for improved generalization
* Spectrogram augmentations (time/frequency masking, brightness adjustment)
* AdamW optimizer with Cosine Annealing LR scheduling
* Debug mode for quick experimentation with smaller datasets

**Pre-computed Spectrograms**
For faster training, you can use pre-computed mel spectrograms from [this dataset](https://www.kaggle.com/datasets/kadircandrisolu/birdclef25-mel-spectrograms) by setting `LOAD_DATA = True`

## Libraries

In [1]:
# Basic imports
import numpy as np, pandas as pd, math, os, random, warnings, json, datetime
from tqdm.auto import tqdm


# Specific imports
import logging, gc, time, cv2

# Audio processing imports
import librosa

# PyTorch imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torch.amp import autocast, GradScaler

# Other ML imports
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score
import timm

# Custom imports
from processing import audio2melspec, process_audio_file, generate_spectrograms
from utilities import set_seed, collate_fn
from training_utilities import get_optimizer, get_scheduler, get_criterion

# Suppress warnings and set logging level
warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.ERROR)

## Configuration

In [2]:
class CFG:
    
    seed = 42
    debug = False  
    num_workers = 4
    
    OUTPUT_DIR = 'output/'
    train_datadir = 'birdclef-2025/train_audio'
    train_csv = 'birdclef-2025/train.csv'
    train_soundscapes = 'birdclef-2025/train_soundscapes'
    test_soundscapes = 'birdclef-2025/test_soundscapes'
    submission_csv = 'birdclef-2025/sample_submission.csv'
    taxonomy_csv = 'birdclef-2025/taxonomy.csv'
    unlabeled_sample_list = "birdclef-2025/sample_list.csv"

    spectrogram_npy = 'archive/birdclef2025_melspec_5sec_256_256.npy'
    spectrogram_npy_unlabeled = 'archive/train_soundscapes_mel_spec_5_256_256.npy'
 
    model_name = 'efficientnet_b0'
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    pretrained = True
    in_channels = 1

    use_soundscapes = True
    pseudo_update_threshold = 0.95
    start_pseudo_epoch = 1
    remove_pseudo = True

    LOAD_DATA = True  
    FS = 32000
    TARGET_DURATION = 5.0
    TARGET_SHAPE = (256, 256)
    
    N_FFT = 1024
    HOP_LENGTH = 512
    N_MELS = 128
    FMIN = 50
    FMAX = 14000
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")

    epochs = 10
    batch_size = 32
    criterion = 'CombinedLoss'  # Options: 'BCEWithLogitsLoss', 'FocalLoss', 'CombinedLoss'
    
    # Focal Loss parameters
    focal_alpha = 1.0
    focal_gamma = 3.0
    
    # Combined Loss weights
    bce_weight = 0.5
    focal_weight = 0.5

    n_fold = 1

    optimizer = 'AdamW'
    lr = 5e-4 
    weight_decay = 1e-5
  
    scheduler = 'CosineAnnealingLR'
    min_lr = 1e-6
    T_max = epochs

    aug_prob = 0.5  
    mixup_alpha = 0.5  
    
    def update_debug_settings(self):
        if self.debug:
            self.epochs = 2
            self.start_pseudo_epoch = 1

    def save_config(self):
        config_dict = {attr: getattr(self, attr) for attr in dir(self) if not attr.startswith('__') and not callable(getattr(self, attr))}
        filename = f"config_{self.timestamp}_{self.model_name}.json"
        with open(os.path.join(self.OUTPUT_DIR, filename), 'w') as f:
            json.dump(config_dict, f, indent=4, default=str)
        print(f"Config saved to {os.path.join(self.OUTPUT_DIR, filename)}")

cfg = CFG()
set_seed(cfg.seed)
cfg.update_debug_settings()

Using device: cuda


## Pre-processing
These functions handle the transformation of audio files to mel spectrograms for model input, with flexibility controlled by the `LOAD_DATA` parameter. The process involves either loading pre-computed spectrograms from this [dataset](https://www.kaggle.com/datasets/kadircandrisolu/birdclef25-mel-spectrograms) (when `LOAD_DATA=True`) or dynamically generating them (when `LOAD_DATA=False`), transforming audio data into spectrogram representations, and preparing it for the neural network.

## Dataset Preparation and Data Augmentations
We'll convert audio to mel spectrograms and apply random augmentations with 50% probability each - including time stretching, pitch shifting, and volume adjustments. This randomized approach creates diverse training samples from the same audio files

In [3]:
class BirdCLEFDatasetFromNPY(Dataset):
    def __init__(self, df, cfg, spectrograms=None, mode="train"):
        self.df = df
        self.cfg = cfg
        self.mode = mode
        self.spectrograms = spectrograms
        
        taxonomy_df = pd.read_csv(self.cfg.taxonomy_csv)
        self.species_ids = taxonomy_df['primary_label'].tolist()
        self.num_classes = len(self.species_ids)
        self.label_to_idx = {label: idx for idx, label in enumerate(self.species_ids)}

        if 'filepath' not in self.df.columns:
            self.df['filepath'] = self.cfg.train_datadir + '/' + self.df.filename
        
        if 'samplename' not in self.df.columns:
            self.df['samplename'] = self.df.filename.map(lambda x: x.split('/')[0] + '-' + x.split('/')[-1].split('.')[0])

        if self.spectrograms:
            sample_names = set(self.df['samplename'])
            found_samples = sum(1 for name in sample_names if name in self.spectrograms)
            print(f"Found {found_samples} matching spectrograms for {mode} dataset out of {len(self.df)} samples")
        
        if cfg.debug:
            self.df = self.df.sample(min(1000, len(self.df)), random_state=cfg.seed).reset_index(drop=True)
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        samplename = row['samplename']

        if self.spectrograms and samplename in self.spectrograms:
            spec = self.spectrograms[samplename]
        elif not self.cfg.LOAD_DATA:
            spec = process_audio_file(row['filepath'], self.cfg)
        else: spec = None

        if spec is None:
            spec = np.zeros(self.cfg.TARGET_SHAPE, dtype=np.float32)
            if self.mode == "train":  # Only print warning during training
                print(f"Warning: Spectrogram for {samplename} not found and could not be generated")

        spec = torch.from_numpy(spec).float().unsqueeze(0)  # Add channel dimension

        if self.mode == "train" and random.random() < self.cfg.aug_prob:
            spec = self.apply_spec_augmentations(spec)
        
        target = self.encode_label(row['primary_label'])
        
        if 'secondary_labels' in row and row['secondary_labels'] not in [[''], None, np.nan]:
            if isinstance(row['secondary_labels'], str):
                secondary_labels = eval(row['secondary_labels'])
            else:
                secondary_labels = row['secondary_labels']
            
            for label in secondary_labels:
                idx = self.label_to_idx.get(label)
                if idx is not None:
                    target[idx] = 1.0
        
        return {
            'melspec': spec, 
            'target': torch.from_numpy(target).float(),
            'filename': row['filename']
        }
    
    def apply_spec_augmentations(self, spec):
        """Apply augmentations to spectrogram"""
    
        # Time masking (horizontal stripes)
        if random.random() < 0.5:
            for _ in range(random.randint(1, 3)):
                width = random.randint(5, 20)
                start = random.randint(0, spec.shape[2] - width)
                spec[0, :, start:start+width] = 0
        
        # Frequency masking (vertical stripes)
        if random.random() < 0.5:
            for _ in range(random.randint(1, 3)):
                height = random.randint(5, 20)
                start = random.randint(0, spec.shape[1] - height)
                spec[0, start:start+height, :] = 0
        
        # Random brightness/contrast
        if random.random() < 0.5:
            gain = random.uniform(0.8, 1.2)
            bias = random.uniform(-0.1, 0.1)
            spec = spec * gain + bias
            spec = torch.clamp(spec, 0, 1) 
            
        return spec
    
    def encode_label(self, label):
        """Encode label to one-hot vector"""
        target = np.zeros(self.num_classes)
        idx = self.label_to_idx.get(label)
        if idx is not None:
            target[idx] = 1.0
        return target
    
    def extend(self, new_samples):
        """Extend the dataset with new samples, supporting secondary labels."""
        print(f"Adding {len(new_samples)} new samples to the train dataset.")
        
        new_rows = []
        new_specs = {}

        for sample in new_samples:
            filename = sample['filename']
            samplename = filename.split('/')[0] + '-' + filename.split('/')[-1].split('.')[0]

            # Ensure target is a proper one-hot encoded vector
            target_array = sample['target'].numpy() if isinstance(sample['target'], torch.Tensor) else sample['target']
            if target_array.ndim > 1 or target_array.sum() > 1.5:
                # If multi-label vector, pick primary label as the one with highest score
                primary_label_idx = target_array.argmax()
            else:
                # If already one-hot
                primary_label_idx = target_array.argmax()

            primary_label = self.species_ids[primary_label_idx]

            # Optional: If you want to save secondary labels (everything else non-zero except the primary)
            secondary_label_indices = [i for i, val in enumerate(target_array) if val > 0 and i != primary_label_idx]
            secondary_labels = [self.species_ids[i] for i in secondary_label_indices] if secondary_label_indices else ['']

            new_row = {
                'filename': filename,
                'samplename': samplename,
                'primary_label': primary_label,
                'secondary_labels': str(secondary_labels),  # store as string for compatibility
                'filepath': self.cfg.test_soundscapes + '/' + filename  # comes from test_soundscapes!
            }
            new_rows.append(new_row)

            # Store the melspec separately
            new_specs[samplename] = sample['melspec'].squeeze(0).numpy()  # remove channel dim for consistency

        # Append to df
        new_df = pd.DataFrame(new_rows)
        self.df = pd.concat([self.df, new_df], ignore_index=True)

        # Update spectrograms dictionary if available
        if self.spectrograms is not None:
            self.spectrograms.update(new_specs)
        else:
            self.spectrograms = new_specs

In [4]:
class SoundscapeDatasetFromNPY(Dataset):
    def __init__(self, df, cfg, spectrograms=None, mode="train"):
        self.df = df
        self.cfg = cfg
        self.mode = mode
        self.spectrograms = spectrograms
        
        taxonomy_df = pd.read_csv(self.cfg.taxonomy_csv)
        self.species_ids = taxonomy_df['primary_label'].tolist()
        self.num_classes = len(self.species_ids)
        self.label_to_idx = {label: idx for idx, label in enumerate(self.species_ids)}
        df["primary_label"] = None
        self.primary_label = [None] * len(self.df)
        df["filename"] = df["samplename"].apply(lambda x: x+".ogg")

        # Count number of samples
        sample_names = set(self.df['samplename'])
        if self.spectrograms:
            found_samples = sum(1 for name in sample_names if name in self.spectrograms)
            print(f"Found {found_samples} matching spectrograms for {mode} dataset out of {len(self.df)} samples")
        
        if cfg.debug:
            self.df = self.df.sample(min(1000, len(self.df)), random_state=cfg.seed).reset_index(drop=True)
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        samplename = row['samplename']
        spec = None

        if self.spectrograms and samplename in self.spectrograms:
            spec = self.spectrograms[samplename]
        elif not self.cfg.LOAD_DATA:
            spec = process_audio_file(row['filepath'], self.cfg)

        if spec is None:
            spec = np.zeros(self.cfg.TARGET_SHAPE, dtype=np.float32)
            if self.mode == "train":  # Only print warning during training
                print(f"Warning: Spectrogram for {samplename} not found and could not be generated")

        spec = torch.tensor(spec, dtype=torch.float32).unsqueeze(0)  # Add channel dimension
        target = self.encode_label(row['primary_label'])

        return {
            'melspec': spec, 
            'target': target,
            'filename': row['filename'],
            'index': idx
        }
    
    def encode_label(self, label):
        """Encode label to one-hot vector"""
        target = np.zeros(self.num_classes)
        if label in self.label_to_idx:
            target[self.label_to_idx[label]] = 1.0
        elif label is None:
            return None
        return target
    
    def remove_indices(self, indices_to_remove):
        """Remove samples by their indices."""
        print(f"Removing {len(indices_to_remove)} samples from pseudo dataset.")
        self.df = self.df.drop(indices_to_remove).reset_index(drop=True)
        if hasattr(self, 'primary_label'):
            self.primary_label = [self.primary_label[i] for i in range(len(self.primary_label)) if i not in indices_to_remove]

## Model Definition

In [5]:
class BirdCLEFModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        taxonomy_df = pd.read_csv(cfg.taxonomy_csv)
        cfg.num_classes = len(taxonomy_df)
        
        self.backbone = timm.create_model(
            cfg.model_name,
            pretrained=cfg.pretrained,
            in_chans=cfg.in_channels,
            drop_rate=0.2,
            drop_path_rate=0.2
        )
        if 'efficientnet' in cfg.model_name:
            backbone_out = self.backbone.classifier.in_features
            self.backbone.classifier = nn.Identity()
        elif 'resnet' in cfg.model_name:
            backbone_out = self.backbone.fc.in_features
            self.backbone.fc = nn.Identity()
        else:
            backbone_out = self.backbone.get_classifier().in_features
            self.backbone.reset_classifier(0, '')
        
        self.pooling = nn.AdaptiveAvgPool2d(1)
        self.feat_dim = backbone_out
        self.classifier = nn.Linear(backbone_out, cfg.num_classes)
        
        self.mixup_enabled = hasattr(cfg, 'mixup_alpha') and cfg.mixup_alpha > 0
        if self.mixup_enabled:
            self.mixup_alpha = cfg.mixup_alpha
            
    def forward(self, x, targets=None):
    
        if self.training and self.mixup_enabled and targets is not None:
            mixed_x, targets_a, targets_b, lam = self.mixup_data(x, targets)
            x = mixed_x
        else:
            targets_a, targets_b, lam = None, None, None
        
        features = self.backbone(x)
        
        if isinstance(features, dict):
            features = features['features']
            
        if len(features.shape) == 4:
            features = self.pooling(features)
            features = features.view(features.size(0), -1)
        
        logits = self.classifier(features)
        
        if self.training and self.mixup_enabled and targets is not None:
            loss = self.mixup_criterion(F.binary_cross_entropy_with_logits, logits, targets_a, targets_b, lam)
            return logits, loss
            
        return logits
    
    def mixup_data(self, x, targets):
        """Applies mixup to the data batch"""
        batch_size = x.size(0)
        lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
        indices = torch.randperm(batch_size).to(x.device, non_blocking=True)
        mixed_x = lam * x + (1 - lam) * x[indices]
        
        return mixed_x, targets, targets[indices], lam
    
    def mixup_criterion(self, criterion, pred, y_a, y_b, lam):
        """Applies mixup to the loss function"""
        return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

## Training Utilities
We are configuring our optimization strategy with the AdamW optimizer, cosine scheduling, and the BCEWithLogitsLoss criterion.

## Training Loop

In [6]:
def train_one_epoch(model, loader, optimizer, criterion, device, scheduler=None, pseudo_loader=None, pseudo_dataset=None, use_amp=True):
    model.train()
    scaler = GradScaler(enabled=use_amp)
    total_loss = 0.0
    all_targets = []
    all_outputs = []

    pbar = tqdm(enumerate(loader), total=len(loader), desc="Training")
    
    for step, batch in pbar:
            
        inputs = batch['melspec'].to(device, non_blocking=True)
        targets = batch['target'].to(device, non_blocking=True)
        optimizer.zero_grad()

        with autocast(enabled=use_amp, device_type=cfg.device):
            outputs = model(inputs)
            loss = criterion(outputs, targets)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        if scheduler and isinstance(scheduler, lr_scheduler.OneCycleLR):
            scheduler.step()

        total_loss += loss.item()
        all_outputs.append(outputs.detach().cpu())
        all_targets.append(targets.detach().cpu())
        pbar.set_postfix({'train_loss': total_loss / len(all_outputs),'lr': optimizer.param_groups[0]['lr']})

    all_outputs = torch.cat(all_outputs)
    all_targets = torch.cat(all_targets)
    auc = calculate_auc(all_targets.numpy(), all_outputs)
    avg_loss = total_loss / len(loader)
    
    return avg_loss, auc

def update_pseudo_labels(model, pseudo_loader, pseudo_dataset, train_dataset, device, threshold=0.9, remove_pseudo=True):
    model.eval()
    to_remove = []
    new_samples = []

    with torch.no_grad():
        for batch in tqdm(pseudo_loader, desc="Updating pseudo-labels"):
            inputs = batch['melspec'].to(device, non_blocking=True)
            indices = batch['index']
            filenames = batch['filename']
            outputs = model(inputs)
            probs = torch.sigmoid(outputs)
            probs = probs.cpu().numpy()
            inputs = inputs.cpu()

            for i, prob in enumerate(probs):
                top_class = np.argmax(prob)
                top_confidence = prob[top_class]

                if top_confidence > threshold:
                    target = np.zeros_like(prob)
                    target[top_class] = 1.0

                    new_samples.append({'melspec': inputs[i], 'target': torch.tensor(target, dtype=torch.float32), 'filename': filenames[i]})
                    to_remove.append(indices[i])

    train_dataset.extend(new_samples)

    if remove_pseudo and to_remove:
        pseudo_dataset.remove_indices(to_remove)

def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0.0
    all_targets = []
    all_outputs = []
    
    with torch.no_grad():
        for batch in tqdm(loader, desc="Validation"):
            inputs = batch['melspec'].to(device, non_blocking=True)
            targets = batch['target'].to(device, non_blocking=True)

            outputs = model(inputs)
            loss = criterion(outputs, targets)

            total_loss += loss.item()
            all_outputs.append(outputs.detach().cpu())
            all_targets.append(targets.detach().cpu())
    
    all_outputs = torch.cat(all_outputs)
    all_targets = torch.cat(all_targets).numpy()
    
    auc = calculate_auc(all_targets, all_outputs)
    avg_loss = total_loss / len(loader)
    
    return avg_loss, auc

def calculate_auc(targets: np.array , outputs: torch.Tensor) -> float:
  
    aucs = []
    probs = torch.sigmoid(outputs).numpy()
    
    for i in range(targets.shape[1]):
        if np.sum(targets[:, i]) > 0:
            class_auc = roc_auc_score(targets[:, i], probs[:, i])
            aucs.append(class_auc)
    
    return np.mean(aucs) if aucs else 0.0

## Training!

In [7]:
def run_training(df, cfg, soundscape_df=None):
    """Training function that can either use pre-computed spectrograms or generate them on-the-fly"""

    taxonomy_df = pd.read_csv(cfg.taxonomy_csv)
    species_ids = taxonomy_df['primary_label'].tolist()
    cfg.num_classes = len(species_ids)
    
    if cfg.debug: cfg.update_debug_settings()

    spectrograms = None
    if cfg.LOAD_DATA:
        print("Loading pre-computed mel spectrograms from NPY file...")
        try:
            spectrograms = np.load(cfg.spectrogram_npy, allow_pickle=True).item()
            soundscape_spectrograms = np.load(cfg.spectrogram_npy_unlabeled, allow_pickle=True).item()
            print(f"Loaded {len(spectrograms)} pre-computed mel spectrograms")
        except Exception as e:
            print(f"Error loading pre-computed spectrograms: {e}")
            print("Will generate spectrograms on-the-fly instead.")
            cfg.LOAD_DATA = False
    
    if not cfg.LOAD_DATA:
        print("Will generate spectrograms on-the-fly during training.")
        if 'filepath' not in df.columns:
            df['filepath'] = cfg.train_datadir + '/' + df.filename
        if 'samplename' not in df.columns:
            df['samplename'] = df.filename.map(lambda x: x.split('/')[0] + '-' + x.split('/')[-1].split('.')[0])

    if cfg.n_fold > 1:
        skf = StratifiedKFold(n_splits=cfg.n_fold, shuffle=True, random_state=cfg.seed)
        folds = skf.split(df, df['primary_label'])
    else:
        folds = [(np.arange(len(df)), np.arange(len(df)))]

    best_scores = []
    
    for fold, (train_idx, val_idx) in enumerate(folds):
            
        print(f'\n{"="*30} Fold {fold} {"="*30}')
        
        train_df = df.iloc[train_idx].reset_index(drop=True)
        val_df = df.iloc[val_idx].reset_index(drop=True)
        
        print(f'Training set: {len(train_df)} samples')
        print(f'Validation set: {len(val_df)} samples')
        
        train_dataset = BirdCLEFDatasetFromNPY(train_df, cfg, spectrograms=spectrograms, mode='train')
        val_dataset = BirdCLEFDatasetFromNPY(val_df, cfg, spectrograms=spectrograms, mode='valid')
        soundscape_dataset = SoundscapeDatasetFromNPY(soundscape_df, cfg, spectrograms=soundscape_spectrograms, mode='train')
        
        train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers,pin_memory=True,collate_fn=collate_fn,drop_last=True)
        val_loader = DataLoader(val_dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers,pin_memory=True,collate_fn=collate_fn)
        soundscape_loader = DataLoader(soundscape_dataset,batch_size=cfg.batch_size,shuffle=True,num_workers=cfg.num_workers,pin_memory=True,collate_fn=collate_fn,drop_last=True)
        
        model = BirdCLEFModel(cfg).to(cfg.device, non_blocking=True)
        model = torch.compile(model, backend="inductor")
        optimizer = get_optimizer(model, cfg)
        criterion = get_criterion(cfg)
        scheduler = get_scheduler(optimizer, cfg, len(train_loader))
        
        best_auc, best_epoch = 0, 0
        
        for epoch in range(cfg.epochs):
            print(f"\nEpoch {epoch+1}/{cfg.epochs}")
            use_pseduo_labels = cfg.use_soundscapes and (epoch+1) >= cfg.start_pseudo_epoch

            train_loss, train_auc = train_one_epoch(model, train_loader, optimizer, criterion, cfg.device,
                scheduler if isinstance(scheduler, lr_scheduler.OneCycleLR) else None,
                pseudo_loader=soundscape_loader if use_pseduo_labels else None,
                pseudo_dataset=soundscape_dataset if use_pseduo_labels else None
            )
                        
            if use_pseduo_labels:
                update_pseudo_labels(model, soundscape_loader, soundscape_dataset, train_dataset, cfg.device, threshold=cfg.pseudo_update_threshold, remove_pseudo=cfg.remove_pseudo)
            
            val_loss, val_auc = validate(model, val_loader, criterion, cfg.device)

            if scheduler is not None and not isinstance(scheduler, lr_scheduler.OneCycleLR):
                if isinstance(scheduler, lr_scheduler.ReduceLROnPlateau):
                    scheduler.step(val_loss)
                else:
                    scheduler.step()

            print(f"Train Loss: {train_loss:.4f}, Train AUC: {train_auc:.4f}")
            print(f"Val Loss: {val_loss:.4f}, Val AUC: {val_auc:.4f}")
            
            if val_auc > best_auc:
                best_auc = val_auc
                best_epoch = epoch + 1
                print(f"New best AUC: {best_auc:.4f} at epoch {best_epoch}")

                torch.save({
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
                    'epoch': epoch,
                    'val_auc': val_auc,
                    'train_auc': train_auc,
                    'cfg': cfg
                }, f"{cfg.OUTPUT_DIR}/model_{cfg.timestamp}_{cfg.model_name}_fold{fold}.pth")
        
        best_scores.append(best_auc)
        print(f"\nBest AUC for fold {fold}: {best_auc:.4f} at epoch {best_epoch}")
        
        # Clear memory
        del model, optimizer, scheduler, train_loader, val_loader, soundscape_loader
        torch.cuda.empty_cache()
        gc.collect()
    
    print("\n" + "="*60)
    print("Cross-Validation Results:")
    for fold, score in enumerate(best_scores):
        print(f"Fold {fold}: {score:.4f}")
    print(f"Mean AUC: {np.mean(best_scores):.4f}")
    print("="*60)

In [8]:
if __name__ == "__main__":
    print("\nLoading training data...")
    train_df = pd.read_csv(cfg.train_csv)
    taxonomy_df = pd.read_csv(cfg.taxonomy_csv)
    soundscape_df = pd.read_csv(cfg.unlabeled_sample_list)

    print("\nStarting training...")
    print(f"LOAD_DATA is set to {cfg.LOAD_DATA}")
    if cfg.LOAD_DATA:
        print("Using pre-computed mel spectrograms from NPY file")
    else:
        print("Will generate spectrograms on-the-fly during training")

    run_training(train_df, cfg, soundscape_df=soundscape_df)
    print("\nTraining complete!")
    cfg.save_config()


Loading training data...

Starting training...
LOAD_DATA is set to True
Using pre-computed mel spectrograms from NPY file
Loading pre-computed mel spectrograms from NPY file...
Loaded 28564 pre-computed mel spectrograms

Training set: 28564 samples
Validation set: 28564 samples
Found 28564 matching spectrograms for train dataset out of 28564 samples
Found 28564 matching spectrograms for valid dataset out of 28564 samples
Found 9726 matching spectrograms for train dataset out of 9726 samples
Loaded 28564 pre-computed mel spectrograms

Training set: 28564 samples
Validation set: 28564 samples
Found 28564 matching spectrograms for train dataset out of 28564 samples
Found 28564 matching spectrograms for valid dataset out of 28564 samples
Found 9726 matching spectrograms for train dataset out of 9726 samples

Epoch 1/10

Epoch 1/10


Training:   0%|          | 0/892 [00:00<?, ?it/s]

W0511 12:41:38.372000 1439 torch/_inductor/utils.py:1137] [0/0] Not enough SMs to use max_autotune_gemm mode


KeyboardInterrupt: 