# 导入相关依赖库.
第一个块用于检查版本.

In [1]:
import os
import random

import numpy as np
import pandas as pd
import sklearn
import timm
import torch
import torchvision
import tqdm

from sklearn.model_selection import KFold
from torch.nn.functional import kl_div, log_softmax
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Resize

print(f'numpy: {np.__version__}')
print(f'pandas: {pd.__version__}')
print(f'sklearn: {sklearn.__version__}')
print(f'timm: {timm.__version__}')
print(f'torch: {torch.__version__}')
print(f'torchvision: {torchvision.__version__}')
print(f'tqdm: {tqdm.__version__}')

numpy: 1.26.4
pandas: 2.2.2
sklearn: 1.2.2
timm: 0.9.16
torch: 2.1.2
torchvision: 0.16.2
tqdm: 4.66.1


# 设置随机种子, 相关路径信息和超参数.

In [2]:
BATCH_SIZE = 16
DATA_DIR = '/kaggle/input/hms-harmful-brain-activity-classification'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
EPOCHS = 3
LABELS = ['seizure_vote',
          'lpd_vote',
          'gpd_vote',
          'lrda_vote',
          'grda_vote',
          'other_vote']
LEARNING_RATE = 1e-3
N_FOLD = 5
PRE_TRAINED = True
SEED = 2024
WEIGHT_DECAY = 1e-2

# 设置随机种子.
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# 数据预处理.
1. 准备数据(加载频谱图).
2. 创建数据集,包含数据预处理和信号处理.

In [3]:
def prepare_data(data_dir):
    """准备数据."""
    dataframe = pd.read_csv(f'{data_dir}/train.csv')

    # 根据频谱图分组.
    dataframe = dataframe.groupby('spectrogram_id')[LABELS].sum()

    # 每个样本投票数不同, 对投票进行归一化.
    vote_sum = dataframe.sum(axis=1)
    for label in LABELS:
        dataframe[label] /= vote_sum

    # 添加频谱图路径信息.
    dataframe['spec_path'] = dataframe.index.map(
        lambda filename: f'{data_dir}/train_spectrograms/{filename}.parquet'
    )
    dataframe = dataframe.reset_index()

    return dataframe


class SpectrogramDataset(Dataset):
    """频谱图数据集."""
    def __init__(self, dataframe):
        self.dataframe = dataframe

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, index):
        # 读取频谱图.
        row_elm = self.dataframe.iloc[index]
        spectrogram = pd.read_parquet(row_elm.spec_path)

        signal = spectrogram.fillna(-1)  # 填充缺失值.
        signal = signal.values[:, 1:]  # 去掉第一个索引列.
        signal = signal.transpose()  # 转换成关于时间的函数.
        signal = torch.tensor(signal, dtype=torch.float32)  # 转换成张量.
        signal = self.preprocess(signal[None, :])  # 增加batch维度再传递.
        signal = Resize([512, 512], antialias=False)(signal)  # 统一大小(不选择抗锯齿).

        label = np.asarray(row_elm.loc[LABELS].values, np.float32)
        label = torch.from_numpy(label)  # 转换成张量.

        return signal, label

    @staticmethod
    def preprocess(signal):
        """处理信息."""
        # 转换成对数, 使得数据更平滑稳定.
        signal = torch.clip(signal, np.exp(-6), np.exp(10))
        signal = torch.log(signal)

        # 进行标准化, 符合ImageNet的预训练数据格式.
        mean, std = torch.mean(signal), torch.std(signal)
        signal = (signal - mean) / (std + 1e-6)  # 常小数避免除零保证数值稳定性.

        return signal

# 创建模型.

In [4]:
def create_model(pre_trained=True,
                 device=torch.device('cpu')):
    """创建模型."""
    model = timm.create_model('tf_efficientnet_b0.ns_jft_in1k',
                              pretrained=pre_trained,
                              num_classes=6,
                              in_chans=1)
    model.to(device)

    return model

# 实现损失函数, 训练代码和验证代码.
1. 输出需要求对数已满足和数据预处理对齐.

In [5]:
def kl_divergence(y_pred, y):
    """计算KL散度."""
    y_pred = log_softmax(y_pred, dim=1)
    y_pred = kl_div(y_pred, y, reduction='batchmean')

    return y_pred


