In [16]:
import torch
import torch.nn.functional as F
from torch import nn

import matplotlib.pyplot as plt
import numpy as np

In [2]:
class BiLSTM(nn.Module):
    inference_chunk_length = 512

    def __init__(self, input_features, recurrent_features):
        super().__init__()
        self.rnn = nn.LSTM(input_features, recurrent_features, batch_first=True, bidirectional=True)

    def forward(self, x):
        if self.training:
            return self.rnn(x)[0]
        else:
            # evaluation mode: support for longer sequences that do not fit in memory
            batch_size, sequence_length, input_features = x.shape
            hidden_size = self.rnn.hidden_size
            num_directions = 2 if self.rnn.bidirectional else 1

            h = torch.zeros(num_directions, batch_size, hidden_size, device=x.device)
            c = torch.zeros(num_directions, batch_size, hidden_size, device=x.device)
            output = torch.zeros(batch_size, sequence_length, num_directions * hidden_size, device=x.device)

            # forward direction
            slices = range(0, sequence_length, self.inference_chunk_length)
            for start in slices:
                end = start + self.inference_chunk_length
                output[:, start:end, :], (h, c) = self.rnn(x[:, start:end, :], (h, c))

            # reverse direction
            if self.rnn.bidirectional:
                h.zero_()
                c.zero_()

                for start in reversed(slices):
                    end = start + self.inference_chunk_length
                    result, (h, c) = self.rnn(x[:, start:end, :], (h, c))
                    output[:, start:end, hidden_size:] = result[:, :, hidden_size:]

            return output

