## 2025.5.4
使用efficientnetv2 s 训练模型，只带有给定标签。  同时使用 BCEloss + Focal loss

## Libraries

In [None]:
import os
import logging
import random
import gc
import time
import pickle
import cv2
import math
import warnings
from pathlib import Path
from sklearn import metrics
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedKFold
import librosa
from sklearn.metrics import roc_auc_score
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm

import timm

warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.ERROR)

In [None]:
""" 
    FocalLossBCE Use Example
"""
class FocalLossBCE(torch.nn.Module):
    def __init__(
            self,
            alpha: float = 0.25,
            gamma: float = 2,
            reduction: str = "mean",
            bce_weight: float = 0.6,
            focal_weight: float = 1.4,
    ):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        self.bce = torch.nn.BCEWithLogitsLoss(reduction=reduction)
        self.bce_weight = bce_weight
        self.focal_weight = focal_weight

    def forward(self, logits, targets):
        focall_loss = torchvision.ops.focal_loss.sigmoid_focal_loss(
            inputs=logits,
            targets=targets,
            alpha=self.alpha,
            gamma=self.gamma,
            reduction=self.reduction,
        )
        bce_loss = self.bce(logits, targets)
        return self.bce_weight * bce_loss + self.focal_weight * focall_loss


## Configuration

In [None]:
class CFG:
    
    seed = 2025     #train
    debug = True
    #True  False
    apex = False
    print_freq = 100
    num_workers = 2
    # 设置鸟类类别数量
    label = 'primary_label'
    
    # label=="class_label"
            
    OUTPUT_DIR = '/kaggle/working/'

    train_datadir = '/kaggle/input/birdclef-2025/train_audio'
    train_csv = '/kaggle/input/birdclef-2025/train.csv'
    taxonomy_csv = '/kaggle/input/birdclef-2025/taxonomy.csv'
    spectrogram_npy = '/kaggle/input/generate-1-mel-spectrogram/likely_best_audio.pkl'
    model_name = 'efficientnet_b0' 
    pretrained = True
    in_channels = 1

    LOAD_DATA = True  
    FS = 32000
    TARGET_DURATION = 5.0
    TARGET_SHAPE = (256, 256)
    
    N_FFT = 1024
    HOP_LENGTH = 512
    N_MELS = 128
    FMIN = 50
    FMAX = 14000
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    epochs = 10
    batch_size = 32

    n_fold = 5
    selected_folds = [0, 1, 2, 3, 4]   

    optimizer = 'AdamW'
    lr = 5e-4 
    weight_decay = 1e-5
  
    scheduler = 'CosineAnnealingLR'
    min_lr = 1e-6
    T_max = epochs

    aug_prob = 0.5  
    mixup_alpha = 0.5  
    
    def update_debug_settings(self):
        if self.debug:
            self.epochs = 2
            self.selected_folds = [0]


cfg = CFG()



In [None]:
# path = "/kaggle/input/audio-to-mel-spec/audio.pkl"
# with open(path,"rb") as f:
#     all_audio = pickle.load(f)
    
# all_bird_data = all_audio['all_bird_data']

In [None]:
print(f"Using device: {cfg.device}")
print(f"Loading taxonomy data...")
taxonomy_df = pd.read_csv(cfg.taxonomy_csv)
species_ids = taxonomy_df['primary_label'].tolist()
num_classes = len(species_ids)
print(f"Number of classes: {num_classes}")

In [None]:
def set_seed(seed=2025):
    """
    Set seed for reproducibility
    """
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(cfg.seed)

In [None]:
def audio2melspec(audio_data, cfg):
    """Convert audio data to mel spectrogram"""
    if np.isnan(audio_data).any():
        mean_signal = np.nanmean(audio_data)
        audio_data = np.nan_to_num(audio_data, nan=mean_signal)

    mel_spec = librosa.feature.melspectrogram(
        y=audio_data,
        sr=cfg.FS,
        n_fft=cfg.N_FFT,
        hop_length=cfg.HOP_LENGTH,
        n_mels=cfg.N_MELS,
        fmin=cfg.FMIN,
        fmax=cfg.FMAX,
        power=2.0
    )

    mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
    mel_spec_norm = (mel_spec_db - mel_spec_db.min()) / (mel_spec_db.max() - mel_spec_db.min() + 1e-8)
    
    return mel_spec_norm

