<h1 style="text-align:center;"><a href="https://docs.wandb.ai/" target="_blank">Weights and Biases(wandb)</a></h1>
<p style="text-align:center;">模型训练可视化分析<p>
<p style="text-align:center;font-size:18px;">2024-07-09</p>

### 特性

* 日志上传云端永久存储，便于分享不怕丢失。

* 可以存管代码,数据集和模型的版本，随时复现。(wandb.Artifact)

* 可以使用交互式表格进行case分析(wandb.Table)

* 可以自动化模型调参。(wandb.sweep)

### 核心功能

1. 实验跟踪：experiment tracking （wandb.log）

2. 版本管理：version management (wandb.log_artifact, wandb.save)

3. case分析：case visualization (wandb.Table, wandb.Image)

4. 超参调优：model optimization (wandb.sweep)

### 以ESC-10数据集为例，展示如何使用wandb进行可视化分析

<p style="font-size:28px">数据集下载: <a href="https://github.com/karoldvl/ESC-50/archive/master.zip" target="_blank">ESC-50</a></p>

<div style="text-align:center"><img src="image/index/1720083382984.png" width="70%" style="border-radius: 10px;border: 2px solid #ddd;"></div>

#### 1 安装wandb

In [1]:
# !pip install wandb

#### 2 <a href="https://wandb.ai/login" target="_blank">注册</a>, 右上角中找到Quickstart，获取API Key
<div style="text-align:center"><img src="image/index/1719988297636.png" width="500px" style="border-radius: 10px;border: 1px solid #ddd;"></div>
<div style="text-align:center"><img src="image/index/1719988151134.png" width="500px" style="border-radius: 10px;border: 1px solid #ddd;"></div>

#### 3 登陆
* <p style="font-size:20px">设置环境变量: <strong>WANDB_API_KEY=@your_api_key</strong>(建议写在.env文件中)</p>

In [2]:
import os
from dotenv import load_dotenv
load_dotenv()

import wandb
wandb.login(key=os.getenv("WANDB_API_KEY"))

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
wandb: Currently logged in as: reviy (esil). Use `wandb login --relogin` to force relogin
wandb: Appending key for api.wandb.ai to your netrc file: C:\Users\Vi\.netrc


True

### 4 模型训练中加入wandb
#### 4-0 引入依赖库

In [3]:
import torch
from argparse import Namespace
from datetime import datetime
import pandas as pd
from sklearn.model_selection import train_test_split
from torchaudio.transforms import MelSpectrogram
from dataset import AudioDataset
from torch.utils.data import DataLoader
from models.panns import CNN10
from torch import nn
from tqdm import tqdm
from utils import df2table, log2wandb, train_an_epoch, test_an_epoch

#### 4-1 参数设置

In [4]:
config = Namespace(
    project_name = "wandb_esc10_demo", # wandb项目名称
    data_dir = 'ESC-50-master/audio', # 音频文件目录
    meta_file = 'ESC-50-master/meta/esc50.csv', # 元数据文件
    sr = 22050, # 采样率
    duration = 5, # 音频时长
    epochs = 10, # 训练轮数
    batch_size = 32, # 批次大小
    lr = 0.001, # 学习率
    step_size = 4, # 步长
    gamma = 0.7, # 学习率衰减率
    random_seed = 1202, # 随机种子
    n_fft = 1024, # FFT的窗长
    hop_length = 512, # 窗步长
    n_mels = 64, # 梅尔滤波器组的数量
    dropout = 0.1, # 丢弃率
)
device = 'cuda' if torch.cuda.is_available() else 'cpu' # 选择设备
device

'cpu'

#### 初始化wandb项目

In [5]:
ENABLE_WANDB = True  # 是否开启wandb可视化分析功能
# ⭐ 初始化wandb项目
if ENABLE_WANDB:
    wandb.init(
        project=config.project_name,
        config=config.__dict__,
        name=datetime.now().strftime("%Y%m%d_%H%M%S"),
    )

#### 4-2 预处理
##### 读取元数据文件

