# Key Estimation

In tonal music, the task of key estimation refers to identifying the main tonality of a musical piece. This is a useful first step for automatic harmonic analysis, automatic pitch spelling, etc.

In symbolic music, we usually need to identify the key of a piece from MIDI-like note information (MIDI pitch, onset, duration).

In [None]:
# Let's import some stuff
import os
# Uncomment this line if the kernel keeps crashing
# See https://stackoverflow.com/a/53014308
os.environ['KMP_DUPLICATE_LIB_OK']='True'

import numpy as np
import partitura as pt
import warnings
warnings.filterwarnings("ignore")

import matplotlib.pyplot as plt
from scipy.stats import mode

from typing import Tuple, Union

from key_profiles import (
    build_key_profile_matrix, 
    key_prof_maj_kk,
    key_prof_min_kk,
    key_prof_maj_cbms,
    key_prof_min_cbms,
    key_prof_maj_kp,
    key_prof_min_kp,
    MAJOR_KEYS,
    MINOR_KEYS,
    KEYS,
    PITCH_CLASSES,
)

from partitura.utils.misc import PathLike

%config InlineBackend.figure_format ='retina'

Here are different pitch distributions for major and minor keys

In [None]:
pitch_profiles = [
    (key_prof_maj_kk, key_prof_min_kk, "Krumhansl"),
    (key_prof_maj_cbms, key_prof_min_cbms, "CBMS"),
    (key_prof_maj_kp, key_prof_min_kp, "Kostka-Payne")
]
fig, (ax1, ax2) = plt.subplots(1, 2, sharey=True, figsize=(10, 5))

for pmaj, pmin, profile in pitch_profiles:
    ax1.plot(pmaj / pmaj.sum(), label=profile)
    ax2.plot(pmin / pmaj.sum(), label=profile)
ax1.set_title("Major")
ax2.set_title("Minor")
ax1.set_xticks(ticks=np.arange(12))
ax1.set_xticklabels(MAJOR_KEYS)
ax2.set_xticks(ticks=np.arange(12))
ax2.set_xticklabels(MAJOR_KEYS)
ax1.set_ylabel("Pitch distribution")
plt.legend(frameon=False)
plt.tight_layout()
plt.show()

The key profile matrix is a $24 \times 12$ matrix that contains the key profiles for each key. The first 12 rows represent the key profiles for major keys and the last 12 rows represent the profiles for minor keys.

In [None]:
key_profile_matrix = build_key_profile_matrix("kk")

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))

ax1.set_title("Major")
ax1.imshow(key_profile_matrix[:12], aspect="equal", cmap="BuPu")
ax1.set_xticks(range(12))
ax1.set_xticklabels(PITCH_CLASSES)
ax1.set_yticks(range(12))
ax1.set_yticklabels(MAJOR_KEYS)
ax1.set_xlabel("Pitch class")
ax1.set_ylabel("Key")

ax2.set_title("Minor")
ax2.imshow(key_profile_matrix[12:], aspect="equal", cmap="BuPu")
ax2.set_xticks(range(12))
ax2.set_xticklabels(PITCH_CLASSES)
ax2.set_yticks(range(12))
ax2.set_yticklabels(MINOR_KEYS)
ax2.set_xlabel("Pitch class")

plt.show()

In [None]:
from hiddenmarkov import HMM, ConstantTransitionModel, ObservationModel

class KeyProfileObservationModel(ObservationModel):
    """
    Observation model that takes a pitch class distribution
    and returns pitch class profiles.
    
    Parameters
    ----------
    key_profile_matrix : np.ndarray or {'kk','cbms','kp'}
        The key profile matrix. If a string is given, it needs to be
        in {'kk','cbms','kp'}. Otherwise, a (24, 12) array is expected.
    """
    def __init__(self, key_profile_matrix: Union[str, np.ndarray] = "kp") -> None:
        super().__init__()
        if isinstance(key_profile_matrix, str):
            self.key_profile_matrix = build_key_profile_matrix(key_profile_matrix)
        elif isinstance(key_profile_matrix, np.ndarray):
            assert(key_profile_matrix.shape == (24, 12))
            self.key_profile_matrix = key_profile_matrix
            
    def __call__(self, observation: np.ndarray) -> np.ndarray:
        """
        Give the likelihood of the observed pitch class distribution given the keys.
        
        Parameters
        ----------
        observation : np.ndarray
            A 12-dimensional vector representing the pitch class distribution
            
        Returns
        -------
        likelihood: np.ndarray
            A 24-dimensional vector representing the likelihood of the observed
            pitch class distribution given the keys. If `self.use_log_probabilities`
            is True, this is the log-likelihood, otherwise it is the actual probabilities.
        """
        if not self.use_log_probabilities:
            p_obs_given_key = np.array(
                [
                    np.prod((kp**observation)*(1-kp)**(1-observation))
                    for kp in self.key_profile_matrix
                ]
            )
            likelihood = p_obs_given_key
        elif self.use_log_probabilities:
            
            log_p_obs_given_key = np.array(
                [
                    np.sum((observation*np.log(kp + 1e-10) +
                           np.log1p(- (kp + 1e-10)) * (1 - observation)
                           )
                    )
                    for kp in self.key_profile_matrix
                ]
            )
            likelihood = log_p_obs_given_key
            
        return likelihood
            
            
