# LSTM Harmonizer

##### Max Brynolf 2023
The code in this notebook trains an LSTM-network to find appropriate chords to melody sequences.

### Import Packages

In [None]:
import music21
import numpy as np
import random
import matplotlib.pyplot as plt
import math

from music21 import note, chord, converter, stream, midi
import xml.etree.ElementTree as ElementTree

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data

### File Handling

The following functions are related to translating MIDI-files into tensors that are used to train the model.

The following constants define the output interval for a chord. Notes that are above `pitch_max` get shifted down to the highest note within the interval, and notes that are below `pitch_min` get shifted up.

In [None]:
pitch_max = 72
pitch_min = 48

The following function takes an input chord and returns a new chord within the interval defined by `pitch_min` and `pitch_max`. Moreover, duplicate notes are removed and a custom filter can be applied by letting the function return `[]` when a chord should be excluded from the data.

In [None]:
def chord_filter(c):
    c = c.removeRedundantPitches()
    for n in c.notes:
        p = n.pitch.midi
        note_name = p % 12
        if p < pitch_min:
            n.pitch.midi = pitch_min + note_name
        elif p > pitch_max:
            n.pitch.midi = pitch_max - 12 + note_name
    c = c.removeRedundantPitches()
    if len(c.notes) > 8 or len(c.notes) == 1:
        return []
    return c

The function `part_to_chord_vectors` takes all the notes in the part `chord_part` and groups them into chords within windows of `chord_length` quarter notes, starting at offset `start_offset`. In other words, for a set chord length $\Delta$ and start offset $t_0$, the notes are categorized into the following intervals, each corresponding to a chord:

$$
[t_0, t_0 + \Delta], [t_0 + \Delta, t_0 + 2\Delta], \dots , [t_0 + (n - 1)\Delta, t_0 + n\Delta]
$$

The function returns two lists, representing the extracted chords. The first contains all the temporal offsets of the chords, and the second contains the chords themselves, translated into vectors. The vectors contain $P_{max} - P_{min} + 1$ components each (where $P_{max}$ is `pitch_max` and $P_{min}$ is `pitch_min`), of which each corresponds to a specific note within the interval. If a component is set to $1$, the note is a part of the chord, whereas a component set to $0$ means that the note is not a part of the chord.

In [None]:
def part_to_chord_vectors(chord_part, start_offset, chord_length):
    chord_offsets = []
    chords = []
    current_offset = start_offset
    while current_offset + chord_length <= chord_part.stream().highestTime:
        chord_notes = chord_part.getElementsByOffset(current_offset,
                                                     current_offset + chord_length,
                                                     includeEndBoundary = False)
        current_offset += chord_length
        if not chord_notes:
            continue # no notes found in the specified interval
        c = chord.Chord(list(chord_notes))
        c = chord_filter(c)
        if not c:
            continue # chord filter excludes the chord
        chord_offsets.append(chord_notes[0].offset)
        chord_vector = [0] * (pitch_max - pitch_min + 1)
        vector_indices = [n.pitch.midi - pitch_min for n in c.notes]
        for ind in vector_indices:
            chord_vector[ind] = 1
        chords.append(chord_vector)
    return chord_offsets, chords

The function `create_training_data` creates the training data. The LSTM-model should take a sequence of melody notes and chords (henceforth referred to as melody-chord-vectors) and predict the next chord, for the upcoming note. Hence, for melody notes $v_i$ and corresponding chords $w_i$, a set of melody-chord-vectors $u_i = (w_i, v_i)$ can be constructed. This is done with respect to the chords extracted from `part_to_chord_vectors`, meaning that all melody notes within a specific chord window gets associated with the corresponding chord. The task of the model $f$ is, for a given sequence of melody-chord-vectors, to predict the next chord $w_{i+1}$, hence:

$$
f(u_i, u_{i - 1}, u_{i - 2}, \dots , u_{i - k}) = w_{i + 1}
$$

