In [1]:
import numpy as np
import glob
import pypianoroll as ppr
import time
import music21
import os
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from utils.utilsPreprocessing import *
#np.set_printoptions(threshold=np.inf)
#torch.set_printoptions(threshold=50000)

In [2]:
##################################
#HYPERPARAMS
##################################
epochs = 20
learning_rate = 1e-4
batch_size= 98
log_interval = 750  #Log/show loss per batch

# Load MIDI files from npz

In [3]:
data = np.load('/media/EXTHD/niciData/YamahaPianoCompetition2002TransposedBy60.npz')

midiDatasetTrain = data['train']
midiDatasetTest = data['test']

data.close()

"""
print("Training set: ({}, {}, {}, {})".format(midiDatasetTrain.size()[0],
                                                midiDatasetTrain.size()[1],
                                                midiDatasetTrain.size()[2],
                                                midiDatasetTrain.size()[3]))
print("Test set: ({}, {}, {}, {})".format(midiDatasetTest.size()[0],
                                                midiDatasetTest.size()[1],
                                                midiDatasetTest.size()[2],
                                                midiDatasetTest.size()[3]))
"""

print("Training set: {}".format(midiDatasetTrain.shape))
print("Test set: {}".format(midiDatasetTest.shape))

Training set: (2398742, 1, 96, 60)
Test set: (289574, 1, 96, 60)


In [4]:
#print(getSlicedPianorollMatrix('WikifoniaServer/train80/Ahmad-Jamal---Poinciana.mid').shape)

In [5]:
fullPitch = 128
_, _, length, reducedPitch = midiDatasetTrain.shape

# CDVAE

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
midiDatasetTrain = torch.from_numpy(midiDatasetTrain)
trainLoader = torch.utils.data.DataLoader(midiDatasetTrain, batch_size=batch_size, shuffle=False, drop_last=True)

midiDatasetTest = torch.from_numpy(midiDatasetTest)
testLoader = torch.utils.data.DataLoader(midiDatasetTest, batch_size=batch_size, shuffle=False, drop_last=True)

