In [None]:
import music21
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
file_to_predict = input("Input a single part musicXML file to predict the chord progression for: ")

In [None]:
model_fp = input("Input the file path of the model to use for chord prediction: ")

In [None]:
prediction_file_name = input("Input the name of the file to save the predicted multipart piece to: ")

In [None]:
bins_per_measure = 16

score = music21.converter.parse(file_to_predict)

step_length = 4 / bins_per_measure

melody = score.parts[0]

melody_measures = melody.getElementsByClass(music21.stream.Measure)

test_data = []
for i, mel_measure in enumerate(melody_measures):
    test_measure = []
    if mel_measure.timeSignature is not None:
        step_length = (mel_measure.timeSignature.numerator / mel_measure.timeSignature.denominator
                       * 4 / bins_per_measure)
    melody_elements = [item for item in mel_measure.notesAndRests]
    melody_index = 0

    for i in range(bins_per_measure):
        offset_timestep = i * step_length

        if (melody_elements[melody_index] is not melody_elements[-1] 
            and melody_elements[melody_index + 1].offset <= offset_timestep):
            melody_index += 1

        melody_item = melody_elements[melody_index]
        
        item_index = melody_item.pitch.pitchClass if melody_item.name != "rest" else 12

        test_data.append([item_index])
    
melody_tensor = torch.tensor(test_data, dtype=torch.long)

In [None]:
class LSTMGenerator(nn.Module):
    
    def __init__(self):
        super(LSTMGenerator, self).__init__()
        self.embedding = nn.Embedding(13, 100)
        self.lstm = nn.LSTM(input_size=100, hidden_size=256, num_layers=2, batch_first=True)
        self.fc1 = nn.Linear(256, 12)
        
    def forward(self, x, hidden_in):
        x = self.embedding(x)
        x, h_out = self.lstm(x, hidden_in)
        x = self.fc1(x)
        return x, h_out

model = LSTMGenerator()

In [None]:
model.load_state_dict(torch.load(model_fp))
model.eval()

In [None]:
hidden_in = (torch.randn(2, 144, 256), torch.randn(2, 144, 256))

chord_tensor = (torch.sigmoid(model(melody_tensor, hidden_in)[0]) > 0.45)

chord_tensor = chord_tensor.view(-1, 12)
print(chord_tensor)

In [None]:
new_part = music21.stream.Part()
new_part.insert(0, music21.instrument.Piano())

time_sig = music21.meter.TimeSignature("4/4")

for measure in torch.chunk(chord_tensor, int(chord_tensor.shape[0] / 16)):
    new_measure = music21.stream.Measure()
    new_measure.insert(0, time_sig)
    pitch_classes = [[i for i, val in enumerate(chord_vector) if val == 1] for chord_vector in measure]
    pitch_class_to_quarterlength = [[pitch_classes[0], (4 / bins_per_measure)]]
    for i, p_class in enumerate(pitch_classes[1:], 1):
        if p_class == pitch_classes[i - 1]:
            pitch_class_to_quarterlength[-1][1] += (4 / bins_per_measure)
        else:
            pitch_class_to_quarterlength.append([p_class, (4 / bins_per_measure)])
            
    for pitch_class_list, quarterlength in pitch_class_to_quarterlength:
        if pitch_class_list != []:
            new_measure.append(music21.chord.Chord(pitch_class_list, quarterLength=quarterlength))
        else:
            new_measure.append(music21.note.Rest(quarterLength=quarterlength))
            
    new_part.append(new_measure)
            
            
new_part.show("text")

In [None]:
score.insert(0, new_part)

gex = music21.musicxml.m21ToXml.GeneralObjectExporter(score)
    
out = gex.parse()
musicxml = out.decode('utf-8').strip()
    
with open(prediction_file_name, "w") as outfile:
    outfile.write(musicxml)