# Hidden Markov model 

In [23]:
from miditoolkit import MidiFile, Instrument, Note

import os
import json
from glob import glob
from collections import Counter, defaultdict

from hmmlearn import hmm
import numpy as np
import random

# MIDI file characteristics

In [24]:
mid_fps = glob('nesmdb_midi/train/*')
print("Number of train files", len(mid_fps))

Number of train files 4502


In [25]:
midi = MidiFile(mid_fps[10])
print(midi, '\n')
for i in range(len(midi.instruments)):
    print(midi.instruments[i])
    print(midi.instruments[i].notes[:5], '\n')

ticks per beat: 22050
max tick: 528334
tempo changes: 1
time sig: 2
key sig: 0
markers: 0
lyrics: False
instruments: 4 

Instrument(program=80, is_drum=False, name=p1) - 61 notes
[Note(velocity=5, pitch=50, start=21, end=2926), Note(velocity=5, pitch=52, start=4402, end=7330), Note(velocity=5, pitch=53, start=8804, end=11719), Note(velocity=5, pitch=52, start=13195, end=16123), Note(velocity=5, pitch=43, start=17603, end=20525)] 

Instrument(program=81, is_drum=False, name=p2) - 79 notes
[Note(velocity=15, pitch=57, start=1, end=460), Note(velocity=15, pitch=50, start=460, end=919), Note(velocity=15, pitch=43, start=919, end=1379), Note(velocity=15, pitch=36, start=1379, end=1838), Note(velocity=9, pitch=64, start=17585, end=17916)] 