def process_audio_file(audio_path, cfg):
    """Process a single audio file to get the mel spectrogram"""
    try:
        audio_data, _ = librosa.load(audio_path, sr=cfg.FS)

        target_samples = int(cfg.TARGET_DURATION * cfg.FS)

        if len(audio_data) < target_samples:
            n_copy = math.ceil(target_samples / len(audio_data))
            if n_copy > 1:
                audio_data = np.concatenate([audio_data] * n_copy)

        # Extract center 5 seconds
        start_idx = max(0, int(len(audio_data) / 2 - target_samples / 2))
        end_idx = min(len(audio_data), start_idx + target_samples)
        center_audio = audio_data[start_idx:end_idx]

        if len(center_audio) < target_samples:
            center_audio = np.pad(center_audio, 
                                 (0, target_samples - len(center_audio)), 
                                 mode='constant')

        mel_spec = audio2melspec(center_audio, cfg)
        
        if mel_spec.shape != cfg.TARGET_SHAPE:
            mel_spec = cv2.resize(mel_spec, cfg.TARGET_SHAPE, interpolation=cv2.INTER_LINEAR)

        return mel_spec.astype(np.float32)
        
    except Exception as e:
        print(f"Error processing {audio_path}: {e}")
        return None

def generate_spectrograms(df, cfg):
    """Generate spectrograms from audio files"""
    print("Generating mel spectrograms from audio files...")
    start_time = time.time()

    all_bird_data = {}
    errors = []

    for i, row in tqdm(df.iterrows(), total=len(df)):
        if cfg.debug and i >= 1000:
            break
        
        try:
            samplename = row['samplename']
            filepath = row['filepath']
            
            mel_spec = process_audio_file(filepath, cfg)
            
            if mel_spec is not None:
                all_bird_data[samplename] = mel_spec
            
        except Exception as e:
            print(f"Error processing {row.filepath}: {e}")
            errors.append((row.filepath, str(e)))

    end_time = time.time()
    print(f"Processing completed in {end_time - start_time:.2f} seconds")
    print(f"Successfully processed {len(all_bird_data)} files out of {len(df)}")
    print(f"Failed to process {len(errors)} files")
    
    return all_bird_data

