**HMM Logic Sequence**

- Input: MIDI file; audio file
- Separate audio into tonal & transient responses
- Move through the score sequentially
- At each point in the score, take the block of samples
- Transition matrix: geometric distribution model of state durations
- Emission probability: tonal likelihood using GP LML
- Viterbi algorithm 'windowed' to compute most probable state sequence

In [None]:
pip install pretty_midi pydub

In [None]:
from google.colab import drive
drive.mount('/content/drive')
from pretty_midi import PrettyMIDI
import matplotlib.pyplot as plt
import pretty_midi


def load_and_filter_midi_pretty(file_path,
                                note_range=(60, 76),
                                remove_note=None,
                                max_time=None):
    """
    Loads a MIDI file using pretty_midi, filters it to only include notes within the specified range,
    optionally removes a specific note based on pitch and timing, and keeps only the first max_time seconds.

    Args:
        file_path (str): Path to the MIDI file.
        note_range (tuple): A tuple specifying the (min_note, max_note) range.
        remove_note (tuple, optional): A tuple specifying the (pitch, start_time, time_tolerance)
                                       of the note to be removed.
        max_time (float, optional): Only notes starting before this time are kept; notes ending after
                                    are trimmed to max_time. If None, no time filtering is applied.

    Returns:
        PrettyMIDI: A new PrettyMIDI object with the filtered and time-trimmed notes.
    """
    try:
        midi_data = pretty_midi.PrettyMIDI(file_path)

        for instrument in midi_data.instruments:
            filtered_notes = [
                note for note in instrument.notes
                if note_range[0] <= note.pitch <= note_range[1]
            ]

            if remove_note:
                pitch_to_remove, start_time, time_tolerance = remove_note
                filtered_notes = [
                    note for note in filtered_notes
                    if not (note.pitch == pitch_to_remove and
                            abs(note.start - start_time) <= time_tolerance)
                ]

            if max_time is not None:
                kept_notes = []
                for note in filtered_notes:
                    if note.start < max_time:
                        if note.end > max_time:
                            note.end = max_time
                        kept_notes.append(note)
                filtered_notes = kept_notes

            instrument.notes = filtered_notes

        return midi_data

    except Exception as e:
        print(f"Error loading MIDI file: {e}")
        return None



