Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sound classification example #5303

Merged
merged 2 commits into from Apr 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
79 changes: 79 additions & 0 deletions PaddleAudio/examples/sound_classification/README.md
@@ -0,0 +1,79 @@
# 声音分类

声音分类和检测是声音算法的一个热门研究方向。
对于声音分类任务,传统机器学习的一个常用做法是首先人工提取音频的时域和频域的多种特征并做特征选择、组合、变换等,然后基于SVM或决策树进行分类。而端到端的深度学习则通常利用深度网络如RNN,CNN等直接对声间波形(waveform)或时频特征(time-frequency)进行特征学习(representation learning)和分类预测。

在IEEE ICASSP 2017 大会上,谷歌开放了一个大规模的音频数据集[Audioset](https://research.google.com/audioset/)。该数据集包含了 632 类的音频类别以及 2,084,320 条人工标记的每段 10 秒长度的声音剪辑片段(来源于YouTube视频)。目前该数据集已经有210万个已标注的视频数据,5800小时的音频数据,经过标记的声音样本的标签类别为527。

`PANNs`([PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition](https://arxiv.org/pdf/1912.10211.pdf))是基于Audioset数据集训练的声音分类/识别的模型。经过预训练后,模型可以用于提取音频的embbedding。本示例将使用`PANNs`的预训练模型Finetune完成声音分类的任务。


## 模型简介

PaddleAudio提供了PANNs的CNN14、CNN10和CNN6的预训练模型,可供用户选择使用:
- CNN14: 该模型主要包含12个卷积层和2个全连接层,模型参数的数量为79.6M,embbedding维度是2048。
- CNN10: 该模型主要包含8个卷积层和2个全连接层,模型参数的数量为4.9M,embbedding维度是512。
- CNN6: 该模型主要包含4个卷积层和2个全连接层,模型参数的数量为4.5M,embbedding维度是512。


## 快速开始

### 模型训练

以环境声音分类数据集`ESC50`为示例,运行下面的命令,可在训练集上进行模型的finetune,支持单机的单卡训练和多卡分布式训练。关于如何使用`paddle.distributed.launch`启动分布式训练,请查看[PaddlePaddle2.0分布式训练](https://www.paddlepaddle.org.cn/documentation/docs/zh/tutorial/quick_start/high_level_api/high_level_api.html#danjiduoka)。

```shell
$ unset CUDA_VISIBLE_DEVICES
$ python -m paddle.distributed.launch --gpus "0" train.py --device gpu --epochs 50 --batch_size 16 --num_worker 16 --checkpoint_dir ./checkpoint --save_freq 10
```

可支持配置的参数:

- `device`: 选用什么设备进行训练,可选cpu或gpu,默认为gpu。如使用gpu训练则参数gpus指定GPU卡号。
- `epochs`: 训练轮次,默认为50。
- `learning_rate`: Fine-tune的学习率;默认为5e-5。
- `batch_size`: 批处理大小,请结合显存情况进行调整,若出现显存不足,请适当调低这一参数;默认为16。
- `num_workers`: Dataloader获取数据的子进程数。默认为0,加载数据的流程在主进程执行。
- `checkpoint_dir`: 模型参数文件和optimizer参数文件的保存目录,默认为`./checkpoint`。
- `save_freq`: 训练过程中的模型保存频率,默认为10。
- `log_freq`: 训练过程中的信息打印频率,默认为10。

示例代码中使用的预训练模型为`CNN14`,如果想更换为其他预训练模型,可通过以下方式执行:
```python
from model import SoundClassifier
from paddleaudio.datasets import ESC50
from paddleaudio.models.panns import cnn14, cnn10, cnn6

# CNN14
backbone = cnn14(pretrained=True, extract_embedding=True)
model = SoundClassifier(backbone, num_class=len(ESC50.label_list))

# CNN10
backbone = cnn10(pretrained=True, extract_embedding=True)
model = SoundClassifier(backbone, num_class=len(ESC50.label_list))

# CNN6
backbone = cnn6(pretrained=True, extract_embedding=True)
model = SoundClassifier(backbone, num_class=len(ESC50.label_list))
```

### 模型预测

```shell
export CUDA_VISIBLE_DEVICES=0
python -u predict.py --device gpu --wav ./dog.wav --top_k 3 --checkpoint ./checkpoint/epoch_50/model.pdparams
```

可支持配置的参数:
- `device`: 选用什么设备进行训练,可选cpu或gpu,默认为gpu。如使用gpu训练则参数gpus指定GPU卡号。
- `wav`: 指定预测的音频文件。
- `top_k`: 预测显示的top k标签的得分,默认为1。
- `checkpoint`: 模型参数checkpoint文件。

输出的预测结果如下:
```
[/audio/dog.wav]
Dog: 0.9999538660049438
Clock tick: 1.3341237718123011e-05
Cat: 6.579841738130199e-06
```
38 changes: 38 additions & 0 deletions PaddleAudio/examples/sound_classification/model.py
@@ -0,0 +1,38 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import paddle.nn as nn
import paddle.nn.functional as F


class SoundClassifier(nn.Layer):
"""
Model for sound classification which uses panns pretrained models to extract
embeddings from audio files.
"""
def __init__(self, backbone, num_class, dropout=0.1):
super(SoundClassifier, self).__init__()
self.backbone = backbone
self.dropout = nn.Dropout(dropout)
self.fc = nn.Linear(self.backbone.emb_size, num_class)

def forward(self, x):
# x: (batch_size, num_frames, num_melbins) -> (batch_size, 1, num_frames, num_melbins)
x = x.unsqueeze(1)
x = self.backbone(x)
x = self.dropout(x)
logits = self.fc(x)

return logits
61 changes: 61 additions & 0 deletions PaddleAudio/examples/sound_classification/predict.py
@@ -0,0 +1,61 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import ast
import os

import numpy as np
import paddle
import paddle.nn.functional as F
from model import SoundClassifier
from paddleaudio.backends import load as load_audio
from paddleaudio.datasets import ESC50
from paddleaudio.features import mel_spect
from paddleaudio.models.panns import cnn14

# yapf: disable
parser = argparse.ArgumentParser(__doc__)
KPatr1ck marked this conversation as resolved.
Show resolved Hide resolved
parser.add_argument('--device', choices=['cpu', 'gpu'], default="gpu", help="Select which device to predict, defaults to gpu.")
parser.add_argument("--wav", type=str, required=True, help="Audio file to infer.")
parser.add_argument("--top_k", type=int, default=1, help="Show top k predicted results")
parser.add_argument("--checkpoint", type=str, required=True, help="Checkpoint of model.")
args = parser.parse_args()
# yapf: enable


def extract_features(file: str, **kwargs):
waveform, sr = load_audio(args.wav, sr=None)
feats = mel_spect(waveform, sample_rate=sr, **kwargs).transpose()
return feats


if __name__ == '__main__':
paddle.set_device(args.device)

model = SoundClassifier(backbone=cnn14(pretrained=False, extract_embedding=True), num_class=len(ESC50.label_list))
model.set_state_dict(paddle.load(args.checkpoint))
model.eval()

feats = extract_features(args.wav)
feats = paddle.to_tensor(np.expand_dims(feats, 0))
logits = model(feats)
probs = F.softmax(logits, axis=1).numpy()

sorted_indices = (-probs[0]).argsort()

msg = f'[{args.wav}]\n'
for idx in sorted_indices[:args.top_k]:
msg += f'{ESC50.label_list[idx]}: {probs[0][idx]}\n'
print(msg)
140 changes: 140 additions & 0 deletions PaddleAudio/examples/sound_classification/train.py
@@ -0,0 +1,140 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import ast
import os

import paddle
import paddle.nn.functional as F
from model import SoundClassifier
from paddleaudio.datasets import ESC50
from paddleaudio.models.panns import cnn14
from paddleaudio.utils import Timer, logger

# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument('--device', choices=['cpu', 'gpu'], default="gpu", help="Select which device to train model, defaults to gpu.")
parser.add_argument("--epochs", type=int, default=50, help="Number of epoches for fine-tuning.")
parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate used to train with warmup.")
parser.add_argument("--batch_size", type=int, default=16, help="Total examples' number in batch for training.")
parser.add_argument("--num_workers", type=int, default=0, help="Number of workers in dataloader.")
parser.add_argument("--checkpoint_dir", type=str, default='./checkpoint', help="Directory to save model checkpoints.")
parser.add_argument("--save_freq", type=int, default=10, help="Save checkpoint every n epoch.")
parser.add_argument("--log_freq", type=int, default=10, help="Log the training infomation every n steps.")
args = parser.parse_args()
# yapf: enable

if __name__ == "__main__":
paddle.set_device(args.device)
nranks = paddle.distributed.get_world_size()
local_rank = paddle.distributed.get_rank()

backbone = cnn14(pretrained=True, extract_embedding=True)
model = SoundClassifier(backbone, num_class=len(ESC50.label_list))
optimizer = paddle.optimizer.Adam(learning_rate=args.learning_rate, parameters=model.parameters())
criterion = paddle.nn.loss.CrossEntropyLoss()

train_ds = ESC50(mode='train', feat_type='mel_spect')
dev_ds = ESC50(mode='dev', feat_type='mel_spect')

train_sampler = paddle.io.DistributedBatchSampler(train_ds,
batch_size=args.batch_size,
shuffle=True,
drop_last=False)
train_loader = paddle.io.DataLoader(
train_ds,
batch_sampler=train_sampler,
num_workers=args.num_workers,
return_list=True,
use_buffer_reader=True,
)

steps_per_epoch = len(train_sampler)
timer = Timer(steps_per_epoch * args.epochs)
timer.start()

for epoch in range(1, args.epochs + 1):
model.train()

avg_loss = 0
num_corrects = 0
num_samples = 0
for batch_idx, batch in enumerate(train_loader):
feats, labels = batch
logits = model(feats)

loss = criterion(logits, labels)
loss.backward()
optimizer.step()
if isinstance(optimizer._learning_rate, paddle.optimizer.lr.LRScheduler):
optimizer._learning_rate.step()
optimizer.clear_grad()

# Calculate loss
avg_loss += loss.numpy()[0]

# Calculate metrics
preds = paddle.argmax(logits, axis=1)
num_corrects += (preds == labels).numpy().sum()
num_samples += feats.shape[0]

timer.count()

if (batch_idx + 1) % args.log_freq == 0 and local_rank == 0:
lr = optimizer.get_lr()
avg_loss /= args.log_freq
avg_acc = num_corrects / num_samples

print_msg = 'Epoch={}/{}, Step={}/{}'.format(epoch, args.epochs, batch_idx + 1, steps_per_epoch)
print_msg += ' loss={:.4f}'.format(avg_loss)
print_msg += ' acc={:.4f}'.format(avg_acc)
print_msg += ' lr={:.6f} step/sec={:.2f} | ETA {}'.format(lr, timer.timing, timer.eta)
logger.train(print_msg)

avg_loss = 0
num_corrects = 0
num_samples = 0

if epoch % args.save_freq == 0 and batch_idx + 1 == steps_per_epoch and local_rank == 0:
dev_sampler = paddle.io.BatchSampler(dev_ds, batch_size=args.batch_size, shuffle=False, drop_last=False)
dev_loader = paddle.io.DataLoader(
dev_ds,
batch_sampler=dev_sampler,
num_workers=args.num_workers,
return_list=True,
)

model.eval()
num_corrects = 0
num_samples = 0
with logger.processing('Evaluation on validation dataset'):
for batch_idx, batch in enumerate(dev_loader):
feats, labels = batch
logits = model(feats)

preds = paddle.argmax(logits, axis=1)
num_corrects += (preds == labels).numpy().sum()
num_samples += feats.shape[0]

print_msg = '[Evaluation result]'
print_msg += ' dev_acc={:.4f}'.format(num_corrects / num_samples)

logger.eval(print_msg)

# Save model
save_dir = os.path.join(args.checkpoint_dir, 'epoch_{}'.format(epoch))
logger.info('Saving model checkpoint to {}'.format(save_dir))
paddle.save(model.state_dict(), os.path.join(save_dir, 'model.pdparams'))
paddle.save(optimizer.state_dict(), os.path.join(save_dir, 'model.pdopt'))
2 changes: 1 addition & 1 deletion PaddleAudio/paddleaudio/datasets/__init__.py
@@ -1,4 +1,4 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
Expand Down
4 changes: 2 additions & 2 deletions PaddleAudio/paddleaudio/datasets/dataset.py
Expand Up @@ -15,11 +15,11 @@
import os
from typing import List, Tuple

import librosa
import numpy as np
import paddle
from tqdm import tqdm

from ..backends import load as load_audio
from ..features import linear_spect, log_spect, mel_spect
from ..utils.log import logger

Expand Down Expand Up @@ -71,7 +71,7 @@ def _get_data(self, input_file: str):
def _convert_to_record(self, idx):
file, label = self.files[idx], self.labels[idx]

waveform, _ = librosa.load(file, sr=self.sample_rate)
waveform, _ = load_audio(file, sr=self.sample_rate)
normal_length = self.sample_rate * self.duration
if len(waveform) > normal_length:
waveform = waveform[:normal_length]
Expand Down