In [None]:
class BirdCLEFDatasetFromNPY(Dataset):
    def __init__(self, df, cfg, spectrograms=None, mode="train"):
        self.df = df
        self.cfg = cfg
        self.mode = mode

        self.spectrograms = spectrograms
        
        taxonomy_df = pd.read_csv(self.cfg.taxonomy_csv)
        self.primary_to_class = dict(zip(taxonomy_df['primary_label'],taxonomy_df['class_name']))
        
        if cfg.label == "primary_label":    
            self.species_ids = taxonomy_df['primary_label'].tolist()

        cfg.num_class = len(self.species_ids)
        self.num_class = len(self.species_ids)
        
        self.label_to_idx = {label: idx for idx, label in enumerate(self.species_ids)}
        

        if 'filepath' not in self.df.columns:
            self.df['filepath'] = self.cfg.train_datadir + '/' + self.df.filename
        
        if 'samplename' not in self.df.columns:
            self.df['samplename'] = self.df.filename.map(lambda x: x.split('/')[0] + '-' + x.split('/')[-1].split('.')[0])

        sample_names = set(self.df['samplename'])
        if self.spectrograms:
            found_samples = sum(1 for name in sample_names if name in self.spectrograms)
            print(f"Found {found_samples} matching spectrograms for {mode} dataset out of {len(self.df)} samples")
        
        if cfg.debug:
            self.df = self.df.sample(min(1000, len(self.df)), random_state=cfg.seed).reset_index(drop=True)
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        samplename = row['samplename']
        spec = None

        if self.spectrograms and samplename in self.spectrograms:
            spec = self.spectrograms[samplename]
        elif not self.cfg.LOAD_DATA:
            spec = process_audio_file(row['filepath'], self.cfg)

        if spec is None:
            spec = np.zeros(self.cfg.TARGET_SHAPE, dtype=np.float32)
            if self.mode == "train":  # Only print warning during training
                print(f"Warning: Spectrogram for {samplename} not found and could not be generated")

        spec = torch.tensor(spec, dtype=torch.float32).unsqueeze(0)  # Add channel dimension

        if self.mode == "train" and random.random() < self.cfg.aug_prob:
            spec = self.apply_spec_augmentations(spec)

        if cfg.label == "primary_label":
            target =self.encode_label(row['primary_label'])
            # 有的有二级标签
            if 'secondary_labels' in row and row['secondary_labels'] not in [[''], None, np.nan]:
                if isinstance(row['secondary_labels'], str):
                    secondary_labels = eval(row['secondary_labels'])
                else:
                    secondary_labels = row['secondary_labels']
                
                for label in secondary_labels:
                    if label in self.label_to_idx:
                        target[self.label_to_idx[label]] = 1.0
        else:
            target =self.encode_label(self.primary_to_class[row['primary_label']])

        return {
            'melspec': spec, 
            'target': torch.tensor(target, dtype=torch.float32),
            'filename': row['filename']
        }
    
    def apply_spec_augmentations(self, spec):
        """Apply augmentations to spectrogram"""
    
        # Time masking (horizontal stripes)
        if random.random() < 0.5:
            num_masks = random.randint(1, 3)
            for _ in range(num_masks):
                width = random.randint(5, 20)
                start = random.randint(0, spec.shape[2] - width)
                spec[0, :, start:start+width] = 0
        
        # Frequency masking (vertical stripes)
        if random.random() < 0.5:
            num_masks = random.randint(1, 3)
            for _ in range(num_masks):
                height = random.randint(5, 20)
                start = random.randint(0, spec.shape[1] - height)
                spec[0, start:start+height, :] = 0
        
        # Random brightness/contrast
        if random.random() < 0.5:
            gain = random.uniform(0.8, 1.2)
            bias = random.uniform(-0.1, 0.1)
            spec = spec * gain + bias
            spec = torch.clamp(spec, 0, 1) 
            
        return spec
    def encode_label(self, label):
        """Encode label to one-hot vector"""
        target = np.zeros(self.num_class)
        if label in self.label_to_idx:
            target[self.label_to_idx[label]] = 1.0
        return target

In [None]:
def collate_fn(batch):
    """Custom collate function to handle different sized spectrograms"""
    batch = [item for item in batch if item is not None]
    if len(batch) == 0:
        return {}
        
    result = {key: [] for key in batch[0].keys()}    #数据集getitem 返回字典，   梅尔频谱图，标签，文件名
    #将一个batch中所有的数据全部放在一个result的字典。
    for item in batch:
        for key, value in item.items():
            result[key].append(value)
    
    for key in result:
        if key == 'target' and isinstance(result[key][0], torch.Tensor):
            result[key] = torch.stack(result[key])   #单个的标签向量合并在一起
        elif key == 'melspec' and isinstance(result[key][0], torch.Tensor):
            shapes = [t.shape for t in result[key]]
            if len(set(str(s) for s in shapes)) == 1:  #说明形状都一样
                result[key] = torch.stack(result[key])
    
    return result    # 此时变成了张量

## Model Definition

