In [1]:
import re
import os
import torch
import shutil
import librosa
import warnings

import mir_eval
import torch.nn as nn
import torch.optim as optim

import numpy as np
import IPython.display as ipd
import matplotlib.pyplot as plt

from glob import glob
from tqdm import tqdm
from bs4 import BeautifulSoup
from os.path import join as ospj
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

from bs4.builder import XMLParsedAsHTMLWarning
warnings.filterwarnings('ignore', category=XMLParsedAsHTMLWarning)

In [2]:
'''
    If things aren't running, try making the following command to copy the data locally:
        rsync -a /storage/datasets/IDMT-SMT-Drums /local/<your-user>/
'''

SAMPLING_RATE = 44100
BASEPATH = '/local/thiago.poppe/IDMT-SMT-Drums'

songnames = os.listdir(ospj(BASEPATH, 'audio/'))
songnames = list(map(lambda filename: filename.split('.')[0], songnames))

has_gpu = torch.cuda.is_available()
device = torch.device('cuda:0' if has_gpu else 'cpu')

print('Running on device:', device)

Running on device: cuda:0


In [3]:
chunkify_hyperparameters = {
    'hop_length': 64,
    'window_size': 256
}

dataloader_hyperparameters = {
    'shuffle': True,
    'batch_size': 32
}

optimizer_hyperparameters = {
    'lr': 0.001,
    'weight_decay': 0.01
}

training_hyperparameters = {
    'num_epochs': 200
}

## Helper functions to load and preprocess data

In [4]:
def calculate_spectrogram(y):
    spec = np.abs(librosa.feature.melspectrogram(y=y, sr=SAMPLING_RATE))
    spec = librosa.amplitude_to_db(spec, ref=np.max)
    
    return spec


def create_annotation_matrix(events, num_frames):
    instrument2index = {'HH': 0, 'SD': 1, 'KD': 2}
    annotations = np.zeros((3, num_frames), dtype=np.float32)
    
    for event in events:
        onset = float(event.onsetsec.string)
        instrument = event.instrument.string
        
        index = instrument2index[instrument]
        onset = librosa.time_to_frames(onset, sr=SAMPLING_RATE)
        annotations[index, onset] = 1.0
    
    return annotations


def create_feature_and_annotation(songname):
    audiofile = ospj(BASEPATH, f'audio/{songname}.wav')
    annotationfile = ospj(BASEPATH, f'annotation/{songname}.xml')
    
    with open(annotationfile, 'r') as fp:
        soup = BeautifulSoup(fp, 'lxml')
        events = soup.find_all('event')
    
    wave, sr = librosa.load(audiofile, sr=SAMPLING_RATE)
    spec = calculate_spectrogram(wave)
    annotation = create_annotation_matrix(events, spec.shape[1])
    
    return spec, annotation

## Chunkifying sequences

#### Creating required folders

In [5]:
if os.path.exists(ospj(BASEPATH, 'chunks')):
    print('Deleting chunks folder...')
    shutil.rmtree(ospj(BASEPATH, 'chunks'))

os.makedirs(ospj(BASEPATH, 'chunks'), exist_ok=True)
os.makedirs(ospj(BASEPATH, 'chunks/train'), exist_ok=True)
os.makedirs(ospj(BASEPATH, 'chunks/validation'), exist_ok=True)

Deleting chunks folder...


#### Splitting data into train and validation

In [6]:
def get_songname_type(songname):
    pattern = re.compile(r'([a-zA-Z]+)')
    matches = pattern.search(songname)
    
    # Returns songname type (RealDrum, WaveDrum or TechnoDrum)
    return matches.group(1)

songname_types = list(map(get_songname_type, songnames))
train_songnames, validation_songnames = train_test_split(songnames, test_size=0.2, stratify=songname_types)

#### Splitting data into multiple chunks

In [7]:
def chunkify(songname, window_size=256, hop_length=64):
    spec, annotation = create_feature_and_annotation(songname)
    num_frames = spec.shape[1]
    
    for i in range(0, num_frames - window_size + 1, hop_length):
        spec_chunk = spec[:, i:i+window_size]
        annotation_chunk = annotation[:, i:i+window_size]
        
        yield spec_chunk, annotation_chunk

In [8]:
print('Splitting train data...')
for songname in tqdm(train_songnames):
    for chunk_id, (spec, annotation) in enumerate(chunkify(songname, **chunkify_hyperparameters)):
        filename = songname + f'_part{chunk_id:03d}'
        filename = ospj(BASEPATH, f'chunks/train/{filename}')
        np.savez(filename, spec=spec, annotation=annotation)
        
