## Libraries

In [None]:
import os
import logging
import random
import gc
import time
import cv2
import math
import warnings
from pathlib import Path

import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import roc_auc_score
import librosa
import multiprocessing

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchaudio
from torchvision.ops import sigmoid_focal_loss

import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm

import timm
import lightgbm as lgb

warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.ERROR)

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        return sigmoid_focal_loss(
            inputs, targets,
            alpha=self.alpha,
            gamma=self.gamma,
            reduction=self.reduction
        )


## Configuration

In [None]:
class CFG:
    # Basic
    seed = 42
    debug = False
    apex = False
    print_freq = 100
    num_workers = max(1, int(multiprocessing.cpu_count() * 0.8))

    # Paths
    OUTPUT_DIR = '/kaggle/working/'
    train_datadir = '/kaggle/input/birdclef-2025/train_audio'
    train_csv = '/kaggle/input/birdclef-2025/train.csv'
    test_soundscapes = '/kaggle/input/birdclef-2025/test_soundscapes'
    submission_csv = '/kaggle/input/birdclef-2025/sample_submission.csv'
    taxonomy_csv = '/kaggle/input/birdclef-2025/taxonomy.csv'
    fabio_csv_path = '/kaggle/input/fabio-csv/fabio.csv'
    batch_spectrograms_dir = '/kaggle/input/birdclef2025-melspecs-256x256-5sec-16bit'

    # model
    in_channels = 1
    num_classes = None  # Will be set dynamically

    # data
    FS = 32000
    TARGET_DURATION = 5.0
    TARGET_SHAPE = (256, 256)
    
    # Audio
    N_FFT = 1024
    HOP_LENGTH = 512
    N_MELS = 128
    FMIN = 50
    FMAX = 14000
    
    # Training
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    epochs = 10
    batch_size = 32
    criterion = 'FocalBCE'

    # Augmentation
    aug_prob = 1.0
    mixup_prob = 0.6     # 1? 0.8? 0.6?
    mixup_alpha = 0.4

    def update_debug_settings(self):
        if self.debug:
            self.epochs = 2

cfg = CFG()
print(cfg.num_workers)

## Utilities

In [None]:
def set_seed(seed=42):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(cfg.seed)

