In [1]:
import os, signal, sys

import numpy as np
import torch
import torch.nn as nn
from torch.nn.functional import conv1d
import torchvision

from scipy.io.wavfile import read

from time import time

sys.path.insert(0, '../')
import musicnet
# from helperfunctions import get_audio_segment, get_piano_roll, export_midi
from sklearn.metrics import average_precision_score

os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'   # see issue #152
os.environ['CUDA_VISIBLE_DEVICES']='3'

import matplotlib.pyplot as plt

from pypianoroll import Multitrack, Track, load, parse

if torch.cuda.is_available():
    device = "cuda:0"
    torch.set_default_tensor_type('torch.cuda.FloatTensor')

In [2]:
# lvl1 convolutions are shared between regions
m = 128
k = 512              # lvl1 nodes
n_fft = 4096              # lvl1 receptive field
window = 16384 # total number of audio samples?
stride = 512
batch_size = 100
epsilon = 1e-8

regions = 1 + (window - n_fft)//stride

In [3]:
class CNN(torch.nn.Module):
    def __init__(self, avg=.9998):
        super(CNN, self).__init__()      
        # Create filter windows
        wsin, wcos = musicnet.create_filters(n_fft,k, low=50, high=6000,
                                      windowing="hann", freq_scale='log')
        self.wsin = torch.Tensor(wsin)
        self.wcos = torch.Tensor(wcos)               
        # Creating Layers
        
        k_out = 128
        k2_out = 256
        self.CNN_freq = nn.Conv2d(1,k_out,
                                kernel_size=(128,1),stride=(2,1))
        self.CNN_time = nn.Conv2d(k_out,k2_out,
                                kernel_size=(1,25),stride=(1,1))        
        self.linear = torch.nn.Linear(k2_out*193, m, bias=False)

        # Initialize weights
            # Do something
        
    def forward(self,x):
        zx = conv1d(x[:,None,:], self.wsin, stride=stride).pow(2) \
           + conv1d(x[:,None,:], self.wcos, stride=stride).pow(2) # shape = (batch, 512,25)
        zx = torch.log(zx + 1e-12)
        z2 = torch.relu(self.CNN_freq(zx.unsqueeze(1))) # Make channel as 1 (N,C,H,W) shape = [10, 128, 193, 25]
        z3 = torch.relu(self.CNN_time(z2)) # shape = [10, 256, 193, 1]
        y = self.linear(torch.relu(torch.flatten(z3,1)))
        return y

In [None]:
model = CNN()
model = model.to(device)

In [None]:
model.load_state_dict(torch.load('../weights/translation_invariant_baseline'))

In [None]:
def access_full(path):
    with open(path, 'rb') as f:
        x = np.fromfile(f, dtype=np.float32)
    return x

In [None]:
def get_piano_roll_from_wav(filepath, model, device, window=16384, stride=1000, offset=44100, count=7500, batch_size=500, m=128):
    sf=4
    x = read(filepath)[1]
    if x.ndim==2:
        x = x.mean(1) # convert stereo to mono
    elif x.ndim>2:
        print("the audio shape {} is not correct, please check and fix it".format(x.shape))
    
    if stride == -1:
        stride = (x.shape[0] - offset - int(sf*window))/(count-1)
        stride = int(stride)
        print("Number of stride = ", stride)
    else:
        count = (x.shape[0]- offset - int(sf*window))/stride + 1
        count = int(count)
        
    X = np.zeros([count, window])
    Y = np.zeros([count, m])    
        
    for i in range(count):
        temp =  x[offset+i*stride:offset+i*stride+window]
        temp = temp / (np.linalg.norm(temp) + epsilon)
        X[i,:] = temp
    
    with torch.no_grad():
        Y_pred = torch.zeros([count,m])
        for i in range(len(X)//batch_size):
            print(f"{i}/{(len(X)//batch_size)} batches", end = '\r')
            X_batch = torch.tensor(X[batch_size*i:batch_size*(i+1)]).float().to(device)
            Y_pred[batch_size*i:batch_size*(i+1)] = model(X_batch).cpu()
    
    return Y_pred

In [None]:
def export_midi(Y_pred, path):
    # Create a piano-roll matrix, where the first and second axes represent time
    # and pitch, respectively, and assign a C major chord to the piano-roll
    # Create a `pypianoroll.Track` instance
    track = Track(pianoroll=Y_pred*127, program=0, is_drum=False,
                  name='my awesome piano')   
    multitrack = Multitrack(tracks=[track], tempo=60, beat_resolution=86)
    multitrack.write(path)    

In [None]:
folder = './'
files = ['BWV846.wav','BWV972.wav','2.wav', '3.wav']
filepath_list = [os.path.join(folder, i) for i in files]

In [None]:
for filepath in filepath_list:
    Y_pred = get_piano_roll_from_wav(filepath, model, device,
                                window=window, m=m, stride=512)
    Yhatpred = Y_pred.cpu().numpy() > 0.4
    export_midi(Yhatpred, './midi_output/{}_{}.mid'.format('CNN_',os.path.basename(filepath)[:-4]))