In [185]:
import glob 
import numpy as np
import pickle
import torch
import torch.nn as nn
import torch.autograd as autograd
import torch.nn.functional as F
import torch.optim as optim
from music21 import converter, instrument, note, chord, stream
import time

In [160]:
class LSTM(nn.Module):
    
    def __init__(self, input_size, embedding_dim, batch_size, hidden_dim, output_size):
        super(LSTM, self).__init__()
        
        self.batch_size = batch_size
        
        self.hidden_dim = hidden_dim
        
        self.embeddings = nn.Embedding(input_size, embedding_dim)
        
        self.lstm = nn.LSTM(embedding_dim, hidden_dim)
        
        self.linear = nn.Linear(hidden_dim, output_size)
        

    def init_hidden(self):

        return (torch.zeros(1, self.batch_size, self.hidden_dim),
                torch.zeros(1, self.batch_size, self.hidden_dim))
    

    def forward(self, inputs):
        
        hidden = self.init_hidden()
        
        embeds = self.embeddings(inputs)
        
        lstm_out, hidden = self.lstm(embeds.view(len(net_in_tensor), 1, -1), hidden)
        
        prediction = self.linear(lstm_out.view(len(net_in_tensor), -1))
        pre_scores = F.log_softmax(prediction, dim=1)
        return pre_scores 

In [161]:
with open('data/notes', 'rb') as filepath:
    notes = pickle.load(filepath)

pitchnames = sorted(set(item for item in notes))

note_to_int = dict((note, number) for number, note in enumerate(pitchnames))
int_to_note = dict((number, note) for number, note in enumerate(pitchnames))

vocab_size = len(note_to_int)

seq_len = 100 

input_size = vocab_size

embedding_dim = 30
batch_size = 1
hidden_dim = 36
learning_rate = 0.01

model = LSTM(input_size, embedding_dim, batch_size, hidden_dim, vocab_size)

model.load_state_dict(torch.load('./net_50.pth'))
model.eval()

LSTM(
  (embeddings): Embedding(72, 30)
  (lstm): LSTM(30, 36)
  (linear): Linear(in_features=36, out_features=72, bias=True)
)

In [173]:
def prepare_sequence(seq, trans):
    idxs = [trans[w] for w in seq]
    return idxs

In [189]:
def create_midi(int_seq):
    """ convert the output from the prediction to notes and create a midi file
        from the notes """
    offset = 0
    output_notes = []

    # create note and chord objects based on the values generated by the model
    for pattern in int_seq:
        # pattern is a chord
        if ('.' in pattern) or pattern.isdigit():
            notes_in_chord = pattern.split('.')
            notes = []
            for current_note in notes_in_chord:
                new_note = note.Note(int(current_note))
                new_note.storedInstrument = instrument.Piano()
                notes.append(new_note)
            new_chord = chord.Chord(notes)
            new_chord.offset = offset
            output_notes.append(new_chord)
        # pattern is a note
        else:
            new_note = note.Note(pattern)
            new_note.offset = offset
            new_note.storedInstrument = instrument.Piano()
            output_notes.append(new_note)

        # increase offset each iteration so that notes do not stack
        offset += 0.5

    midi_stream = stream.Stream(output_notes)

    midi_stream.write('midi', fp='test_output.mid')

In [190]:
#split = torch.randint(0, vocab_size, (1, 100))
split = np.random.randint(0, len(notes)-100) 
start = notes[split : split + 100]
net_in_tensor =  torch.tensor(prepare_sequence(start, note_to_int), dtype=torch.long)
prediction_out = torch.tensor([0])
# generate 500 notes
for note_index in range(500):

    prediction = model(net_in_tensor)

    feed_note = torch.tensor([torch.argmax(torch.narrow(prediction, 0, 99, 1))])

    net_in_tensor = torch.narrow(net_in_tensor, 0, 1, 99)

    net_in_tensor = torch.cat((net_in_tensor, feed_note), 0)

    prediction_out = torch.cat((prediction_out, feed_note), 0)

In [None]:
prediction_out

In [192]:
a = [0, 36, 55, 43, 44, 43, 44, 43, 44, 57, 44, 43, 44, 43, 44, 43, 23, 31,
        64, 64, 32, 44,  0, 36, 25, 55, 55, 64, 36, 55,  5, 42, 32, 30,  5, 57,
        32, 36, 11, 36, 33, 36, 33, 56, 53, 49, 30, 44, 43, 23, 67, 36, 45, 57,
        36, 55, 43, 23, 21, 30, 45, 4]

In [193]:
Generatednotes = prepare_sequence(a, int_to_note)

create_midi(Generatednotes)