Therefore, the training data consists of sequences of melody-chord-vectors and corresponding output chords. The function `create_training_data` translates a melody part `melody_part` into a list of melody-chord-sequences, each of length `sequence_length`, where each sequence leads up to the offsets in `chord_offsets`, i.e. each sequence leads up to the chords. Both the melody-chord sequences, corresponding to the $X$-data, and the chords, corresponding to the $y$-data, are returned. The function also makes sure to delete chords if no new melody notes appear after the previous chord, before returning the data.

In [None]:
def create_training_data(melody_part, chord_offsets, chords, chord_length, sequence_length):
    
    # Create melody stream and initialize variables
    melody_stream = melody_part.stream()
    sequences = [] # melody-chord-sequences
    removed_chord_indices = [] # a list pointing to chords that will be removed from the training data
    
    # Loop through the chords
    for i in range(len(chord_offsets)):
        
        # Initialize a new sequence of sequence_length melody-chord-vectors
        sequence = []
        previous_offset = chord_offsets[i] # keeps track of the previous offset in the loop below
        two_succ_chords = False # whether there has been two successive chords without new melody notes
        
        # Starting at the chord, go backwards until sequence_length melody-chord-vectors have been added
        for j in range(sequence_length):
            
            # Find the first melody note preceding the current offset
            current_offset = melody_stream.getElementBeforeOffset(previous_offset).offset
            
            # Make sure that there are not two chords in a row
            if j == 0 and i != 0 and current_offset < chord_offsets[i - 1]:
                two_succ_chords = True
                removed_chord_indices.append(i) # if there are, remember to remove the chord later
                break
            
            # Extract the melody note, and in the case of a chord, choose the top note
            melody_notes = list(melody_stream.getElementsByOffset(current_offset))
            if len(melody_notes) > 1 or type(melody_notes[0]) == chord.Chord:
                ch = chord.Chord(melody_notes)
                n = ch.sortFrequencyAscending()[-1]
                n.offset = current_offset
            else:
                n = melody_notes[0]
            melody_vector = [0] * 88
            melody_vector[n.pitch.midi - 21] = 1
            
            # Find the chord corresponding to the current melody note
            k = i
            while k >= 0:
                if chord_offsets[k] <= n.offset: # chord precedes melody note
                    if n.offset - chord_offsets[k] < chord_length: # melody note is played within chord_length
                        chord_vector = chords[k]
                        break
                    else:
                        chord_vector = [0] * (pitch_max - pitch_min + 1) # no associated chord
                        break
                elif k == 0:
                    chord_vector = [0] * (pitch_max - pitch_min + 1) # no associated chord
                k -= 1
            
            # Add the total, concatenated vector to the sequence
            sequence.append(chord_vector + melody_vector)
            previous_offset = current_offset
        
        # Go to the next chord if there are no new melody notes preceding the chord
        if two_succ_chords:
            continue
        
        # Reverse the melody-chord-sequence and add it to the total sequences list
        sequence = sequence[::-1]
        sequences.append(sequence)
    
    # Remove all chords where there are no new preceding melody notes
    for i in removed_chord_indices[::-1]:
        del chords[i]
    
    # Return the melody-chord-sequences and corresponding chords
    return sequences, chords

The function `midi_to_tensors` converts a MIDI-file at path `path` to tensors that can be used to train the LSTM-network. These tensors contain melody-chord sequences of size `sequence_length`, packed into `X`, and corresponding chords following the given melody-chords, packed into `y`.

The parameters `mel_chs` and `acc_chs` specify the channels corresponding to the melody- and accompaniment respectively. For example, if the MIDI-file has one melody channel $0$ and two accompaniment channels $1$ and $2$, then `mel_chs = [0]` and `acc_chs = [1, 2]`. The channels are then merged into a single stream, from which the tensors are extracted.

