In [11]:
import re
import os
import torch
import shutil
import warnings

import numpy as np
import torch.nn as nn
import torch.optim as optim

from glob import glob
from tqdm import tqdm
from os.path import join as ospj

from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

from transcription.globals import BASEPATH
from transcription.dataloaders import TranscriptionDataset

from transcription.preprocessing import chunkify
from transcription.preprocessing import create_feature_and_annotation

from bs4.builder import XMLParsedAsHTMLWarning
warnings.filterwarnings('ignore', category=XMLParsedAsHTMLWarning)

In [3]:
# import re
# import os
# import torch
# import shutil
# import librosa
# import warnings

# import mir_eval
# import torch.nn as nn
# import torch.optim as optim

# import numpy as np
# import IPython.display as ipd
# import matplotlib.pyplot as plt

# from glob import glob
# from tqdm import tqdm
# from bs4 import BeautifulSoup
# from os.path import join as ospj
# from torch.utils.data import Dataset
# from torch.utils.data import DataLoader
# from sklearn.model_selection import train_test_split

# from bs4.builder import XMLParsedAsHTMLWarning
# warnings.filterwarnings('ignore', category=XMLParsedAsHTMLWarning)

In [2]:
'''
    If things aren't running, try making the following command to copy the data locally:
        rsync -a /storage/datasets/IDMT-SMT-Drums /local/<your-user>/
'''
songnames = os.listdir(ospj(BASEPATH, 'audio/'))
songnames = list(map(lambda filename: filename.split('.')[0], songnames))

has_gpu = torch.cuda.is_available()
device = torch.device('cuda:0' if has_gpu else 'cpu')

print('Running on device:', device)

Running on device: cuda:0


## Chunkifying sequences

#### Creating required folders

In [3]:
if os.path.exists(ospj(BASEPATH, 'chunks')):
    print('Deleting chunks folder...')
    shutil.rmtree(ospj(BASEPATH, 'chunks'))

os.makedirs(ospj(BASEPATH, 'chunks'), exist_ok=True)
os.makedirs(ospj(BASEPATH, 'chunks/train'), exist_ok=True)
os.makedirs(ospj(BASEPATH, 'chunks/validation'), exist_ok=True)

Deleting chunks folder...


#### Splitting data into train and validation

In [4]:
def get_songname_type(songname):
    pattern = re.compile(r'([a-zA-Z]+)')
    matches = pattern.search(songname)
    
    # Returns songname type (RealDrum, WaveDrum or TechnoDrum)
    return matches.group(1)

songname_types = list(map(get_songname_type, songnames))
train_songnames, validation_songnames = train_test_split(songnames, test_size=0.2, stratify=songname_types)

#### Splitting data into multiple chunks

In [5]:
print('Splitting train data...')
for songname in tqdm(train_songnames):
    for chunk_id, (spec, annotation) in enumerate(chunkify(songname)):
        filename = songname + f'_part{chunk_id:03d}'
        filename = ospj(BASEPATH, f'chunks/train/{filename}')
        np.savez(filename, spec=spec, annotation=annotation)
        
print('Splitting validation data...')
for songname in tqdm(validation_songnames):
    for chunk_id, (spec, annotation) in enumerate(chunkify(songname)):
        filename = songname + f'_part{chunk_id:03d}'
        filename = ospj(BASEPATH, f'chunks/validation/{filename}')
        np.savez(filename, spec=spec, annotation=annotation)

Splitting train data...


100%|██████████| 76/76 [00:06<00:00, 12.04it/s]


Splitting validation data...


100%|██████████| 19/19 [00:01<00:00, 14.39it/s]


## Training Neural Network (*Long-Short Term Memory* - LSTM)

