Inspired from HW 3: 
Markov chain that computes p(beat_length | previous_beat_length, beat_position)

In [1]:
import os
from glob import glob
import numpy as np
from collections import defaultdict, Counter
from pretty_midi import PrettyMIDI
from midiutil import MIDIFile
from sklearn.model_selection import train_test_split

#######################
# DATA HELPER FUNCTIONS
#######################

def extract_aligned_notes(midi_file, right_idx=0, left_idx=1):
    """
    Extracts aligned notes for RH and LH by start time. 
    Returns [(rh_pitch, lh_pitch, start, duration, velocity)].
    If there is no LH note at a given RH onset, lh_pitch is None.
    """
    midi = PrettyMIDI(midi_file)
    if len(midi.instruments) <= max(right_idx, left_idx):
        return []
    rh_notes = midi.instruments[right_idx].notes
    lh_notes = midi.instruments[left_idx].notes

    # Build lookup for LH notes by quantized start time
    lh_by_start = defaultdict(list)
    for note in lh_notes:
        key = round(note.start, 3)
        lh_by_start[key].append(note)

    alignment = []
    for rn in rh_notes:
        start = round(rn.start, 3)
        # try to match any LH starting within +-30ms window
        cand = [note for s, notes in lh_by_start.items()
                if abs(s - start) < 0.03 for note in notes]
        if cand:
            # use the first matching LH note
            ln = cand[0]
            alignment.append((rn.pitch, ln.pitch, start, rn.end - rn.start, rn.velocity))
            # Remove LH note so it's not used twice
            lh_by_start[round(ln.start, 3)].remove(ln)
        else:
            alignment.append((rn.pitch, None, start, rn.end - rn.start, rn.velocity))
    return alignment

def extract_parallel_sequences(lh_files, rh_files):
    """
    For corresponding MIDI file pairs, extract aligned sequences as tuples:
    returns: list of (RH sequence, LH sequence) with durations (not just pitches).
    """
    pairs = []
    for lh_file, rh_file in zip(lh_files, rh_files):
        try:
            midi = PrettyMIDI(rh_file)
            if len(midi.instruments) < 1:
                continue
            # Both RH and LH files must match by time
            # We'll use only files with similar number of notes for alignment
            rh_notes = midi.instruments[0].notes
            midi2 = PrettyMIDI(lh_file)
            if len(midi2.instruments) < 1:
                continue
            lh_notes = midi2.instruments[0].notes
            # Align by quantized start
            rh_by_start = {round(n.start, 3): n for n in rh_notes}
            lh_by_start = {round(n.start, 3): n for n in lh_notes}
            rhs, lhs, durations, velocities = [], [], [], []
            for st in sorted(rh_by_start):
                rhs.append(rh_by_start[st].pitch)
                durations.append(rh_by_start[st].end - rh_by_start[st].start)
                velocities.append(rh_by_start[st].velocity)
                if st in lh_by_start:
                    lhs.append(lh_by_start[st].pitch)
                else:
                    lhs.append(None)
            pairs.append((rhs, lhs, durations, velocities))
        except Exception as e:
            continue
    return pairs

####################
# MODEL TRAINING
####################

def train_lh_given_rh(pairs):
    """
    Build P(LH pitch | next RH pitch), i.e. next left hand conditioned on current right hand.
    Returns a dict: {rh_pitch: Counter(lh_pitches)}
    """
    counts = defaultdict(Counter)
    for rh_seq, lh_seq, _, _ in pairs:
        for rh, lh in zip(rh_seq, lh_seq):
            if lh is not None:
                counts[rh][lh] += 1
    # Turn counts into probability lists
    probs = {}
    for rh, counter in counts.items():
        total = sum(counter.values())
        ps = np.array([counter[k] for k in counter.keys()], dtype=float)
        ps /= ps.sum()
        probs[rh] = (list(counter.keys()), ps)
    return probs

####################
# LH GENERATION
####################

def generate_left_hand_given_rh_sequence(
    rh_seq, probs, random_state=None, default_pitch=36
):
    rng = np.random.default_rng(random_state)
    lh_seq = []
    for rh in rh_seq:
        if rh in probs:
            pitches, ps = probs[rh]
            lh_note = rng.choice(pitches, p=ps)
        else:
            lh_note = default_pitch # fallback: low C
        lh_seq.append(lh_note)
    return lh_seq