The parameter `chord_length` is passed into `part_to_chord_vectors` and hence decides how long each window is when grouping notes into chords. The parameter `chord_start_offset` sets the offset of the first chord that's on an even beat, so that the grouping of chords is done on correct beats. However, `chord_start_offset` doesn't have to be preceded by `sequence_length` notes, since this is calculated automatically — it is only there to specify where the even beats are. This saves some time when manually processing MIDI-files. Setting `chord_start_offset` to $-1$ makes it the offset of the first chord in the accompaniment.

In [None]:
def midi_to_tensors(path, mel_chs, acc_chs, chord_length, sequence_length, chord_start_offset):
    
    # Parse MIDI-file
    score = converter.parse(path)
    melody = score.parts[mel_chs[0]]
    for i in mel_chs[1:]:
        for n in score.parts[i].flatten().notes:
            melody.insert(n.offset, n)
    melody = melody.flatten().notes
    accompaniment = score.parts[acc_chs[0]]
    for i in acc_chs[1:]:
        for n in score.parts[i].flatten().notes:
            accompaniment.insert(n.offset, n)
    accompaniment = accompaniment.flatten().notes
    if chord_start_offset == -1:
        chord_start_offset = accompaniment[0].offset # let the first chord be the first beat
    
    # Extract chords
    o = -1
    counter = 0
    for lower_bound_offset in [float(n.offset) for n in list(melody)]:
        if lower_bound_offset != o:
            counter += 1
        if counter > sequence_length:
            break
        o = lower_bound_offset
    start_offset = chord_start_offset + math.ceil((lower_bound_offset - chord_start_offset)/chord_length) * chord_length
    chord_offsets, chords = part_to_chord_vectors(accompaniment, start_offset, chord_length)
    
    # Extract chord-melody-sequences
    chord_mel_sequences, chords = create_training_data(melody,
                                                       chord_offsets,
                                                       chords,
                                                       chord_length,
                                                       sequence_length)
    
    # Transform into tensors
    X = torch.FloatTensor(chord_mel_sequences)
    y = torch.FloatTensor(chords)
    return X, y

### Playback

The following functions are concerned with generating chords when a model is given.

The function `melody_from_midi` extracts melody notes from a MIDI-file and converts them into melody vectors. The channel used is specified as `melody_channel`. Aside from a list with melody vectors, a corresponding offset list is returned, specifying the temporal offsets and lengths of the melody notes. This information is important if the playback should respect the original rhythms.

In [None]:
def melody_from_midi(path, melody_channel):
    score = converter.parse(path)
    melody = score.parts[melody_channel].flatten().notes
    current_note = list(melody)[0]
    melody_vectors = []
    offset_list  = []
    while current_note is not None:
    
        # Extract the melody note, and in the case of a chord, choose the top note
        melody_notes = list(melody.getElementsByOffset(current_note.offset))
        if len(melody_notes) > 1 or type(melody_notes[0]) == chord.Chord:
            ch = chord.Chord(melody_notes)
            n = ch.sortFrequencyAscending()[-1]
            n.offset = current_note.offset
        else:
            n = melody_notes[0]
        melody_vector = [0] * 88
        melody_vector[n.pitch.midi - 21] = 1
        melody_vectors.append(melody_vector)
        offset_list.append([current_note.offset, current_note.quarterLength])
        
        # Proceed to next melody note
        next_note = melody.stream().getElementAfterElement(current_note)
        while next_note is not None and next_note.offset == current_note.offset:
            current_note = next_note
            next_note = melody.stream().getElementAfterElement(current_note)
        current_note = next_note
    
    return offset_list, melody_vectors

The function `generate_chords_from_melody` takes a melody and generates chords to it using the model `net`. It returns a stream that can be written to a MIDI-file. As inputs, it expects an LSTM-network `net`, a list of melody notes `melody_notes`, an offset list `offset_list` with temporal information about the melody notes, and a sequence length `sequence_length` that specifies the length of each sequence that is fed into the LSTM-network.

