In [1]:
import os,sys,signal, copy
import math
from contextlib import contextmanager

import pickle
import numpy as np                                       # fast vectors and matrices
import matplotlib.pyplot as plt                          # plotting
import matplotlib.ticker as ticker
from matplotlib.animation import ArtistAnimation

from scipy.fftpack import fft
from scipy.signal.windows import hann

sys.path.insert(0, '../')
import musicnetRaven as musicnet

from time import time

sys.path.insert(0,'../lib/')
import config
import diagnosticsP3
# import base_model

from sklearn.metrics import average_precision_score

os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'   # see issue #152
os.environ['CUDA_VISIBLE_DEVICES']='3'

import torch
from torch.nn.functional import conv1d, mse_loss
import torch.nn.functional as F
import torch.nn as nn
from torchcontrib.optim import SWA

from tqdm import tqdm
import mir_eval

from torchsummary import summary

%matplotlib inline

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
# Network Parameters
epochs = 20
train_size = 100000
test_size = 50000
lr = 5e-4
pitch_shift = 0
jitter = 0.
num_workers = 15

# lvl1 convolutions are shared between regions
m = 128
k = 512              # lvl1 nodes
n_fft = 4096              # lvl1 receptive field
window = 16384 # total number of audio samples?
stride = 512
batch_size = 100

In [3]:
regions = 1 + (window - n_fft)//stride

def worker_init(args):
    signal.signal(signal.SIGINT, signal.SIG_IGN) # ignore signals so parent can handle them
    np.random.seed(os.getpid() ^ int(time())) # approximately random seed for workers
kwargs = {'num_workers': num_workers, 'pin_memory': True, 'worker_init_fn': worker_init}

In [4]:
start = time()
root = '../../data/'
test_set = musicnet.MusicNet(root=root, train=False, download=True, refresh_cache=False, window=window, epoch_size=test_size, mmap=False)
print("Time used = ", time()-start)

Time used =  0.3503999710083008


In [5]:
test_loader = torch.utils.data.DataLoader(dataset=test_set,batch_size=batch_size,**kwargs)

In [6]:
def create_filtersv2(n_fft, freq_bins=None, low=50,high=6000, sr=44100, freq_scale='linear', mode="fft"):
    if freq_bins==None:
        freq_bins = n_fft//2+1
    
    s = torch.arange(0, n_fft, 1.)
    wsin = torch.empty((freq_bins,1,n_fft))
    wcos = torch.empty((freq_bins,1,n_fft))
    start_freq = low
    end_freq = high
    

    # num_cycles = start_freq*d/44000.
    # scaling_ind = np.log(end_freq/start_freq)/k
    
    if mode=="fft":
        window_mask = 1
    elif mode=="stft":
        window_mask = 0.5-0.5*torch.cos(2*math.pi*s/(n_fft)) # same as hann(n_fft, sym=False)
    else:
        raise Exception("Unknown mode, please chooes either \"stft\" or \"fft\"")
        
    if freq_scale == 'linear':
        start_bin = start_freq*n_fft/sr
        scaling_ind = (end_freq/start_freq)/freq_bins
        for k in range(freq_bins): # Only half of the bins contain useful info
            wsin[k,0,:] = window_mask*torch.sin(2*math.pi*(k*scaling_ind*start_bin)*s/n_fft)
            wcos[k,0,:] = window_mask*torch.cos(2*math.pi*(k*scaling_ind*start_bin)*s/n_fft)
    elif freq_scale == 'log':
        start_bin = start_freq*n_fft/sr
        scaling_ind = np.log(end_freq/start_freq)/freq_bins
        for k in range(freq_bins): # Only half of the bins contain useful info
            wsin[k,0,:] = window_mask*torch.sin(2*math.pi*(np.exp(k*scaling_ind)*start_bin)*s/n_fft)
            wcos[k,0,:] = window_mask*torch.cos(2*math.pi*(np.exp(k*scaling_ind)*start_bin)*s/n_fft)    
    else:
        print("Please select the correct frequency scale, 'linear' or 'log'")
    
    return wsin,wcos


