# Predicition_continual
This notebook details the pipeline for continual next-chord prediction.

In [1]:
# Useful starting lines
%matplotlib inline
%load_ext autoreload
%autoreload 2

## Load the data 
When loading the chord dataset, we can choose whether to keep sections in major or minor key, or both.

In [2]:
from load_data import load_train_test_sentences, all_composers

In [3]:
# Choose which composers to train on and which to test on
composers = all_composers
test_composers = ['Pleyel']

train_sentences, test_sentences, _ = load_train_test_sentences(composers, test_composers, key_mode='MAJOR')
print(len(test_sentences))

11


## Apply Word2Vec
Several hyperparameters to choose.

In [4]:
from gensim.models import Word2Vec
from load_data import get_chord_sentences

In [5]:
# Ignore words with a lower frequency frequency than this
min_count = 50
# Size of the embedding space
size = 5
# Neighborhood of the focus word to study
window = 2
# 0 for CBOW, 1 for skip-gram
sg = 1
# Number of iterations (epochs)
iter = 500

# The first argument has to be a list of lists of words
w2v_model = Word2Vec(train_sentences, min_count=min_count, size=size, window=window, sg=sg, iter=iter)

In [6]:
w2v_model.wv.vocab.keys()

dict_keys(['I:MAJ', 'V:MAJ', 'IV:MAJ', '#IV:DIM', 'II:MAJ', 'VI:MIN', 'bVII:MAJ', 'VII:DIM', 'III:MAJ', 'VI:MAJ', 'II:MIN', '#I:DIM', 'V:MIN', 'III:DIM', 'II:DIM', 'IV:MIN', '#V:DIM', 'VII:MAJ', 'III:MIN', 'I:MIN', 'VII:MIN', 'bIII:MAJ', 'bVI:MAJ', '#II:DIM', 'VI:DIM', 'I:AUG', 'V:AUG', 'bVII:MIN', 'bII:MAJ', '#IV:MAJ', 'V:DIM', '#I:MAJ'])

## Predict
Train the LSTM predictor on the same dataset as the Word2Vec model, then test it on the test dataset

In [7]:
from lstm_continual import LSTMPredictor
import torch
import torch.nn as nn
import torch.optim as optim

### Train the predictor

In [8]:
lstm_predictor = LSTMPredictor(w2v_model, hidden_dim=4)
optimiser = optim.Adam(lstm_predictor.parameters(), lr=0.001)

# Training (the method 'train' was already taken)
lstm_predictor.learn(train_sentences, optimiser, 2)
# Training takes a couple minutes

Starting epoch 0


  embed = torch.tensor(self.wv[chord])


Iteration 5000 : average loss = 2.5826123503684997
Iteration 10000 : average loss = 2.659360311770439
Iteration 15000 : average loss = 2.4755226123332976
Iteration 20000 : average loss = 2.308939964568615
Iteration 25000 : average loss = 2.393791701543331
Iteration 30000 : average loss = 2.108278911948204
Iteration 35000 : average loss = 1.7450329397380353
Iteration 40000 : average loss = 1.948295321702957
Iteration 45000 : average loss = 2.162517180353403
Iteration 50000 : average loss = 2.523907002145052
Iteration 55000 : average loss = 1.9201798030078412
Iteration 60000 : average loss = 1.6977726721227169
Iteration 65000 : average loss = 2.1149688821807504
Closing epoch 0 

Starting epoch 1
Iteration 5000 : average loss = 2.299036927694082
Iteration 10000 : average loss = 2.607781600686908
Iteration 15000 : average loss = 2.3695591732859613
Iteration 20000 : average loss = 2.0840927137702705
Iteration 25000 : average loss = 2.243418509307504
Iteration 30000 : average loss = 1.932399

### Test the predictor

In [9]:
accuracy_total, accuracy_by_chord, occurrences_by_chord = lstm_predictor.test(test_sentences)

print('Total accuracy:', accuracy_total)
print('Accuracy by chord\n', accuracy_by_chord)
print('Occurrences by chord\n', occurrences_by_chord)

Total accuracy: 0.6432160804020101
Accuracy by chord
 {'V:MAJ': 0.9685863874345549, 'I:MAJ': 0.8796296296296297, 'IV:MAJ': 0.034482758620689655, 'II:MIN': 0.0, 'VI:MIN': 0.2916666666666667, 'VII:DIM': 0.0, 'III:MIN': 0.0, 'I:MIN': 0.0, 'VI:MAJ': 0.0, '#IV:DIM': 0.0, 'II:MAJ': 0.0, 'IV:MIN': 0.0, 'III:MAJ': 0.0, '#II:DIM': 0.0, '#I:DIM': 0.0, 'I:AUG': 0.0, '#V:DIM': 0.0, 'VI:DIM': 0.0, 'III:DIM': 0.0, 'II:DIM': 0.0}
Occurrences by chord
 {'V:MAJ': 191, 'I:MAJ': 216, 'IV:MAJ': 58, 'II:MIN': 39, 'VI:MIN': 24, 'VII:DIM': 20, 'III:MIN': 2, 'I:MIN': 4, 'VI:MAJ': 7, '#IV:DIM': 6, 'II:MAJ': 5, 'IV:MIN': 1, 'III:MAJ': 4, '#II:DIM': 3, '#I:DIM': 8, 'I:AUG': 2, '#V:DIM': 4, 'VI:DIM': 1, 'III:DIM': 1, 'II:DIM': 1}