In [None]:
def generate_chords_from_melody(net, melody_notes, offset_list, sequence_length):
    accompaniment = stream.Part()
    accompaniment.insert(music21.instrument.Piano())
    melody = stream.Part()
    total_offset = 0
    def create_note(pitch, length, offset, volume = 127):
        n = note.Note()
        n.pitch.midi = pitch
        n.quarterLength = length
        n.offset = offset
        n.volume = volume
        return n
    
    # Add chords and melody notes
    latest_chords = []
    for i in range(sequence_length):
        total_offset, note_length = offset_list[i]
        current_note = create_note(melody_notes[i].index(1) + 21, note_length, total_offset)
        melody.insert(current_note)
        latest_chords.append([0] * (pitch_max - pitch_min + 1))
    for i in range(sequence_length, len(melody_notes)):
        total_offset, note_length = offset_list[i]
        current_note = create_note(melody_notes[i].index(1) + 21, note_length, total_offset)
        melody.insert(current_note)
        input_vector = [list(l) + list(m) for l, m in zip(latest_chords, melody_notes[i - sequence_length : i])]
        input = torch.FloatTensor(input_vector).unsqueeze(0)
        output = net(input).squeeze().detach().cpu().numpy()
        pitches = [i + pitch_min for i, v in enumerate(output) if v > 0.5]
        for pitch in pitches:
            current_note = create_note(pitch, note_length, total_offset, 80)
            accompaniment.insert(current_note)
        latest_chords.append(output)
        latest_chords.pop(0)
    
    output_stream = stream.Score()
    meta_data = music21.metadata.Metadata()
    meta_data.title = "Harmonized melody"
    output_stream.insert(0, meta_data)
    output_stream.insert(0, melody)
    output_stream.insert(0, accompaniment)
    return output_stream

### LSTM-Network

The class `RNN` defines the architecture of the LSTM-network. It takes a structural argument `layers` upon initialization, defining the network structure. The network is divided into two parts - one that handles the chords and one that handles the melody notes, and each of these branches can contain several LSTM cells.

The argument `layers` is a list containing two inner lists along with the output layer size. The first list corresponds to the layers in the melody branch and the second corresponds to the layers in the accompaniment branch. The output layer size is the same as the chord vector size, given by $P_{max} - P_{min} + 1$. For example, if:

```
layers = [[100, 10], [50, 20, 10], pitch_max - pitch_min + 1]
```

then the melody notes will be fed to two LSTM cells, the first with $100$ output nodes and the second with $10$ output nodes, and the accompaniment notes will be fed to three LSTM cells with $50$, $20$ and $10$ output nodes respectively. The branches are then fully connected to the output layer, in this case from $10 + 10 = 20$ nodes to $P_{max} - P_{min} + 1$ output nodes.

Using this structure, it is possible to balance the influence that the chords and melody has on the output separately. For instance, changing to:

```
layers = [[100, 10], [50, 20, 5], pitch_max - pitch_min + 1]
```

gives the accompaniment a smaller influence on the output.

Lastly, the methods `train` and `test` allow you to switch between training- and testing mode. In training mode, the raw logits are returned, so that the `BCEWithLogitsLoss` can be used as a loss function. In testing mode, the probabilities are returned instead, i.e. the logits are passed through a sigmoid layer before being returned.