In [8]:
class CDVAE(nn.Module):
    def __init__(self, batch_size=7, tie_weights=True):
        super(CDVAE, self).__init__()
        
        self.batch_size = batch_size
        self.bias1 = torch.empty(400).to(device);torch.nn.init.normal_(self.bias1)
        self.bias2 = torch.empty(800).to(device);torch.nn.init.normal_(self.bias2)
        self.bias3 = torch.empty(2400).to(device);torch.nn.init.normal_(self.bias3)
        self.bias4 = torch.empty(400).to(device);torch.nn.init.normal_(self.bias4)
        self.bias5 = torch.empty(200).to(device);torch.nn.init.normal_(self.bias5)
        self.bias6 = torch.empty(100).to(device);torch.nn.init.normal_(self.bias6)
        self.bias7 = torch.empty(1).to(device);torch.nn.init.normal_(self.bias7)

        
        ###ENCODER###
        self.conv1 = nn.Conv2d(1,100,(16,5),stride=(16,5),padding=0)
        self.bn1 = nn.BatchNorm2d(100)
        self.elu1 = nn.ELU()
        self.conv2 = nn.Conv2d(100,200,(2,1),stride=(2,1),padding=0)
        self.bn2 = nn.BatchNorm2d(200)
        self.elu2 = nn.ELU()
        self.conv3 = nn.Conv2d(200,400,(2,2),stride=(1,2),padding=0)
        self.bn3 = nn.BatchNorm2d(400)
        self.elu3 = nn.ELU()
        self.conv4 = nn.Conv2d(400,800,(2,2),stride=(2,2),padding=0)
        self.bn4 = nn.BatchNorm2d(800)
        self.elu4 = nn.ELU()
        
        self.fc5 =  nn.Linear(2400,800)
        self.bn5 = nn.BatchNorm1d(800)
        self.elu5 = nn.ELU()
        self.fc6 = nn.Linear(800,400)
        self.bn6 = nn.BatchNorm1d(400)
        self.elu6 = nn.ELU()
        self.fc7 = nn.Linear(400,100)
        self.bn7 = nn.BatchNorm1d(100)
        self.elu7 = nn.ELU()
        
        ###LSTM###
        self.lstm1 = nn.LSTM(input_size=100, hidden_size=400, num_layers=2)
        self.lstm2 = nn.LSTM(input_size=400, hidden_size=100, num_layers=1)

    
    def encoder(self, x):
        #print("ENOCDER")
        hEnc = self.conv1(x); hEnc = self.bn1(hEnc); hEnc = self.elu1(hEnc)
        hEnc = self.conv2(hEnc); hEnc = self.bn2(hEnc); hEnc = self.elu2(hEnc)
        hEnc = self.conv3(hEnc); hEnc = self.bn3(hEnc); hEnc = self.elu3(hEnc)
        #print(hEnc.size())
        hEnc = self.conv4(hEnc); hEnc = self.bn4(hEnc); hEnc = self.elu4(hEnc)
        #print(hEnc.size())
        
        hEnc = torch.squeeze(hEnc,3).view(-1,800*3)

        hEnc = self.fc5(hEnc); hEnc = self.bn5(hEnc); hEnc = self.elu5(hEnc)
        hEnc = self.fc6(hEnc); hEnc = self.bn6(hEnc); hEnc = self.elu6(hEnc)
        hEnc = self.fc7(hEnc); hEnc = self.bn7(hEnc); hEnc = self.elu7(hEnc)
        return hEnc

    def decoder(self, z):
        #print("DECODER")
        hDec = F.linear(z,weight=self.fc7.weight.transpose(0,1),bias=self.bias1)
        hDec = F.batch_norm(hDec, running_mean=self.bn6.running_mean,
                            running_var=self.bn6.running_var, weight=self.bn6.weight)
        hDec = F.elu(hDec)
        hDec = F.linear(hDec,weight=self.fc6.weight.transpose(0,1),bias=self.bias2)
        hDec = F.batch_norm(hDec, running_mean=self.bn5.running_mean,
                            running_var=self.bn5.running_var, weight=self.bn5.weight)
        hDec = F.elu(hDec)
        hDec = F.linear(hDec,weight=self.fc5.weight.transpose(0,1),bias=self.bias3)
        ###CANNOT REUSE BATCHNORM
        #hDec = F.batch_norm(hDec, running_mean=self.bn4.running_mean,
        #                    running_var=self.bn4.running_var, weight=self.bn4.weight)
        hDec = F.elu(hDec)
        
        hDec = hDec.view(hDec.size()[0],800,-1).unsqueeze(2)
        #print(hDec.size())
        hDec = F.conv_transpose2d(hDec, weight=self.conv4.weight,
                                 bias=self.bias4,stride=(2,2),padding=0)
        hDec = F.batch_norm(hDec, running_mean=self.bn3.running_mean,
                            running_var=self.bn3.running_var, weight=self.bn3.weight)
        #print(hDec.size())
        hDec = F.elu(hDec)
        hDec = F.conv_transpose2d(hDec, weight=self.conv3.weight,
                                 bias=self.bias5,stride=(1,2),padding=0)
        hDec = F.batch_norm(hDec, running_mean=self.bn2.running_mean,
                            running_var=self.bn2.running_var, weight=self.bn2.weight)
        hDec = F.elu(hDec)
        hDec = F.conv_transpose2d(hDec, weight=self.conv2.weight,
                                 bias=self.bias6,stride=(2,1),padding=0)
        hDec = F.batch_norm(hDec, running_mean=self.bn1.running_mean,
                            running_var=self.bn1.running_var, weight=self.bn1.weight)
        hDec = F.elu(hDec)
        hDec = F.conv_transpose2d(hDec, weight=self.conv1.weight,
                                  bias=self.bias7,stride=(16,5),padding=0)
        ###CANNOT REUSE BATCHNORM 
        #hDec = F.batch_norm(hDec, running_mean=self.bn5.running_mean,
        #                    running_var=self.bn5.running_var, weight=self.bn5.weight)
        hDec = F.elu(hDec)

        return hDec


    def forward(self, x):
        embed = self.encoder(x)
        
        ####MOVE TO HIDDEN_INIT
        h_t = torch.zeros(2,int(embed.size()[0]/7),400).to(device)
        c_t = torch.zeros(2,int(embed.size()[0]/7),400).to(device)
        h2_t = torch.zeros(1,int(embed.size()[0]/7),100).to(device)
        c2_t = torch.zeros(1,int(embed.size()[0]/7),100).to(device)  
        ###HIDDEN INIT END
        
        #IF FOR TESTING UNKNOWN SEQUENCES
        if(embed.size()[0]>7):
            embedTemp = torch.chunk(embed, int(self.batch_size/7),dim=0)
            #print(len(embedTemp))
            embed7s = embedTemp[0].unsqueeze(1)
            for emb in embedTemp[1:]:
                embed7s = torch.cat((embed7s, emb.unsqueeze(1)),dim=1)
        else:
            embed7s = embed.unsqueeze(1)
        
        lstmOut, (h_t, c_t) = self.lstm1(embed7s,(h_t, c_t))
        lstmOut, (h2_t,c2_t) = self.lstm2(lstmOut,(h2_t, c2_t))
        
        return embed, lstmOut, self.decoder(lstmOut[-1,:,:])

def init_weights(m):
    if type(m) == nn.Linear:# or nn.Conv2d:
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
    elif type(m) == nn.Conv2d:
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)

