In [None]:
import os
import re
import numpy as np
from collections import defaultdict
import miditoolkit
from miditoolkit import pianoroll
from miditoolkit.midi import parser as mid_parser
from miditoolkit.pianoroll import parser as pr_parser
from miditoolkit.pianoroll import utils

class Item(object):
    def __init__(self, name, start, end, velocity, pitch):
        self.name = name
        self.start = start
        self.end = end
        self.velocity = velocity
        self.pitch = pitch

    def __repr__(self):
        return 'Item(name={}, start={}, end={}, velocity={}, pitch={})'.format(
            self.name, self.start, self.end, self.velocity, self.pitch)

# parameters for output
DEFAULT_RESOLUTION = 480

def read_items(file_path):
    midi_obj = miditoolkit.midi.parser.MidiFile(file_path)
    # notes
    note_items = []
    notes = midi_obj.instruments[0].notes
    notes.sort(key=lambda x: (x.start, x.pitch))
    for note in notes:
        note_items.append(Item(
            name='Note', 
            start=note.start, 
            end=note.end, 
            velocity=note.velocity, 
            pitch=note.pitch))
    note_items.sort(key=lambda x: x.start)
    # tempo
    tempo_items = []
    for tempo in midi_obj.tempo_changes:
        tempo_items.append(Item(
            name='Tempo',
            start=tempo.time,
            end=None,
            velocity=None,
            pitch=int(tempo.tempo)))
    tempo_items.sort(key=lambda x: x.start)
    # expand to all beats
    if tempo_items:
        max_tick = tempo_items[-1].start
        existing_ticks = {item.start: item.pitch for item in tempo_items}
        wanted_ticks = np.arange(0, max_tick+1, DEFAULT_RESOLUTION)
        output = []
        for tick in wanted_ticks:
            if tick in existing_ticks:
                output.append(Item(
                    name='Tempo',
                    start=tick,
                    end=None,
                    velocity=None,
                    pitch=existing_ticks[tick]))
            else:
                output.append(Item(
                    name='Tempo',
                    start=tick,
                    end=None,
                    velocity=None,
                    pitch=output[-1].pitch if output else tempo_items[0].pitch))
        tempo_items = output
    return note_items, tempo_items