In [None]:
class BirdCLEFModel(nn.Module):
    def __init__(self, cfg):
        # 继承nn.Module
        super().__init__()
        # 将配置对象存储为类的属性
        self.cfg = cfg
        
        # 从CSV文件读取鸟类分类信息
        taxonomy_df = pd.read_csv(cfg.taxonomy_csv)
        
        # 使用timm创建预训练模型作为backbone
        self.backbone = timm.create_model(
            cfg.model_name,          # 模型名称
            pretrained=cfg.pretrained,  # 是否使用预训练权重
            in_chans=cfg.in_channels,  # 输入通道数
            drop_rate=0.2,             # dropout率
            drop_path_rate=0.2          # stochastic depth的丢弃率
        )
        
        # 根据backbone类型，修改分类器
        if 'efficientnet' in cfg.model_name:
            # 获取EfficientNet分类器的输入特征数
            backbone_out = self.backbone.classifier.in_features
            # 将EfficientNet的分类器替换为恒等映射
            self.backbone.classifier = nn.Identity()   #恒等映射
        elif 'resnet' in cfg.model_name:
            # 获取ResNet全连接层的输入特征数
            backbone_out = self.backbone.fc.in_features
            # 将ResNet的全连接层替换为恒等映射
            self.backbone.fc = nn.Identity()
        else:
            # 获取其他backbone的分类器的输入特征数
            backbone_out = self.backbone.get_classifier().in_features
            # 重置分类器
            self.backbone.reset_classifier(0, '')
        
        # 自适应平均池化，将特征图调整为1x1
        self.pooling = nn.AdaptiveAvgPool2d(1)
            
        # backbone输出特征维度
        self.feat_dim = backbone_out
        
        # 定义分类器，将特征映射到类别数量
        self.classifier = nn.Linear(backbone_out, cfg.num_class)
        
        # 是否启用Mixup
        self.mixup_enabled = hasattr(cfg, 'mixup_alpha') and cfg.mixup_alpha > 0
        if self.mixup_enabled:
            # Mixup alpha参数
            self.mixup_alpha = cfg.mixup_alpha
            
    def forward(self, x, targets=None):
        # 前向传播函数
        
        # 如果启用Mixup，则进行数据增强
        if self.training and self.mixup_enabled and targets is not None:
            # 使用Mixup增强数据
            mixed_x, targets_a, targets_b, lam = self.mixup_data(x, targets)
            # 替换输入
            x = mixed_x
        else:
            # 未启用Mixup，则设置为空
            targets_a, targets_b, lam = None, None, None
        
        # 通过backbone提取特征
        features = self.backbone(x)
        
        # 如果特征是字典，则从中提取'features'
        if isinstance(features, dict):
            features = features['features']
            
        # 如果特征维度为4，则进行池化和展平操作
        if len(features.shape) == 4:
            # 自适应平均池化
            features = self.pooling(features)
            # 展平
            features = features.view(features.size(0), -1)
        
        # 使用分类器进行分类
        logits = self.classifier(features)
        
        if self.training and self.mixup_enabled and targets is not None:
            
            loss = self.mixup_criterion(F.binary_cross_entropy_with_logits, 
                                       logits, targets_a, targets_b, lam)
            return logits, loss
            
        return logits
    
    def mixup_data(self, x, targets):
        # Mixup数据增强函数
        """Applies mixup to the data batch"""
        # 批量大小
        batch_size = x.size(0)

        # 从beta分布中随机抽样
        lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)

        # 生成随机排列的索引
        indices = torch.randperm(batch_size).to(x.device)

        # Mixup操作
        mixed_x = lam * x + (1 - lam) * x[indices]
        
        # 返回增强后的数据、原始目标、增强后的目标以及lambda值
        return mixed_x, targets, targets[indices], lam
    
    def mixup_criterion(self, criterion, pred, y_a, y_b, lam):
        # Mixup损失函数
        """Applies mixup to the loss function"""
        # 计算Mixup损失
        return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

In [None]:
def get_optimizer(model, cfg):
  
    if cfg.optimizer == 'Adam':
        optimizer = optim.Adam(
            model.parameters(),
            lr=cfg.lr,
            weight_decay=cfg.weight_decay
        )
    elif cfg.optimizer == 'AdamW':
        optimizer = optim.AdamW(
            model.parameters(),
            lr=cfg.lr,
            weight_decay=cfg.weight_decay
        )
    elif cfg.optimizer == 'SGD':
        optimizer = optim.SGD(
            model.parameters(),
            lr=cfg.lr,
            momentum=0.9,
            weight_decay=cfg.weight_decay
        )
    else:
        raise NotImplementedError(f"Optimizer {cfg.optimizer} not implemented")
        
    return optimizer