print('Splitting validation data...')
for songname in tqdm(validation_songnames):
    for chunk_id, (spec, annotation) in enumerate(chunkify(songname, **chunkify_hyperparameters)):
        filename = songname + f'_part{chunk_id:03d}'
        filename = ospj(BASEPATH, f'chunks/validation/{filename}')
        np.savez(filename, spec=spec, annotation=annotation)

Splitting train data...


100%|██████████| 76/76 [00:06<00:00, 12.39it/s]


Splitting validation data...


100%|██████████| 19/19 [00:01<00:00, 14.13it/s]


## Training Neural Network (*Rhythm Transformer*)

#### Computing class imbalance weights in training set
- Class weights computed following the suggestions on the [PyTorch's BCEWithLogitsLoss webpage](https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html#torch.nn.BCEWithLogitsLoss).

In [9]:
total_frames = 0
presence_counter = np.zeros(3)

for filename in glob('/local/thiago.poppe/IDMT-SMT-Drums/chunks/train/*.npz'):
    data = np.load(filename)
    total_frames += data['annotation'].shape[1]
    presence_counter += (data['annotation'] == 1).sum(axis=1)

pos_weight = torch.from_numpy((total_frames - presence_counter) / presence_counter).to(device)
print('Class weights:', pos_weight)

Class weights: tensor([27.0182, 77.3636, 53.8848], device='cuda:0', dtype=torch.float64)


#### Defining DataLoaders

In [10]:
class TranscriptionDataset(Dataset):
    CHUNKS_PATH = '/local/thiago.poppe/IDMT-SMT-Drums/chunks/'
    
    def __init__(self, is_train: bool):
        split = 'train' if is_train else 'validation'
        self.filenames = glob(ospj(self.CHUNKS_PATH, split, '*.npz'))
        
    def __len__(self):
        return len(self.filenames)
    
    def __getitem__(self, idx):
        data = np.load(self.filenames[idx])
        return data['spec'].T, data['annotation'].T

In [11]:
train_dataloader = DataLoader(TranscriptionDataset(is_train=True), **dataloader_hyperparameters)
validation_dataloader = DataLoader(TranscriptionDataset(is_train=False), **dataloader_hyperparameters)

In [12]:
# Checking if DataLoader output has correct shape
for X, y in train_dataloader:
    break

X.shape, y.shape

(torch.Size([32, 256, 128]), torch.Size([32, 256, 3]))

#### Defining model

In [21]:
class DrumTransformer(nn.Module):
    def __init__(self):
        super().__init__()
    
        self.transformer = nn.Transformer(
            d_model=128,
            nhead=8,
            num_encoder_layers=2,
            num_decoder_layers=2,
            dim_feedforward=256,
            dropout=0.5,
            batch_first=True,
        )    
        
        self.classifier = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(), nn.Dropout(),
            nn.Linear(64, 32),
            nn.ReLU(), nn.Dropout(),
            nn.Linear(32, 3)
        )

    def forward(self, x):
        x = self.transformer(x, x)
        x = self.classifier(x)
        return x

In [22]:
model = DrumTransformer().to(device)
model

