In [None]:

# -*- coding: utf-8 -*-
# Author: Vi
# Created on: 2024-06-26 09:17:19
# Description: Experiments for tuning and optimizing the model
import os
import wandb
import datetime
import torch
from dotenv import load_dotenv
from functools import partial
from tqdm import tqdm
import utils.wlog as wlog
from utils.common import clear_jupyter

load_dotenv()

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

wandb.login(key=os.getenv("WANDB_API_KEY"), relogin=True) # 登录wandb

In [None]:
# Project Configuration
from argparse import Namespace

feature_config = Namespace(
    type="MEL_SPECTROGRAM",
    args=Namespace(
        sample_rate=32000,
        n_fft=1024,
        hop_length=512,
        win_length=None
        # ⭐ n_mels in sweep_config
    ).__dict__
)

CONFIG = Namespace(
    project_name="Noise-Model-S-Sweep-Test",
    seed=1202,
    model_type="CNN10",
    feature = feature_config.__dict__,
    duration=10,
    epochs=2,
    batch_size=32,
    num_workers=8,
    # # ⭐ 需要扫描的参数写在sweep_config的parameters_dict里面
    # optimizer="adam",
    # n_mels=64,
    # dropout=0.1,
    # lr=0.001,
)


In [None]:
# Sweep Configuration: https://docs.wandb.ai/guides/sweeps/define-sweep-configuration
sweep_count = 5 # 定义搜索次数

sweep_config = {"method": "random"}  # grid, random, bayes

metric = {
    "name": "accuracy",
    "goal": "maximize",  # minimize
}

sweep_config["metric"] = metric

parameters_dict = {
    "optimizer": {"values": ["adam", "sgd"]},
    "n_mels": {"values": list(range(64, 257, 32))},
    "dropout": {"distribution": "uniform", "min": 0, "max": 0.6},
    "lr": {
        "distribution": "uniform",  # log_uniform_values, uniform, q_uniform
        "min": 1e-6,
        "max": 0.1,
    },
}

sweep_config["parameters"] = parameters_dict

# 定义剪枝策略, 提前终止那些没有希望的任务
sweep_config['early_terminate'] = {
    'type':'hyperband',
    'min_iter':3,
    'eta':2,
    's':3
} #在step=3, 6, 12 时考虑是否剪枝

In [None]:
# 创建sweep
sweep_id = wandb.sweep(
    sweep_config, 
    project=CONFIG.project_name
) # ⭐如果需要继续之前的sweep，在网页上找到对应的sweep id，直接赋值即可 sweep_id = "xxxxxx" 

sweep_id

In [None]:
# ⭐ Customized for the Dataset
import datasets.SupportedSources as Sources
from datasets.SupportedSources import SupportedSourceTypes as SST
from datasets import Label, Category, DatasetFactory

from utils.audio.features import get_feature_transformer

from torch.utils.data import DataLoader

def get_labels():# ⭐修改数据标签
    # nature = Sources.get_data_source(SST.NATURE)
    # traffic = Sources.get_data_source(SST.TRAFFIC)

    # labels = nature.get_childs("雷声,蛙声,蝉鸣声,狗叫声".split(',')) + traffic.childs

    # bird_label = Label(
    #     name='鸟叫',
    #     sources=nature.get_childs(
    #         # ['北红尾鸲叫声', '叉尾太阳鸟叫声', '大鹰鹃叫声', '强脚树莺叫声', '普通夜鹰叫声', '棕颈钩嘴鹛叫声', '淡脚柳莺叫声']
    #         ['长尾缝叶莺叫声', '普通夜鹰叫声', '大鹰鹃叫声']
    #     )
    # )

    # labels.append(bird_label)
    
    nature = Sources.get_data_source(SST.NATURE)
    labels = nature.childs[:2] # 测试用，只取两个标签
    return labels

def get_category(name:str):
    labels = get_labels()
    category = Category(name=name, labels=labels)
    return category

def get_extractor(feature_type:str, feature_args:dict,**kwargs):
    extractor = get_feature_transformer(feature_type, **feature_args, **kwargs)
    return extractor
    
def get_dataset(category, target_sr, duration, feature_type, feature_args, **kwargs):
    factory = DatasetFactory(category, test_ratio=0.2, seed=CONFIG.seed)
    extractor = get_extractor(feature_type, feature_args, **kwargs) # ⭐ 注意这里的kwargs参数，根据实际情况传入
    train_dataset = factory.create_dataset(train=True, target_sr=target_sr,duration=duration,extractor=extractor)
    test_dataset = factory.create_dataset(train=False, target_sr=target_sr,duration=duration,extractor=extractor)
    return train_dataset, test_dataset

def get_dataloader(train_dataset, test_dataset):
    train_loader = DataLoader(train_dataset, batch_size=CONFIG.batch_size, shuffle=True, num_workers=CONFIG.num_workers)
    test_loader = DataLoader(test_dataset, batch_size=CONFIG.batch_size, shuffle=False, num_workers=CONFIG.num_workers)
    return train_loader, test_loader

