In [3]:
import pretty_midi
import torch
import os
import pickle
import numpy as np
from MelodyLSTM import MelodyLSTM
import settings as st

In [4]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("Device", device)
mlstm = MelodyLSTM(st.input_size, st.hidden_size, st.output_size, st.num_layers, device)
mlstm.load_state_dict(torch.load('../models/muse_w10_v1.pth')) 
mlstm.eval()
mlstm.to(device)

Device cuda


MelodyLSTM(
  (lstm): LSTM(111, 512, batch_first=True)
  (fc_1): Linear(in_features=512, out_features=256, bias=True)
  (fc): Linear(in_features=256, out_features=60, bias=True)
  (relu): ReLU()
)

In [5]:
with open('../data/chords/CHORD_DICT.pickle', 'rb') as f: 
    CHORD_DICT = pickle.load(f)

In [6]:
with open('../data/chords_reduced/CHORD_TO_EMB.pickle', 'rb') as f:
    CHORD_TO_EMB = pickle.load(f)

In [7]:
EMB_TO_CHORD = {v: k for k, v in CHORD_TO_EMB.items()}

In [30]:
def select_note(p):
    dist = torch.nn.functional.softmax(p, dim=1).detach().cpu().numpy()
    dist = np.reshape(dist, (-1))
    return np.random.choice(range(len(dist)), p=dist)

In [31]:
def generate_melody_compass(chord, prev=None, ts=8):
    notes = []
    if prev == None:
        prev = torch.zeros(111)
        prev[(st.ub-st.lb) + chord] = 1.
        prev = torch.reshape(prev, (1, 1, -1))
    
    for _ in range(ts):
        prev = prev.to(device)
        p = mlstm(prev)
        # print(torch.nn.functional.softmax(p, dim=1).detach().numpy())
        # print(np.sum(torch.nn.functional.softmax(p, dim=1).detach().cpu().numpy()))
        notes.append(select_note(p))
        # notes.append(np.argmax(torch.nn.functional.softmax(p, dim=1).detach().cpu().numpy()))
        
        last = torch.clone(prev[:,-1,:])
        last = torch.reshape(last, (1, 1, -1))
        last[:,:,notes[-1]] = 1
        prev = torch.cat((prev, last), dim=1)
        if prev.size(1) == 10:
            prev = prev[:, 1:, :]
        print(prev.size())

    return notes, prev

In [32]:
def generate_melody_for_chords(chords):
    notes = []
    prev = None
    for chord in chords: 
        notes_i, prev = generate_melody_compass(chord, prev)
        notes += notes_i
    return notes

In [33]:
notes = generate_melody_for_chords([0, 1, 2, 3])

torch.Size([1, 2, 111])
torch.Size([1, 3, 111])
torch.Size([1, 4, 111])
torch.Size([1, 5, 111])
torch.Size([1, 6, 111])
torch.Size([1, 7, 111])
torch.Size([1, 8, 111])
torch.Size([1, 9, 111])
torch.Size([1, 9, 111])
torch.Size([1, 9, 111])
torch.Size([1, 9, 111])
torch.Size([1, 9, 111])
torch.Size([1, 9, 111])
torch.Size([1, 9, 111])
torch.Size([1, 9, 111])
torch.Size([1, 9, 111])
torch.Size([1, 9, 111])
torch.Size([1, 9, 111])
torch.Size([1, 9, 111])
torch.Size([1, 9, 111])
torch.Size([1, 9, 111])
torch.Size([1, 9, 111])
torch.Size([1, 9, 111])
torch.Size([1, 9, 111])
torch.Size([1, 9, 111])
torch.Size([1, 9, 111])
torch.Size([1, 9, 111])
torch.Size([1, 9, 111])
torch.Size([1, 9, 111])
torch.Size([1, 9, 111])
torch.Size([1, 9, 111])
torch.Size([1, 9, 111])


In [34]:
notes

[48,
 24,
 28,
 8,
 17,
 14,
 9,
 57,
 41,
 31,
 12,
 26,
 16,
 26,
 59,
 36,
 48,
 32,
 7,
 35,
 12,
 9,
 37,
 31,
 6,
 12,
 9,
 25,
 56,
 6,
 40,
 45]

In [35]:
def write_melody_to_piano_roll(notes, piano_roll):
    for i, note in enumerate(notes):
        piano_roll[note + st.lb, i] = 100
    
    return piano_roll

In [36]:
def write_chords_to_piano_roll(chords, piano_roll, root=24, ts=8):
    for i, emb in enumerate(chords):
        cid = EMB_TO_CHORD[emb]
        chord = CHORD_DICT[cid]
        t1 = chord % 100
        chord //= 100
        t2 = chord % 100
        chord //= 100
        t3 = chord % 100 
        piano_roll[root + t1, i * 8: i * 8 + 8] = 100
        piano_roll[root + t2, i * 8: i * 8 + 8] = 100
        piano_roll[root + t3, i * 8: i * 8 + 8] = 100
    return piano_roll

In [37]:
piano_roll_chords = np.zeros((128, 32))
piano_roll_notes = np.zeros((128, 32))

In [38]:
piano_roll_notes = write_melody_to_piano_roll(notes, piano_roll_notes)
piano_roll_chords = write_chords_to_piano_roll([1,2,3,4], piano_roll_chords, root=36)

In [39]:
def write_piano_roll_to_midi(piano_roll, midi, pid):
    instrument = pretty_midi.Instrument(program=pid)
    
    # add notes to the instrument object
    for note_idx, time_slice in enumerate(piano_roll.T):
        note_numbers = np.nonzero(time_slice)[0]
        for note_number in note_numbers:
            note_start = note_idx / 4.0
            note_end = (note_idx + 1) / 4.0
            note_velocity = int(time_slice[note_number])
            note = pretty_midi.Note(
                velocity=note_velocity,
                pitch=note_number,
                start=note_start,
                end=note_end
            )
            instrument.notes.append(note)

    # add the instrument object to the MIDI object
    midi.instruments.append(instrument)
    
    return midi

In [40]:
def save_piano_roll_to_midi(piano_roll_notes, piano_roll_chords, filename):
    # create a PrettyMIDI object
    midi = pretty_midi.PrettyMIDI()
    midi = write_piano_roll_to_midi(piano_roll_chords, midi, 0)
    midi = write_piano_roll_to_midi(piano_roll_notes, midi, 1)
    
    # write the MIDI object to a file
    midi.write(filename)

In [41]:
save_piano_roll_to_midi(piano_roll_notes, piano_roll_chords, 'test_sep_rand.mid')