In [7]:
class Model(nn.Module):
    def __init__(self, avg=.9998):
        super(Model, self).__init__()
        # Create filter windows
        wsin, wcos = create_filtersv2(n_fft,k, low=50, high=6000,
                                      mode="stft", freq_scale='log')
        with torch.cuda.device(0):
            self.wsin = torch.Tensor(wsin).cuda()
            self.wcos = torch.Tensor(wcos).cuda()
            
        # Creating Layers
#         self.linear = torch.nn.Linear(regions*k, k,bias=False)
#         self.linear_output = torch.nn.Linear(k,m, bias=False)
#         wscale = 10e-5
#         torch.nn.init.normal_(self.linear.weight, std=1e-4) # initialize
#         torch.nn.init.normal_(self.linear_output.weight, std=1e-4)
#         torch.nn.init.zeros_(self.linear.weight)
#         torch.nn.init.zeros_(self.linear_output.weight)
        k_out = 128
        k2_out = 256
        stride1 = (2,1)
        self.CNN_freq = nn.Conv2d(1,k_out,
                                kernel_size=(128,1),stride=(2,1))
        self.CNN_time = nn.Conv2d(k_out,k2_out,
                                kernel_size=(1,25),stride=(1,1))
        self.Linear = nn.Linear(k2_out*1*193, m)
        self.activation = nn.ReLU()
        
        torch.nn.init.normal_(self.CNN_freq.weight, std=1e-4)
        torch.nn.init.normal_(self.CNN_time.weight, std=1e-4)
        torch.nn.init.normal_(self.Linear.weight, std=1e-4)
        
        
        self.avg = avg
        #Create a container for weight average
        self.averages = copy.deepcopy(list(parm.cuda().data for parm in self.parameters())) 
    def _get_conv_output(self, shape):
        bs = 1
        x = torch.tensor(torch.rand(bs, *shape))
        output_feat = self._forward_features(x)
        n_size = output_feat.data.view(bs, -1).size(1)
        return n_size
        
    def forward(self,x):
        zx = conv1d(x[:,None,:], self.wsin, stride=stride).pow(2) \
           + conv1d(x[:,None,:], self.wcos, stride=stride).pow(2)
        zx = torch.log(zx + 10e-8) # Log Magnitude Spectrogram
        z2 = self.CNN_freq(zx.unsqueeze(1)) # Make channel as 1 (N,C,H,W)
        z2 = F.relu(z2)
        z3 = self.CNN_time(z2)
        z3 = F.relu(z3)
        y = self.Linear(self.activation(torch.flatten(z3,1)))
        return y
    
    def average_iterates(self):
        for parm, pavg in zip(self.parameters(), self.averages):
            pavg.mul_(self.avg).add_(1.-self.avg, parm.data) # 0.9W_avg + 0.1W_this_ite
 

    
@contextmanager
def averages(model):
    orig_parms = copy.deepcopy(list(parm.data for parm in model.parameters()))
    for parm, pavg in zip(model.parameters(), model.averages):
        parm.data.copy_(pavg)
    yield
    for parm, orig in zip(model.parameters(), orig_parms):
        parm.data.copy_(orig)

In [8]:
class MultiLayerModel(torch.nn.Module):
    def __init__(self, avg=.9998):
        super(MultiLayerModel, self).__init__()
        # Create filter windows
        wsin, wcos = create_filtersv2(n_fft,k, low=50, high=6000,
                                      mode="stft", freq_scale='log')
        with torch.cuda.device(0):
            self.wsin = torch.Tensor(wsin).cuda()
            self.wcos = torch.Tensor(wcos).cuda()
            
        # Creating Layers
        self.linear = torch.nn.Linear(regions*k, k,bias=False)
        self.linear_output = torch.nn.Linear(k,m, bias=False)
        wscale = 10e-5
        torch.nn.init.normal_(self.linear.weight, std=1e-4) # initialize
        torch.nn.init.normal_(self.linear_output.weight, std=1e-4)
