In [1]:
import os,sys,signal
import math

import pickle
import numpy as np                                       # fast vectors and matrices
import matplotlib.pyplot as plt                          # plotting
sys.path.insert(0, '../')

from time import time

from sklearn.metrics import average_precision_score


import torch
from torch.nn.functional import conv1d, mse_loss
import torch.nn.functional as F
import torch.nn as nn

# For the dataloading
from torch.utils.data import Dataset
from abc import abstractmethod
from tqdm import tqdm
from glob import glob
from nnAudio import Spectrogram
from scipy.io import wavfile
import soundfile
from time import time

In [2]:
class PianoRollAudioDataset(Dataset):
    def __init__(self, path, groups=None, sequence_length=None, seed=42, refresh=False, device='cpu'):
        self.path = path
        self.groups = groups if groups is not None else self.available_groups()
        self.sequence_length = sequence_length
        self.device = device
        self.random = np.random.RandomState(seed)
        self.refresh = refresh

        self.data = []
        print(f"Loading {len(groups)} group{'s' if len(groups) > 1 else ''} "
              f"of {self.__class__.__name__} at {path}")
        for group in groups:
            for input_files in tqdm(self.files(group), desc='Loading group %s' % group): #self.files is defined in MAPS class
                self.data.append(self.load(*input_files)) # self.load is a function defined below. It first loads all data into memory first
    def __getitem__(self, index):

        data = self.data[index]
        result = dict(path=data['path'])

        audio_length = len(data['audio'])
        step_begin = self.random.randint(audio_length - self.sequence_length) // HOP_LENGTH
        n_steps = self.sequence_length // HOP_LENGTH
        step_end = step_begin + n_steps

        begin = step_begin * HOP_LENGTH
        end = begin + self.sequence_length

        result['audio'] = data['audio'][begin:end]
        result['label'] = data['label'][step_begin:step_end, :]
        result['velocity'] = data['velocity'][step_begin:step_end, :]


        result['audio'] = result['audio'].float().div_(32768.0) # converting to float by dividing it by 2^15
        result['frame'] = (result['label'] > 1).float()
        # print(f"result['audio'].shape = {result['audio'].shape}")
        # print(f"result['label'].shape = {result['label'].shape}")
        return result

    def __len__(self):
        return len(self.data)

    @classmethod # This one seems optional?
    @abstractmethod # This is to make sure other subclasses also contain this method
    def available_groups(cls):
        """return the names of all available groups"""
        raise NotImplementedError

    @abstractmethod
    def files(self, group):
        """return the list of input files (audio_filename, tsv_filename) for this group"""
        raise NotImplementedError

    def load(self, audio_path, tsv_path):
        """
        load an audio track and the corresponding labels
        Returns
        -------
            A dictionary containing the following data:
            path: str
                the path to the audio file
            audio: torch.ShortTensor, shape = [num_samples]
                the raw waveform
            label: torch.ByteTensor, shape = [num_steps, midi_bins]
                a matrix that contains the onset/offset/frame labels encoded as:
                3 = onset, 2 = frames after onset, 1 = offset, 0 = all else
            velocity: torch.ByteTensor, shape = [num_steps, midi_bins]
                a matrix that contains MIDI velocity values at the frame locations
        """
        saved_data_path = audio_path.replace('.flac', '.pt').replace('.wav', '.pt')
        if os.path.exists(saved_data_path) and self.refresh==False: # Check if .pt files exist, if so just load the files
            return torch.load(saved_data_path)
        # Otherwise, create the .pt files
        audio, sr = soundfile.read(audio_path, dtype='int16')
#         audio, sr = wavfile.read(audio_path)
        assert sr == SAMPLE_RATE

        audio = torch.ShortTensor(audio) # convert numpy array to pytorch tensor
        audio_length = len(audio)

        n_keys = MAX_MIDI - MIN_MIDI + 1
        n_steps = (audio_length - 1) // HOP_LENGTH + 1 # This will affect the labels time steps

        label = torch.zeros(n_steps, n_keys, dtype=torch.uint8)
        velocity = torch.zeros(n_steps, n_keys, dtype=torch.uint8)

        tsv_path = tsv_path
        midi = np.loadtxt(tsv_path, delimiter='\t', skiprows=1)

        for onset, offset, note, vel in midi:
            left = int(round(onset * SAMPLE_RATE / HOP_LENGTH)) # Convert time to time step
            onset_right = min(n_steps, left + HOPS_IN_ONSET) # Ensure the time step of onset would not exceed the last time step
            frame_right = int(round(offset * SAMPLE_RATE / HOP_LENGTH))
            frame_right = min(n_steps, frame_right) # Ensure the time step of frame would not exceed the last time step
            offset_right = min(n_steps, frame_right + HOPS_IN_OFFSET)

            f = int(note) - MIN_MIDI
            label[left:onset_right, f] = 3
            label[onset_right:frame_right, f] = 2
            label[frame_right:offset_right, f] = 1
            velocity[left:frame_right, f] = vel

        data = dict(path=audio_path, audio=audio, label=label, velocity=velocity)
#         torch.save(data, saved_data_path)
        return data