In [None]:
class BirdCLEFDatasetFromNPY(Dataset):
    def __init__(self, df, cfg, spectrograms=None, negative_spectrograms=None, mode="train"):
        self.df = df.copy()
        self.cfg = cfg
        self.mode = mode
        self.spectrograms = spectrograms
        self.negative_spectrograms = negative_spectrograms

        # Load Fabio intervals
        if os.path.exists(cfg.fabio_csv_path):
            fabio_df = pd.read_csv(cfg.fabio_csv_path)
            self.fabio_intervals = {row['filename']: (row['start'], row['stop']) for _, row in fabio_df.iterrows()}
        else:
            self.fabio_intervals = {}
        
        # Load taxonomy (한 번만)
        taxonomy_df = pd.read_csv(self.cfg.taxonomy_csv)
        self.species_ids = taxonomy_df['primary_label'].tolist()
        self.num_classes = len(self.species_ids)
        self.label_to_idx = {label: idx for idx, label in enumerate(self.species_ids)}
        self.valid_species = set(self.species_ids)  # 미리 set으로 변환
        
        # Update config
        cfg.num_classes = self.num_classes

        # File paths
        if 'filepath' not in self.df.columns:
            self.df['filepath'] = self.cfg.train_datadir + '/' + self.df['filename']

        # Sample name: key of spectrogram dictionary
        if 'samplename' not in self.df.columns:
            self.df['samplename'] = self.df['filename'].map(
                lambda x: x.split('/')[0] + '-' + x.split('/')[-1].split('.')[0]
            )

        # Debug mode
        if cfg.debug:
            print(f"Debug mode: sampling {min(1000, len(self.df))} samples before expansion")
            self.df = self.df.sample(min(1000, len(self.df)), random_state=cfg.seed).reset_index(drop=True)

        # 모든 스펙트로그램 세그먼트를 개별 샘플로 확장 (최적화된 버전)
        if spectrograms and mode == "train":
            print("Building global spectrogram index...")
            self.global_spec_index = self._build_global_index(spectrograms)
            
            print("Expanding dataset with optimized processing...")
            self.df = self._expand_dataframe_ultimate_optimized(self.df)
            print(f"Expanded to {len(self.df)} samples with all segments")

        # Add negative samples to dataframe if provided
        if self.negative_spectrograms is not None and mode == "train":
            negative_df = pd.DataFrame({
                'filename': list(self.negative_spectrograms.keys()),
                'primary_label': ['nocall'] * len(self.negative_spectrograms),
                'samplename': list(self.negative_spectrograms.keys()),
                'filepath': [''] * len(self.negative_spectrograms)
            })
            self.df = pd.concat([self.df, negative_df], ignore_index=True)

        # 클래스별 샘플 수 계산
        if mode == "train":
            self.class_counts = self.df['primary_label'].value_counts().to_dict()
            self.rare_threshold = 20
            self.target_samples = 50
            print(f"Classes with < {self.rare_threshold} samples: {sum(1 for count in self.class_counts.values() if count < self.rare_threshold)}")
            
        # Check spectrograms availability
        if self.spectrograms:
            sample_names = set(self.df['samplename'])
            found_samples = sum(1 for name in sample_names if name in self.spectrograms)
            print(f"Found {found_samples} matching positive spectrograms for {mode} dataset")
        
        if self.negative_spectrograms:
            neg_sample_names = set(self.negative_spectrograms.keys())
            found_neg_samples = len(neg_sample_names)
            print(f"Found {found_neg_samples} negative spectrograms for {mode} dataset")

    def _build_global_index(self, spectrograms):
        """전역 스펙트로그램 인덱스 한 번만 구축 - 최적화"""
        from collections import defaultdict
        
        global_index = defaultdict(list)
        
        # 벡터화된 키 처리
        positive_keys = [k for k in spectrograms.keys() if not k.startswith('negative-')]
        
        for key in positive_keys:
            # species 필터링
            species_id = key.split('-', 1)[0]
            if species_id in self.valid_species:
                base_key = key.rsplit('_', 1)[0]
                global_index[base_key].append(key)
        
        print(f"Built global index with {len(global_index)} base keys from {len(positive_keys)} total keys")
        return global_index

    def _expand_dataframe_ultimate_optimized(self, df):
        """메모리 효율성을 높인 궁극의 최적화 버전"""
        
        # 1. 조기 필터링 - 매칭 가능한 샘플만 추출
        available_samples = set(self.global_spec_index.keys())
        df_filtered = df[df['samplename'].isin(available_samples)].copy()
        
        print(f"Pre-filtered: {len(df)} → {len(df_filtered)} samples")
        
        if len(df_filtered) == 0:
            print("Warning: No matching samples found!")
            return pd.DataFrame()
        
        # 2. 청크 단위 처리로 메모리 효율성 향상
        chunk_size = 1000
        expanded_chunks = []
        df_records = df_filtered.to_dict('records')
        
        for i in range(0, len(df_records), chunk_size):
            chunk = df_records[i:i+chunk_size]
            chunk_expanded = []
            
            for row_dict in chunk:
                samplename = row_dict['samplename']
                segment_keys = self.global_spec_index.get(samplename, [])
                
                # 각 세그먼트에 대해 새 행 생성
                for segment_key in segment_keys:
                    new_row = row_dict.copy()
                    new_row['samplename'] = segment_key
                    new_row['segment_id'] = segment_key.split('_')[-1]
                    chunk_expanded.append(new_row)
            
            if chunk_expanded:
                expanded_chunks.append(pd.DataFrame(chunk_expanded))
            
            # 메모리 정리
            del chunk_expanded
            gc.collect()
        
        if expanded_chunks:
            result_df = pd.concat(expanded_chunks, ignore_index=True)
            print(f"Generated {len(result_df)} expanded samples")
            return result_df
        else:
            return pd.DataFrame()
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        # 입력 데이터 검증
        if idx >= len(self.df):
            print(f"Warning: Index {idx} out of range for dataset of size {len(self.df)}")
            return None
            
        row = self.df.iloc[idx]
        
        # 필수 컬럼 존재 확인
        required_columns = ['samplename', 'filename', 'primary_label']
        for col in required_columns:
            if col not in row or pd.isna(row[col]):
                print(f"Warning: Missing {col} for index {idx}")
                return None
                
        samplename = row['samplename']
        filename = row['filename']
        primary_label = row['primary_label']
        
        # Get spectrogram with error handling
        spec = self._get_spectrogram(samplename)
        
        if spec is None:
            spec = np.zeros(self.cfg.TARGET_SHAPE, dtype=np.float32)
            if self.mode == "train":
                print(f"Warning: No spectrogram found for {samplename}. Using zero spectrogram.")

        # uint16 → float32
        if hasattr(spec, 'dtype') and spec.dtype == np.uint16:
            spec = spec.astype(np.float32) / 65535.0

        spec = torch.tensor(spec, dtype=torch.float32).unsqueeze(0)  # Add channel dimension

        # Apply augmentations
        if self.mode == "train" and random.random() < self.cfg.aug_prob:
            spec = self.apply_spec_augmentations(spec, primary_label)
            
        # Handle secondary labels
        secondary_labels = []
        if 'secondary_labels' in row and pd.notna(row['secondary_labels']) and row['secondary_labels'] != '':
            secondary_labels = self._parse_secondary_labels(row['secondary_labels'])
            
        # Encode labels
        target = self.encode_label(primary_label, secondary_labels)
        
        # Apply mixup (60% 확률로 고정)
        if self.mode == "train" and random.random() < 0.6:
            spec, target = self._apply_mixup(spec, target, idx)
        
        return {
            'melspec': spec, 
            'target': torch.tensor(target, dtype=torch.float32),
            'filename': filename
        }
    
    def _get_spectrogram(self, samplename):
        """Get spectrogram from cache with error handling"""
        try:
            if self.spectrograms and samplename in self.spectrograms:
                spec = self.spectrograms[samplename]
                # 스펙트로그램 유효성 검사
                if spec is None or spec.size == 0:
                    print(f"Warning: Empty spectrogram for {samplename}")
                    return None
                return spec
            elif self.negative_spectrograms and samplename in self.negative_spectrograms:
                spec = self.negative_spectrograms[samplename]
                if spec is None or spec.size == 0:
                    print(f"Warning: Empty negative spectrogram for {samplename}")
                    return None
                return spec
            else:
                return None
        except Exception as e:
            print(f"Error loading spectrogram for {samplename}: {e}")
            return None
    
    def _parse_secondary_labels(self, secondary_labels):
        """Parse secondary labels from string or list"""
        if isinstance(secondary_labels, str):
            try:
                return eval(secondary_labels)
            except:
                return []
        elif isinstance(secondary_labels, list):
            return secondary_labels
        return []

    def _smart_mixup_pairing(self, idx, target):
        """동일하거나 유사한 클래스끼리 우선 페어링 - 최적화"""
        current_classes = np.where(target > 0)[0]
        
        # 캐시된 positive indices 사용 (성능 향상)
        if not hasattr(self, '_positive_indices'):
            self._positive_indices = [i for i, row in self.df.iterrows() if row['primary_label'] != 'nocall']
        
        positive_candidates = [i for i in self._positive_indices if i != idx]
        same_class_candidates = []
        
        # 샘플링으로 후보 수 제한 (성능 향상)
        if len(positive_candidates) > 1000:
            positive_candidates = random.sample(positive_candidates, 1000)
        
        for i in positive_candidates:
            candidate_target = self.encode_label(self.df.iloc[i]['primary_label'])
            candidate_classes = np.where(candidate_target > 0)[0]
            
            if len(np.intersect1d(current_classes, candidate_classes)) > 0:
                same_class_candidates.append(i)
        
        if same_class_candidates:
            return random.choice(same_class_candidates)
        elif positive_candidates:
            return random.choice(positive_candidates)
        else:
            return random.randint(0, len(self.df) - 1)
       
    def _apply_mixup(self, spec, target, idx):
        row = self.df.iloc[idx]
        
        # Dismiss negative 
        if row['primary_label'] == 'nocall':
            return spec, target
        
        # pairing
        mix_idx = self._smart_mixup_pairing(idx, target)
        row2 = self.df.iloc[mix_idx]
        
        # Load spectrogram
        spec2 = self._get_spectrogram(row2['samplename'])
        if spec2 is None:
            return spec, target
        
        # uint16 → float32 변환 (필요시)
        if hasattr(spec2, 'dtype') and spec2.dtype == np.uint16:
            spec2 = spec2.astype(np.float32) / 65535.0
            
        spec2 = torch.tensor(spec2, dtype=torch.float32).unsqueeze(0)
        
        # Target encoding
        target2 = self.encode_label(row2['primary_label'])
        
        # Deal with Secondary labels
        if 'secondary_labels' in row2 and pd.notna(row2['secondary_labels']) and row2['secondary_labels'] != '':
            secondary_labels2 = self._parse_secondary_labels(row2['secondary_labels'])
            for label in secondary_labels2:
                if label in self.label_to_idx:
                    target2[self.label_to_idx[label]] = 1.0
        
        # Set alpha
        alpha = random.uniform(0.2, 0.8)
        lam = np.random.beta(alpha, alpha)
        
        # Mixup
        mixed_spec = lam * spec + (1 - lam) * spec2
        mixed_target = lam * target + (1 - lam) * target2
        
        return mixed_spec, mixed_target
    
    def apply_spec_augmentations(self, spec, primary_label=None):
        """클래스별 적응적 augmentation with improved quality checks"""
        
        # 클래스별 샘플 수 확인
        if primary_label and hasattr(self, 'class_counts') and primary_label in self.class_counts:
            sample_count = self.class_counts[primary_label]
            is_rare_class = sample_count < self.rare_threshold
            # 로그 스케일로 augmentation 강도 조절
            aug_strength = min(1.0, math.log(1000 / max(sample_count, 1)) / math.log(10))
        else:
            is_rare_class = False
            aug_strength = 0.5
        
        # 스펙트로그램 품질 체크
        if torch.mean(spec) < 0.01:  # 너무 조용한 스펙트로그램
            # 노이즈 추가로 학습 신호 강화
            noise = torch.randn_like(spec) * 0.1
            spec = spec + noise
        
        # Rare class: 강한 augmentation
        if is_rare_class:
            aug_prob = 0.8
            max_techniques = 4
        else:
            aug_prob = 0.5
            max_techniques = 3
        
        applied_count = 0
        
        # Time masking
        if random.random() < aug_prob and applied_count < max_techniques:
            if is_rare_class:
                num_masks = random.randint(1, 4)
                for _ in range(num_masks):
                    width = random.randint(3, 25)
                    start = random.randint(0, max(1, spec.shape[2] - width))
                    spec[0, :, start:start+width] = 0
            else:
                num_masks = random.randint(1, 3)
                for _ in range(num_masks):
                    width = random.randint(5, 20)
                    start = random.randint(0, max(1, spec.shape[2] - width))
                    spec[0, :, start:start+width] = 0
            applied_count += 1
        
        # Frequency masking
        if random.random() < aug_prob and applied_count < max_techniques:
            if is_rare_class:
                num_masks = random.randint(1, 4)
                for _ in range(num_masks):
                    height = random.randint(3, 25)
                    start = random.randint(0, max(1, spec.shape[1] - height))
                    spec[0, start:start+height, :] = 0
            else:
                num_masks = random.randint(1, 3)
                for _ in range(num_masks):
                    height = random.randint(5, 20)
                    start = random.randint(0, max(1, spec.shape[1] - height))
                    spec[0, start:start+height, :] = 0
            applied_count += 1
        
        # Random brightness/contrast
        if random.random() < aug_prob and applied_count < max_techniques:
            if is_rare_class:
                gain = random.uniform(0.7, 1.3)
                bias = random.uniform(-0.15, 0.15)
            else:
                gain = random.uniform(0.8, 1.2)
                bias = random.uniform(-0.1, 0.1)
            
            spec = spec * gain + bias
            spec = torch.clamp(spec, 0, 1)
            applied_count += 1

        # Gaussian noise
        if random.random() < aug_prob and applied_count < max_techniques:
            if is_rare_class:
                noise_level = random.uniform(0.03, 0.08)
            else:
                noise_level = 0.05
            
            noise = torch.randn_like(spec) * noise_level
            spec = spec + noise
            spec = torch.clamp(spec, 0, 1)
            applied_count += 1

        # Random erasing
        if random.random() < aug_prob and applied_count < max_techniques:
            if is_rare_class:
                num_erases = random.randint(1, 3)
                for _ in range(num_erases):
                    erase_height = random.randint(3, 25)
                    erase_width = random.randint(3, 25)
                    max_x = spec.shape[2] - erase_width
                    max_y = spec.shape[1] - erase_height
                    if max_x > 0 and max_y > 0:
                        x = random.randint(0, max_x)
                        y = random.randint(0, max_y)
                        spec[0, y:y+erase_height, x:x+erase_width] = 0
            else:
                erase_height = random.randint(5, 20)
                erase_width = random.randint(5, 20)
                max_x = spec.shape[2] - erase_width
                max_y = spec.shape[1] - erase_height
                if max_x > 0 and max_y > 0:
                    x = random.randint(0, max_x)
                    y = random.randint(0, max_y)
                    spec[0, y:y+erase_height, x:x+erase_width] = 0
            applied_count += 1
            
        return spec
    
    def encode_label(self, label, secondary_labels=None):
        target = np.zeros(self.num_classes)
        if label == 'nocall' or label == '':
            return target
        if label in self.label_to_idx:
            target[self.label_to_idx[label]] = 1.0
        if secondary_labels:
            for sec in secondary_labels:
                if sec in self.label_to_idx:
                    target[self.label_to_idx[sec]] = 1.0
        return target


