In [None]:
import numpy as np
import glob
import pypianoroll as ppr
import time
import music21
import os
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from utils.utils import *
#np.set_printoptions(threshold=np.inf)
#torch.set_printoptions(threshold=50000)

In [None]:
# ##########HYPERPARAMS#####################
epochs = 100
learning_rate = 1e-3
weight_decay = 0.999
batch_size = 1 #CHANGE THIS VERWRIRUNG
seq_length = 8
log_interval = 10 #Log/show loss per batch
input_size=100
hidden_size=128
##########################################
##########################################
batch_loader = batch_size * seq_length

In [None]:
data = np.load('../../YamahaPianoCompetition2002NoTranspose.npz')
#midiDatasetTrain = data['train']
midiDatasetTest = data['test']
data.close()

# Load test set

In [None]:
#midiDatasetTrain = torch.from_numpy(midiDatasetTrain)
#trainLoader = torch.utils.data.DataLoader(midiDatasetTrain, batch_size=batch_size, shuffle=True, drop_last=True)

midiDatasetTest = torch.from_numpy(midiDatasetTest)
testLoader = torch.utils.data.DataLoader(midiDatasetTest, batch_size=batch_loader, shuffle=False, drop_last=True)

# Load models

In [None]:
from utils.LSTM import LSTM_Many2Many
from utils.VAE import VAE
from loadModel import loadModel

#for gpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#load models
lstmModel = LSTM_Many2Many(batch_size=batch_loader, seq_length=seq_length, 
             input_size=input_size, hidden_size=hidden_size)
autoencoderModel = VAE()

#load weights
lstmModel = loadModel(lstmModel, '../../new_models_and_plots/LSTM_WikifoniaTP12_128hidden_180epochs_LRe-4_Many2Many.model')
autoencoderModel = loadModel(autoencoderModel, 
                             '../../new_models_and_plots/YamahaPC2002_VAE_Reconstruct_NoTW_20Epochs.model',
                            dataParallelModel=False)
# to device
lstmModel = lstmModel.double().to(device)
autoencoderModel = autoencoderModel.to(device)

# Generate new sample by feeding 4

In [None]:
playSeq = 8
lstmModel.batch_size = int(batch_loader/2)
lstmModel.seq_length = int(seq_length/2)
interact_seq_legnth = int(seq_length/2)  # 4 

if(lstmModel.train()):
    lstmModel.eval()
if(autoencoderModel.train()):
    autoencoderModel.eval()

with torch.no_grad():
    for pathToSampleSeq in glob.glob('../../WikifoniaServer/test/*.mid'):
        sampleNp1 = getSlicedPianorollMatrixNp(pathToSampleSeq)
        sampleNp1 = deleteZeroMatrices(sampleNp1)
        sample = np.expand_dims(sampleNp1[0,:,36:-32],axis=0)
        #print(sample.shape)
        for i, sampleNp in enumerate(sampleNp1[playSeq:playSeq+(interact_seq_legnth-1)]):
            #print(sampleNp.shape)
            if(np.any(sampleNp)):
                sampleNp = sampleNp[:,36:-32]
                sampleNp = np.expand_dims(sampleNp,axis=0)
                sample = np.concatenate((sample,sampleNp),axis=0)
        samplePlay = sample[0,:,:]
        for s in sample[1:]:
            samplePlay = np.concatenate((samplePlay,s),axis=0)
        samplePlay = addCuttedOctaves(samplePlay)
        #print(samplePlay.shape)
        
        #####PREPARE SAMPLE for input
        sample = torch.from_numpy(sample).float().to(device)
        sample = torch.unsqueeze(sample,1)
#         print(sample.size())

        #####MODEL##############
        embed, _ = autoencoderModel.encoder(sample)
#         print('embed.size() =', embed.size())
        embed = embed.unsqueeze(0).double()
        #print('embed.size() =', embed.size())

        embed, lstmOut = lstmModel(embed, future=0)
        lstmOut = lstmOut.float().squeeze(0)
#         print("lstmOut.size()", lstmOut.size())
#         print("embed before decode", embed.size())
        recon = autoencoderModel.decoder(embed.squeeze(0).float())
#         print("recon.size()", recon.size())
        pred = autoencoderModel.decoder(lstmOut)
#         print("pred.size()", pred.size())
        ########################
        
        # reorder prediction
        pred = pred.squeeze(1)
        predict = pred[0]
        for p in pred[1:]:
            predict = torch.cat((predict, p), dim=0)
#         print(predict.size()) 
        
        #reorder reconstruction
        recon = recon.squeeze(1)
        reconstruction = recon[0]
        for r in recon[1:]:
            reconstruction = torch.cat((reconstruction, r), dim=0)
#         print("reconstruction.size()", reconstruction.size())
        
        predict = predict.cpu().numpy()
        reconstruction = reconstruction.cpu().numpy()
#         print(predict.shape)
        #print(predict.shape)

        #NORMALIZE PREDICTIONS
        reconstruction /= np.abs(np.max(reconstruction))
        prediction = predict
        prediction /= np.abs(np.max(prediction))
        #print(prediction)

        #CHECK MIDI ACTIVATIONS IN PREDICTION TO INCLUDE RESTS
        reconstruction[reconstruction < 0.3] = 0
        prediction[prediction < 0.3] = 0
        #print(prediction)

        samplePlay = debinarizeMidi(samplePlay, prediction=False)
        samplePlay = addCuttedOctaves(samplePlay)
        reconstruction = debinarizeMidi(reconstruction, prediction=True)
        reconstruction = addCuttedOctaves(reconstruction)
        prediction = debinarizeMidi(prediction, prediction=True)
        prediction = addCuttedOctaves(prediction)
        print("INPUT")
        print(samplePlay.shape)
        pianorollMatrixToTempMidi(samplePlay, show=True,showPlayer=True,autoplay=True)
        print("RECONSTRUCTION")
        pianorollMatrixToTempMidi(reconstruction, show=True,
                                   showPlayer=True,autoplay=True, prediction=True)
        print("PREDICTION")
        pianorollMatrixToTempMidi(prediction, prediction=True, 
                                  show=True,showPlayer=True,autoplay=True)        
        print("\n\n")
        
        

print('')