# 定义函数，用于获取测试集的音频文件名和标签名称，用于在测试时记录预测错误的结果
def get_file_path(index, test_dataset):
    return os.path.basename(test_dataset._get_audio_path(index))

def get_label_name(index, category):
    return category.get_label(index).name

In [None]:
# Model
def get_num_class(category:Category):
    return len(category.labels)

def get_model_input_size(dataset):
    input_size = dataset[0][0].shape[2]
    return input_size

def get_model(**kwargs):
    if CONFIG.model_type == 'CNN10':
        from models.panns import CNN10 as Model # args: num_class, input_size, dropout
        # ⭐️ Add your custom model here ⭐️
    else:
        raise ValueError('Invalid model type: {}'.format(CONFIG.model_type))
    model = Model(**kwargs).to(DEVICE)
    print(repr(model))
    return model

In [None]:
# Optimizer
from torch.optim import Adam, SGD

def get_optimizer(type_, model, lr):
    if type_ == 'adam':
        return Adam(model.parameters(), lr=lr)
    elif type_ =='sgd':
        return SGD(model.parameters(), lr=lr)
    else:
        raise ValueError('Invalid optimizer type')
    
def get_lr_scheduler(optimizer, type_='exp', step_size=10, gamma=0.1):
    if type_ == 'exp':
        return torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)
    else:
        raise ValueError('Invalid lr scheduler type')
    
# Loss
from torch.nn import CrossEntropyLoss

def get_loss_fn(type_='crossentropy'):
    if type_ == 'crossentropy':
        return CrossEntropyLoss()
    else:
        raise ValueError('Invalid loss function type')

In [None]:
# sweep agent for hyperparameter tuning and optimization
from utils.pytorch import Trainer, Tester

def train(config):
    save_dir = os.path.join('training', datetime.datetime.now().strftime("%Y%m%d%H%M%S"))

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    category = get_category(name="test")

    wandb.log({"Label": wlog.df2table(category.labels_dataframe)})
    
    train_dataset, test_dataset = get_dataset(
        category,
        target_sr=config.feature.get('args').get('sample_rate'),
        duration=config.duration,
        feature_type=config.feature.get('type'),
        feature_args=config.feature.get('args'),
        n_mels=config.n_mels,  # ⭐ sweep n_mels
    )  # ⭐ 注意这里的kwargs参数，根据实际情况传入，例如这里的n_mels
    
    train_dataloader, test_dataloader = get_dataloader(train_dataset, test_dataset)
    
    num_class = get_num_class(category)
    input_size = get_model_input_size(train_dataset)
    model = get_model(
        num_class=num_class, 
        input_size=input_size, 
        dropout=config.dropout # ⭐ sweep dropout
    )
    
    optimizer = get_optimizer(
        type_=config.optimizer, # ⭐ sweep optimizer
        model=model, 
        lr=config.lr # ⭐ sweep lr
    )
    lr_scheduler = get_lr_scheduler(optimizer=optimizer)
    loss_fn = get_loss_fn()

    trainer = Trainer(
        model=model,
        training_dataloader=train_dataloader,
        loss_func=loss_fn,
        optimizer=optimizer,
        scheduler=lr_scheduler,
        device=DEVICE,
    )

    best_acc = 0.0 # 最佳准确率
    tqdm_instance = tqdm(range(config.epochs))
    for i in tqdm_instance:
        train_loss = trainer.train_an_epoch(tqdm_instance=tqdm_instance)
        tester = Tester.from_trainer(
            trainer,
            test_dataloader,
            num_class,
        )
        metrics, bad_cases = tester.test_an_epoch(
            get_file_path=partial(get_file_path, test_dataset=test_dataset),
            get_label_name=partial(get_label_name, category=category),
            tqdm_instance=tqdm_instance,
        )
        if metrics.accuracy > best_acc:
            best_acc = metrics.accuracy
            trainer.save_model(os.path.join(save_dir, "best_model.pth"))
            trainer.save_optimizer(os.path.join(save_dir, "optimizer.pth"))
            with open(os.path.join(save_dir, 'best_acc.txt'), 'w') as f:
                f.write(f'[{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}] model:{repr(model)}, Epoch: {i+1}, Best Accuracy: {best_acc:.4f}, Optimizer: {type(optimizer).__name__}')
        log_info = {"Epoch": i + 1, "train_loss": train_loss, 'lr':optimizer.param_groups[0]['lr']}
        if weight_decay:=optimizer.param_groups[0].get('weight_decay', None) is not None:
            log_info['weight_decay'] = weight_decay
        for key, value in metrics.model_dump().items():
            if value is not None:
                log_info[key] = value
        wandb.log(log_info)
        wandb.log({"bad_cases": wlog.df2table(bad_cases)})

    wandb.finish()

In [None]:
# ⭐start
def main():
    wandb.init(
        project=CONFIG.project_name,
        name=datetime.datetime.now().strftime("%Y%m%d_%H%M%S"),
        config=CONFIG.__dict__,
        save_code=True,
    )
    train(wandb.config)
    
wandb.agent(sweep_id=sweep_id, function=main, count=sweep_count)