In [None]:
class RNN(nn.Module):
    
    def __init__(self, layers):
        super(RNN, self).__init__()
        self.acc_size = layers[2]
        self.rnn_mel_layers = nn.ModuleList()
        self.rnn_acc_layers = nn.ModuleList()
        self.rnn_mel_layers.append(nn.LSTM(88, layers[0][0], 1, batch_first = True))
        for i, layer in enumerate(layers[0][0:-1]):
            self.rnn_mel_layers.append(nn.LSTM(layer, layers[0][i + 1], 1, batch_first = True))
        self.rnn_acc_layers.append(nn.LSTM(layers[2], layers[1][0], 1, batch_first = True))
        for i, layer in enumerate(layers[1][0:-1]):
            self.rnn_acc_layers.append(nn.LSTM(layer, layers[1][i + 1], 1, batch_first = True))
        self.fc = nn.Linear(layers[0][-1] + layers[1][-1], layers[2])
        self.output_activation = nn.Sigmoid()
        self.training_mode = False
    
    def forward(self, x):
        y_acc = x[:, :, : self.acc_size]
        y_mel = x[:, :, self.acc_size :]
        for layer in self.rnn_mel_layers:
            y_mel, (h, c) = layer(y_mel)
        for layer in self.rnn_acc_layers:
            y_acc, (h, c) = layer(y_acc)
        y = self.fc(torch.cat((y_acc[:, -1, :], y_mel[:, -1, :]), 1))
        if not self.training_mode:
            y = self.output_activation(y)
        return y
    
    def train(self):
        self.training_mode = True
    
    def test(self):
        self.training_mode = False

### File Processing

The function `data_plot` plots information about the data, in such a way that an overview of the dataset can be obtained.

In [None]:
def data_plot(chord_list, data_families):
    def disp_text(x):
        return f"{(x * y.shape[0] / 100):.0f}\n({x:.0f}%)"
    fig, (ax1, ax2) = plt.subplots(1, 2)
    fig.set_figwidth(12)
    ax1.pie([t.shape[0] for t in y_families], labels = data_families, autopct=disp_text)
    ax1.set_title("Data origin distribution")
    chord_sizes = [torch.sum(chord_tensor, dim=1).numpy() for chord_tensor in chord_list]
    min_notes = min([size.min() for size in chord_sizes])
    max_notes = max([size.max() for size in chord_sizes])
    ax2.hist(chord_sizes, np.arange(min_notes - 0.5, max_notes + 1.5, 1), stacked = True)
    ax2.set(xlabel = "Number of notes")
    ax2.set(ylabel = "Number of chords")
    ax2.legend(labels = data_families)
    ax2.set_title("Chord size distribution")

The following code processes a set of MIDI-files and converts them into tensors using the `midi_to_tensors`-function. The files are assumed to have the following paths:

> `data_family`/`file_identifier`_1.mid
> 
> `data_family`/`file_identifier`_2.mid
>
> ...
>
> `data_family`/`file_identifier`_`n_files`.mid

Note that doing this might take some time, depending on the files. Because of this, the `X` and `y` tensors are saved in "data/X_`file_identifier`.pt" and "data/y_`file_identifier`.pt" so that they can be accessed in the future without having to process the MIDI-files again.

In order to properly read the files, the XML file `training_info.xml` should contain information about the melody channels, accompaniment channels, chord length and chord start offset for each file. For example:

```xml
<data-families>
    <data-family name="file_identifier">
        <file number="1">
            <mel>0<mel>
            <mel>2</mel>
            ...
            <acc>1</acc>
            <acc>3</acc>
            ...
            <chord-length>1</chord-length>
            <start-offset>-1</start-offset>
        </file>
        <file number="2">
            ...
        </file>
        ...
    </data-family>
</data-families>
```

Note that specifying the number attribute for the file elements isn't necessary but recommended for readability.

In [None]:
# Specify which data to process
data_family = ""
file_identifier = ""
sequence_length = 8
excluded_files = [] # allows you to exclude files manually