Instrument(program=38, is_drum=False, name=tr) - 62 notes
[Note(velocity=1, pitch=62, start=8, end=2921), Note(velocity=1, pitch=64, start=4393, end=7325), Note(velocity=1, pitch=65, start=8795, end=11715), Note(velocity=1, pitch=64, start=13186, end=1611

In [26]:
# tpb = []
# ts = []
# for fp in mid_fps:
#     midi = MidiFile(fp)
#     tpb.append(midi.ticks_per_beat)

In [27]:
# np.unique(tpb)
# array([22050])

In [28]:
# velocities = []
# for fp in mid_fps:
#     midi = MidiFile(fp)
#     for instrument in midi.instruments:
#         velocities.append([note.velocity for note in instrument.notes])

In [29]:
# vflat = [v for vv in velocities for v in vv]

In [30]:
# np.unique(vflat)
# array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15])

# Hidden Markov Model

In [None]:
# uses MIDI synth 
# 80 synth lead (pulse square), 81 synth lead (pulse sawtooth), 38 synth bass (triangle), 121 Breath (noise)

NAME_TO_PROGRAM = {
    'p1': 80,
    'p2': 81,
    'tr': 38,
    'no': 121
}

INSTRUMENTS = ['p1', 'p2', 'tr', 'no']

In [None]:
# model variables
TICKS_PER_BEAT = 22050
# quantization of durations to downsize observation space
# 1 is whole note, 0.25 is quarter note, etc.
# chosen by personal discretion after 
DURATION_VALUES = [int(TICKS_PER_BEAT * d) for d in [0.25, 0.5, 0.75, 1.0, 2.0, 4.0, 8.0, 16.0, 32.0]]
# hidden state count
HMM_COMPONENTS = 16
# sequence length that is generated
GEN_LENGTH = 400

In [None]:
# quantize note durations as specified bins in DURATION_VALUES
def quantize_duration(duration):
    return min(DURATION_VALUES, key=lambda x: abs(duration - x))

In [None]:
# quantize note velocities to downsize observation space
# also reduces jarring transitions between volume
# the original velocities are within: array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15])
def quantize_velocity(velocity):
    if velocity == 0: return 0
    elif velocity < 4: return 2
    elif velocity < 8: return 4
    elif velocity < 12: return 6
    elif velocity < 16: return 8
    else: return 10

In [None]:
# note in format (pitch, duration, velocity)
# return a dictionary of sequences per instrument
# accounts for REST durations 
def extract_quantized_notes(midi_path):
    midi = MidiFile(midi_path)
    notes_by_instr = {instr: [] for instr in INSTRUMENTS}
    for instrument in midi.instruments:
        name = instrument.name
        notes = sorted(instrument.notes, key=lambda n: n.start)
        current_time = 0
        for note in notes:
            if note.start > current_time:
                # quantize rest and insert
                rest = quantize_duration(note.start - current_time)
                notes_by_instr[name].append((0, rest, 0))  # REST
            pitch = note.pitch
            duration = note.end - note.start
            duration = quantize_duration(duration)
            # velocity = note.velocity
            velocity = quantize_velocity(note.velocity)
            notes_by_instr[name].append((pitch, duration, velocity))
            current_time = note.end
    return notes_by_instr

In [None]:
# example output of quantized format
extract_quantized_notes(mid_fps[0])

{'p1': [(0, 5512, 0),
  (65, 16537, 8),
  (69, 5512, 6),
  (65, 5512, 8),
  (69, 5512, 8),
  (70, 16537, 8),
  (74, 5512, 6),
  (70, 5512, 8),
  (74, 5512, 8),
  (77, 22050, 8)],
 'p2': [],
 'tr': [(0, 5512, 0),
  (62, 16537, 2),
  (65, 5512, 2),
  (62, 5512, 2),
  (65, 5512, 2),
  (67, 16537, 2),
  (70, 5512, 2),
  (67, 5512, 2),
  (70, 5512, 2),
  (74, 22050, 2)],
 'no': []}

In [None]:
# shows that the correct instruments are identified 
# unused instruments are indicated by an empty sequence 
midi = MidiFile(mid_fps[0])
midi.instruments

[Instrument(program=80, is_drum=False, name=p1) - 9 notes,
 Instrument(program=38, is_drum=False, name=tr) - 9 notes]

In [None]:
# enumerate observations
def build_vocab(flat_notes):
    vocab = {n: i for i, n in enumerate(sorted(set(flat_notes)))}
    inv_vocab = {i: n for n, i in vocab.items()}
    return vocab, inv_vocab

# convert token to value
def encode(notes, vocab):
    return np.array([vocab[n] for n in notes if n in vocab]).reshape(-1, 1)

# convert value to token
def decode(indices, inv_vocab):
    return [inv_vocab[i[0]] for i in indices]

# run the HMM
def train_hmm(encoded_sequences, n_components=HMM_COMPONENTS):
    model = hmm.CategoricalHMM(n_components=n_components, n_iter=100, random_state=17)
    lengths = [len(seq) for seq in encoded_sequences]
    X = np.concatenate(encoded_sequences)
    model.fit(X, lengths)
    return model


In [39]:
# download the non-encoded multi-instrument sequences as a midi at the given output path
def save_midi(note_sequences, output_path):
    midi = MidiFile(ticks_per_beat=TICKS_PER_BEAT)
    for instr, notes in note_sequences.items():
        inst = Instrument(program=NAME_TO_PROGRAM[instr], is_drum=False, name=instr)
        time = 0
        for pitch, duration, velocity in notes:
            #print("output duration", duration)
            if pitch == 0:
                time += duration  # rest
            else:
                note = Note(velocity=velocity, pitch=pitch, start=time, end=time + duration)
                inst.notes.append(note)
                time += duration
        midi.instruments.append(inst)
    midi.dump(output_path)
    return midi

In [None]:
# load dataset
all_instr_notes = {instr: [] for instr in INSTRUMENTS}

mid_fps = glob('nesmdb_midi/train/*')
mid_fps = random.sample(mid_fps, len(mid_fps) // 2)

# extract all sequences per midi file
for file in mid_fps:
    notes = extract_quantized_notes(file)
    for instr in INSTRUMENTS:
        if len(notes[instr]) > 0:
            all_instr_notes[instr].append(notes[instr])

In [41]:
# save tokens as json
with open("train_instr.json", "w") as fp:
    json.dump(all_instr_notes, fp)

In [None]:
# Train HMMs

models = defaultdict()

for instr in INSTRUMENTS:
    print(f"Training HMM for {instr}")
    
    # flatten the list of sequences per instrument
    flat_notes = [note for song in all_instr_notes[instr] for note in song]
    vocab, inv_vocab = build_vocab(flat_notes)
    
    # encode the sequence using the vocabulary
    encoded_sequences = [encode(song, vocab) for song in all_instr_notes[instr] if len(song) > 0]
    #print(encoded_sequences)
    model = train_hmm(encoded_sequences)
    models[instr] = model

Training HMM for p1
Training HMM for p2
Training HMM for tr
Training HMM for no


# Sequence generation and log-likelihood

In [55]:
# generate notes from trained models

generated_tracks = defaultdict()

for instr in INSTRUMENTS:
    print(f"Generating notes for {instr}")
    flat_notes = [note for song in all_instr_notes[instr] for note in song]
    vocab, inv_vocab = build_vocab(flat_notes)
    model = models[instr]
    generated_encoded = model.sample(GEN_LENGTH)[0]
    log_prob = model.score(generated_encoded)
    print("Normalized log-likelihood:", log_prob / len(generated_encoded))
    generated_notes = decode(generated_encoded, inv_vocab)
    generated_tracks[instr] = generated_notes

Generating notes for p1
Normalized log-likelihood: -4.697758443595988
Generating notes for p2
Normalized log-likelihood: -4.227946109643358
Generating notes for tr
Normalized log-likelihood: -2.230846940285682
Generating notes for no
Normalized log-likelihood: -2.3744022937124085


In [56]:
output_path = 'outputs/hmm_14.mid'
out = save_midi(generated_tracks, output_path)
print(f"MIDI saved to: {output_path}")

MIDI saved to: outputs/hmm_14.mid


# KL Divergence (pitch)

In [125]:
from scipy.stats import entropy

In [146]:
def extract_pitch_histograms_per_instrument(midi_paths):

    pitch_bins = np.arange(129)
    pitch_counts = defaultdict(list)

    for path in midi_paths:
        midi = MidiFile(path)
        for instr in midi.instruments:
            name = instr.name
            pitches = [note.pitch for note in instr.notes]
            if pitches:
                hist, _ = np.histogram(pitches, bins=pitch_bins, density=False)
                pitch_counts[name].append(hist)

    # Combine and normalize per instrument
    histograms = {}
    for instr, hists in pitch_counts.items():
        total = np.sum(hists, axis=0) + 1e-8
        histograms[instr] = total / total.sum()

    return histograms

In [147]:
def compute_kl_per_instrument(gen_histograms, ref_histograms):
    kl_dict = {}
    for instr in ref_histograms:
        if instr in gen_histograms:
            kl = entropy(ref_histograms[instr], gen_histograms[instr])
            kl_dict[instr] = kl
        else:
            kl_dict[instr] = np.inf  # missing instrument
    return kl_dict

In [176]:
training_paths = glob("nesmdb_midi/train/*")
ref_hist = extract_pitch_histograms_per_instrument(training_paths)

In [177]:
hmm_paths = glob("outputs/hmm_*.mid")
hmm_hist = extract_pitch_histograms_per_instrument(hmm_paths)
hmm_kl = compute_kl_per_instrument(hmm_hist, ref_hist)

In [178]:
hmm_kl

{'p1': 0.039340996622072376,
 'tr': 0.17909712939797354,
 'p2': 0.06244394502672241,
 'no': 0.00539675125873548}

In [182]:
chain_paths = glob("outputs/markov_*.mid")
chain_hist = extract_pitch_histograms_per_instrument(chain_paths)
chain_kl = compute_kl_per_instrument(chain_hist, ref_hist)

In [183]:
chain_kl

{'p1': 0.08232277652595449,
 'tr': 0.10465714078364963,
 'p2': 0.09822218660567458,
 'no': 0.007510883348865858}