def display_midi_and_notes_pretty(midi_data):
    """
    Displays the notes of a PrettyMIDI object.

    Args:
        midi_data (PrettyMIDI): The PrettyMIDI object to analyze.

    Returns:
        list: A list of notes as tuples (start, end, pitch).
    """
    try:
        notes = []
        for instrument in midi_data.instruments:
            for note in instrument.notes:
                notes.append((note.start, note.end, note.pitch))

        notes = sorted(notes, key=lambda x: x[0])

        def midi_to_note_name(midi_number):
            """Converts a MIDI note number to a piano note name."""
            note_names = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
            octave = (midi_number // 12) - 1
            note = note_names[midi_number % 12]
            return f"{note}{octave}"

        note_names = [(start, end, midi_to_note_name(pitch)) for start, end, pitch in notes]
        print("\nParsed Notes (start, end, pitch):")
        for note in note_names:
            print(note)

        return notes

    except Exception as e:
        print(f"Error displaying MIDI notes: {e}")
        return []


def plot_midi_piano_roll_pretty(midi_data):
    """
    Plots a piano roll visualization of a PrettyMIDI object.

    Args:
        midi_data (PrettyMIDI): The PrettyMIDI object to visualize.
    """
    try:
        notes = []
        times = []

        for instrument in midi_data.instruments:
            for note in instrument.notes:
                notes.append(note.pitch)
                times.append(note.start)

        plt.figure(figsize=(10, 6))
        plt.scatter(times, notes, marker='o', color='blue', alpha=0.7)
        plt.xlabel("Time (seconds)")
        plt.ylabel("MIDI Note Number")
        plt.title("MIDI Piano Roll Visualization (PrettyMIDI)")
        plt.yticks(range(min(notes), max(notes) + 1, 2))
        plt.grid(True, linestyle='--', alpha=0.6)
        plt.show()

    except Exception as e:
        print(f"Error plotting piano roll: {e}")


def list_instruments(midi):
    print("Instruments in this MIDI file:")
    for idx, inst in enumerate(midi.instruments):
        name = (
            "Drums"
            if inst.is_drum
            else pretty_midi.program_to_instrument_name(inst.program)
        )
        print(f"{idx+1}. {name}  (program={inst.program}, notes={len(inst.notes)})")



def remove_instruments_by_index(midi_data, indices_to_remove):
    """
    Removes instruments from a PrettyMIDI object based on their index.

    Args:
        midi_data (PrettyMIDI): The MIDI object to modify.
        indices_to_remove (set or list of int): Zero-based indices of instruments to drop.

    Returns:
        PrettyMIDI: The same midi_data, with the specified instruments removed.
    """
    to_remove = set(indices_to_remove)
    midi_data.instruments = [
        inst for idx, inst in enumerate(midi_data.instruments)
        if idx not in to_remove
    ]
    return midi_data

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
note_to_remove = (71, 8.7, 0.1)  # Example: Remove note 70 around 8s with 0.2s tolerance
midi = load_and_filter_midi_pretty('/content/drive/MyDrive/furelise.mid', remove_note=note_to_remove)
plot_midi_piano_roll_pretty(midi)

In [None]:
import numpy as np
from scipy.stats import geom
import librosa
import pretty_midi
from scipy.ndimage import median_filter
from pydub import AudioSegment
import librosa
import matplotlib.pyplot as plt
from IPython.display import Audio
from scipy.io import wavfile
from scipy.fft import fft, ifft
from scipy.signal.windows import hamming, hann
from scipy.linalg import cho_factor, cho_solve




audio = AudioSegment.from_file('/content/drive/MyDrive/furelise.mp3')
sample_rate = audio.frame_rate
audio_data = np.array(audio.get_array_of_samples())


time_window_size = 2048
hop_size = int(time_window_size * 0.5)
hmm_hop_size = 2048
window_size = 19
alpha = 0.4
eta = 0.1
sigma_p2 = 1e1 # Transient noise

sigma_f2 = 2e-5  # Decay parameter for covariance
sigma_n2 = 3  # Noise variance
wq = [1.0]  # Weights for fundamentals
M = 10  # Number of harmonics

sample_rate = 44100

T = 0.5
v = 2.37




def parse_midi(midi_data):    #✅
    notes = []
    for instrument in midi_data.instruments:
        if not instrument.is_drum:
            for note in instrument.notes:
                notes.append((note.start, note.end, note.pitch))
    notes = sorted(notes, key=lambda x: x[0])
    print(notes)
    return notes


def compute_transition_matrices(E_Z_list, block_size=2048):
    """
    Compute a single left-to-right transition matrix for the entire score.
    Each state corresponds to a score event (note/chord) from the MIDI file.

    We first compute the average note duration (in samples) from E_Z_list
    and then convert it to the average number of blocks (of size 'block_size').
    For an average of d_avg blocks per note, a simple duration model is:

        self-transition probability:   p = 1 - 1/d_avg
        transition to next state:       1 - p

    The resulting matrix is of size (n_states x n_states) where n_states is
    the number of MIDI events.
    """
    n_states = len(E_Z_list)
    avg_duration_samples = np.mean(E_Z_list)
    d_avg = avg_duration_samples / block_size

    p = 1 - 1/d_avg if d_avg > 1 else 0.0

    T = np.zeros((n_states, n_states))
    for i in range(n_states):
        T[i, i] = p
        if i < n_states - 1:
            T[i, i+1] = 1 - p

    T[n_states-1, n_states-1] = 1
    return T



def extract_features():   #✅
    audio = AudioSegment.from_file('/content/drive/MyDrive/furelise.mp3')
    audio = audio[:int(44100*14.3)]
    sample_rate = audio.frame_rate
    samples = np.array(audio.get_array_of_samples())

    if audio.channels > 1:   # If the audio is stereo, reshape and convert to mono
        samples = samples.reshape((-1, audio.channels))
        audio_data = samples.mean(axis=1)
    else:
        audio_data = samples


    time_hamming_window = hann(time_window_size)
    randomized_phase = np.random.uniform(-np.pi, np.pi, time_window_size)
    processed_audio = np.zeros(len(audio_data))
    transient_probabilities = []

    time_hamming_window = hann(time_window_size)
    randomized_phase = np.random.uniform(-np.pi, np.pi, time_window_size)

    processed_audio = np.zeros(len(audio_data))

    for start in range(0, len(audio_data) - time_window_size, hop_size):
        segment = audio_data[start:start + time_window_size].astype(np.float64)
        segment *= time_hamming_window

        segment_ft = fft(segment)
        magnitude = np.abs(segment_ft)
        phase = np.angle(segment_ft)

        median_magnitude = alpha * median_filter(magnitude, size=(window_size,))
        filtered_magnitude = np.minimum(median_magnitude, magnitude)

        filtered_segment_ft = filtered_magnitude * np.exp(1j * phase)
        filtered_segment = ifft(filtered_segment_ft).real

        processed_audio[start:start + time_window_size] += filtered_segment * time_hamming_window

    processed_audio = np.int16(processed_audio / np.max(np.abs(processed_audio)) * 32767)

    tonal_audio = audio_data - processed_audio
    tonal_audio = np.int16(tonal_audio / np.max(np.abs(tonal_audio)) * 32767)

    return tonal_audio, processed_audio


def harmonic_weight(m, T, v):   #✅
    return 1 / (1 + T * m**v)


def covariance_function(tau, wq, midi_frequencies, T, v):
    cov = np.exp(-2 * np.pi**2 * sigma_f2 * tau**2)
    for fq in midi_frequencies:
        harmonic_sum = 0
        for m in range(1, M + 1):
            E_m = harmonic_weight(m, T, v)
            harmonic_sum += E_m * np.cos(2 * np.pi * m * fq * tau)
        cov += wq[0] * harmonic_sum
    return cov


def compute_covariance_matrix(midi_frequencies, block_size, sample_rate, T, v):
    t = np.linspace(0, block_size / sample_rate, block_size)
    tau = t[:, None] - t[None, :]

    cov = np.exp(-2 * np.pi**2 * sigma_f2 * tau**2)

    harmonic_weights = np.array([harmonic_weight(m, T, v) for m in range(1, M + 1)])
    harmonic_weights = harmonic_weights[:, None, None]

    for fq in midi_frequencies:
        harmonic_matrix = harmonic_weights * np.cos(2 * np.pi * np.arange(1, M + 1)[:, None, None] * fq * tau)
        cov += np.sum(harmonic_matrix, axis=0)

    return cov


def tonal_LML(y, K, sigma_n2):
    N = len(y)
    K_noise = K + sigma_n2 * np.eye(N)
    cho_decomp = cho_factor(K_noise, lower=True)
    alpha = cho_solve(cho_decomp, y)
    log_det = 2 * np.sum(np.log(np.diag(cho_decomp[0])))
    LML = (
        -0.5 * np.dot(y.T, alpha)
        - 0.5 * log_det
        - 0.5 * N * np.log(2 * np.pi)
    )
    return LML






def is_onset_state(state_index, midi_notes, current_time, onset_window=hmm_hop_size/sample_rate):
    """
    Determine if the candidate state is in its transient (onset) phase.
    Only returns True if the note’s onset has occurred (current_time >= onset_time)
    and the current time is within onset_window seconds after the onset.
    """
    onset_time = midi_notes[state_index][0]
    if current_time < onset_time:
        return False
    return (current_time - onset_time) < onset_window







def estimate_transient_mog_parameters(data, K=2, tol=1e-6, max_iter=10, verbose=True, plot_convergence=True, n_init=10):
    """
    Estimate the mixing weights and variances for a K-component Gaussian Mixture Model
    on 1D transient data (assumed zero-mean) using the EM algorithm.
    This function first tries n_init random initializations, computes the initial log-likelihood
    for each, selects the initialization with the highest log-likelihood, and then performs EM.
    Finalize parameters if relative improvement is less than 0.01%.
    Additionally prints iteration details and plots the log-likelihood progression.
    Returns: (pi, sigma, final_log_likelihood)
    """
    N = len(data)
    best_init_log_likelihood = -np.inf
    best_pi_init = None
    best_sigma_init = None

    for init in range(n_init):
        pi_candidate = np.random.rand(K)
        pi_candidate = pi_candidate / np.sum(pi_candidate)
        overall_var = np.var(data)
        sigma_candidate = overall_var * (0.5 + np.random.rand(K))

        total_ll_candidate = 0.0
        for i in range(N):
            ll_components = []
            for k in range(K):
                ll_components.append(np.log(pi_candidate[k]) - 0.5 * np.log(2 * np.pi * sigma_candidate[k])
                                     - (data[i] ** 2) / (2 * sigma_candidate[k]))
            max_val = np.max(ll_components)
            total_ll_candidate += max_val + np.log(np.sum(np.exp(np.array(ll_components) - max_val)))

        if total_ll_candidate > best_init_log_likelihood:
            best_init_log_likelihood = total_ll_candidate
            best_pi_init = pi_candidate
            best_sigma_init = sigma_candidate

    pi = best_pi_init.copy()
    sigma = best_sigma_init.copy()

    log_likelihood_prev = -np.inf
    log_likelihoods = []

    for iteration in range(max_iter):
        # E-step: compute responsibilities
        log_r = np.zeros((N, K))
        for k in range(K):
            log_r[:, k] = np.log(pi[k]) - 0.5 * np.log(2 * np.pi * sigma[k]) - (data ** 2) / (2 * sigma[k])
        max_log = np.max(log_r, axis=1, keepdims=True)
        r_exp = np.exp(log_r - max_log)
        r = r_exp / np.sum(r_exp, axis=1, keepdims=True)

        # M-step: update parameters
        pi_new = np.sum(r, axis=0) / N
        sigma_new = np.zeros(K)
        for k in range(K):
            sigma_new[k] = np.sum(r[:, k] * (data ** 2)) / np.sum(r[:, k])

        # Compute log-likelihood for this iteration
        total_log_likelihood = 0.0
        for i in range(N):
            ll_components = []
            for k in range(K):
                ll_components.append(np.log(pi_new[k]) - 0.5 * np.log(2 * np.pi * sigma_new[k])
                                     - (data[i] ** 2) / (2 * sigma_new[k]))
            max_val = np.max(ll_components)
            total_log_likelihood += max_val + np.log(np.sum(np.exp(np.array(ll_components) - max_val)))
        log_likelihoods.append(total_log_likelihood)

        if verbose:
            print(f"Iteration {iteration+1}: Log-Likelihood = {total_log_likelihood:.6f}")

        # Check relative improvement: if improvement is less than 0.01%
        if iteration > 0:
            rel_improvement = np.abs((total_log_likelihood - log_likelihood_prev) / total_log_likelihood)
            if rel_improvement < 0.0001:
                if verbose:
                    print("Convergence reached (relative improvement < 0.01%).")
                break

        log_likelihood_prev = total_log_likelihood
        pi = pi_new
        sigma = sigma_new

    if plot_convergence:
        plt.figure(figsize=(8, 4))
        plt.plot(range(1, len(log_likelihoods) + 1), log_likelihoods, marker='o')
        plt.xlabel('Iteration')
        plt.ylabel('Log-Likelihood')
        plt.title(f'EM Convergence (K = {K})')
        plt.grid(True)
        plt.show()

    print("Final EM parameters:")
    print("Mixing weights:", pi)
    print("Variances:", sigma)
    return pi, sigma, log_likelihoods[-1]




def select_optimal_K(data, max_K=3, tol=1e-6, verbose=True, plot_convergence=False):
    """
    Runs the EM algorithm for candidate K = 1 ... max_K and selects the optimal number
    based on the Bayesian Information Criterion (BIC).
    Returns: (optimal_K, optimal_pi, optimal_sigma, BIC_values)
    """
    N = len(data)
    best_BIC = np.inf
    best_K = None
    best_pi = None
    best_sigma = None
    BIC_values = {}
    for K in range(1, max_K+1):
        print(f"\nRunning EM for K = {K}")
        pi, sigma, final_log_likelihood = estimate_transient_mog_parameters(data, K=K, tol=tol, verbose=verbose, plot_convergence=plot_convergence)

        p = 2 * K - 1
        BIC = -2 * final_log_likelihood + p * np.log(N)
        BIC_values[K] = BIC
        print(f"BIC for K = {K}: {BIC:.6f}")
        if BIC < best_BIC:
            best_BIC = BIC
            best_K = K
            best_pi = pi
            best_sigma = sigma

    print(f"\nOptimal number of Gaussians selected: K = {best_K} with BIC = {best_BIC:.6f}")
    return best_K, best_pi, best_sigma, BIC_values






def compute_transient_mog_likelihood(transient_block, pi, sigma_list):
    """
    Compute the log likelihood of the transient block using a Mixture of Gaussians model.
    transient_block: the transient audio samples.
    pi: list of mixing weights (should sum to 1).
    sigma_list: list of variance values for each Gaussian component.
    """
    N_samples = len(transient_block)
    transient_block = transient_block.astype(np.float64)
    log_likelihoods = []
    for k in range(len(pi)):
        ll = -0.5 * N_samples * np.log(2 * np.pi * sigma_list[k]) - 0.5 * np.sum(transient_block**2) / sigma_list[k]
        log_likelihoods.append(np.log(pi[k]) + ll)
    max_log = np.max(log_likelihoods)
    return max_log + np.log(np.sum(np.exp(np.array(log_likelihoods) - max_log)))


def compute_emission_probabilities(tonal_audio, transient_audio, sample_rate, midi_notes,
                                   covariance_matrices,
                                   block_size=2048, hmm_hop_size=hmm_hop_size,
                                   T_val=T, v_val=v, sigma_n2_val=sigma_n2,
                                   sigma_t2=0.5,
                                   pi_onset=None, sigma_onset_components=None,
                                   pi_non_onset=None, sigma_non_onset_components=None):
    n_states = len(midi_notes)
    midi_starts = [note[0] for note in midi_notes]
    emissions = []
    transient_ll = []
    allowed_max_states = []

    if pi_onset is None:
        pi_onset = [0.6, 0.4]
    if sigma_onset_components is None:
        sigma_onset_components = [sigma_t2, sigma_t2 * 1.5]
    if pi_non_onset is None:
        pi_non_onset = [0.6, 0.4]
    if sigma_non_onset_components is None:
        sigma_non_onset_components = [sigma_t2 * alpha, sigma_t2 * alpha * 1.5]

    for obs_idx, start in enumerate(range(0, len(tonal_audio) - block_size, hmm_hop_size)):
        tonal_block = tonal_audio[start:start + block_size]
        transient_block = transient_audio[start:start + block_size]
        emission_vector = np.full(n_states, -np.inf)
        current_time = start / sample_rate
        allowed_idx = np.searchsorted(midi_starts, current_time, side='right') - 1
        allowed_idx = max(allowed_idx, 0)
        allowed_max_states.append(allowed_idx)
        lower_bound = max(0, allowed_idx - 3)
        upper_bound = min(n_states, allowed_idx + 2 + 1)

        for i in range(lower_bound, upper_bound):
            K = covariance_matrices[i]
            log_likelihood_tonal = tonal_LML(tonal_block, K, sigma_n2_val)
            if i == allowed_idx and is_onset_state(i, midi_notes, current_time):
                print(f"TRANSIENT DETECTED FOR STATE {i}")
                log_likelihood_trans = compute_transient_mog_likelihood(transient_block, pi_onset, sigma_onset_components)
            else:
                log_likelihood_trans = compute_transient_mog_likelihood(transient_block, pi_non_onset, sigma_non_onset_components)
            combined_log_likelihood = log_likelihood_tonal + log_likelihood_trans
            emission_vector[i] = combined_log_likelihood
            transient_ll.append(log_likelihood_trans)

        print(f"Block starting at sample {start} (time {current_time:.2f}s): allowed max state = {allowed_idx}")
        print(emission_vector)
        emissions.append(emission_vector)

    return np.array(emissions), np.array(allowed_max_states), np.array(transient_ll)





def run_tests(scores, true_onsets, threshold, onset_window=1):
    """
    scores          : 1D array of per-frame transient log-likelihoods
    true_onsets     : bool array of same length (ground-truth flags)
    threshold       : score threshold to call a transient
    onset_window    : frames of slack when matching preds→truths
    """
    preds = np.where(scores > threshold)[0]
    truths = np.where(true_onsets)[0]

    matched = set()
    TP = 0
    for p in preds:
        close = [t for t in truths
                 if abs(t - p) <= onset_window
                 and t not in matched]
        if close:
            TP += 1
            matched.add(close[0])
    FP = len(preds) - TP
    FN = len(truths) - TP
    precision = TP / (TP + FP) if TP+FP>0 else 0.0
    recall    = TP / (TP + FN) if TP+FN>0 else 0.0
    f1        = 2*precision*recall/(precision+recall) if precision+recall>0 else 0.0

    return {'TP':TP,'FP':FP,'FN':FN,'precision':precision,
            'recall':recall,'f1':f1}






def viterbi(transition_matrix, emissions, allowed_max_states):
    """
    Modified Viterbi algorithm that, at each observation time t, only allows states
    up to allowed_max_states[t]. That is, for each time step the algorithm
    does not even consider transitions to states beyond the current allowed maximum.
    """
    n_obs, n_states = emissions.shape

    dp = np.full((n_obs, n_states), -np.inf)
    backpointer = np.zeros((n_obs, n_states), dtype=int)

    for j in range(n_states):
        if j <= allowed_max_states[0]:
            dp[0, j] = emissions[0, j]

    for t in range(1, n_obs):
        for j in range(n_states):
            if j > allowed_max_states[t]:
                dp[t, j] = -np.inf
                continue

            max_prob = -np.inf
            best_state = 0

            for i in range(n_states):
                if dp[t - 1, i] == -np.inf:
                    continue
                if transition_matrix[i, j] > 0:
                    prob = dp[t - 1, i] + np.log(transition_matrix[i, j])
                else:
                    prob = -np.inf
                if prob > max_prob:
                    max_prob = prob
                    best_state = i
            dp[t, j] = max_prob + emissions[t, j]
            backpointer[t, j] = best_state

    best_last_state = np.argmax(dp[-1])
    state_sequence = np.zeros(n_obs, dtype=int)
    state_sequence[-1] = best_last_state
    for t in range(n_obs - 1, 0, -1):
        state_sequence[t - 1] = backpointer[t, state_sequence[t]]
    return state_sequence.tolist()




def log_sum_exp(log_values):
    """Compute log(sum(exp(log_values))) in a numerically stable way."""
    max_log = np.max(log_values)
    return max_log + np.log(np.sum(np.exp(log_values - max_log)))


def forward_algorithm(transition_matrix, emissions, allowed_max_states):
    """
    Computes forward probabilities (in log-space) for the HMM.

    Parameters:
      - transition_matrix: (n_states x n_states) array of transition probabilities.
      - emissions: (n_obs x n_states) array of log emission probabilities.
      - allowed_max_states: (n_obs,) array specifying the maximum allowed state index per observation.

    Returns:
      - alpha: (n_obs x n_states) array of log forward probabilities.
    """
    n_obs, n_states = emissions.shape
    alpha = np.full((n_obs, n_states), -np.inf)

    for j in range(n_states):
        if j <= allowed_max_states[0]:
            alpha[0, j] = emissions[0, j]

    for t in range(1, n_obs):
        for j in range(n_states):
            if j > allowed_max_states[t]:
                continue
            prev_logs = []
            for i in range(n_states):
                if alpha[t-1, i] > -np.inf and transition_matrix[i, j] > 0:
                    prev_logs.append(alpha[t-1, i] + np.log(transition_matrix[i, j]))
            if prev_logs:
                alpha[t, j] = emissions[t, j] + log_sum_exp(np.array(prev_logs))

    return alpha


def compute_total_log_likelihood(alpha):
    """Compute the total log-likelihood of the observation sequence using the forward probabilities."""
    return log_sum_exp(alpha[-1, :])




def plot_transient_mog(transient, pi_onset_em, sigma_onset_em):

    plt.figure(figsize=(18, 6))
    plt.hist(transient, bins=5000, density=True, alpha=0.6, label='Transient histogram')
    x_values = np.linspace(np.min(transient), np.max(transient), 1000)

    total_density = np.zeros_like(x_values, dtype=float)
    for p_k, s_k in zip(pi_onset_em, sigma_onset_em):
        total_density += p_k * (1/np.sqrt(2*np.pi*s_k)) * np.exp(-x_values**2/(2*s_k))
    plt.plot(x_values, total_density, color='red', linewidth=2, label='Mixture density (total)')

    component_colors = ['blue', 'green', 'orange', 'purple']
    for idx, (p_k, s_k) in enumerate(zip(pi_onset_em, sigma_onset_em)):
        comp = p_k * (1/np.sqrt(2*np.pi*s_k)) * np.exp(-x_values**2/(2*s_k))
        plt.plot(
            x_values, comp,
            linestyle='--',
            color=component_colors[idx % len(component_colors)],
            linewidth=2,
            label=f'Component {idx+1}'
        )
    plt.xlabel('Amplitude')
    plt.ylabel('Density')
    plt.title('Transient Samples Histogram with MoG Overlay (Full Range)')
    plt.legend()
    plt.show()

    plt.figure(figsize=(18, 6))
    plt.hist(transient, bins=5000, density=True, alpha=0.6, label='Transient histogram')
    x_zoom = np.linspace(-10000, 10000, 1000)
    total_density_zoom = np.zeros_like(x_zoom, dtype=float)
    for p_k, s_k in zip(pi_onset_em, sigma_onset_em):
        total_density_zoom += p_k * (1/np.sqrt(2*np.pi*s_k)) * np.exp(-x_zoom**2/(2*s_k))
    plt.plot(x_zoom, total_density_zoom, color='red', linewidth=2, label='Mixture density (total)')
    for idx, (p_k, s_k) in enumerate(zip(pi_onset_em, sigma_onset_em)):
        comp = p_k * (1/np.sqrt(2*np.pi*s_k)) * np.exp(-x_zoom**2/(2*s_k))
        plt.plot(
            x_zoom, comp,
            linestyle='--',
            color=component_colors[idx % len(component_colors)],
            linewidth=2,
            label=f'Component {idx+1}'
        )
    plt.xlabel('Amplitude')
    plt.xlim(-5000, 5000)
    plt.ylabel('Density')
    plt.title('Transient Samples Histogram with MoG Overlay (Zoomed In)')
    plt.legend()
    plt.show()






def main(midi_file, max_K_candidates=3):
    midi_notes = parse_midi(midi_file)
    tonal, transient = extract_features()

    block_size = 2048
    T_val = T
    v_val = v
    covariance_matrices_unique = {}
    covariance_matrices = {}
    for i, note in enumerate(midi_notes):
        _, _, pitch = note
        if pitch not in covariance_matrices_unique:
            base_frequency = 440.0 * 2 ** ((pitch - 69) / 12.0)
            covariance_matrices_unique[pitch] = compute_covariance_matrix(
                [base_frequency], block_size, sample_rate, T_val, v_val)
        covariance_matrices[i] = covariance_matrices_unique[pitch]

    print("Covariance matrices computed")

    transient_float = transient.astype(np.float32) / np.max(np.abs(transient))
    onset_frames = librosa.onset.onset_detect(y=transient_float, sr=sample_rate)
    onset_times = librosa.frames_to_time(onset_frames, sr=sample_rate)
    onset_window = int(0.05 * sample_rate)  # 50 ms window
    onset_indices = []
    for t in onset_times:
        onset_idx = int(t * sample_rate)
        start_idx = max(0, onset_idx - onset_window // 2)
        end_idx = min(len(transient), onset_idx + onset_window // 2)
        onset_indices.extend(range(start_idx, end_idx))
    onset_mask = np.zeros(len(transient), dtype=bool)
    onset_mask[onset_indices] = True

    data_onset = transient[onset_mask].astype(np.float64)
    data_non_onset = transient[~onset_mask].astype(np.float64)

    print("Selecting optimal number of Gaussians for onset data:")
    best_K_onset, pi_onset_em, sigma_onset_em, BIC_onset = select_optimal_K(
        data_onset, max_K=max_K_candidates, tol=1e-6, verbose=True, plot_convergence=False)

    print("\nSelecting optimal number of Gaussians for non-onset data:")
    best_K_non_onset, pi_non_onset_em, sigma_non_onset_em, BIC_non_onset = select_optimal_K(
        data_non_onset, max_K=max_K_candidates, tol=1e-6, verbose=True, plot_convergence=False)

    print("Estimated onset parameters (EM):", pi_onset_em, sigma_onset_em)
    print("Estimated non-onset parameters (EM):", pi_non_onset_em, sigma_non_onset_em)




    plot_transient_mog(transient, pi_onset_em, sigma_onset_em)





    E_Z_list = [int((note[1] - note[0]) * sample_rate) for note in midi_notes]
    transition_matrix = compute_transition_matrices(E_Z_list, block_size=block_size)

    emissions, allowed_max_states, transient_scores = compute_emission_probabilities(
        tonal, transient, sample_rate, midi_notes,
        covariance_matrices,  # Pass the precomputed matrices
        block_size=block_size, hmm_hop_size=hmm_hop_size,
        T_val=T, v_val=v, sigma_n2_val=sigma_n2, sigma_t2=0.5,
        pi_onset=pi_onset_em, sigma_onset_components=sigma_onset_em,
        pi_non_onset=pi_non_onset_em, sigma_non_onset_components=sigma_non_onset_em
    )

    alpha_mat = forward_algorithm(transition_matrix, emissions, allowed_max_states)
    total_log_likelihood = compute_total_log_likelihood(alpha_mat)
    print("Alpha matrix:", alpha_mat)
    print("Total log-likelihood:", total_log_likelihood)

    state_sequence = viterbi(transition_matrix, emissions, allowed_max_states)
    print("Most probable state sequence:", state_sequence)


    frame_times = np.arange(len(allowed_max_states)) * (hmm_hop_size / sample_rate)
    true_onsets = np.array([
        is_onset_state(state_idx, midi_notes, t)
        for state_idx, t in zip(allowed_max_states, frame_times)
    ], dtype=bool)

    on_ll = []
    off_ll = []
    for start in range(0, len(transient) - block_size, hmm_hop_size):
        block = transient[start:start + block_size]
        on_ll .append(compute_transient_mog_likelihood(block,
                            pi_onset_em,    sigma_onset_em))
        off_ll.append(compute_transient_mog_likelihood(block,
                            pi_non_onset_em, sigma_non_onset_em))
    on_ll  = np.array(on_ll)
    off_ll = np.array(off_ll)

    p_on = true_onsets.mean()
    tau_star = np.log((1 - p_on) / p_on)

    preds = (on_ll - off_ll) >= tau_star

    TP = np.sum( preds &  true_onsets)
    FP = np.sum( preds & ~true_onsets)
    FN = np.sum(~preds &  true_onsets)
    TN = np.sum(~preds & ~true_onsets)
    precision = TP/(TP+FP) if TP+FP>0 else 0.0
    recall    = TP/(TP+FN) if TP+FN>0 else 0.0
    f1        = 2*precision*recall/(precision+recall) if precision+recall>0 else 0.0

    results = {
      'TP':TP, 'FP':FP, 'FN':FN, 'TN':TN,
      'precision':precision, 'recall':recall, 'f1':f1,
      'tau*': tau_star
    }
    print("Bayes‐optimal detection results:", results)
    print("Transient scores:", transient_scores)


    return alpha_mat, total_log_likelihood, state_sequence, results, transient_scores






midi_notes = parse_midi(midi)
alpha_matrix, total_ll, viterbi_output, test_results, transient_scores = main(midi)



In [None]:
def plot_viterbi_output(midi_data, state_sequence, figsize, sample_rate=44100, hop_size=hmm_hop_size):
    """
    Plots a piano roll visualization of a PrettyMIDI object with small 'x' markers at every observation,
    where the y-value for each marker is chosen based on the Viterbi state. In particular, if the state
    is 0 then the marker is placed at the pitch of the first note (notes[0]); otherwise, it uses the note
    corresponding to the state value (if valid).

    Args:
        midi_data (pretty_midi.PrettyMIDI): The PrettyMIDI object to visualize.
        state_sequence (list or np.array): Viterbi state indices at each observation.
        sample_rate (int): Audio sample rate (samples per second).
        hop_size (int): Number of samples between observations in state_sequence.
    """
    import numpy as np
    import matplotlib.pyplot as plt

    try:
        # Extract note pitches and start times from all instruments.
        notes = []
        times = []
        for instrument in midi_data.instruments:
            for note in instrument.notes:
                notes.append(note.pitch)
                times.append(note.start)

        print(notes)

        if not notes:
            print("No notes found in the MIDI data.")
            return

        plt.figure(figsize=figsize)
        plt.scatter(times, notes, marker='o', color='blue', alpha=0.7)
        plt.xlabel("Time (seconds)")
        plt.ylabel("MIDI Note Number")
        plt.title("MIDI Piano Roll Visualization & Viterbi Decoded States")
        plt.yticks(range(min(notes), max(notes) + 1, 2))
        plt.grid(True, linestyle='--', alpha=0.6)

        obs_times = np.arange(len(state_sequence)) * (hop_size / sample_rate)


        colors = ['red', 'green', 'blue']

        for i, state in enumerate(state_sequence):
            if state == 0:
                note_val = notes[0]
            else:
                if state < len(notes):
                    note_val = notes[state]
                else:
                    continue
            color = colors[state % len(colors)]
            plt.text(obs_times[i], note_val, 'x', color=color,
                     fontsize=8, horizontalalignment='center', verticalalignment='center')

        segments = []
        seg_start = 0
        for i in range(1, len(state_sequence)):
            if state_sequence[i] != state_sequence[i - 1]:
                segments.append((seg_start, i - 1, state_sequence[i - 1]))
                seg_start = i
        segments.append((seg_start, len(state_sequence) - 1, state_sequence[-1]))

        ax = plt.gca()
        top_y = max(notes) + 2
        for seg in segments:
            seg_start_idx, seg_end_idx, state = seg
            block_start_time = obs_times[seg_start_idx]
            block_end_time = obs_times[seg_end_idx] + (hop_size / sample_rate)
            mid_time = (block_start_time + block_end_time) / 2.0
            ax.text(mid_time, top_y, f"S={state}",
                    horizontalalignment='center',
                    verticalalignment='bottom',
                    fontsize=10,
                    bbox=dict(facecolor='white', alpha=0.6, edgecolor='none'))

        plt.show()

    except Exception as e:
        print(f"Error plotting piano roll: {e}")


    midi_notes = []
    for instr in midi_data.instruments:
        if not instr.is_drum:
            for note in instr.notes:
                midi_notes.append((note.start, note.end, note.pitch))
    midi_notes.sort(key=lambda x: x[0])
    starts = [n[0] for n in midi_notes]

    true_states = []
    for t in obs_times:
        idx = np.searchsorted(starts, t, side='right') - 1
        true_states.append(max(idx, 0))
    true_states = np.array(true_states)

    seq = np.array(state_sequence)
    correct = np.sum(seq == true_states)
    total   = len(seq)
    acc     = correct / total * 100

    print(f"Alignment: {correct}/{total} frames correct ({acc:.1f}%)")

In [None]:
plot_viterbi_output(midi, viterbi_output, (22,10), sample_rate, hmm_hop_size)