####################
# MIDI FILE OUTPUT
####################

def write_midi(rh_seq, rh_durations, lh_seq, filename, tempo=60):
    midi = MIDIFile(2) # Track 0: RH, Track 1: LH
    midi.addTempo(0, 0, tempo)
    midi.addTempo(1, 0, tempo)
    time = 0.0
    for i, pitch in enumerate(rh_seq):
        dur = rh_durations[i] if i < len(rh_durations) else 0.5
        midi.addNote(0, 0, int(pitch), time, dur, 100)
        time += dur
    time = 0.0
    for i, pitch in enumerate(lh_seq):
        dur = rh_durations[i] if i < len(rh_durations) else 0.5
        midi.addNote(1, 1, int(pitch), time, dur, 90)
        time += dur
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    with open(filename, "wb") as out:
        midi.writeFile(out)

#####################
# MAIN PIPELINE
#####################

# Collect file lists
data_left_hand = sorted(glob("data_left_hand/*.mid"))
data_right_hand = sorted(glob("data_right_hand/*.mid"))

# Train/test split
file_pairs = list(zip(data_left_hand, data_right_hand))
train_pairs, test_pairs = train_test_split(file_pairs, test_size=0.2, random_state=42)
lh_train, rh_train = zip(*train_pairs)
lh_test, rh_test = zip(*test_pairs)

# Align and extract pitch sequences
train_sequences = extract_parallel_sequences(lh_train, rh_train)
test_sequences = extract_parallel_sequences(lh_test, rh_test)

# Train probabilistic model
probs = train_lh_given_rh(train_sequences)

# Synthesize new LH for each test RH piece
OUTPUT_DIR = "generated_lh_by_rh"
os.makedirs(OUTPUT_DIR, exist_ok=True)

for idx, (rh_seq, _, durations, _) in enumerate(test_sequences):
    lh_seq = generate_left_hand_given_rh_sequence(rh_seq, probs, random_state=idx)
    # Output MIDI
    basename = f"test_{idx:03d}_synth.mid"
    out_path = os.path.join(OUTPUT_DIR, basename)
    write_midi(rh_seq, durations, lh_seq, out_path)
    print(f"Wrote {out_path}")

print("Done.")

######################
# Tips for improvement
######################
# - You can easily expand this to train P(LH | prev_LH, curr_RH) for a higher-order Markov model.
# - Experiment with using more expressive features (chord labels, intervals, etc) for musicality.
# - Instead of quantizing by start time alone, you could try precise matching in polyphonic passages.

Wrote generated_lh_by_rh/test_000_synth.mid
Wrote generated_lh_by_rh/test_001_synth.mid
Wrote generated_lh_by_rh/test_002_synth.mid
Wrote generated_lh_by_rh/test_003_synth.mid
Wrote generated_lh_by_rh/test_004_synth.mid
Wrote generated_lh_by_rh/test_005_synth.mid
Wrote generated_lh_by_rh/test_006_synth.mid
Wrote generated_lh_by_rh/test_007_synth.mid
Wrote generated_lh_by_rh/test_008_synth.mid
Wrote generated_lh_by_rh/test_009_synth.mid
Wrote generated_lh_by_rh/test_010_synth.mid
Wrote generated_lh_by_rh/test_011_synth.mid
Wrote generated_lh_by_rh/test_012_synth.mid
Wrote generated_lh_by_rh/test_013_synth.mid
Wrote generated_lh_by_rh/test_014_synth.mid
Wrote generated_lh_by_rh/test_015_synth.mid
Wrote generated_lh_by_rh/test_016_synth.mid
Wrote generated_lh_by_rh/test_017_synth.mid
Wrote generated_lh_by_rh/test_018_synth.mid
Wrote generated_lh_by_rh/test_019_synth.mid
Wrote generated_lh_by_rh/test_020_synth.mid
Wrote generated_lh_by_rh/test_021_synth.mid
Wrote generated_lh_by_rh/test_02