DLDC is a deep learning model built with PyTorch that timestamps and classifies drum samples from an audio file into one of 5 classes:
- Kick
- Snare
- Hi-hat
- Tom
- Cymbal
The model is trained on a combination of labeled drum samples from Kaggle and MusicRadar. The model is trained on 5 classes, but can be extended to more classes by adding more labeled samples.
The file test_sequence_2.wav. This is a 4 bar drum sequence with 13 drum samples.
Detected 13 transients
0.000: kick (99.13%)
0.256: kick (99.64%)
0.512: snare (57.97%)
0.736: snare (75.33%)
0.992: cymbal (99.94%)
1.280: cymbal (100.00%)
1.504: cymbal (100.00%)
1.760: tom (100.00%)
2.016: tom (100.00%)
2.240: tom (95.66%)
2.496: cymbal (100.00%)
2.752: cymbal (100.00%)
3.008: cymbal (100.00%)
View the requirements.txt file for a list of dependencies. They are listed below for convenience:
- python 3.7
- numpy 1.20.3
- torch 1.10.2
- librosa 0.9.2
- tqdm 4.65.0
- matplotlib 3.7.1
- soundfile 0.11.0
The repository includes the latest trained model, model_v0.12.pt
. To train the model yourself, follow the steps below.
- Tweak the hyperparameters in
train.py
as desired. The default values are:
# hyperparams
num_classes = 5
num_epochs = 250
batch_size = 64
learning_rate = 0.001
- Run
train.py
to train the model with:
python train.py
This will save the new model to model_v0.XX.pt
.
- Run
classify.py
to classify an audio file with:
python classify.py <input_file> [<export_path>]
input_file
is the path to the audio file to classifyexport_path
path to export the sliced and classified samples to. If not specified, the detected samples will not be exported.
After classification, the script will print the resulting transcription to the console, and plot the results to a matplotlib figure.
For more information about this project, see the paper.
This project is licensed under the GNU GPLv3 license. See the LICENSE file for more information.