In [1]:
import music21
import torch
import torch.nn as nn
import torch.nn.functional as F

In [13]:
file_to_predict = input("Input a single part musicXML file to predict the chord progression for: ")

Input a single part musicXML file to predict the chord progression for: C:\Users\Danie\PycharmProjects\ChordGenerator\data\test_data\test_inputs\Melody2.musicxml


In [4]:
model_fp = input("Input the file path of the model to use for chord prediction: ")

Input the file path of the model to use for chord prediction: C:\Users\Danie\PycharmProjects\ChordGenerator\data\models\model2.pt


In [14]:
prediction_file_name = input("Input the name of the file to save the predicted multipart piece to: ")

Input the name of the file to save the predicted multipart piece to: C:\Users\Danie\PycharmProjects\ChordGenerator\data\test_data\test_outputs\Melody2WithChords_m2.musicxml


In [15]:
bins_per_measure = 16

score = music21.converter.parse(file_to_predict)

step_length = 4 / bins_per_measure

melody = score.parts[0]

melody_measures = melody.getElementsByClass(music21.stream.Measure)

test_measures = []

for i, mel_measure in enumerate(melody_measures):
    test_measure = []
    if mel_measure.timeSignature is not None:
        step_length = (mel_measure.timeSignature.numerator / mel_measure.timeSignature.denominator
                       * 4 / bins_per_measure)
    melody_elements = [item for item in mel_measure.notesAndRests]
    melody_index = 0

    for i in range(bins_per_measure):
        offset_timestep = i * step_length

        if (melody_elements[melody_index] is not melody_elements[-1] 
            and melody_elements[melody_index + 1].offset <= offset_timestep):
            melody_index += 1

        melody_item = melody_elements[melody_index]
        
        item_index = melody_item.pitch.pitchClass if melody_item.name != "rest" else 12

        test_measure.append([1 if j == item_index else 0 for j in range(13)])
        
    test_measures.append(test_measure)
    
test_data = []
for i in range(len(test_measures)):
                if i == 0:
                    previous_melody_measure = [[0 for i in range(13)] for k in range(16)]
                else:
                    previous_melody_measure = test_measures[i-1]
                    
                if i == len(melody_measures) - 1:
                    next_melody_measure = [[0 for i in range(13)] for k in range(16)]
                else:
                    next_melody_measure = test_measures[i+1]
                    
                melody_vector = previous_melody_measure + test_measures[i] + next_melody_measure
                
                melody_tensor = torch.tensor(melody_vector, dtype=torch.float)
                
                test_data.append(melody_tensor)


In [16]:
class ConditionedFeedforward(nn.Module):
    
    def __init__(self):
        super(ConditionedFeedforward, self).__init__()
        self.fc1 = nn.Linear(48 * 13, 1024)
        self.batch1 = nn.BatchNorm1d(1024)
        self.fc2 = nn.Linear(1024, 512)
        self.batch2 = nn.BatchNorm1d(512)
        self.fc3 = nn.Linear(512, 16 * 12)
        
    def forward(self, x):
        x = x.view(-1, 48 * 13) # Flatten input for fully connected layer
        x = self.fc1(x)
        x = self.batch1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = self.batch2(x)
        x = F.relu(x)
        x = self.fc3(x)
        return x
    
model = ConditionedFeedforward()

In [17]:
model.load_state_dict(torch.load(model_fp))
model.eval()

ConditionedFeedforward(
  (fc1): Linear(in_features=624, out_features=1024, bias=True)
  (batch1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc2): Linear(in_features=1024, out_features=512, bias=True)
  (batch2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc3): Linear(in_features=512, out_features=192, bias=True)
)

In [18]:
chord_measure_vectors = []

for datum in test_data:
    chord_measure_vectors.append(torch.sigmoid(model(datum)).view(16, 12) > 0.5)

In [19]:
new_part = music21.stream.Part()
new_part.insert(0, music21.instrument.Piano())

key = music21.key.Key("C")
time_sig = music21.meter.TimeSignature("4/4")

for chord_measure_vecs in chord_measure_vectors:
    prior_vector = []
    prior_chord = None
    new_measure = music21.stream.Measure()
    new_measure.insert(0, key)
    new_measure.insert(0, time_sig)
    for vector in chord_measure_vecs:
        pitch_class_numbers = [i for i, val in enumerate(vector) if val == 1]
        if pitch_class_numbers == prior_vector:
            prior_chord.quarterLength += (4 / bins_per_measure)
        else:
            new_chord = music21.chord.Chord(pitch_class_numbers, quarterLength=(4 / bins_per_measure))
            new_measure.append(new_chord)
            prior_chord = new_chord
        prior_vector = pitch_class_numbers
        
    new_part.append(new_measure)

In [20]:
score.insert(0, new_part)

gex = music21.musicxml.m21ToXml.GeneralObjectExporter(score)
    
out = gex.parse()
musicxml = out.decode('utf-8').strip()
    
with open(prediction_file_name, "w") as outfile:
    outfile.write(musicxml)