# Extract training information about the data from training_info.xml
training_info_xml = ElementTree.parse("training_info.xml")
data_family_xml = training_info_xml.getroot().findall(f".//data-family[@name='{data_family}']")
assert len(data_family_xml) == 1, f"training_info.xml should contain exactly one data-family element with name {data_family}"
training_info = []
for file in data_family_xml[0]:
    mels = []
    accs = []
    chord_length = 1
    start_offset = -1
    for file_data in file:
        if file_data.tag == "mel":
            mels.append(int(file_data.text))
        if file_data.tag == "acc":
            accs.append(int(file_data.text))
        if file_data.tag == "chord-length":
            chord_length = float(file_data.text)
        if file_data.tag == "start-offset":
            start_offset = float(file_data.text)
    training_info.append({"mel": mels, "acc": accs, "chord-length": chord_length, "start-offset": start_offset})

# Process the files
n_files = len(training_info)
trainloaders = []
X_list = []
y_list = []
print("Processing files...")
for i in range(1, n_files + 1):
    if i in excluded_files:
        continue
    filename = f"MIDIs/{data_family}/{file_identifier}_{i}.mid"
    mel_ch = training_info[i - 1]["mel"]
    acc_ch = training_info[i - 1]["acc"]
    c_length = training_info[i - 1]["chord-length"]
    c_offset = training_info[i - 1]["start-offset"]
    X, y = midi_to_tensors(filename, mel_ch, acc_ch, c_length, sequence_length, c_offset)
    X_list.append(X)
    y_list.append(y)
    print(f"File {filename} has been processed.")
torch.save(X_list, f"data/X_{file_identifier}.pt")
torch.save(y_list, f"data/y_{file_identifier}.pt")

### Training

The code below decides which data to include and plots the distribution of chords.

In [None]:
data_families = [] # data families to include
exclude_sizes = [] # list of chord sizes to exclude from the data

X_families = []
y_families = []
for data_family in data_families:
    X_list = torch.load(f"data/X_{data_family}.pt")
    y_list = torch.load(f"data/y_{data_family}.pt")
    X_families.append(torch.cat(X_list))
    y_families.append(torch.cat(y_list))
for i in range(len(y_families)):
    num_of_notes = torch.sum(y_families[i], dim = 1)
    if len(exclude_sizes) > 0:
        condition = num_of_notes != exclude_sizes[0]
        for j in exclude_sizes[1:]:
            condition = torch.logical_and(condition, num_of_notes != j)
        y_families[i] = y_families[i][condition]
        X_families[i] = X_families[i][condition]
X = torch.cat(X_families)
y = torch.cat(y_families)

data_plot(y_families, data_families)

The following code trains the LSTM network with the data from `X` and `y`. A number of hyperparameters can be tuned: the number of epochs, the LSTM structure, the batch size and the learning rate. Moreover, the split ratio `split_ratio` defines the percentage of the train-test-split. The lists `train_losses` and `test_losses` are appended with train- and test loss data respectively, as the model is being trained.

Each output chord is represented by a vector with $P_{max} - P_{min} + 1$ components, where each component corresponds to a note. For example:

$$
y_i = [0, 0, 0, \dots, 0, 1, 0, 0, 1, 0, 1, \dots 0, 0, 0]
$$

where $y$ is a batch of several chords. Each melody note is also represented by a vector, with the difference that only one component can be $1$ for each melody vector. The melody vectors have $88$ components and cover the entire piano, i.e. from A0 to C8. The inputs to the model consist of melody-chord-vectors, with chords constituting the first $P_{max} - P_{min} + 1$ components, and melody notes constituting the last $88$ components.

$$
X_i = [\underbrace{1, 0, 1, \dots , 0, 1, 0}_\text{chord}, \underbrace{0, 0, 0, 0, 0, \dots , 0, 1, 0, \dots , 0, 0, 0, 0}_\text{melody}]
$$

The training data consists of sequences of melody-chord vectors $X$ along with corresponding chords $y$. The optimization used is RMSprop with binary cross entropy loss. For a batch of $n$ chords, each of size $m = 88 + P_{max} - P_{min} + 1$, the loss is defined by the mean of the Binary Cross Entropy loss for each note, hence:

