In [28]:
import pandas as pd
import mido
import pretty_midi
from collections import defaultdict


In [3]:
df = pd.read_csv('data/all_more_than_20.csv')
file_path = df[["file_name"]].assign(p='data/nesmdb_midi/all/' + df['file_name'])['p'].to_list()
file_path

['data/nesmdb_midi/all/108_Famista_90_10_11Unknown2.mid',
 'data/nesmdb_midi/all/178_Ironsword_Wizards_amp_WarriorsII_12_13FireElemental.mid',
 'data/nesmdb_midi/all/387_Wizardry_ProvingGroundsofTheMadOverlord_16_17Werdna.mid',
 'data/nesmdb_midi/all/195_KonamiWaiWaiWorld_18_19FinalStageBossBGM.mid',
 'data/nesmdb_midi/all/050_ChaosWorld_20_21Chaos.mid',
 'data/nesmdb_midi/all/135_GanbareGoemonGaiden2_TenkanoZaih__45_46PosanVillageMinatoVillage.mid',
 'data/nesmdb_midi/all/100_Falsion_05_06FlyingHighStage5BGM.mid',
 'data/nesmdb_midi/all/336_TenchioKurauII_ShokatsuKoumeiDen_18_19Cave2.mid',
 'data/nesmdb_midi/all/076_DonkeyKongCountry4_00_01Theme.mid',
 'data/nesmdb_midi/all/217_MajouDensetsuII_DaimashikyouGalious_11_12Ending.mid',
 'data/nesmdb_midi/all/105_FamicomJump_HeroRetsuden_06_07AdventureJumpWorld.mid',
 'data/nesmdb_midi/all/227_MegaMan4_06_07BrightManStage.mid',
 'data/nesmdb_midi/all/314_SummerCarnival_92_Recca_10_11BostuneFinalBoss.mid',
 'data/nesmdb_midi/all/291_Shadowof

In [4]:
midi_file = mido.MidiFile(file_path[0])
midi_file

MidiFile(type=1, ticks_per_beat=22050, tracks=[
  MidiTrack([
    MetaMessage('set_tempo', tempo=500000, time=0),
    MetaMessage('time_signature', numerator=4, denominator=4, clocks_per_click=24, notated_32nd_notes_per_beat=8, time=0),
    MetaMessage('time_signature', numerator=1, denominator=1, clocks_per_click=24, notated_32nd_notes_per_beat=8, time=1131553),
    MetaMessage('end_of_track', time=1)]),
  MidiTrack([
    MetaMessage('track_name', name='p1', time=0),
    Message('program_change', channel=0, program=80, time=0),
    Message('control_change', channel=0, control=12, value=1, time=57),
    Message('note_on', channel=0, note=81, velocity=3, time=1),
    Message('control_change', channel=0, control=11, value=7, time=686),
    Message('control_change', channel=0, control=11, value=10, time=734),
    Message('control_change', channel=0, control=11, value=14, time=734),
    Message('control_change', channel=0, control=11, value=3, time=2220),
    Message('note_on', channel=0, 

In [35]:
def parse_midi_basic_info(midi_path):
    """
    Parses a MIDI file and extracts basic information and track details.
    """
    try:
        mid = mido.MidiFile(midi_path)
    except Exception as e:
        print(f"Error loading MIDI file {midi_path}: {e}")
        return None

    if mid.type != 1:
        print(f"Warning: Expected Type 1 MIDI, got Type {mid.type} for {midi_path}")
        # You might choose to skip or handle differently

    print(f"\n--- Processing: {midi_path} ---")
    print(f"Type: {mid.type}")
    print(f"Ticks per beat: {mid.ticks_per_beat}")
    print(f"Number of tracks: {len(mid.tracks)}")

    track_info = []
    for i, track in enumerate(mid.tracks):
        info = {
            "index": i,
            "name": "Unnamed",
            "program": 0, # Default program
            "message_count": len(track)
        }
        for msg in track: # Look for initial track name and program change
            if msg.is_meta and msg.type == 'track_name':
                info["name"] = msg.name
            if msg.type == 'program_change':
                info["program"] = msg.program
                break # Often the first program change defines the track's main instrument
        track_info.append(info)
        print(f"  Track {i}: Name='{info['name']}', Program={info['program']}, Messages={info['message_count']}")
    return mid, track_info

# --- Step 2: Part/Instrument Role Identification ---
def identify_track_roles(track_info_list, mid_file_path=""):
    """
    Identifies roles for each track based on track name and program change.
    THIS FUNCTION REQUIRES CUSTOMIZATION BASED ON YOUR DATASET.
    """
    roles = {} # track_index -> "ROLE_NAME"
    print(f"\nIdentifying roles for {mid_file_path}:")

    # --- !!! CUSTOMIZE THIS LOGIC EXTENSIVELY !!! ---
    # Example rules (very basic, you need to analyze your dataset):
    for track_info in track_info_list:
        track_idx = track_info["index"]
        track_name = track_info["name"].lower()
        program = track_info["program"]

        # These are placeholder examples.
        if "drum" in track_name or "noise" in track_name or "perc" in track_name or (program >= 112 and program <= 119): # GM Percussion patches
            roles[track_idx] = "NOISE"
        elif "bass" in track_name or (program >= 32 and program <= 39): # GM Bass patches
            roles[track_idx] = "BASS"
        elif "lead" in track_name or "pulse" in track_name or "melody" in track_name or (program >= 80 and program <= 87): # GM Lead patches
            roles[track_idx] = "LEAD_PULSE1"
        elif "harmony" in track_name or "chord" in track_name or "square" in track_name or (program >= 0 and program <= 15) : # GM Piano / Chromatic Perc / Organ
             # This is a very broad category, you might want more specific pulse/square wave roles
            roles[track_idx] = "HARMONY_PULSE2"
        else:
            roles[track_idx] = "OTHER" # Or skip tracks with role "OTHER" later

        print(f"  Track {track_idx} (Name: '{track_info['name']}', Prog: {program}) assigned role: {roles[track_idx]}")
    # --- !!! END CUSTOMIZATION SECTION !!! ---

    return roles

# --- Utility to extract notes from tracks ---
def extract_notes_from_tracks(mid, track_roles):
    """
    Extracts note events from each track, associating them with their identified role.
    Converts note_on/note_off pairs into (pitch, velocity, start_tick, end_tick, role) tuples.
    """
    notes_by_role = defaultdict(list)
    ticks_per_beat = mid.ticks_per_beat

    for track_idx, track in enumerate(mid.tracks):
        current_time_ticks = 0
        active_notes = {} # pitch -> (start_tick, velocity)
        role = track_roles.get(track_idx, "OTHER") # Get assigned role

        if role == "OTHER": # Optionally skip tracks not assigned a useful role
            # print(f"  Skipping track {track_idx} with role OTHER for note extraction.")
            continue

        for msg in track:
            current_time_ticks += msg.time # Accumulate delta time in ticks

            if msg.type == 'note_on' and msg.velocity > 0:
                active_notes[msg.note] = (current_time_ticks, msg.velocity)
            elif msg.type == 'note_off' or (msg.type == 'note_on' and msg.velocity == 0):
                if msg.note in active_notes:
                    start_tick, velocity = active_notes.pop(msg.note)
                    end_tick = current_time_ticks
                    if end_tick > start_tick: # Ensure duration is positive
                         notes_by_role[role].append({
                            "pitch": msg.note,
                            "velocity": velocity,
                            "start_tick": start_tick,
                            "end_tick": end_tick,
                            "role": role # Keep track of the role for this note
                        })
    return notes_by_role


# --- Step 3: Monophony Enforcement ---
def enforce_monophony(note_list):
    """
    Enforces monophony for a list of notes from a single part/role.
    Assumes note_list is sorted by start_tick.
    Modifies notes in place by adjusting end_tick.
    """
    if not note_list:
        return []

    # Sort by start_tick, then by pitch (as a tie-breaker, though less critical here)
    sorted_notes = sorted(note_list, key=lambda x: (x["start_tick"], x["pitch"]))
    
    monophonic_notes = []
    if not sorted_notes:
        return monophonic_notes

    monophonic_notes.append(dict(sorted_notes[0])) # Start with the first note

    for i in range(1, len(sorted_notes)):
        current_note = dict(sorted_notes[i])
        last_added_note = monophonic_notes[-1]

        # If the current note starts before the last added note ends,
        # truncate the last added note.
        if current_note["start_tick"] < last_added_note["end_tick"]:
            last_added_note["end_tick"] = current_note["start_tick"]
        
        # Add the current note if its duration is positive after potential truncation of previous
        if last_added_note["end_tick"] > last_added_note["start_tick"]: # Ensure previous note is still valid
            monophonic_notes.append(current_note)
        elif monophonic_notes: # If previous note became zero-duration, replace it if this one is valid
             monophonic_notes[-1] = current_note
        else: # If list was empty (should not happen if sorted_notes not empty)
            monophonic_notes.append(current_note)


    # Final pass to remove zero-duration notes that might have resulted
    return [note for note in monophonic_notes if note["end_tick"] > note["start_tick"]]


# --- Step 4: Velocity Processing & Quantization ---
def quantize_velocity(velocity, num_bins=4):
    """
    Quantizes MIDI velocity (0-127) into a smaller number of bins.
    THIS IS A BASIC EXAMPLE. You might want role-specific bins.
    """
    if velocity == 0: return 0 # Typically note_off
    # Ensure velocity is within 1-127 for binning
    velocity = max(1, min(127, velocity))
    bin_size = 127 / num_bins
    # Bin index (0 to num_bins-1)
    quantized_bin = int((velocity -1) / bin_size)
    return quantized_bin # Or return a representative value for the bin, e.g., bin_center


# --- Step 5: Tempo, Key, and Time Signature Handling ---
def extract_meta_info(mid):
    """
    Extracts tempo changes, key signatures, and time signatures.
    Decision on how to USE this info (normalize, tokenize, ignore) is separate.
    """
    tempos = [] # List of (tick, tempo_value_microseconds_per_beat)
    key_signatures = [] # List of (tick, key_name)
    time_signatures = [] # List of (tick, numerator, denominator, clocks_per_click, notated_32nd_notes_per_quarter)

    # Meta messages are often in track 0, but can be in others for Type 1.
    # We need to iterate all tracks and maintain absolute time.
    for track_idx, track in enumerate(mid.tracks):
        current_time_ticks = 0
        for msg in track:
            current_time_ticks += msg.time
            if msg.is_meta:
                if msg.type == 'set_tempo':
                    tempos.append({"tick": current_time_ticks, "tempo_us_per_beat": msg.tempo})
                elif msg.type == 'key_signature':
                    key_signatures.append({"tick": current_time_ticks, "key": msg.key})
                elif msg.type == 'time_signature':
                    time_signatures.append({
                        "tick": current_time_ticks,
                        "numerator": msg.numerator,
                        "denominator": msg.denominator,
                        "clocks_per_click": msg.clocks_per_click, # MIDI clocks per metronome click
                        "notated_32nd_notes_per_quarter": msg.notated_32nd_notes_per_quarter # Number of 32nd notes in a MIDI quarter note
                    })
    
    # Sort by tick time as messages can come from different tracks
    tempos.sort(key=lambda x: x["tick"])
    key_signatures.sort(key=lambda x: x["tick"])
    time_signatures.sort(key=lambda x: x["tick"])

    print("\nMeta Information:")
    if not tempos: print("  No tempo changes found. Assume default 120 BPM (500000 us/beat).")
    else: print(f"  Tempos: {tempos}")
    
    if not key_signatures: print("  No key signatures found.")
    else: print(f"  Key Signatures: {key_signatures}")

    if not time_signatures: print("  No time signatures found. Assume default 4/4.")
    else: print(f"  Time Signatures: {time_signatures}")
        
    return tempos, key_signatures, time_signatures


# --- Main Processing Function to Orchestrate Steps ---
def preprocess_midi_file(midi_path, num_velocity_bins=4):
    """
    Orchestrates the preprocessing steps for a single MIDI file.
    """
    # Step 1
    mid, track_info_list = parse_midi_basic_info(midi_path)
    if mid is None:
        return None

    # Step 2
    # !!! CRITICAL: You MUST customize identify_track_roles for your data !!!
    track_roles = identify_track_roles(track_info_list, midi_path)
    
    # Utility: Extract notes based on these roles
    notes_by_role_raw = extract_notes_from_tracks(mid, track_roles)

    # Step 3 & 4 (Applied per role)
    processed_notes_all_roles = []
    roles_to_make_monophonic = ["LEAD_PULSE1", "HARMONY_PULSE2", "BASS"] # Customize this list!

    print("\nProcessing notes per role:")
    for role, notes in notes_by_role_raw.items():
        print(f"  Role: {role}, Raw note count: {len(notes)}")
        if not notes:
            continue

        # Step 3: Monophony
        current_role_notes = notes # Use raw notes for the role
        if role in roles_to_make_monophonic:
            print(f"    Enforcing monophony for role: {role}")
            current_role_notes = enforce_monophony(current_role_notes) # Pass a copy if needed
            print(f"    Note count after monophony: {len(current_role_notes)}")

        # Step 4: Velocity Quantization
        for note in current_role_notes:
            note["quantized_velocity"] = quantize_velocity(note["velocity"], num_velocity_bins)
        
        processed_notes_all_roles.extend(current_role_notes)

    # Sort all processed notes globally by start time, then pitch, then role (for stable order)
    # This global list of notes is what you'd feed into a Step 6 (Tokenization)
    processed_notes_all_roles.sort(key=lambda x: (x["start_tick"], x["pitch"], x["role"]))
    print(f"\nTotal processed notes for {midi_path}: {len(processed_notes_all_roles)}")
    # for note in processed_notes_all_roles[:10]: # Print a few samples
    #     print(f"    {note}")

    # Step 5
    tempos, key_signatures, time_signatures = extract_meta_info(mid)

    # At this point, you have:
    # - processed_notes_all_roles: A list of note dicts, each with pitch, start_tick, end_tick,
    #                              original_velocity, quantized_velocity, and role.
    # - tempos, key_signatures, time_signatures: Lists of meta events.
    #
    # The next step (not covered here) would be to convert this structured data
    # into a single sequence of tokens for your Transformer. This would involve
    # creating "Note_On_Role_Pitch", "Duration", "Velocity" tokens, and potentially
    # "Tempo", "Key", "TimeSignature" tokens, all ordered correctly in time.

    return {
        "file_path": midi_path,
        "ticks_per_beat": mid.ticks_per_beat,
        "processed_notes": processed_notes_all_roles,
        "tempos": tempos,
        "key_signatures": key_signatures,
        "time_signatures": time_signatures
    }