#         torch.nn.init.zeros_(self.linear.weight)
#         torch.nn.init.zeros_(self.linear_output.weight)
        self.activation = torch.nn.ReLU()
        
        self.avg = avg
        #Create a container for weight average
        self.averages = copy.deepcopy(list(parm.cuda().data for parm in self.parameters())) 

        
    def forward(self,x):
        zx = conv1d(x[:,None,:], self.wsin, stride=stride).pow(2) \
           + conv1d(x[:,None,:], self.wcos, stride=stride).pow(2)
        zx = self.activation(self.linear(torch.log(zx + 10e-8).view(x.data.size()[0],regions*k)))
        output = self.linear_output(zx)
        return output
    
    def average_iterates(self):
        for parm, pavg in zip(self.parameters(), self.averages):
            pavg.mul_(self.avg).add_(1.-self.avg, parm.data) # 0.9W_avg + 0.1W_this_ite

In [9]:
class CNNModel(nn.Module):
    def __init__(self, avg=.9998):
        super(CNNModel, self).__init__()
        # Create filter windows
        wsin, wcos = create_filtersv2(n_fft,k, low=50, high=6000,
                                      mode="stft", freq_scale='log')
        with torch.cuda.device(0):
            self.wsin = torch.Tensor(wsin).cuda()
            self.wcos = torch.Tensor(wcos).cuda()
            
        # Creating Layers
#         self.linear = torch.nn.Linear(regions*k, k,bias=False)
#         self.linear_output = torch.nn.Linear(k,m, bias=False)
#         wscale = 10e-5
#         torch.nn.init.normal_(self.linear.weight, std=1e-4) # initialize
#         torch.nn.init.normal_(self.linear_output.weight, std=1e-4)
#         torch.nn.init.zeros_(self.linear.weight)
#         torch.nn.init.zeros_(self.linear_output.weight)
        k_out = 128
        k2_out = 256
        stride1 = (2,1)
        self.CNN_freq = nn.Conv2d(1,k_out,
                                kernel_size=(128,1),stride=(2,1))
        self.CNN_time = nn.Conv2d(k_out,k2_out,
                                kernel_size=(1,25),stride=(1,1))
        self.Linear = nn.Linear(k2_out*1*193, m)
        self.activation = nn.ReLU()
        
        torch.nn.init.normal_(self.CNN_freq.weight, std=1e-4)
        torch.nn.init.normal_(self.CNN_time.weight, std=1e-4)
        torch.nn.init.normal_(self.Linear.weight, std=1e-4)
        
        
        self.avg = avg
        #Create a container for weight average
        self.averages = copy.deepcopy(list(parm.cuda().data for parm in self.parameters())) 
    def _get_conv_output(self, shape):
        bs = 1
        x = torch.tensor(torch.rand(bs, *shape))
        output_feat = self._forward_features(x)
        n_size = output_feat.data.view(bs, -1).size(1)
        return n_size
        
    def forward(self,x):
        zx = conv1d(x[:,None,:], self.wsin, stride=stride).pow(2) \
           + conv1d(x[:,None,:], self.wcos, stride=stride).pow(2)
        zx = torch.log(zx + 10e-8) # Log Magnitude Spectrogram
        z2 = self.CNN_freq(zx.unsqueeze(1)) # Make channel as 1 (N,C,H,W)
        z2 = F.relu(z2)
        z3 = self.CNN_time(z2)
        z3 = F.relu(z3)
        y = self.Linear(self.activation(torch.flatten(z3,1)))
        return y
    
    def average_iterates(self):
        for parm, pavg in zip(self.parameters(), self.averages):
            pavg.mul_(self.avg).add_(1.-self.avg, parm.data) # 0.9W_avg + 0.1W_this_ite

# Averaged Weights

In [10]:
model = CNNModel()
model.cuda()

result_dict = {'loss_history_train': [],
               'avgp_history_train': [],
               'loss_history_test': [],
               'avgp_history_test': [],
               'parameters': {}}

result_dict['parameters']['train_size'] = train_size
result_dict['parameters']['test_size'] = test_size
result_dict['parameters']['lr'] = lr
result_dict['parameters']['pitch_shift'] = pitch_shift
result_dict['parameters']['jitter'] = jitter