def get_scheduler(optimizer, cfg):
   
    if cfg.scheduler == 'CosineAnnealingLR':
        scheduler = lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=cfg.T_max,
            eta_min=cfg.min_lr
        )
    elif cfg.scheduler == 'ReduceLROnPlateau':
        scheduler = lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor=0.5,
            patience=2,
            min_lr=cfg.min_lr,
            verbose=True
        )
    elif cfg.scheduler == 'StepLR':
        scheduler = lr_scheduler.StepLR(
            optimizer,
            step_size=cfg.epochs // 3,
            gamma=0.5
        )
    elif cfg.scheduler == 'OneCycleLR':
        scheduler = None  
    else:
        scheduler = None
        
    return scheduler

def get_criterion(cfg):
 
    return FocalLossBCE()

In [None]:
def train_one_epoch(model, loader, optimizer, criterion, device, scheduler=None):
    
    model.train()
    losses = []
    all_targets = []
    all_outputs = []
    
    pbar = tqdm(enumerate(loader), total=len(loader), desc="Training")
    
    for step, batch in pbar:
        if isinstance(batch['melspec'], list):
            batch_outputs = []
            batch_losses = []
            
            for i in range(len(batch['melspec'])):
                inputs = batch['melspec'][i].unsqueeze(0).to(device)
                target = batch['target'][i].unsqueeze(0).to(device)
                
                optimizer.zero_grad()
                output = model(inputs)
                loss = criterion(output, target)   #也就是criterion是在定义损失函数。
                loss.backward()
                
                batch_outputs.append(output.detach().cpu())
                batch_losses.append(loss.item())
            
            optimizer.step()
            outputs = torch.cat(batch_outputs, dim=0).numpy()
            loss = np.mean(batch_losses)
            targets = batch['target'].numpy()
            
        else:
            inputs = batch['melspec'].to(device)
            targets = batch['target'].to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            
            if isinstance(outputs, tuple):
                outputs, loss = outputs  
            else:
                loss = criterion(outputs, targets)
                
            loss.backward()
            optimizer.step()
            
            outputs = outputs.detach().cpu().numpy()
            targets = targets.detach().cpu().numpy()
        
        
        if scheduler is not None and isinstance(scheduler, lr_scheduler.OneCycleLR):
            scheduler.step()
            
        all_outputs.append(outputs)
        all_targets.append(targets)
        losses.append(loss if isinstance(loss, float) else loss.item())
        
        pbar.set_postfix({
            'train_loss': np.mean(losses[-10:]) if losses else 0,
            'lr': optimizer.param_groups[0]['lr']
        })
    
    all_outputs = np.concatenate(all_outputs)
    all_targets = np.concatenate(all_targets)
    auc = calculate_auc(all_targets, all_outputs)
    avg_loss = np.mean(losses)
    
    return avg_loss, auc
def validate(model, loader, criterion, device):
    model.eval()
    losses = []
    all_targets = []
    all_outputs = []
    
    with torch.no_grad():
        for batch in tqdm(loader, desc="Validation"):
            if isinstance(batch['melspec'], list):
                batch_outputs = []
                batch_losses = []
                
                for i in range(len(batch['melspec'])):
                    inputs = batch['melspec'][i].unsqueeze(0).to(device)
                    target = batch['target'][i].unsqueeze(0).to(device)
                    
                    output = model(inputs)
                    loss = criterion(output, target)
                    
                    batch_outputs.append(output.detach().cpu())
                    batch_losses.append(loss.item())
                
                outputs = torch.cat(batch_outputs, dim=0).numpy()
                loss = np.mean(batch_losses)
                targets = batch['target'].numpy()
                
            else:
                inputs = batch['melspec'].to(device)
                targets = batch['target'].to(device)
                
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                
                outputs = outputs.detach().cpu().numpy()
                targets = targets.detach().cpu().numpy()
            
            all_outputs.append(outputs)
            all_targets.append(targets)
            losses.append(loss if isinstance(loss, float) else loss.item())
    
    all_outputs = np.concatenate(all_outputs)
    all_targets = np.concatenate(all_targets)
    
    auc = calculate_auc(all_targets, all_outputs)
    avg_loss = np.mean(losses)
    
    return avg_loss, auc