#### Computing class imbalance weights in training set
- Class weights computed following the suggestions on the [PyTorch's BCEWithLogitsLoss webpage](https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html#torch.nn.BCEWithLogitsLoss).

In [6]:
total_frames = 0
presence_counter = np.zeros(3)

for filename in tqdm(glob('/local/thiago.poppe/IDMT-SMT-Drums/chunks/train/*.npz')):
    data = np.load(filename)
    total_frames += data['annotation'].shape[1]
    presence_counter += (data['annotation'] == 1).sum(axis=1)

pos_weight = torch.from_numpy((total_frames - presence_counter) / presence_counter).to(device)
print('Class weights:', pos_weight)

100%|██████████| 1289/1289 [00:00<00:00, 2838.02it/s]


Class weights: tensor([27.4714, 78.2659, 55.1293], device='cuda:0', dtype=torch.float64)


#### Defining DataLoaders

In [8]:
batch_size = 32

train_dataloader = DataLoader(TranscriptionDataset(is_train=True), batch_size, shuffle=True)
validation_dataloader = DataLoader(TranscriptionDataset(is_train=False), batch_size, shuffle=False)

In [9]:
# Checking if DataLoader output has correct shape
for X, y in train_dataloader:
    break

X.shape, y.shape

(torch.Size([32, 256, 128]), torch.Size([32, 256, 3]))

#### Defining model

In [45]:
class SimpleLSTM(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.lstm = nn.LSTM(input_size=128, hidden_size=256, num_layers=2, 
                           dropout=0.5, batch_first=True, bidirectional=False)
        
        self.classifier = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(), nn.Dropout(),
            nn.Linear(128, 64),
            nn.ReLU(), nn.Dropout(),
            nn.Linear(64, 32),
            nn.ReLU(), nn.Dropout(),
            nn.Linear(32, 3)
        )
        
    def forward(self, x: torch.tensor):
        outputs, _ = self.lstm(x)
        return self.classifier(outputs)

In [46]:
model = SimpleLSTM().to(device)
model

SimpleLSTM(
  (lstm): LSTM(128, 256, num_layers=2, batch_first=True, dropout=0.5)
  (classifier): Sequential(
    (0): Linear(in_features=256, out_features=128, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=128, out_features=64, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=64, out_features=32, bias=True)
    (7): ReLU()
    (8): Dropout(p=0.5, inplace=False)
    (9): Linear(in_features=32, out_features=3, bias=True)
  )
)

In [49]:
sum(p.numel() for p in model.lstm.parameters())

921600

In [47]:
# Checking if forward is correct
for X, y in train_dataloader:
    X = X.to(device)
    y = y.to(device)
    break
    
outputs = model(X)
assert outputs.shape == y.shape

In [48]:
# Checking if loss function works
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
criterion(outputs, y)

tensor(1.3178, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)

#### Running training script

In [25]:
def evaluate(model, validation_dataloader, criterion):
    val_loss = []
    
    model.eval()
    with torch.no_grad():
        for X, y in validation_dataloader:
            X = X.to(device)
            y = y.to(device)
            
            outputs = model(X)
            loss = criterion(outputs, y)
            val_loss.append(loss.item())
            
    val_loss = np.mean(val_loss)
    return val_loss        

In [26]:
num_epochs = 50
learning_rate = 0.001

model = SimpleLSTM().to(device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
optimizer = optim.Adam(model.parameters(), lr=learning_rate, amsgrad=True)

best_validation_loss = np.inf
for epoch in range(1, num_epochs + 1):
    model.train()
    epoch_loss = []
    
    for X, y in train_dataloader:
        X = X.to(device)
        y = y.to(device)
        outputs = model(X)

        optimizer.zero_grad()
        loss = criterion(outputs, y)
        epoch_loss.append(loss.item())
        loss.backward()
        optimizer.step()
    
    if epoch == 1 or epoch % 5 == 0:
        train_loss = np.mean(epoch_loss)
        validation_loss = evaluate(model, validation_dataloader, criterion)
        
        if validation_loss < best_validation_loss:
            print('Saving new best model...')
            best_validation_loss = validation_loss
            torch.save(model.state_dict(), 'simple_lstm_model.ckpt')
        
        print(f'[epoch {epoch}/{num_epochs}] --> train loss: {train_loss:.5f}, validation loss: {validation_loss:5f}')

Saving new best model...
[epoch 1/50] --> train loss: 1.35282, validation loss: 1.332384
Saving new best model...
[epoch 5/50] --> train loss: 0.89896, validation loss: 0.774593
Saving new best model...
[epoch 10/50] --> train loss: 0.85648, validation loss: 0.741579
[epoch 15/50] --> train loss: 0.88468, validation loss: 0.756567
Saving new best model...
[epoch 20/50] --> train loss: 0.87111, validation loss: 0.737557
Saving new best model...
[epoch 25/50] --> train loss: 0.85660, validation loss: 0.722500
[epoch 30/50] --> train loss: 0.90553, validation loss: 0.772338
Saving new best model...
[epoch 35/50] --> train loss: 0.83509, validation loss: 0.713120
[epoch 40/50] --> train loss: 0.93167, validation loss: 0.775635
[epoch 45/50] --> train loss: 0.99257, validation loss: 0.846737
[epoch 50/50] --> train loss: 0.97869, validation loss: 0.792049


#### Checking validation metrics

In [31]:
sigmoid = nn.Sigmoid()
best_model = SimpleLSTM().to(device)
best_model.load_state_dict(torch.load('simple_lstm_model.ckpt'))

total_size = 0
recall_scores = np.zeros(3)
fmeasure_scores = np.zeros(3)
precision_scores = np.zeros(3)

best_model.eval()
with torch.no_grad():
    for X, y in validation_dataloader:
        X = X.to(device)
        y = y.to(device)

        outputs = best_model(X)
        y_numpy = y.detach().cpu().numpy()
        predictions = sigmoid(outputs).detach().cpu().numpy()

        total_size += outputs.shape[0]
        for batch_idx in range(outputs.shape[0]):        
            for instrument in range(3):
                reference_onsets = np.where(y_numpy[batch_idx, :, instrument])[0]
                reference_onsets = librosa.frames_to_time(reference_onsets, sr=SAMPLING_RATE)

                params = {'pre_max': 5, 'post_max': 5, 'pre_avg': 5, 'post_avg': 5, 'delta': 0.25, 'wait': 5}
                estimated_onsets = librosa.util.peak_pick(predictions[batch_idx, :, instrument], **params)
                estimated_onsets = librosa.frames_to_time(estimated_onsets, sr=SAMPLING_RATE)

                if len(reference_onsets) != 0 and len(estimated_onsets) != 0:
                    metrics = mir_eval.onset.evaluate(reference_onsets, estimated_onsets)
                    recall_scores[instrument] += metrics['Recall']
                    fmeasure_scores[instrument] += metrics['F-measure']
                    precision_scores[instrument] += metrics['Precision']

recall_scores /= total_size
fmeasure_scores /= total_size
precision_scores /= total_size

In [32]:
for i, instrument in enumerate(['Hi-Hat', 'Snare Drum', 'Kick Drum']):
    print(f'{instrument} metrics:')
    print(f'  - Mean Recall: {recall_scores[i]:5f}')
    print(f'  - Mean Precision: {precision_scores[i]:5f}')
    print(f'  - Mean F-Measure: {fmeasure_scores[i]:5f}')
    print()
    
print('Overall drumkit metrics:')
print(f'  - Mean Recall: {np.mean(recall_scores):5f}')
print(f'  - Mean Precision: {np.mean(precision_scores[i]):5f}')
print(f'  - Mean F-Measure: {np.mean(fmeasure_scores[i]):5f}')

Hi-Hat metrics:
  - Mean Recall: 0.978927
  - Mean Precision: 0.703111
  - Mean F-Measure: 0.809992

Snare Drum metrics:
  - Mean Recall: 0.962359
  - Mean Precision: 0.834901
  - Mean F-Measure: 0.874343

Kick Drum metrics:
  - Mean Recall: 0.973663
  - Mean Precision: 0.950588
  - Mean F-Measure: 0.960358

Overall drumkit metrics:
  - Mean Recall: 0.971650
  - Mean Precision: 0.950588
  - Mean F-Measure: 0.960358
