**HMM Logic Sequence**

- Input: MIDI file; audio file
- Separate audio into tonal & transient responses
- Move through the score sequentially
- At each point in the score, take the block of samples
- Transition matrix: geometric distribution model of state durations
- Emission probability: tonal likelihood using GP LML
- Viterbi algorithm 'windowed' to compute most probable state sequence

In [None]:
pip install pretty_midi pydub

In [None]:
from google.colab import drive
drive.mount('/content/drive')
from pretty_midi import PrettyMIDI
import matplotlib.pyplot as plt
import pretty_midi


def load_and_filter_midi_pretty(file_path, note_range=(60, 76), remove_note=None):
    """
    Loads a MIDI file using pretty_midi, filters it to only include notes within the specified range,
    and optionally removes a specific note based on pitch and timing.

    Args:
        file_path (str): Path to the MIDI file.
        note_range (tuple): A tuple specifying the (min_note, max_note) range.
        remove_note (tuple, optional): A tuple specifying the (pitch, start_time, time_tolerance)
                                       of the note to be removed.

    Returns:
        PrettyMIDI: A new PrettyMIDI object with the filtered notes.
    """
    try:
        midi_data = pretty_midi.PrettyMIDI(file_path)

        for instrument in midi_data.instruments:
            filtered_notes = [
                note for note in instrument.notes
                if note_range[0] <= note.pitch <= note_range[1]
            ]

            if remove_note:
                pitch_to_remove, start_time, time_tolerance = remove_note
                filtered_notes = [
                    note for note in filtered_notes
                    if not (note.pitch == pitch_to_remove and
                            abs(note.start - start_time) <= time_tolerance)
                ]

            instrument.notes = filtered_notes  # Update notes

        return midi_data

    except Exception as e:
        print(f"Error loading MIDI file: {e}")
        return None