observation_model = KeyProfileObservationModel()

In [None]:
def compute_transition_probabilities(inertia_param: float = 0.8) -> np.ndarray:
    """
    Matrix of transition probabilities 
    
    Parameters
    ----------
    intertia_param : float
        The probability of staying in the same key. This number must be
        between 0 and 1
    
    Returns
    -------
    A : np.ndarray
        The modulation transition probabilities.
        The i,j-th element of this matrix correspond
        to the probability of modulating from key i to key j
        (the indices correspond to the keys in the matrix)
    """
    modulation_prob = (1 - inertia_param)/23.
    A = (modulation_prob * (np.ones(24) - np.eye(24))
         + inertia_param * np.eye(24))

    return A

# How likely are going to stay in the same key
intertia_param = 0.8
transition_probabilities = compute_transition_probabilities()
key_profile_matrix = build_key_profile_matrix("kp")


fig, ax = plt.subplots(figsize=(10,10))
ax.matshow(transition_probabilities, aspect="equal")
ax.set_xticks(range(24))
ax.set_xticklabels(KEYS)
ax.set_yticks(range(24))
ax.set_yticklabels(KEYS)


transition_model = ConstantTransitionModel(transition_probabilities)

In [None]:
hmm = HMM(observation_model=observation_model, transition_model=transition_model)

In [None]:
midi_fn = "example_data/03.mid"
ppart = pt.load_performance_midi(midi_fn)
note_array = ppart.note_array()

time_div = 16
piano_roll = pt.utils.compute_pianoroll(ppart, time_div=time_div).toarray()
# in seconds
win_size = 2

plt.imshow(piano_roll, origin="lower", aspect="auto", cmap="gray")

n_windows = int(np.ceil(piano_roll.shape[1] / (time_div * win_size)))

window_size = win_size * time_div

observations = np.zeros((n_windows, 12))
for win in range(n_windows):
    idx = slice(win * window_size, (win + 1) * window_size)
    segment = piano_roll[:, idx].sum(1)
    dist = np.zeros(12)
    pitch_idxs = np.where(segment != 0)[0]
    for pix in pitch_idxs:
        dist[pix % 12] += segment[pix]
    dist /= dist.sum()
    observations[win] = dist
    
path, log_lik = hmm.find_best_sequence(observations, log_probabilities=True)

key_idx = int(mode(path).mode[0])

key = KEYS[key_idx]

print(f"The key is {key}")

Let's put everything together!

In [None]:
def key_identification(
    filename: PathLike, 
    key_profiles: Union[str, np.ndarray] = "kp", 
    inertia_param: float = 0.8, 
    piano_roll_resolution: int = 16, 
    win_size: float = 2) -> Tuple[str, float]:
    """
    Probabilistic Key Identification
    
    Parameters
    ----------
    fn : filename
        MIDI file
    key_profiles: {"kp", "kk", "cbms"} or np.ndarray.
        Key profiles to use in the KeyProfileObservationModel 
        (see definition in `key_profiles.py`)
    intertia_param: float
        Parameter between 0 and 1 indicating how likely it is that we 
        will stay on the same key.
    piano_roll_resolution: int
        Resolution of the piano roll (how many cells per second).
    win_size: float
        Window size in seconds.
        
    Returns
    -------
    key : str
        The estimated key of the piece
    log_lik:
        The log-likelihood of the estimated key
    """
    # build observation model
    observation_model = KeyProfileObservationModel(key_profile_matrix=key_profiles)
    
    # Compute transition model
    transition_probabilities = compute_transition_probabilities(inertia_param=inertia_param)
    transition_model = ConstantTransitionModel(transition_probabilities)
    
    hmm = HMM(observation_model=observation_model,
              transition_model=transition_model)
    # Load score
    ppart = pt.load_performance_midi(filename)
    note_array = ppart.note_array()
    
    # Compute piano roll
    piano_roll = pt.utils.compute_pianoroll(ppart, time_div=piano_roll_resolution).toarray()
    
    # Number of windows in the piano roll
    n_windows = int(np.ceil(piano_roll.shape[1] / (time_div * win_size)))
    
    # window size in cells
    window_size = win_size * time_div
    
    # Constuct observations (these are non-overlapping windows, but you can test other possibilities)
    observations = np.zeros((n_windows, 12))
    for win in range(n_windows):
        idx = slice(win * window_size, (win + 1) * window_size)
        segment = piano_roll[:, idx].sum(1)
        dist = np.zeros(12)
        pitch_idxs = np.where(segment != 0)[0]
        for pix in pitch_idxs:
            dist[pix % 12] += segment[pix]
        # Normalize pitch class distribution
        dist /= dist.sum()
        observations[win] = dist
    
    # Compute the sequence
    path, log_lik = hmm.find_best_sequence(observations)

    key_idx = int(mode(path).mode[0])

    key = KEYS[key_idx]
    
    return key, log_lik