In [3]:
class OnsetsAndFrames(nn.Module):
    def __init__(self, input_features, output_features, model_complexity=48):
        super().__init__()

        model_size = model_complexity * 16
        sequence_model = lambda input_size, output_size: BiLSTM(input_size, output_size // 2)

        self.onset_stack = nn.Sequential(
            ConvStack(input_features, model_size), #(batch, 640, 768)
            sequence_model(model_size, model_size), #(batch, 640, 768)
            nn.Linear(model_size, output_features),
            nn.Sigmoid()
        )
        self.offset_stack = nn.Sequential(
            ConvStack(input_features, model_size),
            sequence_model(model_size, model_size),
            nn.Linear(model_size, output_features),
            nn.Sigmoid()
        )
        self.frame_stack = nn.Sequential(
            ConvStack(input_features, model_size),
            nn.Linear(model_size, output_features),
            nn.Sigmoid()
        )
        self.combined_stack = nn.Sequential(
            sequence_model(output_features * 3, model_size),
            nn.Linear(model_size, output_features),
            nn.Sigmoid()
        )
        self.velocity_stack = nn.Sequential(
            ConvStack(input_features, model_size),
            nn.Linear(model_size, output_features)
        )

        self.sequence_model = sequence_model(output_features * 3, model_size)
        self.Linear = nn.Linear(model_size, output_features)
    def forward(self, mel):
        onset_pred = self.onset_stack(mel)
        offset_pred = self.offset_stack(mel)
        activation_pred = self.frame_stack(mel)
#         combined_pred = torch.cat([onset_pred.detach(), offset_pred.detach(), activation_pred], dim=-1)
        combined_x = torch.cat([onset_pred.detach(), offset_pred.detach(), activation_pred], dim=-1)
        print(f'combined_x.shape = {combined_x.shape}')
        x = self.sequence_model(combined_x)
        print(f'self.sequence_model output shape = {x.shape}')
        x = self.Linear(x)
        print(f'self.Linear output shape = {x.shape}')
        frame_pred = x

#         frame_pred = self.combined_stack(combined_pred)
        velocity_pred = self.velocity_stack(mel)
        return onset_pred, offset_pred, activation_pred, frame_pred, velocity_pred

    def run_on_batch(self, batch):
        audio_label = batch['audio']
        onset_label = batch['onset']
        offset_label = batch['offset']
        frame_label = batch['frame']
        velocity_label = batch['velocity']

        mel = melspectrogram(audio_label.reshape(-1, audio_label.shape[-1])[:, :-1]) # x = torch.rand(8,229, 640)
        mel = mel.transpose(-1,-2) # swap mel bins with timesteps so that it fits LSTM later # shape (8,640,229)
        onset_pred, offset_pred, _, frame_pred, velocity_pred = self(mel)

        predictions = {
            'onset': onset_pred.reshape(*onset_label.shape),
            'offset': offset_pred.reshape(*offset_label.shape),
            'frame': frame_pred.reshape(*frame_label.shape),
            'velocity': velocity_pred.reshape(*velocity_label.shape)
        }

        losses = {
            'loss/onset': F.binary_cross_entropy(predictions['onset'], onset_label),
            'loss/offset': F.binary_cross_entropy(predictions['offset'], offset_label),
            'loss/frame': F.binary_cross_entropy(predictions['frame'], frame_label),
            'loss/velocity': self.velocity_loss(predictions['velocity'], velocity_label, onset_label)
        }

        return predictions, losses

    def velocity_loss(self, velocity_pred, velocity_label, onset_label):
        denominator = onset_label.sum()
        if denominator.item() == 0:
            return denominator
        else:
            return (onset_label * (velocity_label - velocity_pred) ** 2).sum() / denominator


In [4]:
class ConvStack(nn.Module):
    def __init__(self, input_features, output_features):
        super().__init__()

        # input is batch_size * 1 channel * frames * input_features
        self.cnn = nn.Sequential(
            # layer 0
            nn.Conv2d(1, output_features // 16, (3, 3), padding=1),
            nn.BatchNorm2d(output_features // 16),
            nn.ReLU(),
            # layer 1
            nn.Conv2d(output_features // 16, output_features // 16, (3, 3), padding=1),
            nn.BatchNorm2d(output_features // 16),
            nn.ReLU(),
            # layer 2
            nn.MaxPool2d((1, 2)),
            nn.Dropout(0.25),
            nn.Conv2d(output_features // 16, output_features // 8, (3, 3), padding=1),
            nn.BatchNorm2d(output_features // 8),
            nn.ReLU(),
            # layer 3
            nn.MaxPool2d((1, 2)),
            nn.Dropout(0.25),
        )
        self.fc = nn.Sequential(
            nn.Linear((output_features // 8) * (input_features // 4), output_features),
            nn.Dropout(0.5)
        )

    def forward(self, mel):
        x = mel.view(mel.size(0), 1, mel.size(1), mel.size(2))
        x = self.cnn(x)
#         print(f'CNN output shape = {x.shape}')
        x = x.transpose(1, 2).flatten(-2)
#         print(f'flatten output shape = {x.shape}')
        x = self.fc(x)
#         print(f'fc output shape = {x.shape}')
        return x

In [5]:
x = torch.rand(8,640,229)

In [6]:
model=OnsetsAndFrames(229,88)

In [7]:
output = model(x)

combined_x.shape = torch.Size([8, 640, 264])
self.sequence_model output shape = torch.Size([8, 640, 768])
self.Linear output shape = torch.Size([8, 640, 88])


In [8]:
from torch.utils.data import DataLoader
from onsets_and_frames import *

STFT filter created, time used = 0.3095 seconds
Mel filter created, time used = 0.0130 seconds


In [9]:
sequence_length = 327680
validation_length = sequence_length
dataset = MAPS(groups=['AkPnBcht', 'AkPnBsdf', 'AkPnCGdD', 'AkPnStgb', 'SptkBGAm', 'SptkBGCl', 'StbgTGd2'], sequence_length=sequence_length)
validation_dataset = MAPS(groups=['ENSTDkAm', 'ENSTDkCl'], sequence_length=validation_length)

Loading group AkPnBcht: 100%|██████████| 30/30 [00:00<00:00, 219.45it/s]
Loading group AkPnBsdf:   0%|          | 0/30 [00:00<?, ?it/s]

Loading 7 groups of MAPS at data/MAPS


Loading group AkPnBsdf: 100%|██████████| 30/30 [00:00<00:00, 211.17it/s]
Loading group AkPnCGdD: 100%|██████████| 30/30 [00:00<00:00, 225.85it/s]
Loading group AkPnStgb: 100%|██████████| 30/30 [00:00<00:00, 166.44it/s]
Loading group SptkBGAm: 100%|██████████| 30/30 [00:00<00:00, 245.46it/s]
Loading group SptkBGCl: 100%|██████████| 30/30 [00:00<00:00, 210.70it/s]
Loading group StbgTGd2: 100%|██████████| 30/30 [00:00<00:00, 211.42it/s]
Loading group ENSTDkAm: 100%|██████████| 30/30 [00:00<00:00, 228.18it/s]
Loading group ENSTDkCl:   0%|          | 0/30 [00:00<?, ?it/s]

Loading 2 groups of MAPS at data/MAPS


Loading group ENSTDkCl: 100%|██████████| 30/30 [00:00<00:00, 210.08it/s]


In [10]:
loader = DataLoader(dataset,batch_size=8, shuffle=True, drop_last=True)

In [11]:
len(loader.dataset)

210

In [12]:
label = next(iter(validation_dataset))

In [19]:
p_ref, i_ref, v_ref = extract_notes(label['onset'], label['frame'], label['velocity'])

In [20]:
p_est, i_est, v_est = extract_notes(label['onset'], label['frame'], label['velocity'])

In [17]:
p_est[0] = 29
i_est[3] = [60, 130]
p_est = np.append(p_est, 12)
i_est = np.vstack((i_est, [99,111]))

In [22]:
p_ref[:10]

array([28, 35, 40, 44, 47, 45, 49, 28, 35, 40])

In [30]:
v_ref[:10]

array([0.234375  , 0.234375  , 0.25260417, 0.3125    , 0.3125    ,
       0.296875  , 0.296875  , 0.2265625 , 0.2265625 , 0.25390625])

In [23]:
i_ref[:10]

array([[ 18, 130],
       [ 18, 130],
       [ 18, 130],
       [ 18,  47],
       [ 18,  47],
       [ 46,  56],
       [ 46,  57],
       [ 55, 130],
       [ 55, 130],
       [ 55, 130]])

In [25]:
p_est[:8]

array([28, 35, 40, 44, 47, 45, 49, 28])

In [None]:
i_est.shape

In [None]:
scaling = 512/44100

In [27]:
from mir_eval.transcription import precision_recall_f1_overlap

In [28]:
precision_recall_f1_overlap(i_ref[:10], p_ref[:10], i_est[:8], p_est[:8], offset_ratio=None)

(1.0, 0.8, 0.888888888888889, 1.0)

In [341]:
i_est

array([[ 18, 130],
       [ 18, 130],
       [ 18, 130],
       [ 60, 130],
       [ 18,  47],
       [ 46,  56],
       [ 46,  57],
       [ 55, 130],
       [ 55, 130],
       [ 55, 130],
       [ 55, 130],
       [ 55,  76],
       [ 75, 130],
       [ 75, 130],
       [ 75, 130],
       [ 75, 130],
       [131, 187],
       [131, 181],
       [131, 241],
       [131, 181],
       [168, 187],
       [168, 181],
       [168, 241],
       [168, 181],
       [182, 187],
       [182, 187],
       [187, 241],
       [188, 241],
       [188, 241],
       [188, 241],
       [242, 279],
       [242, 298],
       [242, 292],
       [243, 279],
       [279, 298],
       [279, 292],
       [280, 352],
       [280, 352],
       [293, 352],
       [299, 352],
       [299, 352],
       [299, 352],
       [299, 352],
       [354, 465],
       [354, 575],
       [354, 410],
       [354, 439],
       [391, 465],
       [391, 575],
       [391, 410],
       [391, 439],
       [410, 465],
       [410,

In [360]:
counter = 0
for label in validation_dataset:
    print(label['path'])
    print(counter)
    counter+=1

data/MAPS/flac/MAPS_MUS-bk_xmas1_ENSTDkAm.flac
0
data/MAPS/flac/MAPS_MUS-chpn-p14_ENSTDkAm.flac
1
data/MAPS/flac/MAPS_MUS-chpn-p15_ENSTDkAm.flac
2
data/MAPS/flac/MAPS_MUS-chpn-p4_ENSTDkAm.flac
3
data/MAPS/flac/MAPS_MUS-chpn_op25_e3_ENSTDkAm.flac
4
data/MAPS/flac/MAPS_MUS-chpn_op25_e4_ENSTDkAm.flac
5
data/MAPS/flac/MAPS_MUS-chpn_op33_2_ENSTDkAm.flac
6
data/MAPS/flac/MAPS_MUS-chpn_op35_1_ENSTDkAm.flac
7
data/MAPS/flac/MAPS_MUS-chpn_op66_ENSTDkAm.flac
8
data/MAPS/flac/MAPS_MUS-chpn_op7_1_ENSTDkAm.flac
9
data/MAPS/flac/MAPS_MUS-grieg_butterfly_ENSTDkAm.flac
10
data/MAPS/flac/MAPS_MUS-grieg_kobold_ENSTDkAm.flac
11
data/MAPS/flac/MAPS_MUS-liz_rhap02_ENSTDkAm.flac
12
data/MAPS/flac/MAPS_MUS-liz_rhap09_ENSTDkAm.flac
13
data/MAPS/flac/MAPS_MUS-liz_rhap12_ENSTDkAm.flac
14
data/MAPS/flac/MAPS_MUS-mendel_op62_5_ENSTDkAm.flac
15
data/MAPS/flac/MAPS_MUS-muss_1_ENSTDkAm.flac
16
data/MAPS/flac/MAPS_MUS-pathetique_2_ENSTDkAm.flac
17
data/MAPS/flac/MAPS_MUS-pathetique_3_ENSTDkAm.flac
18
data/MAPS/flac/M

In [351]:
len(validation_dataset)

60

In [352]:
len(dataset)

210

In [None]:
validation_dataset.

In [348]:
label['label'].shape

torch.Size([640, 88])

In [350]:
label


{'path': 'data/MAPS/flac/MAPS_MUS-bk_xmas1_ENSTDkAm.flac',
 'audio': tensor([-0.0038, -0.0070, -0.0106,  ...,  0.0487,  0.0453,  0.0421],
        device='cuda:0'),
 'label': tensor([[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]], device='cuda:0', dtype=torch.uint8),
 'velocity': tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0'),
 'onset': tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]], devi

In [362]:
len(validation_dataset)

60

In [363]:
validation_dataset.ba

[{'path': 'data/MAPS/flac/MAPS_MUS-bk_xmas1_ENSTDkAm.flac',
  'audio': tensor([23, 23, 22,  ..., -5, -6, -6], dtype=torch.int16),
  'label': tensor([[0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          ...,
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0]], dtype=torch.uint8),
  'velocity': tensor([[0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          ...,
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0]], dtype=torch.uint8)},
 {'path': 'data/MAPS/flac/MAPS_MUS-chpn-p14_ENSTDkAm.flac',
  'audio': tensor([  3,   4,   3,  ..., -22, -23, -21], dtype=torch.int16),
  'label': tensor([[0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          ...,
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
    