In [1]:
from abc_utils import *
import pandas as pd
from hmmlearn import hmm
import numpy as np
from sklearn.metrics import accuracy_score

In [2]:
train_set, train_lengths, val_set, _ = load_datasets()

In [3]:
def load_song(index):
    start_pos = train_lengths.iloc[:index].sum().item()
    return train_set.iloc[start_pos : start_pos + train_lengths.iloc[index].item()]

In [12]:
def load_songs(indices):
    end_positions = np.cumsum(train_lengths)
    positions = np.insert(end_positions, 0, np.array([0]))
    songs = []
    for i in indices:
        song = train_set.iloc[positions[i] : positions[i+1]]
        songs.append(song)
    # return songs and lengths
    return pd.concat(songs), train_lengths.iloc[indices, :].values.flatten()

total_songs = 100
song_indices = list(range(total_songs))
songs, song_lengths = load_songs(song_indices)

In [13]:
# define global variables for number of chords and number of songs
NUM_CHORDS = 2
NUM_NOTES = 2
true_states, true_observations = dataframe_to_states(
    songs, 
    NUM_CHORDS, 
    NUM_NOTES
)

Processing states: 100%|██████████| 9052/9052 [00:00<00:00, 16562.60it/s]


In [14]:
transition_matrix, emission_probs, unique_states, unique_obs, states_to_index, observation_to_index = states_to_transition(true_states, true_observations)

In [15]:
observation_indices = np.array([int(observation_to_index[(o,)]) for o in true_observations]) 

In [16]:
model = hmm.CategoricalHMM(
    n_components=transition_matrix.shape[0],
    init_params=''
)

In [17]:
model.transmat_ = transition_matrix.T
model.emissionprob_ = emission_probs.T
start_probs = np.ones(transition_matrix.shape[0]) / transition_matrix.shape[0]
model.startprob_ = start_probs

print(model.transmat_.shape, model.emissionprob_.shape, model.startprob_.shape)

(3047, 3047) (3047, 39) (3047,)


Other tools:
- dynamax
  - runs on JAX
  - might have a more flexible structure
  - 

In [18]:
num_songs = 3
likelihood, pred_states = model.decode(observation_indices.reshape(-1, 1)[:np.sum(song_lengths[:num_songs])])
pred_states

# get the actual chords from unique_states
states = unique_states[pred_states, :]
states

def chord_accuracy(full_pred: np.array, true_states: np.array, num_chords: int=NUM_CHORDS, num_notes: int=NUM_NOTES):
    '''
    Given the predicted matrix of states, compute the misclassification rate compared with the true_observations.
    Could be edited in the future to also compute the accuracy of our predicted note sequence.
    '''
    # obtain the actual predicted chords 
    pred_chords = full_pred[:, num_chords-1]
    true_chords = true_states[:len(pred_chords), num_chords-1]
    # obtain the accuracy
    acc = accuracy_score(true_chords, pred_chords)

    return acc

print(chord_accuracy(states, true_states))

0.7676348547717843


In [21]:
songs, song_lengths = load_songs([total_songs + 1])
true_states, true_observations = dataframe_to_states(
    songs, 
    NUM_CHORDS, 
    NUM_NOTES
)
observation_indices = np.array([int(observation_to_index[(o,)]) for o in true_observations]) 
likelihood, pred_states = model.decode(observation_indices.reshape(-1, 1))
states = unique_states[pred_states, :]
print(np.hstack((true_states, states)))
chord_accuracy(states, true_states)

Processing states: 100%|██████████| 57/57 [00:00<00:00, 9367.42it/s]




[[ 0  0  0  0  3  3 70 72]
 [ 0  1  0 70  3  3 72 73]
 [ 1  1 70 75  3  1 73 75]
 [ 1  1 75 77  1  3 75 77]
 [ 1  1 77 75  3  8 77 75]
 [ 1  1 75 72  8  8 75 74]
 [ 1  1 72 70  8  8 74 76]
 [ 1  1 70 79  8  8 76 77]
 [ 1  1 79 79  8  8 77 79]
 [ 1  1 79 77  8  8 79 80]
 [ 1  6 77 79  8  8 80 82]
 [ 6  6 79 80  8  8 82 82]
 [ 6  6 80 80  8  1 82 80]
 [ 6  6 80 79  1  1 80 79]
 [ 6  1 79 80  1  1 79 82]
 [ 1  1 80 82  1  1 82 80]
 [ 1  1 82 79  1  1 80 79]
 [ 1  8 79 82  1  8 79 75]
 [ 8  8 82 80  8  8 75 80]
 [ 8  8 80 77  8  8 80 77]
 [ 8  8 77 74  8  8 77 74]
 [ 8  8 74 70  8  1 74 77]
 [ 8  8 70 79  1  1 77 75]
 [ 8  1 79 77  1  1 75 75]
 [ 1  1 77 75  1  1 75 74]
 [ 1  1 75 75  1  1 74 75]
 [ 1  1 75 75  1  1 75 77]
 [ 1  1 75 75  1  8 77 75]
 [ 1  1 75 74  8  8 75 74]
 [ 1  8 74 75  8  8 74 70]
 [ 8  8 75 77  8  8 70 74]
 [ 8  8 77 77  8  8 74 77]
 [ 8  8 77 77  8  8 77 70]
 [ 8  8 77 74  8  8 70 74]
 [ 8  8 74 70  8  8 74 77]
 [ 8  1 70 70  8  8 77 74]
 [ 1  1 70 70  8  8 74 76]
 

0.5964912280701754