model = CDVAE(batch_size=batch_size).to(device)
model.apply(init_weights).to(device)
print(model)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)


def loss_function(embed, lstmOut, embedNext):
    
    cosLSTM = nn.CosineSimilarity(dim=0, eps=1e-8)
    
    cosSimLSTM = 0
    #BATCHSIZE 7
    if(embed.size()[0]==7):
        #print("batchsize = 7 ?")
        cosSimLSTM = cosLSTM(lstmOut.squeeze(1)[-1], embedNext[0])
     
    #BATCHSIZE > 7
    else:
        j=0
        for i in range(7,embed.size()[0],7):
            #print(i)
            cosSimLSTM += cosLSTM(lstmOut[-1,j,:],embed[i])
            j+=1
        cosSimLSTM += cosLSTM(lstmOut[-1,j,:],embedNext[0])
    
    return -cosSimLSTM/lstmOut.size()[1]
        

def train(epoch):
    model.train()
    trainLoss = 0

    for batch_idx, data in enumerate(trainLoader):
        #print(batch_idx)
        data = data.float().to(device)
        optimizer.zero_grad()
        embedding, lstmOut, reconPrediction = model(data)
        nextBatch = next(iter(trainLoader)).float().to(device)
        embeddingNext, _, _ = model(nextBatch)
        #print(nextBatch.size())
        loss = loss_function(embedding, lstmOut, embeddingNext)
        loss.backward()
        trainLoss += loss.item()
        optimizer.step()
        if(batch_idx % log_interval == 0):
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tAccuracy: {:.6f}'.format(
                epoch, batch_idx * len(data), len(trainLoader.dataset),
                100. * batch_idx / len(trainLoader),
                -loss.item() / len(data)))
        #if(batch_idx==10):
        #    break
    print('====> Epoch: {} Average accuracy: {:.4f}'.format(
          epoch, -trainLoss / len(trainLoader.dataset)))

def test(epoch):
    model.eval()
    testLoss = 0
    with torch.no_grad():
        for i, data in enumerate(testLoader):
            data = data.float().to(device)
            embedding, lstmOut, reconPrediction = model(data)
            nextBatch = next(iter(testLoader)).float().to(device)
            embeddingNext, _, _= model(nextBatch)
            testLoss += loss_function(embedding, lstmOut, embeddingNext).item()
            
            #if(i==10):
            #    break
    testLoss /= len(testLoader.dataset)

    print('====> Test set accuracy: {:.4f}'.format(-testLoss))

