In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle as pkl
from torch import optim
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
from midiutil import MIDIFile

In [2]:
class MelodyRNN(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(MelodyRNN, self).__init__()

        self.embedding = nn.Embedding(input_size, hidden_size, padding_idx=0)
        self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
        self.out = nn.Linear(hidden_size, input_size)
        self.dropout = nn.Dropout(0.3)
        self.hidden = Variable(torch.zeros(1, 1, hidden_size))                
        
    def forward(self, x):
        embedded = self.embedding(x)
        embedded = self.dropout(embedded)
        output, self.hidden = self.gru(embedded, self.hidden)
        output = self.out(self.hidden[-1])
        return output, self.hidden

In [3]:
model = torch.load('Full_RNN')
index2chord = pkl.load( open( "index2chord.pkl", "rb" ) )



In [116]:
def pick_rand_note(out):
    sorted_pred = out.argsort()
    pred = np.random.choice(sorted_pred[0][-10:].detach().numpy(), 1)[0]
    return pred

def generate(model, random=0.5):
    pred = 0
    model.eval()
    notes = []
    all_notes = torch.tensor([[0,0,1]])
    for i in range(96):
        out, _ = model(all_notes)
        if np.random.randint(2) == 1:
            pred = pick_rand_note(out)
        else:
            pred = int(out.argmax())
        
        # Ensure notes are chosen
        while pred == 0 or pred == 1 or pred == 2:
            pred = pick_rand_note(out)
            
        all_notes = torch.cat([all_notes, torch.tensor([[pred]])], 1)
    all_notes = torch.cat([all_notes[0], torch.tensor([2])])
    converted_notes = [index2chord[int(note)] for note in all_notes]
    return converted_notes

note_mapper = {'C': 60, 'Db':61, 'D':62,
               'Eb':63, 'E':64, 'F':65,
              'Gb':66, 'G':67, 'Ab':68,
              'A':69, 'Bb': 70, 'B':71, 'R':None}

In [117]:
def generate_melody(output_file):
    melody = generate(model)[:-1]
    notes = [note_mapper[each.split('_')[0]] for each in melody[3:]]
    lengths = [each.split('_')[1] for each in melody[3:]]    
    track    = 0
    channel  = 0
    time     = 0   # In beats
    duration = 0.5   # In beats
    tempo    = 120  # In BPM
    volume   = 100 # 0-127, as per the MIDI standard

    MyMIDI = MIDIFile(1) # One track, defaults to format 1 (tempo track
                         # automatically created)
    MyMIDI.addTempo(track,time, tempo)

    prev_pitch = None
    long_note = False
    some_notes = []

    for i, pitch in enumerate(notes):
        if pitch is None:
            if prev_pitch is not None and long_note:
                some_notes.append((time,prev_pitch, duration))
                MyMIDI.addNote(track, channel, prev_pitch, time, duration, volume)
                time += duration
                duration = 0.5
            prev_pitch = pitch
            time += 0.5

        elif long_note and prev_pitch != pitch and prev_pitch is not None:
            some_notes.append((time,prev_pitch, duration))  
            MyMIDI.addNote(track, channel, prev_pitch, time, duration, volume)

            time += duration
            duration = 0.5
            if lengths[i] == 'L':
                some_notes.append((time, pitch, duration))              
                MyMIDI.addNote(track, channel, pitch, time, duration, volume)
                long_note = False
                time += 0.5

        elif long_note and prev_pitch == pitch:
            duration += 0.5    

            if lengths[i] == 'S':
                long_note = True
            elif lengths[i] == 'L':
                some_notes.append((time, pitch, duration))            
                MyMIDI.addNote(track, channel, pitch, time, duration, volume)  
                long_note = False
                time += duration            
                duration = 0.5

        elif not long_note:
            if lengths[i] == 'S':
                long_note = True
                prev_pitch = pitch
            else:
                some_notes.append((time, pitch, duration))            
                MyMIDI.addNote(track, channel, pitch, time, duration, volume)
                time += 0.5

        prev_pitch = pitch

    with open(output_file, "wb") as output_file:
        MyMIDI.writeFile(output_file)

In [119]:
generate_melody('out.midi')