In [None]:
def collate_fn(batch):
    """Custom collate function to handle different sized spectrograms"""
    batch = [item for item in batch if item is not None]
    if len(batch) == 0:
        return {}
        
    result = {key: [] for key in batch[0].keys()}
    
    for item in batch:
        for key, value in item.items():
            result[key].append(value)
    
    for key in result:
        if key in ['target', 'melspec'] and isinstance(result[key][0], torch.Tensor):
            try:
                result[key] = torch.stack(result[key])
            except RuntimeError as e:
                print(f"Error stacking {key}: {e}")
                continue
    
    return result

In [None]:
class ConvNeXtFeatureExtractor(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        # 개선된 ConvNeXt 모델 설정
        self.backbone = timm.create_model(
            'convnextv2_nano.fcmae',
            pretrained=True,
            in_chans=cfg.in_channels,
            num_classes=0,  # Remove classification layer
            drop_rate=0.2,  # 드롭아웃 추가
            drop_path_rate=0.1  # 드롭패스 추가
        )
        
        # 더 효과적인 풀링 전략
        self.pooling = nn.AdaptiveAvgPool2d(1)
        self.global_pool = nn.AdaptiveMaxPool2d(1)  # 추가

    def forward(self, x):
        features = self.backbone(x)  # (B, C, H, W)
        if len(features.shape) == 4:
            avg_pool = self.pooling(features).view(features.size(0), -1)  # (B, C)
            max_pool = self.global_pool(features).view(features.size(0), -1)  # (B, C)
            # 평균과 최대 풀링 결합
            return torch.cat([avg_pool, max_pool], dim=1)
        return features

def extract_features(model, dataloader, device):
    """Extract features using the model with improved memory management"""
    model.eval()
    features = []
    labels = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Extracting features"):
            if 'melspec' not in batch or 'target' not in batch:
                continue
                
            x = batch['melspec'].to(device)
            feats = model(x)  # (B, C)
            features.append(feats.cpu().numpy())
            labels.append(batch['target'].cpu().numpy())
            
            # GPU 메모리 정리 추가
            del x, feats
            torch.cuda.empty_cache()
            
    if features:
        features = np.concatenate(features, axis=0)
        labels = np.concatenate(labels, axis=0)
        return features, labels
    else:
        return np.array([]), np.array([])

def train_lightgbm(X, y):
    """Train LightGBM models with optimized hyperparameters"""
    models = []
    print(f"Training LightGBM for {y.shape[1]} classes...")
    
    # 개선된 기본 파라미터 설정
    base_params = {
        'objective': 'binary',
        'metric': 'auc',
        'boosting_type': 'gbdt',
        'num_leaves': 31,
        'learning_rate': 0.05,
        'feature_fraction': 0.9,
        'bagging_fraction': 0.8,
        'bagging_freq': 5,
        'verbose': -1,
        'min_data_in_leaf': 20,  # 추가
        'lambda_l1': 0.1,        # 추가
        'lambda_l2': 0.1,        # 추가
        'max_depth': -1          # 추가
    }
    y_binary = (y > 0.5).astype(int)
    
    for i in tqdm(range(y_binary.shape[1]), desc="Training LightGBM"):
        try:
            # 클래스별 동적 파라미터 조정
            class_positive_count = np.sum(y_binary[:, i])
            params = base_params.copy()
            
            if class_positive_count < 50:  # 희귀 클래스
                params['learning_rate'] = 0.03
                params['num_leaves'] = 15
                params['min_data_in_leaf'] = 5
            
            # 먼저 stratified split 시도
            X_train, X_val, y_train, y_val = train_test_split(
                X, y[:, i], 
                test_size=0.2, 
                random_state=42,
                stratify=y[:, i]
            )
        except ValueError as e:
            # Stratify 실패 시 random split으로 대체
            print(f"Class {i}: Stratify failed ({str(e)}), using random split")
            X_train, X_val, y_train, y_val = train_test_split(
                X, y[:, i], 
                test_size=0.2, 
                random_state=42,
                stratify=None
            )
        
        lgb_train = lgb.Dataset(X_train, label=y_train)
        lgb_val = lgb.Dataset(X_val, label=y_val)
        
        model = lgb.train(
            params,
            lgb_train,
            num_boost_round=100,
            valid_sets=[lgb_val],
            valid_names=['valid'],
            callbacks=[
                lgb.early_stopping(10),
                lgb.log_evaluation(0)
            ]
        )
        models.append(model)
    
    return models


def load_spectrograms_batch(cfg, batch_dir):
    """Load spectrograms from batch npy files"""
    import glob
    
    # batch npy 파일들 찾기
    batch_files = glob.glob(os.path.join(batch_dir, '*.npy'))
    batch_files = [f for f in batch_files if 'melspecs_uint16_batch' in f]
    batch_files.sort()
    
    if not batch_files:
        print(f"No batch files found in {batch_dir}")
        return None, None
    
    print(f"Loading spectrograms from {len(batch_files)} batch files...")
    
    all_spectrograms = {}
    for batch_file in batch_files:
        print(f"Loading {os.path.basename(batch_file)}...")
        batch_data = np.load(batch_file, allow_pickle=True).item()
        
        # uint16 → float32 변환
        for key, spec in batch_data.items():
            if spec.dtype == np.uint16:
                spec = spec.astype(np.float32) / 65535.0
            all_spectrograms[key] = spec
    
    # positive, negative 분리
    positive_spectrograms = {k: v for k, v in all_spectrograms.items() if not k.startswith('negative-')}
    negative_spectrograms = {k: v for k, v in all_spectrograms.items() if k.startswith('negative-')}
    
    print(f"Loaded {len(positive_spectrograms)} positive spectrograms")
    print(f"Loaded {len(negative_spectrograms)} negative spectrograms")
    
    return positive_spectrograms, negative_spectrograms

In [None]:
import pickle
import json
from datetime import datetime

def main():
    """Main training function"""
    print("Starting BirdCLEF 2025 training pipeline...")
    
    # Update debug settings
    cfg.update_debug_settings()
    
    # Load data
    print("Loading training data...")
    train_df = pd.read_csv(cfg.train_csv)
    print(f"Loaded {len(train_df)} training samples")
    
    # Load spectrograms
    positive_spectrograms, negative_spectrograms = load_spectrograms_batch(cfg, cfg.batch_spectrograms_dir)
            
    # Create feature extractor
    print("Initializing ConvNeXt feature extractor...")
    feature_model = ConvNeXtFeatureExtractor(cfg).to(cfg.device)
    
    # Create dataset and dataloader
    print("Creating dataset...")
    train_dataset = BirdCLEFDatasetFromNPY(
        train_df, 
        cfg, 
        spectrograms=positive_spectrograms, 
        negative_spectrograms=negative_spectrograms,
        mode='train'
    )
    train_loader = DataLoader(
        train_dataset, 
        batch_size=cfg.batch_size, 
        shuffle=False,
        num_workers=cfg.num_workers,
        pin_memory=True,
        collate_fn=collate_fn
    )
    
    # Extract features
    print("Extracting features...")
    X_train, y_train = extract_features(feature_model, train_loader, cfg.device)
    
    if len(X_train) == 0:
        print("No features extracted. Check your data and model.")
        return
    
    print(f"Extracted features shape: {X_train.shape}")
    print(f"Labels shape: {y_train.shape}")
    print("Positive count per class:", np.sum(y_train, axis=0))

    # Check positive/negative ratio
    positive_samples = np.sum(np.any(y_train == 1, axis=1))
    negative_samples = len(y_train) - positive_samples
    print(f"Positive samples: {positive_samples}")
    print(f"Negative samples: {negative_samples}")
    print(f"Negative ratio: {negative_samples/len(y_train)*100:.1f}%")
    
    # Train LightGBM
    print("Training LightGBM models...")
    lgbm_models = train_lightgbm(X_train, y_train)
    
    print(f"Training completed! Trained {len(lgbm_models)} LightGBM models.")
    
    # Save models and metadata
    print("Saving models...")
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    # Create output directory if it doesn't exist
    os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
    
    # 1. Save LightGBM models
    models_dir = os.path.join(cfg.OUTPUT_DIR, f'lgbm_models_{timestamp}')
    os.makedirs(models_dir, exist_ok=True)
    
    for i, model in enumerate(lgbm_models):
        model_path = os.path.join(models_dir, f'lgbm_model_class_{i}.txt')
        model.save_model(model_path)
    
    print(f"Saved {len(lgbm_models)} LightGBM models to {models_dir}")
    
    # 2. Save ConvNeXt feature extractor
    feature_model_path = os.path.join(cfg.OUTPUT_DIR, f'convnext_feature_extractor_{timestamp}.pth')
    torch.save({
        'model_state_dict': feature_model.state_dict(),
        'model_config': {
            'model_name': 'convnextv2_nano.fcmae',
            'in_chans': cfg.in_channels,
            'num_classes': 0
        }
    }, feature_model_path)
    
    print(f"Saved ConvNeXt feature extractor to {feature_model_path}")
    
    # 3. Save species mapping and metadata
    taxonomy_df = pd.read_csv(cfg.taxonomy_csv)
    species_mapping = {
        'species_ids': taxonomy_df['primary_label'].tolist(),
        'label_to_idx': {label: idx for idx, label in enumerate(taxonomy_df['primary_label'])},
        'num_classes': len(taxonomy_df)
    }
    
    metadata = {
        'timestamp': timestamp,
        'training_samples': len(train_df),
        'extracted_features_shape': X_train.shape,
        'labels_shape': y_train.shape,
        'positive_samples': int(positive_samples),
        'negative_samples': int(negative_samples),
        'negative_ratio': float(negative_samples/len(y_train)*100),
        'num_lgbm_models': len(lgbm_models),
        'cfg_settings': {
            'model_name': 'convnextv2_nano.fcmae',
            'batch_size': cfg.batch_size,
            'TARGET_SHAPE': cfg.TARGET_SHAPE,
            'aug_prob': cfg.aug_prob,
            'mixup_alpha': cfg.mixup_alpha,
            'debug': cfg.debug
        },
        'species_mapping': species_mapping
    }
    
    # Save metadata as JSON
    metadata_path = os.path.join(cfg.OUTPUT_DIR, f'training_metadata_{timestamp}.json')
    with open(metadata_path, 'w') as f:
        json.dump(metadata, f, indent=2)
    
    print(f"Saved training metadata to {metadata_path}")
    
    # 4. Save class counts for inference
    if hasattr(train_dataset, 'class_counts'):
        class_counts_path = os.path.join(cfg.OUTPUT_DIR, f'class_counts_{timestamp}.json')
        with open(class_counts_path, 'w') as f:
            json.dump(train_dataset.class_counts, f, indent=2)
        print(f"Saved class counts to {class_counts_path}")
    
    # 5. Create inference config file
    inference_config = {
        'models_dir': models_dir,
        'feature_extractor_path': feature_model_path,
        'metadata_path': metadata_path,
        'class_counts_path': class_counts_path if hasattr(train_dataset, 'class_counts') else None,
        'num_classes': len(taxonomy_df),
        'target_shape': cfg.TARGET_SHAPE,
        'model_name': cfg.model_name,
        'timestamp': timestamp
    }
    
    inference_config_path = os.path.join(cfg.OUTPUT_DIR, f'inference_config_{timestamp}.json')
    with open(inference_config_path, 'w') as f:
        json.dump(inference_config, f, indent=2)
    
    print(f"Saved inference config to {inference_config_path}")
    
    print("\n" + "="*50)
    print("TRAINING SUMMARY")
    print("="*50)
    print(f"Timestamp: {timestamp}")
    print(f"Training samples: {len(train_df):,}")
    print(f"Features shape: {X_train.shape}")
    print(f"Positive samples: {positive_samples:,}")
    print(f"Negative samples: {negative_samples:,}")
    print(f"LightGBM models: {len(lgbm_models)}")
    print("\nSaved files:")
    print(f"- LightGBM models: {models_dir}")
    print(f"- Feature extractor: {feature_model_path}")
    print(f"- Metadata: {metadata_path}")
    print(f"- Inference config: {inference_config_path}")
    if hasattr(train_dataset, 'class_counts'):
        print(f"- Class counts: {class_counts_path}")
    print("="*50)
    
    return {
        'lgbm_models': lgbm_models,
        'feature_model': feature_model,
        'metadata': metadata,
        'inference_config': inference_config,
        'timestamp': timestamp
    }

if __name__ == "__main__":
    main()