def train(model,
          dataloader,
          loss_fn,
          optimizer,
          epochs=1,
          device=torch.device('cpu')):
    """训练模型."""
    for epoch in tqdm.tqdm(range(epochs)):
        model.train()

        total_loss = 0.0
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            y_pred = model(x)

            # 计算损失并反向传播.
            loss = loss_fn(y_pred, y)
            loss.backward()
            total_loss += loss.item()

            # 更新参数.
            optimizer.step()
            optimizer.zero_grad()

        print(f'Epoch: {epoch + 1}, Loss: {total_loss / len(dataloader)}')


def validation(model, dataloader, loss_fn, device=torch.device('cpu')):
    """验证模型."""
    model.eval()

    total_loss = 0.0
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        y_pred = model(x)

        # 计算损失.
        loss = loss_fn(y_pred, y)
        total_loss += loss.item()

    print(f'Val Loss: {total_loss / len(dataloader)}')

In [6]:
df = prepare_data(DATA_DIR)

# 使用交叉验证训练.
kfold = KFold(n_splits=N_FOLD, shuffle=True, random_state=SEED)
for fold, (train_set, valid_set) in enumerate(kfold.split(df)):
    print(f'Fold {fold + 1}:')

    # 创建DataLoader.
    train_ds = SpectrogramDataset(df.iloc[train_set])
    train_dl = DataLoader(train_ds,
                          batch_size=BATCH_SIZE,
                          num_workers=os.cpu_count(),
                          drop_last=True)
    valid_ds = SpectrogramDataset(df.iloc[valid_set])
    valid_dl = DataLoader(valid_ds,
                          batch_size=BATCH_SIZE,
                          num_workers=os.cpu_count(),
                          drop_last=True)

    # 创建模型.
    model = create_model(PRE_TRAINED, DEVICE)
    optimizer = AdamW(model.parameters(),
                      lr=LEARNING_RATE,
                      weight_decay=WEIGHT_DECAY)
    train(model, train_dl, kl_divergence, optimizer, EPOCHS, DEVICE)

    # 保存模型.
    torch.save(model.state_dict(), f'fold{fold + 1:02d}-model.pth')

    # 验证模型.
    validation(model, valid_dl, kl_divergence, DEVICE)

Fold 1:


model.safetensors:   0%|          | 0.00/21.4M [00:00<?, ?B/s]

 33%|███▎      | 1/3 [03:18<06:36, 198.02s/it]

Epoch: 1, Loss: 0.9158835526850584


 67%|██████▋   | 2/3 [06:20<03:09, 189.07s/it]

Epoch: 2, Loss: 0.6201312914758706


100%|██████████| 3/3 [09:22<00:00, 187.62s/it]

Epoch: 3, Loss: 0.504804401973383





Val Loss: 0.718278951996522
Fold 2:


 33%|███▎      | 1/3 [03:07<06:14, 187.31s/it]

Epoch: 1, Loss: 0.9049592796418306


 67%|██████▋   | 2/3 [06:13<03:06, 186.41s/it]

Epoch: 2, Loss: 0.6147540080890381


100%|██████████| 3/3 [09:18<00:00, 186.33s/it]

Epoch: 3, Loss: 0.4917825357787472





Val Loss: 0.7879917137056803
Fold 3:


 33%|███▎      | 1/3 [03:07<06:15, 187.59s/it]

Epoch: 1, Loss: 0.8960190440574995


 67%|██████▋   | 2/3 [06:14<03:07, 187.14s/it]

Epoch: 2, Loss: 0.6090584173262548


100%|██████████| 3/3 [09:20<00:00, 186.91s/it]

Epoch: 3, Loss: 0.4929387055927043





Val Loss: 0.716630942958722
Fold 4:


 33%|███▎      | 1/3 [03:09<06:18, 189.22s/it]

Epoch: 1, Loss: 0.9015311781022188


 67%|██████▋   | 2/3 [06:15<03:07, 187.47s/it]

Epoch: 2, Loss: 0.6226945290569779


100%|██████████| 3/3 [09:21<00:00, 187.07s/it]

Epoch: 3, Loss: 0.5096829648551752





Val Loss: 0.8221211845068623
Fold 5:


 33%|███▎      | 1/3 [03:08<06:16, 188.39s/it]

Epoch: 1, Loss: 0.904100743129099


 67%|██████▋   | 2/3 [06:13<03:06, 186.59s/it]

Epoch: 2, Loss: 0.6099935338812338


100%|██████████| 3/3 [09:18<00:00, 186.25s/it]

Epoch: 3, Loss: 0.48294984270557223





Val Loss: 0.7360910494550527