CDVAE(
  (conv1): Conv2d(1, 100, kernel_size=(16, 5), stride=(16, 5))
  (bn1): BatchNorm2d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (elu1): ELU(alpha=1.0)
  (conv2): Conv2d(100, 200, kernel_size=(2, 1), stride=(2, 1))
  (bn2): BatchNorm2d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (elu2): ELU(alpha=1.0)
  (conv3): Conv2d(200, 400, kernel_size=(2, 2), stride=(1, 2))
  (bn3): BatchNorm2d(400, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (elu3): ELU(alpha=1.0)
  (conv4): Conv2d(400, 800, kernel_size=(2, 2), stride=(2, 2))
  (bn4): BatchNorm2d(800, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (elu4): ELU(alpha=1.0)
  (fc5): Linear(in_features=2400, out_features=800, bias=True)
  (bn5): BatchNorm1d(800, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (elu5): ELU(alpha=1.0)
  (fc6): Linear(in_features=800, out_features=400, bias=True)
  (bn6): BatchNorm1d(400, eps=1e-05, mome

In [9]:
"""
#LOAD MODEL
pathToModel = 'model/YamahaPianoComp2002_10Epochs_LSTM_TiedWeights.model'

try:
    #LOAD TRAINED MODEL INTO GPU
    if(torch.cuda.is_available()):
        model = torch.load(pathToModel)
        
    #LOAD MODEL TRAINED ON GPU INTO CPU
    else:
        model = torch.load(pathToModel, map_location=lambda storage, loc: storage)
    print("\n--------model restored--------\n")
except:
    print("\n--------no saved model found--------\n")
"""
print('')




In [10]:
for epoch in range(1, epochs + 1):
    train(epoch)
    test(epoch)

====> Epoch: 1 Average accuracy: 0.0098
====> Test set accuracy: 0.0102
====> Epoch: 2 Average accuracy: 0.0102
====> Test set accuracy: 0.0102
====> Epoch: 3 Average accuracy: 0.0102
====> Test set accuracy: 0.0102
====> Epoch: 4 Average accuracy: 0.0102
====> Test set accuracy: 0.0102


====> Epoch: 5 Average accuracy: 0.0102
====> Test set accuracy: 0.0102
====> Epoch: 6 Average accuracy: 0.0102
====> Test set accuracy: 0.0102
====> Epoch: 7 Average accuracy: 0.0102
====> Test set accuracy: 0.0102
====> Epoch: 8 Average accuracy: 0.0102
====> Test set accuracy: 0.0102


====> Epoch: 9 Average accuracy: 0.0102
====> Test set accuracy: 0.0102
====> Epoch: 10 Average accuracy: 0.0102
====> Test set accuracy: 0.0102
====> Epoch: 11 Average accuracy: 0.0102
====> Test set accuracy: 0.0102
====> Epoch: 12 Average accuracy: 0.0102
====> Test set accuracy: 0.0102


====> Epoch: 13 Average accuracy: 0.0102
====> Test set accuracy: 0.0102
====> Epoch: 14 Average accuracy: 0.0102
====> Test set accuracy: 0.0102
====> Epoch: 15 Average accuracy: 0.0102
====> Test set accuracy: 0.0102
====> Epoch: 16 Average accuracy: 0.0102
====> Test set accuracy: 0.0102


====> Epoch: 17 Average accuracy: 0.0102
====> Test set accuracy: 0.0102
====> Epoch: 18 Average accuracy: 0.0102
====> Test set accuracy: 0.0102
====> Epoch: 19 Average accuracy: 0.0102
====> Test set accuracy: 0.0102
====> Epoch: 20 Average accuracy: 0.0102
====> Test set accuracy: 0.0102


In [14]:
torch.save(model,'/media/EXTHD/niciData/models/YamahaPianoComp2002_10Epochs_LSTM_TiedWeights.model')

  "type " + obj.__name__ + ". It won't be checked "


# Play Prediction by generating an 8th sequence after listening to 7

In [12]:
#np.set_printoptions(precision=2, suppress=True, threshold=np.inf)


In [13]:
"""
###PLAY WHOLE SONG IN BARS
with torch.no_grad():
    
    sampleNp1 = getSlicedPianorollMatrixNp("/Volumes/EXT/DATASETS/DougMcKenzieFiles_noDrums/samples/KingandI.mid")
    sampleNp1 = deleteZeroMatrices(sampleNp1)
    sample = np.expand_dims(sampleNp1[0,:,36:-32],axis=0)
    print(sample.shape)
    for i, sampleNp in enumerate(sampleNp1[1:7]):
        print(sampleNp.shape)
        if(np.any(sampleNp)):
            sampleNp = sampleNp[:,36:-32]
            sampleNp = np.expand_dims(sampleNp,axis=0)
            sample = np.concatenate((sample,sampleNp),axis=0)
    samplePlay = sample[0,:,:]
    for s in sample:
        samplePlay = np.concatenate((samplePlay,s),axis=0)
    samplePlay = addCuttedOctaves(samplePlay)
    print(samplePlay.shape)
    sample = torch.from_numpy(sample).float().to(device)
    sample = torch.unsqueeze(sample,1)
    print(sample.size())
    _,_, pred = model(sample)
    #reconstruction = recon.squeeze(0).squeeze(0).cpu().numpy()
    prediction = pred.squeeze(0).squeeze(0).cpu().numpy()

    #print(sampleNp[:,:])
    #print(prediction[:,:])
    #print(np.sum(sampleNp.numpy(), axis=1))

    #NORMALIZE PREDICTIONS
    #reconstruction /= np.abs(np.max(reconstruction))
    prediction /= np.abs(np.max(prediction))
    #print(prediction)

    #CHECK MIDI ACTIVATIONS IN PREDICTION TO INCLUDE RESTS
    #reconstruction[reconstruction < 0.3] = 0
    prediction[prediction < 0.3] = 0



    ###MONOPHONIC OUTPUT MATRIX POLOYPHONIC POSSIBLE WITH ACTIVATION THRESHOLD###
    #score = music21.converter.parse('WikifoniaServer/samples/The-Doors---Don\'t-you-love-her-Madly?.mid')
    #score.show()

    samplePlay = debinarizeMidi(samplePlay, prediction=False)
    samplePlay = addCuttedOctaves(samplePlay)
    #reconstruction = debinarizeMidi(reconstruction, prediction=True)
    #reconstruction = addCuttedOctaves(reconstruction)
    prediction = debinarizeMidi(prediction, prediction=True)
    prediction = addCuttedOctaves(prediction)

    #print(np.argmax(samplePlay, axis=1))
    #print('')
    #print(np.argmax(prediction, axis=1))
    print("INPUT")
    print(samplePlay.shape)
    pianorollMatrixToTempMidi(samplePlay)
    tempMidi(show=True,play=True)
    #print("RECONSTRUCTION")
    #pianorollMatrixToTempMidi(reconstruction)        
    #tempMidi(show=True,play=True)
    print("PREDICTION")
    pianorollMatrixToTempMidi(prediction, prediction=True)        
    tempMidi(show=True,play=True)
    print("\n\n")
            
"""
print('')