In [11]:
model.load_state_dict(torch.load('./Transistion Invariant/weights/2 Layers CNN-Tried Another Optim'))

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

# Mirex stats

In [12]:
from pypianoroll import Multitrack, Track, load, parse

def get_mir_accuracy(Yhat, Y_true, threshold=0.4):
    Yhatlist = []
    Ylist = []
    Yhatpred = Yhat>threshold
    for i in range(len(Yhatpred)):
        print(f"{i}/{len(Yhatpred)} batches", end = '\r')
        fhat = []
        f = []
        for note in range(m):
            if Yhatpred[i][note] == 1:
                fhat.append(440.*2**(((note)-69.)/12.))

            if Y_true[i][note] == 1:
                f.append(440.*2**(((note)-69.)/12.))

        Yhatlist.append(np.array(fhat))
        Ylist.append(np.array(f))
    avp = average_precision_score(Y_true.flatten(),Yhat.flatten())
    P,R,Acc,Esub,Emiss,Efa,Etot,cP,cR,cAcc,cEsub,cEmiss,cEfa,cEtot = \
    mir_eval.multipitch.metrics(np.arange(len(Ylist))/100.,Ylist,np.arange(len(Yhatlist))/100.,Yhatlist)
    print('{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}'.format(100*avp,100*P,100*R,Acc,Etot,Esub,Emiss,Efa))
    return avp,P,R,Acc,Etot