DrumTransformer(
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0): TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
          )
          (linear1): Linear(in_features=128, out_features=256, bias=True)
          (dropout): Dropout(p=0.5, inplace=False)
          (linear2): Linear(in_features=256, out_features=128, bias=True)
          (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.5, inplace=False)
          (dropout2): Dropout(p=0.5, inplace=False)
        )
        (1): TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
          )
          (linear1): Linear(in_features=1

In [23]:
# Checking if forward is correct
for X, y in train_dataloader:
    X = X.to(device)
    y = y.to(device)
    break
    
outputs = model(X)
assert outputs.shape == y.shape

In [24]:
# Checking if loss function works
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
criterion(outputs, y)

tensor(1.3864, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)

#### Running training script

In [25]:
def evaluate(model, validation_dataloader, criterion):
    val_loss = []
    
    model.eval()
    with torch.no_grad():
        for X, y in validation_dataloader:
            X = X.to(device)
            y = y.to(device)
            
            outputs = model(X)
            loss = criterion(outputs, y)
            val_loss.append(loss.item())
            
    val_loss = np.mean(val_loss)
    return val_loss

In [None]:
num_epochs = training_hyperparameters['num_epochs']

model = DrumTransformer().to(device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
optimizer = optim.Adam(model.parameters(), **optimizer_hyperparameters)

best_validation_loss = np.inf
for epoch in range(1, num_epochs + 1):
    model.train()
    epoch_loss = []
    
    for X, y in train_dataloader:
        X = X.to(device)
        y = y.to(device)
        outputs = model(X)

        optimizer.zero_grad()
        loss = criterion(outputs, y)
        epoch_loss.append(loss.item())
        loss.backward()
        optimizer.step()
    
    if epoch == 1 or epoch % 10 == 0:
        train_loss = np.mean(epoch_loss)
        validation_loss = evaluate(model, validation_dataloader, criterion)
        
        if validation_loss < best_validation_loss:
            print('Saving new best model...')
            best_validation_loss = validation_loss
            torch.save(model.state_dict(), 'drum_transformer.ckpt')
        
        print(f'[epoch {epoch}/{num_epochs}] --> train loss: {train_loss:.5f}, validation loss: {validation_loss:5f}')

Saving new best model...
[epoch 1/200] --> train loss: 1.35467, validation loss: 1.260717
Saving new best model...
[epoch 10/200] --> train loss: 0.74510, validation loss: 0.655588
Saving new best model...
[epoch 20/200] --> train loss: 0.74053, validation loss: 0.648601
Saving new best model...
[epoch 30/200] --> train loss: 0.73330, validation loss: 0.633292
Saving new best model...
[epoch 40/200] --> train loss: 0.70642, validation loss: 0.620975
Saving new best model...
[epoch 50/200] --> train loss: 0.69326, validation loss: 0.595422
[epoch 60/200] --> train loss: 0.69165, validation loss: 0.612897
[epoch 70/200] --> train loss: 0.69655, validation loss: 0.596512
Saving new best model...
[epoch 80/200] --> train loss: 0.66747, validation loss: 0.575930
[epoch 90/200] --> train loss: 0.67398, validation loss: 0.597754
[epoch 100/200] --> train loss: 0.68770, validation loss: 0.657561
[epoch 110/200] --> train loss: 0.67500, validation loss: 0.636732


#### Checking validation metrics

In [None]:
sigmoid = nn.Sigmoid()
best_model = DrumTransformer().to(device)
best_model.load_state_dict(torch.load('drum_transformer.ckpt'))

total_size = 0
recall_scores = np.zeros(3)
fmeasure_scores = np.zeros(3)
precision_scores = np.zeros(3)

best_model.eval()
with torch.no_grad():
    for X, y in validation_dataloader:
        X = X.to(device)
        y = y.to(device)

        outputs = best_model(X)
        y_numpy = y.detach().cpu().numpy()
        predictions = sigmoid(outputs).detach().cpu().numpy()

        total_size += outputs.shape[0]
        for batch_idx in range(outputs.shape[0]):        
            for instrument in range(3):
                reference_onsets = np.where(y_numpy[batch_idx, :, instrument])[0]
                reference_onsets = librosa.frames_to_time(reference_onsets, sr=SAMPLING_RATE)

                params = {'pre_max': 5, 'post_max': 5, 'pre_avg': 5, 'post_avg': 5, 'delta': 0.25, 'wait': 5}
                estimated_onsets = librosa.util.peak_pick(predictions[batch_idx, :, instrument], **params)
                estimated_onsets = librosa.frames_to_time(estimated_onsets, sr=SAMPLING_RATE)

                if len(reference_onsets) != 0 and len(estimated_onsets) != 0:
                    metrics = mir_eval.onset.evaluate(reference_onsets, estimated_onsets)
                    recall_scores[instrument] += metrics['Recall']
                    fmeasure_scores[instrument] += metrics['F-measure']
                    precision_scores[instrument] += metrics['Precision']

recall_scores /= total_size
fmeasure_scores /= total_size
precision_scores /= total_size

In [None]:
for i, instrument in enumerate(['Hi-Hat', 'Snare Drum', 'Kick Drum']):
    print(f'{instrument} metrics:')
    print(f'  - Mean Recall: {recall_scores[i]:5f}')
    print(f'  - Mean Precision: {precision_scores[i]:5f}')
    print(f'  - Mean F-Measure: {fmeasure_scores[i]:5f}')
    print()
    
print('Overall drumkit metrics:')
print(f'  - Mean Recall: {np.mean(recall_scores):5f}')
print(f'  - Mean Precision: {np.mean(precision_scores[i]):5f}')
print(f'  - Mean F-Measure: {np.mean(fmeasure_scores[i]):5f}')