$$
J(y, \hat{y}) = -\frac{1}{mn}\sum_{i = 1}^{n} \sum_{j = 1}^m \left( y_{i,j} \log \sigma (\hat{y}_{i,j}) + \left( 1 - y_{i,j} \right) \log \left( 1 - \sigma (\hat{y}_{i,j}) \right) \right)
$$

where $\hat{y} = f(X)$.

In [None]:
# Set hyperparameters
epochs = 500
network_layers = [[100], [30], pitch_max - pitch_min + 1]
batch_size = 128
lr = 1e-3
split_ratio = 0.9
update_freq = 10 # print loss information every update_freq epoch

# Optimizer and loss function
device = torch.device("mps")
net = RNN(network_layers).to(device)
net.train()
best_net = RNN(network_layers).to(device)
bce_loss = nn.BCEWithLogitsLoss()
optimizer = optim.RMSprop(net.parameters(), lr = lr)

# Prepare dataset
dataset = data.TensorDataset(X, y)
training_dataset, test_dataset = data.random_split(dataset, [split_ratio, 1 - split_ratio])
trainloader = data.DataLoader(training_dataset, shuffle = True, batch_size = batch_size)
testloader = data.DataLoader(test_dataset, shuffle = True, batch_size = batch_size)

train_losses = []
test_losses = []
for epoch in range(epochs):
    train_loss = 0
    test_loss = 0
    for i, data_list in enumerate(trainloader):
        X_batch = data_list[0]
        y_batch = data_list[1]
        optimizer.zero_grad()
        output = net(X_batch.to(device))
        target = y_batch.to(device)
        loss = bce_loss(output, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    train_loss /= i + 1
    for i, data_list in enumerate(testloader):
        X_batch = data_list[0]
        y_batch = data_list[1]
        output = net(X_batch.to(device))
        target = y_batch.to(device)
        loss = bce_loss(output, target)
        test_loss += loss.item()
    test_loss /= i + 1
    if epoch == 0:
        best_test_loss = test_loss
    elif test_loss < best_test_loss:
        best_test_loss = test_loss
        best_net_epoch = epoch
        best_net.load_state_dict(net.state_dict())
    train_losses.append(train_loss)
    test_losses.append(test_loss)
    if epoch % update_freq == 0:
        m = train_losses[-1 * update_freq :]
        n = test_losses[-1 * update_freq :]
        print(f"Epoch {epoch}, average train loss: {sum(m)/len(m):.4f}, average test loss: {sum(n)/len(n):.4f}")
print(f"Training finished! Best performance found at epoch {best_net_epoch} with loss {best_test_loss:.4f}.")

### Model Evaluation

The cell below plots the train- and test losses of the trained model.

In [None]:
plt.plot(train_losses, label = "Training loss")
plt.plot(test_losses, label = "Test loss")
plt.plot(test_losses, label = "Best net", markevery=[best_net_epoch], ls="", marker="o")
plt.legend()
plt.show()

The cell below saves the model. Note that this saved model can be immediately accessed by the Flask application and hence used when testing the model interactively.

In [None]:
torch.save({"structure": network_layers,
            "model": best_net.cpu().state_dict(),
            "range": [pitch_min, pitch_max]},
            'model.pt')

The cell below loads an existing model. Note that this can be used to proceed training after a checkpoint.

In [None]:
loaded_model = torch.load("model.pt")
net = RNN(loaded_model["structure"])
net.load_state_dict(loaded_model["model"])

The cell below extracts the melody from a specified MIDI-file and applies the model to it. The generated chords along with the corresponding melody notes are written to a new file.

In [None]:
file_path = ""
net = net.to(torch.device("cpu"))
print("Extracting melody...")
o, m = melody_from_midi(file_path, 0)
print("Generating chords...")
f = generate_chords_from_melody(net, m, o, 8)
print("Writing to file...")
f.write("midi", f"harmonized_melody.mid")
print(f"Finished! File saved as \"harmonized_melody.mid\".")