In [1]:
# -*- coding: utf-8 -*-
# Author: Vi
# Created on: 2024-06-19 22:03:37
# Description: 自定义数据集，训练模型，测试模型

import os
import datetime
from argparse import Namespace

import tqdm
import wandb
from dotenv import load_dotenv
from torch import nn
from torch.utils.data import DataLoader

from models.panns import CNN10

from datasets import DatasetFactory, Category, Label
from datasets import SupportedSources as Sources
from datasets.SupportedSources import SupportedSourceTypes as SST

from utils import config
import utils.wlog as wlog
from utils.audio.features import get_feature_transformer, FeatureType
from utils.pytorch import Trainer, Tester

load_dotenv()

True

In [2]:
# 参数
parameters = Namespace(
    epochs=10,# 训练轮数
    SR=22050,# 采样率，测试时可以调小一点，以便于快速测试
    DURATION=10,# 训练时音频的长度，单位为秒
    BATCH_SIZE=32,# 批大小
    ACCURACY_THRESHOLD=0.5, # 准确率阈值
    FEATURE_TYPE=FeatureType.MEL_SPECTROGRAM, # 特征类型
)

In [3]:
def get_extractor(feature_type:FeatureType):
    extractor = get_feature_transformer(
        feature_type, **config.features[feature_type.name]
    )
    return extractor

def get_dataloader(train_data, test_data, batch_size=parameters.BATCH_SIZE):
    train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False)
    return train_dataloader, test_dataloader

def get_model(input_size, num_classes):
    model = CNN10(num_class=num_classes, input_size=input_size)
    return model

In [4]:
wandb.login(key=os.environ['WANDB_API_KEY'], force=True)
time_point = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

wandb.init(
    project='ESIL-Noise.Model-S-Experiment-Test',
    name = time_point,
    config=parameters.__dict__
)

save_dir = f'training/{datetime.datetime.now().strftime("%Y-%m-%d")}'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mreviy[0m ([33mesil[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: C:\Users\Vi\.netrc


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

In [5]:
# 自定义数据集

bird_label = Label(
    name='鸟叫',
    sources=Sources.get_data_source(SST.NATURE).get_childs(
        ['北红尾鸲叫声', '叉尾太阳鸟叫声', '大鹰鹃叫声', '强脚树莺叫声', '普通夜鹰叫声', '棕颈钩嘴鹛叫声', '淡脚柳莺叫声']
    )
)
labels = [
    Sources.get_data_source(SST.NATURE).get_child("雷声"),
    Sources.get_data_source(SST.NATURE).get_child("蛙声"),
    bird_label,
] + Sources.get_data_source(SST.TRAFFIC).childs
category = Category('Noise', labels)

In [6]:
# 简单一点的数据集，用来测试用
category = Category("Noise", Sources.get_data_source(SST.NATURE).childs[:2])

In [7]:
# 查看category的labels_dataframe，也可以使用.labels_info查看
category.labels_dataframe

Unnamed: 0,id,name,length
0,0,北红尾鸲叫声,1200
1,1,叉尾太阳鸟叫声,1200


In [8]:
# 使用wandb记录labels
wandb.log({"Labels": wlog.df2table(category.labels_dataframe)})

In [9]:
# 分类数量
num_classes = len(category.labels)
print(f"{num_classes=}")

num_classes=2


In [10]:
# 创建数据集工厂对象，并创建训练集和测试集数据集对象
dataset_factory = DatasetFactory(category)

train_data = dataset_factory.create_dataset(train=True, target_sr=parameters.SR, duration=parameters.DURATION, extractor=get_extractor(parameters.FEATURE_TYPE))
test_data = dataset_factory.create_dataset(train=False, target_sr=parameters.SR, duration=parameters.DURATION, extractor=get_extractor(parameters.FEATURE_TYPE))

In [11]:
# 定义函数，用于获取测试集的音频文件名和标签名称，用于在测试时记录预测错误的结果
def get_file_path(index):
    return os.path.basename(test_data._get_audio_path(index))

def get_label_name(index):
    return category.get_label(index).name

In [12]:
# 创建dataloader
train_dataloader, test_dataloader = get_dataloader(train_data=train_data, test_data=test_data)

# 创建模型
input_size=test_data[0][0].shape[2]
print(f"{input_size=}")
model = get_model(input_size=input_size, num_classes=num_classes)

# 创建loss函数
loss = nn.CrossEntropyLoss()

# 创建训练器
trainer = Trainer(
    model, 
    train_dataloader,
    loss,
    using_amp=False
)
# 重载模型（有需要时），需要指定模型路径和优化器路径
# trainer.reload_trainer(model_path='test_model.pth', optimizer_path='test_params.pth')


input_size=431


In [13]:
# metrics = None
# tester = Tester.from_trainer(trainer, test_dataloader, num_classes, accuracy=0.0 if metrics is None else metrics.accuracy)
# metrics, bad_cases = tester.test_an_epoch(get_file_path=get_file_path, tqdm_instance=None)

In [14]:
metrics = None # 初始化评价指标

# 开始训练
tqdm_instance = tqdm.tqdm(range(parameters.epochs))
for i in tqdm_instance:
    train_loss = trainer.train_an_epoch(tqdm_instance=tqdm_instance)
    tester = Tester.from_trainer(trainer, test_dataloader, num_classes, accuracy=0.0 if metrics is None else metrics.accuracy)
    metrics, bad_cases = tester.test_an_epoch(get_file_path=get_file_path, get_label_name=get_label_name, tqdm_instance=tqdm_instance)
    if metrics.accuracy > parameters.ACCURACY_THRESHOLD:
        tester.save_model(os.path.join(save_dir, f"model_{i}_{metrics.accuracy:.2f}.pth"))
    log_info = {
        "Epoch": i+1,
        "train_loss": train_loss
    }
    for key, value in metrics.model_dump().items():
        if value is not None:
            log_info[key] = value
    wandb.log(log_info)
    wandb.log({"bad_cases": wlog.df2table(bad_cases)})

[valid] accuracy: 0.6417, loss: 0.6815: 100%|██████████| 10/10 [22:13<00:00, 133.31s/it]


In [15]:
wandb.finish() # 结束wandb的记录

VBox(children=(Label(value='0.477 MB of 0.477 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
Epoch,▁▂▃▃▄▅▆▆▇█
accuracy,███▇▁▅▅█▇▅
auc,▆▆█▄▁▄▃▆▅▄
f1_score,███▇▁▅▆█▇▅
f1_score_micro,███▇▁▅▅█▇▅
loss,█▅▇▁▂▃▃▅▂▁
precision,▇█▇▇▁▅▅▇▆▅
recall,███▇▁▅▅█▇▅
train_loss,▆▇▃▃▅▅▁▃▆█

0,1
Epoch,10.0
accuracy,0.64167
auc,0.72992
f1_score,0.60513
f1_score_micro,0.64167
loss,0.68147
precision,0.72489
recall,0.64167
train_loss,0.69156