def calculate_auc(targets, outputs):
  
    num_classes = targets.shape[1]
    aucs = []
    
    probs = 1 / (1 + np.exp(-outputs))
    
    for i in range(num_classes):
        
        if np.sum(targets[:, i]) > 0:
            class_auc = roc_auc_score(targets[:, i], probs[:, i])
            aucs.append(class_auc)
    
    return np.mean(aucs) if aucs else 0.0


## Training!

In [None]:
def run_training(df, cfg):
    """训练函数"""

    taxonomy_df = pd.read_csv(cfg.taxonomy_csv)
    primary_to_class = dict(zip(taxonomy_df['primary_label'],taxonomy_df['class_name']))
    
    df["class_name"] =df['primary_label'].apply(lambda x: primary_to_class.get(x))
    
    if cfg.label=="primary_label":
        result=result_primary
        pkl_name=pkl_name_1
    if cfg.debug:
        cfg.update_debug_settings()

    spectrograms = None
    if cfg.LOAD_DATA:

        print("run_training：Loading pre-computed mel spectrograms from NPY file...")
        try:
            spectrograms = np.load(cfg.spectrogram_npy, allow_pickle=True).item()
            print(f"Loaded {len(spectrograms)} pre-computed mel spectrograms")
       
        except Exception as e:
            print(f"run_training：Error loading pre-computed spectrograms: {e}")
            print("run_training：Will generate spectrograms on-the-fly instead.")
            cfg.LOAD_DATA = False
    
    if not cfg.LOAD_DATA:
        print("run_training：Will generate spectrograms on-the-fly during training.")
        if 'filepath' not in df.columns:
            df['filepath'] = cfg.train_datadir + '/' + df.filename
        if 'samplename' not in df.columns:
            df['samplename'] = df.filename.map(lambda x: x.split('/')[0] + '-' + x.split('/')[-1].split('.')[0])
        
    skf = StratifiedKFold(n_splits=cfg.n_fold, shuffle=True, random_state=cfg.seed)
    
    best_scores = []
    
    for fold, (train_idx, val_idx) in enumerate(skf.split(df, df[cfg.label])):
        if fold not in cfg.selected_folds:  #0 1 2 3 4
            continue
            
        print(f'\n{"="*30} Fold {fold} {"="*30}')
        
        train_df = df.iloc[train_idx].reset_index(drop=True)
        val_df = df.iloc[val_idx].reset_index(drop=True)
        
        print(f'run_training：Training set: {len(train_df)} samples')
        print(f'run_training：Validation set: {len(val_df)} samples')
        # 数据集   数据加载器
        # spectrograms=spectrograms 如果是正常的那么就是预加载的数据集如果为none，那么现场处理数据集
        train_dataset = BirdCLEFDatasetFromNPY(train_df, cfg, spectrograms=spectrograms, mode='train')
        val_dataset = BirdCLEFDatasetFromNPY(val_df, cfg, spectrograms=spectrograms, mode='valid')
        
        train_loader = DataLoader(
            train_dataset, 
            batch_size=cfg.batch_size, 
            shuffle=True, 
            num_workers=cfg.num_workers,
            pin_memory=True,
            collate_fn=collate_fn,
            drop_last=True
        )
        
        val_loader = DataLoader(
            val_dataset, 
            batch_size=cfg.batch_size, 
            shuffle=False, 
            num_workers=cfg.num_workers,
            pin_memory=True,    # 使用gpu，预留内存，从而cpu  gpu转移更块
            collate_fn=collate_fn
        )
        
        model = BirdCLEFModel(cfg).to(cfg.device)
        optimizer = get_optimizer(model, cfg)
        criterion = get_criterion(cfg)
        
        if cfg.scheduler == 'OneCycleLR':
            scheduler = lr_scheduler.OneCycleLR(
                optimizer,
                max_lr=cfg.lr,
                steps_per_epoch=len(train_loader),
                epochs=cfg.epochs,
                pct_start=0.1
            )
        else:
            scheduler = get_scheduler(optimizer, cfg)
        
        best_auc = 0
        best_epoch = 0
        temp={
                "train_loss":[],
                "train_auc":[],
                "val_loss":[],
                "val_auc":[]
            }
        
        for epoch in range(cfg.epochs):
            print(f"\nEpoch {epoch+1}/{cfg.epochs}")
            
            train_loss, train_auc = train_one_epoch(
                model, 
                train_loader, 
                optimizer, 
                criterion, 
                cfg.device,
                scheduler if isinstance(scheduler, lr_scheduler.OneCycleLR) else None
            )
            
            val_loss, val_auc = validate(model, val_loader, criterion, cfg.device)

            if scheduler is not None and not isinstance(scheduler, lr_scheduler.OneCycleLR):
                if isinstance(scheduler, lr_scheduler.ReduceLROnPlateau):
                    scheduler.step(val_loss)
                else:
                    scheduler.step()

            print(f"Train Loss: {train_loss:.4f}, Train AUC: {train_auc:.4f}")
            print(f"Val Loss: {val_loss:.4f}, Val AUC: {val_auc:.4f}")
            temp["train_loss"].append(train_loss)
            temp["train_auc"].append(train_auc)
            temp["val_loss"].append(val_loss)
            temp["val_auc"].append(val_auc)
            
            if val_auc > best_auc:
                best_auc = val_auc
                best_epoch = epoch + 1
                print(f"New best AUC: {best_auc:.4f} at epoch {best_epoch}")

                torch.save({
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
                    'epoch': epoch,
                    'val_auc': val_auc,
                    'train_auc': train_auc,
                    'cfg': cfg
                }, f"model_fold{fold}.pth")
        
        best_scores.append(best_auc)
        result.append(temp)
        print(f"\nBest AUC for fold {fold}: {best_auc:.4f} at epoch {best_epoch}")
        
        # Clear memory
        del model, optimizer, scheduler, train_loader, val_loader  
        
        #del 是 Python 的一个关键字，用于删除变量。
        # 注意： 删除变量并不会立即释放内存，Python 的垃圾回收机制会在适当的时候自动回收这些内存。
        # torch.cuda.empty_cache() 是 PyTorch 中的一个函数，用于清空 CUDA 缓存
        
        torch.cuda.empty_cache()
        gc.collect()
    # gc.collect() 是 Python 的垃圾回收器（garbage collector）的一个函数，用于手动触发垃圾回收。
    
    final_result=(result,best_scores)
    path = os.path.join("/kaggle/working/",pkl_name)
    with open(path, "wb") as f:
        pickle.dump(final_result, f)
        print("Save %s." % path)

    print("\n" + "="*60)
    print("Cross-Validation Results:")
    for fold, score in enumerate(best_scores):
        print(f"Fold {cfg.selected_folds[fold]}: {score:.4f}")
    print(f"Mean AUC: {np.mean(best_scores):.4f}")
    print("="*60)

In [None]:
def main_primary():    
    print("\n加载train_csv")
    train_df = pd.read_csv(cfg.train_csv)
    cfg.label="primary_label"
    
    print("\n开始训练")
    print(f"LOAD_DATA is set to {cfg.LOAD_DATA}")
    if cfg.LOAD_DATA:
        print("使用已经预处理好的音频数据进行训练")
    else:
        print("先进行预处理，在进行训练")
    
    run_training(train_df, cfg)
    
    print("\n训练完成")

In [None]:
if __name__ == "__main__":
    result_primary=[]
    pkl_name_1="results_primary_effnetv2.pkl"
    main_primary()
