load data (midi files within artist folders)
tokenize each midi file with an llm prompt containing genre of artist

before doing above make sure you have dataset ready of each artist associated with the genre

keep in mind there is a csv file containing artist and their respective genres which will be fed into prompt

In [7]:
import os
import glob
import pandas as pd
import pretty_midi
import time

# Read the CSV file with artist and mapped genre information.
df = pd.read_csv("../artist_genre_mapped.csv")
# Filter out artists with genre "Unknown" (case-insensitive)
df = df[df['genre'].str.lower() != "unknown"]
df

Unnamed: 0,artist,genre
0,.38 Special,rock
1,"10,000_Maniacs",pop
2,101_Strings,classical
3,10cc,rock
4,1910_Fruitgum_Company,pop
...,...,...
2190,a-ha,rock
2191,blink-182,pop
2192,dc_Talk,rock
2193,elgar,classical


In [8]:
# Create a dictionary mapping from artist name to genre.
artist_genre = dict(zip(df['artist'], df['genre']))

# Define the base input folder (with artist subfolders) and output file.
input_base = "../data"
output_file = "all_tokenized.txt"

def tokenize_midi(midi_file):
    """
    Tokenizes a MIDI file into a sequence of tokens.
    Each note is represented as two tokens:
       - NOTE_<pitch> (MIDI pitch number)
       - DUR_<duration> (duration in seconds, formatted to 2 decimals)
    Notes are sorted by start time.
    """
    try:
        midi_data = pretty_midi.PrettyMIDI(midi_file)
    except Exception as e:
        print(f"Error processing {midi_file}: {e}")
        return None

    notes = []
    # Collect notes from all instruments.
    for instrument in midi_data.instruments:
        notes.extend(instrument.notes)
    # Sort notes by start time.
    notes.sort(key=lambda n: n.start)
    
    tokens = []
    for note in notes:
        duration = note.end - note.start
        tokens.append(f"NOTE_{note.pitch}")
        tokens.append(f"DUR_{duration:.2f}")
    return " ".join(tokens)

# Count total number of MIDI files to process for progress tracking.
total_files = 0
for artist in artist_genre.keys():
    artist_folder = os.path.join(input_base, artist)
    if os.path.isdir(artist_folder):
        midi_files = glob.glob(os.path.join(artist_folder, "*.mid")) + glob.glob(os.path.join(artist_folder, "*.midi"))
        total_files += len(midi_files)
print(f"Total MIDI files to process: {total_files}")

start_time = time.time()
count = 0

# Open the output file in write mode.
with open(output_file, "w", encoding="utf-8") as out_f:
    # Loop over each artist from the CSV mapping.
    for artist, genre in artist_genre.items():
        artist_folder = os.path.join(input_base, artist)
        if not os.path.isdir(artist_folder):
            print(f"Folder for artist '{artist}' not found in {input_base}. Skipping.")
            continue
        
        # Find all MIDI files in the artist's folder.
        midi_files = glob.glob(os.path.join(artist_folder, "*.mid")) + glob.glob(os.path.join(artist_folder, "*.midi"))
        
        for midi_file in midi_files:
            tokens = tokenize_midi(midi_file)
            if tokens is None:
                continue  # Skip files that couldn't be processed.
            
            # Prepend the text prompt with the genre.
            tokenized_sequence = f"[TEXT_PROMPT: {genre}] [SEP] {tokens}"
            # Write the tokenized sequence to the single output file.
            out_f.write(tokenized_sequence + "\n")
            
            count += 1
            # Print progress every 10 files or on the last file.
            if count % 10 == 0 or count == total_files:
                elapsed = time.time() - start_time
                print(f"Processed {count}/{total_files} files. Elapsed time: {elapsed:.2f} seconds.")

end_time = time.time()
print(f"Processing complete. {count} files processed in {end_time - start_time:.2f} seconds. Output saved to {output_file}.")

Total MIDI files to process: 14047




Error processing ../data/10cc/Dreadlock_Holiday.4.mid: data byte must be in range 0..127
Processed 10/14047 files. Elapsed time: 0.76 seconds.
Processed 20/14047 files. Elapsed time: 1.58 seconds.
Processed 30/14047 files. Elapsed time: 2.25 seconds.
Processed 40/14047 files. Elapsed time: 3.05 seconds.
Processed 50/14047 files. Elapsed time: 3.88 seconds.
Processed 60/14047 files. Elapsed time: 4.54 seconds.
Processed 70/14047 files. Elapsed time: 5.44 seconds.
Processed 80/14047 files. Elapsed time: 6.28 seconds.
Processed 90/14047 files. Elapsed time: 7.41 seconds.
Processed 100/14047 files. Elapsed time: 8.35 seconds.
Processed 110/14047 files. Elapsed time: 9.09 seconds.
Processed 120/14047 files. Elapsed time: 9.89 seconds.
Error processing ../data/ABBA/Thank_You_for_the_Music.2.mid: data byte must be in range 0..127
Processed 130/14047 files. Elapsed time: 10.65 seconds.
Error processing ../data/ABBA/Take_a_Chance_on_Me.3.mid: data byte must be in range 0..127
Processed 140/1404