In [1]:
import os, signal, sys

import numpy as np
import torch
import torch.nn as nn
from torch.nn.functional import conv1d
import torchvision

from scipy.io.wavfile import read

from time import time

sys.path.insert(0, '../')
import musicnet
# from helperfunctions import get_audio_segment, get_piano_roll, export_midi
from sklearn.metrics import average_precision_score

os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'   # see issue #152
os.environ['CUDA_VISIBLE_DEVICES']='3'

import matplotlib.pyplot as plt

from pypianoroll import Multitrack, Track, load, parse

if torch.cuda.is_available():
    device = "cuda:0"
    torch.set_default_tensor_type('torch.cuda.FloatTensor')

In [2]:
# lvl1 convolutions are shared between regions
m = 128
k = 512              # lvl1 nodes
n_fft = 4096              # lvl1 receptive field
window = 16384 # total number of audio samples?
stride = 512
batch_size = 100
epsilon = 1e-8

regions = 1 + (window - n_fft)//stride

In [3]:
class spectrograms_stft(torch.nn.Module):
    def __init__(self, avg=.9998):
        super(spectrograms_stft, self).__init__()
        # Create filter windows for stft
        wsin, wcos = musicnet.create_filters(n_fft,k, windowing="no", freq_scale='linear')
        self.wsin = torch.tensor(wsin, dtype=torch.float)
        self.wcos = torch.tensor(wcos, dtype=torch.float)
            
        # Creating Layers
        self.linear = torch.nn.Linear(regions*k, m, bias=False)
        torch.nn.init.constant_(self.linear.weight, 0) # initialize
        
        self.avg = avg
        
    def forward(self,x):
        zx = conv1d(x[:,None,:], self.wsin, stride=stride).pow(2) \
           + conv1d(x[:,None,:], self.wcos, stride=stride).pow(2) # Doing STFT by using conv1d
        return self.linear(torch.log(zx + 10e-8).view(x.data.size()[0],regions*k))
    

In [4]:
model = spectrograms_stft()
model = model.to(device)

In [5]:
model.load_state_dict(torch.load('../weights/spectrograms_stft'))

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [6]:
def access_full(path):
    with open(path, 'rb') as f:
        x = np.fromfile(f, dtype=np.float32)
    return x

In [7]:
def get_piano_roll_from_wav(filepath, model, device, window=16384, stride=1000, offset=44100, count=7500, batch_size=500, m=128):
    sf=4
    x = read(filepath)[1]
    if x.ndim==2:
        x = x.mean(1) # convert stereo to mono
    elif x.ndim>2:
        print("the audio shape {} is not correct, please check and fix it".format(x.shape))
    
    if stride == -1:
        stride = (x.shape[0] - offset - int(sf*window))/(count-1)
        stride = int(stride)
        print("Number of stride = ", stride)
    else:
        count = (x.shape[0]- offset - int(sf*window))/stride + 1
        count = int(count)
        
    X = np.zeros([count, window])
    Y = np.zeros([count, m])    
        
    for i in range(count):
        temp =  x[offset+i*stride:offset+i*stride+window]
        temp = temp / (np.linalg.norm(temp) + epsilon)
        X[i,:] = temp
    
    with torch.no_grad():
        Y_pred = torch.zeros([count,m])
        for i in range(len(X)//batch_size):
            print(f"{i}/{(len(X)//batch_size)} batches", end = '\r')
            X_batch = torch.tensor(X[batch_size*i:batch_size*(i+1)]).float().to(device)
            Y_pred[batch_size*i:batch_size*(i+1)] = model(X_batch).cpu()
    
    return Y_pred

In [8]:
def export_midi(Y_pred, path):
    # Create a piano-roll matrix, where the first and second axes represent time
    # and pitch, respectively, and assign a C major chord to the piano-roll
    # Create a `pypianoroll.Track` instance
    track = Track(pianoroll=Y_pred*127, program=0, is_drum=False,
                  name='my awesome piano')   
    multitrack = Multitrack(tracks=[track], tempo=60, beat_resolution=86)
    multitrack.write(path)    

In [19]:
folder = './'
files = ['BWV846.wav']
filepath_list = [os.path.join(folder, i) for i in files]

In [20]:
for filepath in filepath_list:
    Y_pred = get_piano_roll_from_wav(filepath, model, device,
                                window=window, m=m, stride=512)
    Yhatpred = Y_pred.cpu().numpy() > 0.4
    export_midi(Yhatpred, './BWV846.mid'.format('transcription_',os.path.basename(filepath)[:-4]))

0/2 batches1/2 batches