def display_midi_and_notes_pretty(midi_data):
    """
    Displays the notes of a PrettyMIDI object.

    Args:
        midi_data (PrettyMIDI): The PrettyMIDI object to analyze.

    Returns:
        list: A list of notes as tuples (start, end, pitch).
    """
    try:
        notes = []
        for instrument in midi_data.instruments:
            for note in instrument.notes:
                notes.append((note.start, note.end, note.pitch))

        # Sort notes by start time
        notes = sorted(notes, key=lambda x: x[0])

        def midi_to_note_name(midi_number):
            """Converts a MIDI note number to a piano note name."""
            note_names = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
            octave = (midi_number // 12) - 1
            note = note_names[midi_number % 12]
            return f"{note}{octave}"

        # Print and display note names
        note_names = [(start, end, midi_to_note_name(pitch)) for start, end, pitch in notes]
        print("\nParsed Notes (start, end, pitch):")
        for note in note_names:
            print(note)

        return notes

    except Exception as e:
        print(f"Error displaying MIDI notes: {e}")
        return []


def plot_midi_piano_roll_pretty(midi_data):
    """
    Plots a piano roll visualization of a PrettyMIDI object.

    Args:
        midi_data (PrettyMIDI): The PrettyMIDI object to visualize.
    """
    try:
        notes = []
        times = []

        for instrument in midi_data.instruments:
            for note in instrument.notes:
                notes.append(note.pitch)
                times.append(note.start)

        # Plot the piano roll
        plt.figure(figsize=(10, 6))
        plt.scatter(times, notes, marker='o', color='blue', alpha=0.7)
        plt.xlabel("Time (seconds)")
        plt.ylabel("MIDI Note Number")
        plt.title("MIDI Piano Roll Visualization (PrettyMIDI)")
        plt.yticks(range(min(notes), max(notes) + 1, 2))  # Display every 2 notes
        plt.grid(True, linestyle='--', alpha=0.6)
        plt.show()

    except Exception as e:
        print(f"Error plotting piano roll: {e}")


Mounted at /content/drive


In [None]:
note_to_remove = (71, 8.7, 0.1)  # Example: Remove note 70 around 8s with 0.2s tolerance
midi = load_and_filter_midi_pretty('/content/drive/MyDrive/furelise.mid', remove_note=note_to_remove)
plot_midi_piano_roll_pretty(midi)

In [None]:
import numpy as np
from scipy.stats import geom
import librosa
import pretty_midi
from scipy.ndimage import median_filter
from pydub import AudioSegment
import librosa
import matplotlib.pyplot as plt
from google.colab import drive
drive.mount('/content/drive')
from IPython.display import Audio
from scipy.io import wavfile
from scipy.fft import fft, ifft
from scipy.signal.windows import hamming, hann
from scipy.linalg import cho_factor, cho_solve




audio = AudioSegment.from_file('/content/drive/MyDrive/furelise.mp3')
sample_rate = audio.frame_rate
audio_data = np.array(audio.get_array_of_samples())

time_window_size = 2048
hop_size = int(time_window_size * 0.5)
hmm_hop_size = 3000
window_size = 19
alpha = 0.4
eta = 0.1
sigma_p2 = 1e1 # Transient noise

sigma_f2 = 2e-5  # Decay parameter for covariance
sigma_n2 = 3  # Noise variance
wq = [1.0]  # Weights for fundamentals
M = 9  # Number of harmonics

sample_rate = 44100

T = 0.5
v = 2.37




def parse_midi(midi_data):    #✅
    notes = []
    for instrument in midi_data.instruments:
        if not instrument.is_drum:
            for note in instrument.notes:
                notes.append((note.start, note.end, note.pitch))
    notes = sorted(notes, key=lambda x: x[0])
    print(notes)
    return notes


def compute_transition_matrices(E_Z_list, block_size=2048):
    """
    Compute a single left-to-right transition matrix for the entire score.
    Each state corresponds to a score event (note/chord) from the MIDI file.

    We first compute the average note duration (in samples) from E_Z_list
    and then convert it to the average number of blocks (of size 'block_size').
    For an average of d_avg blocks per note, a simple duration model is:

        self-transition probability:   p = 1 - 1/d_avg
        transition to next state:       1 - p

    The resulting matrix is of size (n_states x n_states) where n_states is
    the number of MIDI events.
    """
    n_states = len(E_Z_list)
    avg_duration_samples = np.mean(E_Z_list)
    d_avg = avg_duration_samples / block_size

    p = 1 - 1/d_avg if d_avg > 1 else 0.0

    T = np.zeros((n_states, n_states))
    for i in range(n_states):
        T[i, i] = p
        if i < n_states - 1:
            T[i, i+1] = 1 - p

    T[n_states-1, n_states-1] = 1
    return T



def extract_features():   #✅
    audio = AudioSegment.from_file('/content/drive/MyDrive/furelise.mp3')
    sample_rate = audio.frame_rate
    samples = np.array(audio.get_array_of_samples())

    if audio.channels > 1:   # If the audio is stereo, reshape and convert to mono
        samples = samples.reshape((-1, audio.channels))
        audio_data = samples.mean(axis=1)
    else:
        audio_data = samples

    time_hamming_window = hann(time_window_size)
    randomized_phase = np.random.uniform(-np.pi, np.pi, time_window_size)
    processed_audio = np.zeros(len(audio_data))
    transient_probabilities = []

    time_hamming_window = hann(time_window_size)
    randomized_phase = np.random.uniform(-np.pi, np.pi, time_window_size)

    processed_audio = np.zeros(len(audio_data))

    for start in range(0, len(audio_data) - time_window_size, hop_size):
        segment = audio_data[start:start + time_window_size].astype(np.float64)
        segment *= time_hamming_window

        segment_ft = fft(segment)
        magnitude = np.abs(segment_ft)
        phase = np.angle(segment_ft)

        median_magnitude = alpha * median_filter(magnitude, size=(window_size,))
        filtered_magnitude = np.minimum(median_magnitude, magnitude)

        filtered_segment_ft = filtered_magnitude * np.exp(1j * phase)
        filtered_segment = ifft(filtered_segment_ft).real

        processed_audio[start:start + time_window_size] += filtered_segment * time_hamming_window

    processed_audio = np.int16(processed_audio / np.max(np.abs(processed_audio)) * 32767)

    tonal_audio = audio_data - processed_audio
    tonal_audio = np.int16(tonal_audio / np.max(np.abs(tonal_audio)) * 32767)

    return tonal_audio, processed_audio


def harmonic_weight(m, T, v):   #✅
    return 1 / (1 + T * m**v)


def covariance_function(tau, wq, midi_frequencies, T, v):
    cov = np.exp(-2 * np.pi**2 * sigma_f2 * tau**2)
    for fq in midi_frequencies:
        harmonic_sum = 0
        for m in range(1, M + 1):
            E_m = harmonic_weight(m, T, v)
            harmonic_sum += E_m * np.cos(2 * np.pi * m * fq * tau)
        cov += wq[0] * harmonic_sum
    return cov


def compute_covariance_matrix(midi_frequencies, block_size, sample_rate, T, v):
    t = np.linspace(0, block_size / sample_rate, block_size)
    tau = t[:, None] - t[None, :]

    cov = np.exp(-2 * np.pi**2 * sigma_f2 * tau**2)

    harmonic_weights = np.array([harmonic_weight(m, T, v) for m in range(1, M + 1)])
    harmonic_weights = harmonic_weights[:, None, None]

    for fq in midi_frequencies:
        harmonic_matrix = harmonic_weights * np.cos(2 * np.pi * np.arange(1, M + 1)[:, None, None] * fq * tau)
        cov += np.sum(harmonic_matrix, axis=0)

    return cov


def tonal_LML(y, K, sigma_n2):
    N = len(y)
    K_noise = K + sigma_n2 * np.eye(N)
    cho_decomp = cho_factor(K_noise, lower=True)
    alpha = cho_solve(cho_decomp, y)
    log_det = 2 * np.sum(np.log(np.diag(cho_decomp[0])))
    LML = (
        -0.5 * np.dot(y.T, alpha)
        - 0.5 * log_det
        - 0.5 * N * np.log(2 * np.pi)
    )
    return LML




def compute_transient_likelihood(transient_block, sigma):
    """
    Compute the log likelihood of the transient block using a Gaussian model
    with variance 'sigma'.
    """
    N_samples = len(transient_block)
    return -0.5 * N_samples * np.log(2 * np.pi * sigma) - 0.5 * np.sum(transient_block.astype(np.float64)**2) / sigma

def is_onset_state(state_index, midi_notes, current_time, onset_window=hmm_hop_size/sample_rate):
    """
    Determine if the candidate state is in its transient (onset) phase.
    Only returns True if the note’s onset has occurred (current_time >= onset_time)
    and the current time is within onset_window seconds after the onset.
    """
    onset_time = midi_notes[state_index][0]
    if current_time < onset_time:
        return False
    return (current_time - onset_time) < onset_window

def compute_emission_probabilities(tonal_audio, transient_audio, sample_rate, midi_notes,
                                   block_size=2048, hmm_hop_size=hmm_hop_size,
                                   T_val=T, v_val=v, sigma_n2_val=sigma_n2,
                                   sigma_t2=sigma_p2):
    """
    Computes state-dependent emission probabilities.
    For each observation, we compute probabilities only for a subset of states:
      - The previous 4 states (including the maximum allowed state)
      - The next 2 states
    All other states remain at -infinity.
    The transient likelihood is computed using a dedicated function.
      - If a candidate state corresponds to a note onset (detected via transient data),
        then that state's likelihood is the sum of the tonal likelihood and the transient
        likelihood computed using sigma_t2.
      - For all other states, the likelihood is computed as the tonal likelihood plus
        the transient likelihood computed using sigma_t2 multiplied by alpha.
    """
    n_states = len(midi_notes)
    midi_starts = [note[0] for note in midi_notes]

    base_frequencies = []
    for note in midi_notes:
        _, _, pitch = note
        base_frequency = 440.0 * 2 ** ((pitch - 69) / 12.0)
        base_frequencies.append(base_frequency)

    n_obs = (len(tonal_audio) - block_size) // hmm_hop_size + 1
    emissions = []
    allowed_max_states = []  # one allowed index per observation

    for obs_idx, start in enumerate(range(0, len(tonal_audio) - block_size, hmm_hop_size)):
        tonal_block = tonal_audio[start:start + block_size]
        transient_block = transient_audio[start:start + block_size]
        emission_vector = np.full(n_states, -np.inf)

        current_time = start / sample_rate
        allowed_idx = np.searchsorted(midi_starts, current_time, side='right') - 1
        allowed_idx = max(allowed_idx, 0)
        allowed_max_states.append(allowed_idx)

        lower_bound = max(0, allowed_idx - 3)
        upper_bound = min(n_states, allowed_idx + 2 + 1)  # +1 because range() is exclusive at the end

        for i in range(lower_bound, upper_bound):
            candidate_frequency = base_frequencies[i]
            K = compute_covariance_matrix([candidate_frequency], block_size, sample_rate, T_val, v_val)
            log_likelihood_tonal = tonal_LML(tonal_block, K, sigma_n2_val)

            if i == allowed_idx and is_onset_state(i, midi_notes, current_time):
                print(f"Transient detected for state {i} at time {current_time:.2f}s")
                transient_sigma = sigma_t2
            else:
                transient_sigma = sigma_t2 * eta

            log_likelihood_trans = compute_transient_likelihood(transient_block, transient_sigma)
            combined_log_likelihood = log_likelihood_tonal + log_likelihood_trans
            emission_vector[i] = combined_log_likelihood

        print(f"Block starting at sample {start} (time {current_time:.2f}s): allowed max state = {allowed_idx}")
        print(emission_vector)
        emissions.append(emission_vector)

    return np.array(emissions), np.array(allowed_max_states)






def viterbi(transition_matrix, emissions, allowed_max_states):
    """
    Modified Viterbi algorithm that, at each observation time t, only allows states
    up to allowed_max_states[t]. That is, for each time step the algorithm
    does not even consider transitions to states beyond the current allowed maximum.
    """
    n_obs, n_states = emissions.shape

    dp = np.full((n_obs, n_states), -np.inf)
    backpointer = np.zeros((n_obs, n_states), dtype=int)

    for j in range(n_states):
        if j <= allowed_max_states[0]:
            dp[0, j] = emissions[0, j]

    for t in range(1, n_obs):
        for j in range(n_states):
            if j > allowed_max_states[t]:
                dp[t, j] = -np.inf
                continue

            max_prob = -np.inf
            best_state = 0

            for i in range(n_states):
                if dp[t - 1, i] == -np.inf:
                    continue
                if transition_matrix[i, j] > 0:
                    prob = dp[t - 1, i] + np.log(transition_matrix[i, j])
                else:
                    prob = -np.inf
                if prob > max_prob:
                    max_prob = prob
                    best_state = i
            dp[t, j] = max_prob + emissions[t, j]
            backpointer[t, j] = best_state

    best_last_state = np.argmax(dp[-1])
    state_sequence = np.zeros(n_obs, dtype=int)
    state_sequence[-1] = best_last_state
    for t in range(n_obs - 1, 0, -1):
        state_sequence[t - 1] = backpointer[t, state_sequence[t]]
    return state_sequence.tolist()




def log_sum_exp(log_values):
    """Compute log(sum(exp(log_values))) in a numerically stable way."""
    max_log = np.max(log_values)
    return max_log + np.log(np.sum(np.exp(log_values - max_log)))


def forward_algorithm(transition_matrix, emissions, allowed_max_states):
    """
    Computes forward probabilities (in log-space) for the HMM.

    Parameters:
      - transition_matrix: (n_states x n_states) array of transition probabilities.
      - emissions: (n_obs x n_states) array of log emission probabilities.
      - allowed_max_states: (n_obs,) array specifying the maximum allowed state index per observation.

    Returns:
      - alpha: (n_obs x n_states) array of log forward probabilities.
    """
    n_obs, n_states = emissions.shape
    alpha = np.full((n_obs, n_states), -np.inf)

    for j in range(n_states):
        if j <= allowed_max_states[0]:
            alpha[0, j] = emissions[0, j]

    for t in range(1, n_obs):
        for j in range(n_states):
            if j > allowed_max_states[t]:
                continue
            prev_logs = []
            for i in range(n_states):
                if alpha[t-1, i] > -np.inf and transition_matrix[i, j] > 0:
                    prev_logs.append(alpha[t-1, i] + np.log(transition_matrix[i, j]))
            if prev_logs:
                alpha[t, j] = emissions[t, j] + log_sum_exp(np.array(prev_logs))

    return alpha


def compute_total_log_likelihood(alpha):
    """Compute the total log-likelihood of the observation sequence using the forward probabilities."""
    return log_sum_exp(alpha[-1, :])




def main(midi_file):
    """
    Main function to run the score follower.

    1. Parse the MIDI score.
    2. Extract tonal and transient features from the audio.
    3. Compute a left-to-right transition matrix based on note durations.
    4. Compute emission probabilities only for states (notes) that have already started.
    5. Run the modified Viterbi algorithm.
    """
    midi_notes = parse_midi(midi_file)

    tonal, transient = extract_features()

    sample_rate = 44100

    E_Z_list = [int((note[1] - note[0]) * sample_rate) for note in midi_notes]

    transition_matrix = compute_transition_matrices(E_Z_list, block_size=2048)

    emissions, allowed_max_states = compute_emission_probabilities(
        tonal, transient, sample_rate, midi_notes,
        block_size=2048, hmm_hop_size=hmm_hop_size,
        T_val=T, v_val=v, sigma_n2_val=sigma_n2, sigma_t2=sigma_p2
    )

    alpha = forward_algorithm(transition_matrix, emissions, allowed_max_states)
    total_log_likelihood = compute_total_log_likelihood(alpha)
    print("Alpha matrix:", alpha)
    print("Total log-likelihood of the observation sequence:", total_log_likelihood)

    state_sequence = viterbi(transition_matrix, emissions, allowed_max_states)
    print("Most probable state sequence:", state_sequence)

    return alpha, total_log_likelihood, state_sequence




# Run the HMM
midi_notes = parse_midi(midi)
alpha_matrix, total_ll, viterbi_output = main(midi)


In [None]:
def plot_viterbi_output(midi_data, state_sequence, figsize, sample_rate=44100, hop_size=hmm_hop_size):
    """
    Plots a piano roll visualization of a PrettyMIDI object with small 'x' markers at every observation,
    where the y-value for each marker is chosen based on the Viterbi state. In particular, if the state
    is 0 then the marker is placed at the pitch of the first note (notes[0]); otherwise, it uses the note
    corresponding to the state value (if valid).

    Args:
        midi_data (pretty_midi.PrettyMIDI): The PrettyMIDI object to visualize.
        state_sequence (list or np.array): Viterbi state indices at each observation.
        sample_rate (int): Audio sample rate (samples per second).
        hop_size (int): Number of samples between observations in state_sequence.
    """
    import numpy as np
    import matplotlib.pyplot as plt

    try:
        # Extract note pitches and start times from all instruments.
        notes = []
        times = []
        for instrument in midi_data.instruments:
            for note in instrument.notes:
                notes.append(note.pitch)
                times.append(note.start)

        print(notes)

        if not notes:
            print("No notes found in the MIDI data.")
            return

        plt.figure(figsize=figsize)
        plt.scatter(times, notes, marker='o', color='blue', alpha=0.7)
        plt.xlabel("Time (seconds)")
        plt.ylabel("MIDI Note Number")
        plt.title("MIDI Piano Roll Visualization + Viterbi States")
        plt.yticks(range(min(notes), max(notes) + 1, 2))
        plt.grid(True, linestyle='--', alpha=0.6)

        obs_times = np.arange(len(state_sequence)) * (hop_size / sample_rate)

        colors = ['red', 'green', 'blue']

        for i, state in enumerate(state_sequence):
            # For state 0, always use the first note's pitch.
            if state == 0:
                note_val = notes[0]
            else:
                if state < len(notes):
                    note_val = notes[state]
                else:
                    continue

            color = colors[state % len(colors)]
            plt.text(obs_times[i], note_val, 'x', color=color,
                     fontsize=8, horizontalalignment='center', verticalalignment='center')

        # Label contiguous segments of the same state at the top of the plot.
        segments = []
        seg_start = 0
        for i in range(1, len(state_sequence)):
            if state_sequence[i] != state_sequence[i - 1]:
                segments.append((seg_start, i - 1, state_sequence[i - 1]))
                seg_start = i
        segments.append((seg_start, len(state_sequence) - 1, state_sequence[-1]))

        ax = plt.gca()
        top_y = max(notes) + 2
        for seg in segments:
            seg_start_idx, seg_end_idx, state = seg
            block_start_time = obs_times[seg_start_idx]
            block_end_time = obs_times[seg_end_idx] + (hop_size / sample_rate)
            mid_time = (block_start_time + block_end_time) / 2.0
            ax.text(mid_time, top_y, f"S={state}",
                    horizontalalignment='center',
                    verticalalignment='bottom',
                    fontsize=10,
                    bbox=dict(facecolor='white', alpha=0.6, edgecolor='none'))

        plt.show()

    except Exception as e:
        print(f"Error plotting piano roll: {e}")


In [None]:
plot_viterbi_output(midi, viterbi_output, (22,10), sample_rate, hmm_hop_size)