class MusicNet(PianoRollAudioDataset):
    def __init__(self, path='../IJCNN2020_music_transcription/data/', groups=None, sequence_length=None, seed=42, refresh=False, device='cpu'):
        super().__init__(path, groups if groups is not None else ['train'], sequence_length, seed, refresh, device)

    @classmethod
    def available_groups(cls):
        return ['train', 'test']

    def files(self, group):

        wavs = sorted(glob(os.path.join(self.path, f'{group}_data/*.wav')))
        tsvs = sorted(glob(os.path.join(self.path, f'tsv_{group}_labels/*.tsv')))
        assert(all(os.path.isfile(wav) for wav in wavs))

        return zip(wavs, tsvs)

In [3]:
device = 'cuda:0'

In [4]:
batch_size = 16
HOP_LENGTH = 512
ONSET_LENGTH = 512
OFFSET_LENGTH = 512
HOPS_IN_ONSET = ONSET_LENGTH // HOP_LENGTH
HOPS_IN_OFFSET = OFFSET_LENGTH // HOP_LENGTH
SAMPLE_RATE = 44100
MIN_MIDI = 21
MAX_MIDI = 108

In [5]:
start = time()
train_set = MusicNet(path='./data/', groups=['train'], sequence_length=327680, refresh=True)
loading_time = time()-start

Loading group train: 0it [00:00, ?it/s]

Loading 1 group of MusicNet at ./data/


Loading group train: 320it [01:41,  3.40it/s]


In [6]:
train_loader = torch.utils.data.DataLoader(train_set,batch_size=batch_size)

In [7]:
factor = 4
n_fft = 4096//factor
lr = 1e-4


Loss = torch.nn.BCELoss()
def L(yhatvar,y):
    return Loss(yhatvar,y) * 128/2

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        f_kernal = 128//factor
        self.STFT_layer = Spectrogram.STFT(sr=44100, n_fft=n_fft, hop_length=HOP_LENGTH, pad_mode='constant', center=True)
        self.freq_cnn1 = torch.nn.Conv2d(1,4, (f_kernal,3), stride=(8,1), padding=1)
        self.freq_cnn2 = torch.nn.Conv2d(4,8, (f_kernal,3), stride=(8,1), padding=1)
        shape = self.shape_inference(f_kernal)
        self.bilstm = torch.nn.LSTM(shape*8, shape*8, batch_first=True, bidirectional=True)
        self.pitch_classifier = torch.nn.Linear(shape*8*2, 88)

    def shape_inference(self, f_kernal):
        layer1 = (n_fft//2+2-(f_kernal))//8 + 1 
        layer2 = (layer1+2-(f_kernal))//8 + 1 
        return layer2
    
    def forward(self,x):
        x = self.STFT_layer(x[:,:-1])
        x = torch.log(x+1e-5)
        x = torch.relu(self.freq_cnn1(x.unsqueeze(1)))
        x = torch.relu(self.freq_cnn2(x))
        x, _ = self.bilstm(x.view(x.size(0), x.size(1)*x.size(2), x.size(3)).transpose(1,2))
        x = torch.sigmoid(self.pitch_classifier(x))
        
        return x
    
model = Model()
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss_fn  = torch.nn.BCELoss()


STFT kernels created, time used = 19.1567 seconds


In [8]:
epoches = 50

times = []
loss_histroy = []
print("epoch\ttrain loss\ttime")
total_i = len(train_loader)
for e in range(epoches):
    running_loss = 0
    start = time()
    for idx, data in enumerate(train_loader):
        optimizer.zero_grad()
        y_pred = model(data['audio'].to(device))
        loss = loss_fn(y_pred, data['frame'].to(device))
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()        
        
        print(f"Training {idx+1}/{total_i} batches\t Loss: {loss.item()}", end = '\r')
    time_used = time()-start
    times.append(time_used)
    print(' '*200, end='\r')
    print(f'{e+1}\t{running_loss/total_i:.6f}\t{time_used:.6f}')
    loss_histroy.append(running_loss/total_i)

epoch	train loss	time
1	0.680794	2.675504                                                                                                                                                                                     
2	0.654296	2.190073                                                                                                                                                                                     
3	0.625197	2.190599                                                                                                                                                                                     
4	0.592302	1.988118                                                                                                                                                                                     
5	0.557689	2.366848                                                                                                                                                           

42	0.156261	2.117943                                                                                                                                                                                    
43	0.157839	2.383978                                                                                                                                                                                    
44	0.159349	2.153829                                                                                                                                                                                    
45	0.154421	2.670984                                                                                                                                                                                    
46	0.154185	2.505843                                                                                                                                                                                

In [9]:
nnAudio_result = {}

In [10]:
nnAudio_result['loss_histroy'] = loss_histroy
nnAudio_result['time_histroy'] = times
nnAudio_result['loading_time'] = loading_time

In [12]:
import pickle
with open(f'./nnAudio_result_{n_fft}', 'wb') as f:
    pickle.dump(nnAudio_result,f)
    
torch.save(model.state_dict(), f'./weight/nnAudio_result_{n_fft}')