class CustomRemiPreprocessing:
    def __init__(self, input_directory, output_directory, check_tempo_change=True):
        self.input_directory = input_directory
        self.output_directory = output_directory
        self.check_tempo_change = check_tempo_change  # Whether to check for tempo changes
        self.ticks_per_beat = 480  # Default number of ticks used in MIDI files
        self.pitch_names = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']  # List of note names

    def process_directory(self):
        """Process all subdirectories to extract and save note information."""
        for subdir in os.listdir(self.input_directory):
            # Only process directories with names that are 3-digit numbers
            if len(subdir) == 3 and subdir.isdigit():
                subdir_path = os.path.join(self.input_directory, subdir)
                if os.path.isdir(subdir_path):
                    # Perform subsequent preprocessing tasks
                    midi_file_path = os.path.join(subdir_path, f'{subdir}.mid')
                    beat_file_path = os.path.join(subdir_path, 'beat_midi.txt')
                    # Extract BPM
                    bpm = self.calculate_bpm_from_beat_file(beat_file_path)

                    # Extract note information
                    notes = self.extract_notes_from_midi(midi_file_path, bpm)

                    # Save note information
                    output_file_path = os.path.join(self.output_directory, f'{subdir}_note.txt')
                    self.save_notes_to_txt(notes, output_file_path)

                    print(f"Processed and saved: {output_file_path}")

                    # Extract tempo information
                    note_items, tempo_items = read_items(midi_file_path)

                    # Check if there are multiple tempo changes and decide whether to skip based on check_tempo_change
                    if self.check_tempo_change and len(tempo_items) > 1:
                        print(f"Skipping {subdir} due to multiple tempo changes.")
                        continue  # Skip

                    bpm = self.get_tempo_from_midi(midi_file_path, beat_file_path)
                    grid, t0, dt = self.create_grid(beat_file_path)

                    # Create Tempo block at the top
                    tempo_block = self.create_tempo_block(bpm)

                    # Call functions to create Note, Bar, Chord, Key blocks
                    note_dict = self.create_note_blocks(subdir, grid, t0, dt, bpm)
                    bar_dict = self.create_bar_blocks(grid, t0, dt, max(note_dict.keys()))
                    chord_dict = self.create_chord_blocks(subdir, grid, t0, dt)
                    key_dict = self.create_key_blocks(subdir, grid, t0, dt)

                    # Create combined_dict
                    combined_dict = defaultdict(str)
                    for key in sorted(set(bar_dict.keys()).union(set(chord_dict.keys()), set(note_dict.keys()), set(key_dict.keys()))):
                        combined_value = ""

                        # Combine values in the order of bar -> key -> chord -> note
                        if key in bar_dict:
                            combined_value += bar_dict[key]
                        if key in key_dict:
                            combined_value += "\n" + key_dict[key]
                        if key in chord_dict:
                            combined_value += "\n" + chord_dict[key]
                        if key in note_dict:
                            combined_value += "\n" + note_dict[key]

                        combined_dict[key] = combined_value

                    # Write to REMI file (add Tempo block first)
                    output_file_path = os.path.join(self.output_directory, f'{subdir}_remi.txt')
                    with open(output_file_path, 'w') as output_file:
                        output_content = tempo_block + "\n"

                        # Save combined_dict contents with newlines
                        for key in sorted(combined_dict.keys()):
                            output_content += f"{combined_dict[key]}\n"

                        # Remove duplicate empty lines
                        output_content = "\n".join([line for line in output_content.splitlines() if line.strip() != ""])

                        # Write to the final file
                        output_file.write(output_content)

                    print(f"Processed and saved: {output_file_path}")

    def convert_note_name(self, note):
        """
        Convert non-standard note names (e.g., B#, Db) to standard note names (e.g., C, C#, D#).
        """
        note_mapping = {
            'B#': 'C', 'Db': 'C#', 'Eb': 'D#', 'Fb': 'E', 'E#': 'F',
            'Gb': 'F#', 'Ab': 'G#', 'Bb': 'A#', 'Cb': 'B'
        }
        # Apply conversion if needed, otherwise return as is
        return note_mapping.get(note, note)

    def get_tempo_from_midi(self, midi_file_path, beat_file_path=None):
        """Extract tempo from a MIDI file or calculate BPM using beat_midi.txt."""
        note_items, tempo_items = read_items(midi_file_path)
        if tempo_items:
            return tempo_items[0].pitch  # Successfully extracted tempo from MIDI file
        if beat_file_path:
            return self.calculate_bpm_from_beat_file(beat_file_path)
        return 120  # Use default value of 120 BPM

    def calculate_bpm_from_beat_file(self, beat_file_path):
        """Calculate BPM from the beat_midi.txt file."""
        beat_times = []
        with open(beat_file_path, 'r') as f:
            for line in f:
                time, _, _ = line.strip().split()
                beat_times.append(float(time))
        beat_intervals = np.diff(beat_times[::1])
        average_interval = np.mean(beat_intervals)  # Calculate one interval as one beat
        estimated_bpm = 60 / average_interval  # Convert beats per second to BPM
        return estimated_bpm

    def create_grid(self, beat_midi_path):
        """Create quantization grid based on beat_midi.txt."""
        with open(beat_midi_path, 'r') as f:
            lines = f.readlines()
        t0 = float(lines[0].split()[0])
        dt = 60 / self.calculate_bpm_from_beat_file(beat_midi_path) * 0.25

        grid = []
        for i in range(2048):  # Sufficient grid length
            grid.append(t0 + i * dt)
        return grid, t0, dt

    def find_nearest_quantized_time(self, target_time, grid):
        """Find the nearest time on the grid for a given target time."""
        nearest_time = min(grid, key=lambda x: abs(x - target_time))
        return nearest_time

    def calculate_position_value(self, quantized_time, t0, dt):
        """Calculate position value in n/16 format based on quantized time."""
        n = int((quantized_time - t0) / dt)
        value_n = (n + 1) % 16  # Cycle from 1 to 16
        if value_n == 0:
            value_n = 16
        return f"{value_n}/16"

    def convert_time_to_ticks(self, quantized_time, t0, dt):
        """Convert quantized time to ticks based on quantization."""
        n = int((quantized_time - t0) / dt)
        ticks = int((n + 1) * self.ticks_per_beat * 0.125)  # Quantized time (n+1)/16 * 60 ticks
        return ticks

    # Tempo block creation method
    def create_tempo_block(self, bpm):
        """Create the tempo block for REMI data."""
        # Determine tempo class based on BPM
        if bpm < 60:
            tempo_class = "slow"
        elif bpm < 120:
            tempo_class = "mid"
        else:
            tempo_class = "fast"

        # Create Tempo block (set time to 0)
        tempo_block = (
            f"Event(name=Tempo Class, time=0, value={tempo_class}, text=None)\n"
            f"Event(name=Tempo Value, time=0, value={int(bpm)}, text=None)"
        )
        
        return tempo_block

    def create_note_blocks(self, subdir, grid, t0, dt, bpm):
        """Function to create Note blocks"""
        note_file_path = os.path.join(self.output_directory, f'{subdir}_note.txt')
        note_dict = defaultdict(str)
        
        with open(note_file_path, 'r') as note_file:
            for line in note_file:
                # First split by comma and remove spaces
                parts = [part.strip() for part in line.strip().split(',')]
                
                # Then split each part by whitespace
                if len(parts) != 4:
                    print(f"Skipping line in {note_file_path} due to incorrect format: {line.strip()}")
                    continue  # Skip the line
                
                velocity, start_time, duration_in_sec = parts[:3]
                duration_in_sec = float(duration_in_sec)
                duration = round(duration_in_sec * (bpm / 60), 3) # in beats
                pitch_note = parts[3].split()

                # pitch is a number, note_name is the note name
                pitch = pitch_note[0]
                note_name = pitch_note[1] if len(pitch_note) > 1 else pitch
                note_name = note_name[1:-1] # strip ()

                start_time = float(start_time)  # Convert start time to seconds
                quantized_time = self.find_nearest_quantized_time(start_time, grid)
                tick_time = self.convert_time_to_ticks(quantized_time, t0, dt)
                position_value = self.calculate_position_value(quantized_time, t0, dt)

                # Create Note block
                note_block = (
                    f"Event(name=Position, time={tick_time}, value={position_value}, text={tick_time})\n"
                    f"Event(name=Note Velocity, time={tick_time}, value={velocity}, text={velocity})\n"
                    f"Event(name=Note On, time={tick_time}, value={pitch}, text={note_name})\n"
                    f"Event(name=Note Duration, time={tick_time}, value={duration}, text={duration})"
                )

                # Save the Note block in the dictionary with the corresponding time as the key
                note_dict[tick_time] += "\n" + note_block
        
        return note_dict

    def create_bar_blocks(self, grid, t0, dt, max_note_time):
        """Function to create Bar blocks"""
        bar_dict = defaultdict(str)
        N = 1
        while True:
            bar_time = t0 + (N - 1) * (16 * dt)
            tick_time = self.convert_time_to_ticks(bar_time, t0, dt)
            
            # Stop if a time greater than the last note time appears
            if tick_time > max_note_time:
                break
            
            bar_block = f"Event(name=Bar, time={tick_time}, value={tick_time}, text={N})"
            bar_dict[tick_time] = bar_block
            N += 1
        return bar_dict

    def create_chord_blocks(self, subdir, grid, t0, dt):
        """Function to create Chord blocks"""
        chord_file_path = os.path.join(self.input_directory, subdir, 'chord_midi.txt')
        chord_dict = defaultdict(str)
        
        with open(chord_file_path, 'r') as f:
            for line in f:
                start_time, _, chord = line.strip().split()
                if chord == 'N':
                    continue
                
                # Separate and standardize the root
                root = chord.split(':')[0]
                root_standardized = self.convert_note_name(root)  # Apply conversion
    
                # Standardize chord type (handle cases like 7 -> dom, hdim -> dim, etc.)
                chord_type = chord.split(':')[1]
                if chord_type.startswith('7'):
                    chord_type = 'dom'
                elif chord_type.startswith('hdim'):
                    chord_type = 'dim'
                elif chord_type.startswith('sus'):
                    chord_type = 'maj'
                else:
                    chord_type = chord_type[:3]  # Only take the first 3 characters for chord type
                
                # Recreate chord name with the converted root
                chord_name = root_standardized + ':' + chord_type
    
                start_time = float(start_time)
                quantized_time = self.find_nearest_quantized_time(start_time, grid)
                tick_time = self.convert_time_to_ticks(quantized_time, t0, dt)
                position_value = self.calculate_position_value(quantized_time, t0, dt)
    
                # Create Chord block
                chord_block = (
                    f"Event(name=Position, time={tick_time}, value={position_value}, text={tick_time})\n"
                    f"Event(name=Chord, time={tick_time}, value={chord_name}, text={chord_name})"
                )
                chord_dict[tick_time] += chord_block
        return chord_dict


    def create_key_blocks(self, subdir, grid, t0, dt):
        """Function to create Key blocks"""
        key_file_path = os.path.join(self.input_directory, subdir, 'key_audio.txt')
        key_dict = defaultdict(str)
        
        with open(key_file_path, 'r') as f:
            for line in f:
                start_time, _, key = line.strip().split()
                
                # Separate and standardize the root
                root = key.split(':')[0]
                root_standardized = self.convert_note_name(root)  # Apply conversion
                key_type = key.split(':')[1]  # Major, Minor, etc.
    
                # Recreate key name with the converted root
                key_name = root_standardized + ':' + key_type
    
                start_time = float(start_time)
                quantized_time = self.find_nearest_quantized_time(start_time, grid)
                tick_time = self.convert_time_to_ticks(quantized_time, t0, dt)
                position_value = self.calculate_position_value(quantized_time, t0, dt)
    
                # Create Key block
                key_block = (
                    f"Event(name=Position, time={tick_time}, value={position_value}, text={tick_time})\n"
                    f"Event(name=Key, time={tick_time}, value={key_name}, text={key_name})"
                )
                key_dict[tick_time] += key_block
        return key_dict


    def midi_to_note_name(self, pitch):
        """Convert MIDI pitch value to actual note name and octave."""
        note_name = self.pitch_names[pitch % 12]
        octave = (pitch // 12) - 1  # Calculate octave, C4 = MIDI pitch 60
        return f"{note_name}{octave}"

    def extract_notes_from_midi(self, midi_file_path, bpm):
        """Extract note information from a MIDI file and return as a list of tuples."""
        midi_obj = miditoolkit.midi.parser.MidiFile(midi_file_path)
        notes = []

        for instrument in midi_obj.instruments:
            for note in instrument.notes:
                # Extract note information (velocity, start time, duration, pitch)
                start_time_sec = note.start / self.ticks_per_beat * (60 / bpm)
                duration_sec = (note.end - note.start) / self.ticks_per_beat * (60 / bpm)
                pitch_name = self.midi_to_note_name(note.pitch)  # Convert to note name
                notes.append((note.velocity, start_time_sec, duration_sec, note.pitch, pitch_name))

        # Sort by time
        notes.sort(key=lambda x: x[1])
        return notes

    def save_notes_to_txt(self, notes, output_file_path):
        """Save the extracted note information to a text file."""
        with open(output_file_path, 'w') as f:
            for velocity, start_time, duration, pitch, pitch_name in notes:
                f.write(f"{velocity}, {start_time:.6f}, {duration:.6f}, {pitch} ({pitch_name})\n")


In [None]:
# usage example
input_directory = r'' # your valid path
output_directory = r''
os.makedirs(output_directory, exist_ok=True)

# REMI preprocess
remi_processor = CustomRemiPreprocessing(input_directory, output_directory, check_tempo_change=True)
remi_processor.process_directory()


In [None]:
def remove_note_files(directory): # remove all temporary note txt files
    for filename in os.listdir(directory):
        if len(filename) == 12 and filename[:3].isdigit() and filename.endswith('_note.txt'):
            file_path = os.path.join(directory, filename)
            os.remove(file_path)
            print(f"Removed: {file_path}")

remove_note_files(output_directory)