In [6]:
df = pd.read_csv(config.meta_file)
df = df[df['esc10']==True] # 只选取esc10的音频
categoties = df['category'].unique() # 10个类别
df['label'] = df['category'].apply(lambda x: categoties.tolist().index(x)) # 给每个音频打上标签
label_df = pd.DataFrame({'category': categoties, "length": [len(df[df['category']==c]) for c in categoties]})
label_df.T

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
category,dog,chainsaw,crackling_fire,helicopter,rain,crying_baby,clock_tick,sneezing,rooster,sea_waves
length,40,40,40,40,40,40,40,40,40,40


In [7]:
# ⭐ 记录分类到wandb
if ENABLE_WANDB:
    wandb.log({'Labels': df2table(label_df)})

##### 拆分训练集和测试集

In [8]:
X, y = df['filename'].values, df['label'].values 
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=config.random_seed)
len(X_train), len(X_test), len(y_train), len(y_test)

(320, 80, 320, 80)

#### 4-3 构建数据集

In [9]:
dataset_params = {
    'data_dir': config.data_dir,
    'sr': config.sr,
    'duration': config.duration,
    'device': device,
    'transform': MelSpectrogram(
        sample_rate=config.sr,
        n_fft=config.n_fft,
        hop_length=config.hop_length,
        n_mels=config.n_mels
    )
}
train_dataset = AudioDataset(X=X_train, y=y_train, **dataset_params)
test_dataset = AudioDataset(X=X_test, y=y_test, **dataset_params)
train_dataset[0][0].shape

torch.Size([1, 64, 216])

#### 4-4 创建Dataloaders

In [10]:
train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False)
for _, (data, target, _) in enumerate(train_dataloader):
    data_shape, target_shape = data.shape, target.shape
    break
data_shape, target_shape

(torch.Size([32, 1, 64, 216]), torch.Size([32]))

#### 4-5 创建模型

In [16]:
model = CNN10(
    num_class=len(categoties), input_size=data_shape[-1], dropout=config.dropout
).to(device)
model

CNN10(num_class=10, input_size=216, dropout=0.1)

#### 4-6 定义优化器、lr调度器、损失函数

In [12]:
optimizer = torch.optim.Adam(model.parameters(), lr=config.lr) # 定义优化器
scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer, step_size=config.step_size, gamma=config.gamma
) # 定义学习率调度器
loss_func = nn.CrossEntropyLoss().to(device) # 定义损失函数

#### 4-7 开始训练

In [13]:
for epoch in (tqdm_bar:= tqdm(range(config.epochs))):
    training_loss = train_an_epoch(
        model=model, optimizer=optimizer, scheduler=scheduler, data_loader=train_dataloader, device=device, loss_func=loss_func, tqdm_instance=tqdm_bar
    )

    testing_loss, accuracy, bad_cases = test_an_epoch(
        categoties=categoties, model=model, data_loader=test_dataloader, device=device, loss_func=loss_func, tqdm_instance=tqdm_bar,
    )
    
    # ⭐ 记录数据到wandb
    if ENABLE_WANDB:
        log2wandb(
            epoch=epoch+1, # 迭代次数
            training_loss=training_loss, # 训练损失
            testing_loss=testing_loss, # 测试损失
            accuracy=accuracy, # 准确率
            lr=optimizer.param_groups[0]["lr"], # 学习率
            bad_cases=bad_cases, # 坏样本
        )

[valid] Progress: 3/3: 100%|██████████| 10/10 [02:34<00:00, 15.46s/it]              


#### 结束

In [14]:
# ⭐ 结束wandb记录
if ENABLE_WANDB: 
    wandb.finish()

0,1
accuracy,▁█▂▅▇█████
epoch,▁▂▃▃▄▅▆▆▇█
lr,█▃▂▁▁▁▁▁▁▁
testing_loss,█▆▅▃▂▁▁▁▁▁
training_loss,█▅▄▃▂▂▁▁▁▁

0,1
accuracy,0.2375
epoch,10.0
lr,0.0
testing_loss,2.22404
training_loss,2.13855


### 其他功能
* sweeps: 超参数搜索
* artifacts: 保存模型、数据集、日志等
* reports: 生成报告

<div style="text-align:center"><img src="image/index/1720081945391.png" width="80%" style="border-radius: 10px;border: 1px solid #ddd;">
<p style="text-align:center;font-size:16px">超参数搜索及敏感性分析</p>
</div>

#### 参考资料

* [WandB 文档](https://docs.wandb.ai/)
* [eat_pytorch_in_20_days](https://github.com/lyhue1991/eat_pytorch_in_20_days)