def get_piano_roll(rec_id, window=16384, stride=1000, offset=44100, count=7500):
    sf=4
    if stride == -1:
        stride = (test_set.records[rec_id][1] - offset - int(sf*window))/(count-1)
        stride = int(stride)
    else:
        count = (test_set.records[rec_id][1] - offset - int(sf*window))/stride + 1
        count = int(count)

    X = np.zeros([count, window])
    Y = np.zeros([count, m])    

    for i in range(count):
        X[i,:], Y[i] =  test_set.access(rec_id, offset+i*stride)

    batch_size = 500
    Y_pred = np.zeros([count,m])
    output = []
    for i in range(len(X)//500):
        X_batch = torch.tensor(X[batch_size*i:batch_size*(i+1)]).float().cuda()   
        Y_pred[i*batch_size:(i+1)*batch_size] = model(X_batch).data.cpu().numpy()
    
    return Y_pred, Y

def export_midi(Y_pred, path):
    # Create a piano-roll matrix, where the first and second axes represent time
    # and pitch, respectively, and assign a C major chord to the piano-roll
    # Create a `pypianoroll.Track` instance
    track = Track(pianoroll=Y_pred*127, program=0, is_drum=False,
                  name='my awesome piano')   
    multitrack = Multitrack(tracks=[track], tempo=120.0, beat_resolution=24)
    multitrack.write(path)    

In [13]:
with test_set:
    print('AvgP\tP\tR\tAcc\tETot\tESub\tEmiss\tEfa')
    model.eval()
    Accavg = 0
    Etotavg = 0
    for songid in test_set.rec_ids:
        Y_pred, Y_true = get_piano_roll(songid, stride=-1)
    #     Yhatpred = Y_pred > 0.4
        _,_,_,Acc,Etot = get_mir_accuracy(Y_pred, Y_true)
        Accavg += Acc
        Etotavg += Etot

AvgP	P	R	Acc	ETot	ESub	Emiss	Efa


RuntimeError: CUDA out of memory. Tried to allocate 1.15 GiB (GPU 0; 31.74 GiB total capacity; 1.28 GiB already allocated; 114.12 MiB free; 24.08 MiB cached)

In [None]:
print('Average Accuracy: \t{:2.2f}\nAverage Error: \t\t{:2.2f}'.format(Accavg*100/len(test_set.rec_ids), Etotavg*100/len(test_set.rec_ids)))

In [38]:
songids = [1759, 2106, 2382, 2556]

In [39]:
for songid in songids:
    model.eval()
    Y_pred, Y_true = get_piano_roll(songid, stride=-1)
    _,_,_,Acc,Etot = get_mir_accuracy(Y_pred, Y_true)
    
    Yhatpred = Y_pred > 0.4
    export_midi(Yhatpred, './{}_{}_Y_pred.mid'.format('CNN_MSE',str(songid)))

0/7500 batches1/7500 batches2/7500 batches3/7500 batches4/7500 batches5/7500 batches6/7500 batches7/7500 batches8/7500 batches9/7500 batches10/7500 batches11/7500 batches12/7500 batches13/7500 batches14/7500 batches15/7500 batches16/7500 batches17/7500 batches18/7500 batches19/7500 batches20/7500 batches21/7500 batches22/7500 batches23/7500 batches24/7500 batches25/7500 batches26/7500 batches27/7500 batches28/7500 batches29/7500 batches30/7500 batches31/7500 batches32/7500 batches33/7500 batches34/7500 batches35/7500 batches36/7500 batches37/7500 batches38/7500 batches39/7500 batches40/7500 batches41/7500 batches42/7500 batches43/7500 batches44/7500 batches45/7500 batches46/7500 batches47/7500 batches48/7500 batches49/7500 batches50/7500 batches51/7500 batches52/7500 batches53/7500 batches54/7500 batches55/7500 batches56/7500 batches57/7500 batches58/7500 batches59/7500 batches60/7500 batches61/7500 batches62/7500 batches63

713/7500 batches714/7500 batches715/7500 batches716/7500 batches717/7500 batches718/7500 batches719/7500 batches720/7500 batches721/7500 batches722/7500 batches723/7500 batches724/7500 batches725/7500 batches726/7500 batches727/7500 batches728/7500 batches729/7500 batches730/7500 batches731/7500 batches732/7500 batches733/7500 batches734/7500 batches735/7500 batches736/7500 batches737/7500 batches738/7500 batches739/7500 batches740/7500 batches741/7500 batches742/7500 batches743/7500 batches744/7500 batches745/7500 batches746/7500 batches747/7500 batches748/7500 batches749/7500 batches750/7500 batches751/7500 batches752/7500 batches753/7500 batches754/7500 batches755/7500 batches756/7500 batches757/7500 batches758/7500 batches759/7500 batches760/7500 batches761/7500 batches762/7500 batches763/7500 batches764/7500 batches765/7500 batches766/7500 batches767/7500 batches768/7500 batches769/7500 batches770/7500 batches771/7500 batc

1393/7500 batches1394/7500 batches1395/7500 batches1396/7500 batches1397/7500 batches1398/7500 batches1399/7500 batches1400/7500 batches1401/7500 batches1402/7500 batches1403/7500 batches1404/7500 batches1405/7500 batches1406/7500 batches1407/7500 batches1408/7500 batches1409/7500 batches1410/7500 batches1411/7500 batches1412/7500 batches1413/7500 batches1414/7500 batches1415/7500 batches1416/7500 batches1417/7500 batches1418/7500 batches1419/7500 batches1420/7500 batches1421/7500 batches1422/7500 batches1423/7500 batches1424/7500 batches1425/7500 batches1426/7500 batches1427/7500 batches1428/7500 batches1429/7500 batches1430/7500 batches1431/7500 batches1432/7500 batches1433/7500 batches1434/7500 batches1435/7500 batches1436/7500 batches1437/7500 batches1438/7500 batches1439/7500 batches1440/7500 batches1441/7500 batches1442/7500 batches1443/7500 batches1444/7500 batches1445/7500 batches1446/7500 batches1447/7500 batches1448/7500

2059/7500 batches2060/7500 batches2061/7500 batches2062/7500 batches2063/7500 batches2064/7500 batches2065/7500 batches2066/7500 batches2067/7500 batches2068/7500 batches2069/7500 batches2070/7500 batches2071/7500 batches2072/7500 batches2073/7500 batches2074/7500 batches2075/7500 batches2076/7500 batches2077/7500 batches2078/7500 batches2079/7500 batches2080/7500 batches2081/7500 batches2082/7500 batches2083/7500 batches2084/7500 batches2085/7500 batches2086/7500 batches2087/7500 batches2088/7500 batches2089/7500 batches2090/7500 batches2091/7500 batches2092/7500 batches2093/7500 batches2094/7500 batches2095/7500 batches2096/7500 batches2097/7500 batches2098/7500 batches2099/7500 batches2100/7500 batches2101/7500 batches2102/7500 batches2103/7500 batches2104/7500 batches2105/7500 batches2106/7500 batches2107/7500 batches2108/7500 batches2109/7500 batches2110/7500 batches2111/7500 batches2112/7500 batches2113/7500 batches2114/7500

2558/7500 batches2559/7500 batches2560/7500 batches2561/7500 batches2562/7500 batches2563/7500 batches2564/7500 batches2565/7500 batches2566/7500 batches2567/7500 batches2568/7500 batches2569/7500 batches2570/7500 batches2571/7500 batches2572/7500 batches2573/7500 batches2574/7500 batches2575/7500 batches2576/7500 batches2577/7500 batches2578/7500 batches2579/7500 batches2580/7500 batches2581/7500 batches2582/7500 batches2583/7500 batches2584/7500 batches2585/7500 batches2586/7500 batches2587/7500 batches2588/7500 batches2589/7500 batches2590/7500 batches2591/7500 batches2592/7500 batches2593/7500 batches2594/7500 batches2595/7500 batches2596/7500 batches2597/7500 batches2598/7500 batches2599/7500 batches2600/7500 batches2601/7500 batches2602/7500 batches2603/7500 batches2604/7500 batches2605/7500 batches2606/7500 batches2607/7500 batches2608/7500 batches2609/7500 batches2610/7500 batches2611/7500 batches2612/7500 batches2613/7500 

3058/7500 batches3059/7500 batches3060/7500 batches3061/7500 batches3062/7500 batches3063/7500 batches3064/7500 batches3065/7500 batches3066/7500 batches3067/7500 batches3068/7500 batches3069/7500 batches3070/7500 batches3071/7500 batches3072/7500 batches3073/7500 batches3074/7500 batches3075/7500 batches3076/7500 batches3077/7500 batches3078/7500 batches3079/7500 batches3080/7500 batches3081/7500 batches3082/7500 batches3083/7500 batches3084/7500 batches3085/7500 batches3086/7500 batches3087/7500 batches3088/7500 batches3089/7500 batches3090/7500 batches3091/7500 batches3092/7500 batches3093/7500 batches3094/7500 batches3095/7500 batches3096/7500 batches3097/7500 batches3098/7500 batches3099/7500 batches3100/7500 batches3101/7500 batches3102/7500 batches3103/7500 batches3104/7500 batches3105/7500 batches3106/7500 batches3107/7500 batches3108/7500 batches3109/7500 batches3110/7500 batches3111/7500 batches3112/7500 batches3113/7500

3557/7500 batches3558/7500 batches3559/7500 batches3560/7500 batches3561/7500 batches3562/7500 batches3563/7500 batches3564/7500 batches3565/7500 batches3566/7500 batches3567/7500 batches3568/7500 batches3569/7500 batches3570/7500 batches3571/7500 batches3572/7500 batches3573/7500 batches3574/7500 batches3575/7500 batches3576/7500 batches3577/7500 batches3578/7500 batches3579/7500 batches3580/7500 batches3581/7500 batches3582/7500 batches3583/7500 batches3584/7500 batches3585/7500 batches3586/7500 batches3587/7500 batches3588/7500 batches3589/7500 batches3590/7500 batches3591/7500 batches3592/7500 batches3593/7500 batches3594/7500 batches3595/7500 batches3596/7500 batches3597/7500 batches3598/7500 batches3599/7500 batches3600/7500 batches3601/7500 batches3602/7500 batches3603/7500 batches3604/7500 batches3605/7500 batches3606/7500 batches3607/7500 batches3608/7500 batches3609/7500 batches3610/7500 batches3611/7500 batches3612/7500 

74.34	69.86	74.10	0.56	0.46	0.12	0.14	0.20
70.29	72.59	63.38	0.51	0.47	0.13	0.23	0.11
48.63	55.47	47.18	0.34	0.74	0.17	0.36	0.21
77.45	74.57	74.33	0.59	0.39	0.12	0.14	0.14
