In [None]:
for i in range(counter):
    for j in range(12):
        midi_data[i][j].write("raw/maestro-v3.0.0/selected_2018/" + file_names[i][j])

In [5]:
from midi_preprocessing import * # Make sure the functions here are updated
import pretty_midi
import os
import h5py

### GLOBAL VARIABLES ###
# Define the number of cols per bar.
COLS_PER_BAR = 16
# Define the directory of the dataset.
INPUT_DIR_PATH = [
    "./raw/maestro-v3.0.0/selected_2018/",
    # "./raw/maestro-v3.0.0/2017/",
    # "./raw/maestro-v3.0.0/2015/",
    # "./raw/maestro-v3.0.0/2014/"
]

# --- CHANGE 1: Initialize lists as completely empty ---
# This is a much cleaner pattern.
midi_data = []
file_names = []

# --- CRITICAL CHANGE: Make sure octave_sum is modified ---
# Your octave_sum function MUST NOT call normalize_melody_roll.
# It should only shift the notes.
# Example of how it should be:
def shift_roll_up(melody_roll):
    """Shifts all notes in the piano roll up by one semitone."""
    # Create a copy to avoid modifying the original
    shifted_roll = np.zeros_like(melody_roll) 
    notes, _ = melody_roll.shape
    
    # Shift each note up by one, with "wrap-around"
    for n in range(notes - 1):
        shifted_roll[n + 1, :] = melody_roll[n, :]
    shifted_roll[0, :] = melody_roll[notes - 1, :]
        
    return shifted_roll

for dir_path in INPUT_DIR_PATH:
    print(f"Processing directory: {dir_path}")
    for filename in os.listdir(dir_path):
        
        # Filter only MIDI files, ignoring folders and other files
        if not filename.lower().endswith(('.mid', '.midi')):
            continue

        file_path = os.path.join(dir_path, filename)
        if not os.path.isfile(file_path):
            continue

        try:
            # --- CHANGE 2: The logic is now contained for each file ---
            
            # Open the MIDI file
            midi = pretty_midi.PrettyMIDI(file_path)
            

            # Calculate parameters
            bar_duration = get_bar_duration(midi)
            
            fs = COLS_PER_BAR / bar_duration
            
            # Extract and pre-process the original melody_roll ONCE per file
            original_melody_roll = extract_melody(midi, fs=fs)
            original_melody_roll[original_melody_roll > 0] = 1
            original_melody_roll = original_melody_roll.astype(bool)
            original_melody_roll = normalize_melody_roll(original_melody_roll, lb=60, ub=83)
            
            # Temporary lists to store the 12 variations of this file
            current_file_augmentations = []
            current_file_names = []

            # Loop for data augmentation (12 transpositions)
            for i in range(12):
                # Always start from a clean copy of the original roll
                roll_to_modify = original_melody_roll.copy()
                
                # Apply the shift 'i' times
                # Assumendo che 'octave_sum' sposti le note di un semitono verso l'alto
                for _ in range(i):
                    roll_to_modify = shift_roll_up(roll_to_modify) # Usa la tua funzione di shift
                
                # --- CORREZIONE: Normalizza UNA SOLA VOLTA, dopo aver completato tutti gli shift ---
                roll_to_modify = normalize_melody_roll(roll_to_modify, lb=60, ub=83)
            
                # Convert the modified piano roll into a MIDI object
                midi_augmented = piano_roll_to_pretty_midi(roll_to_modify)
                
                # Add the results to the temporary lists
                current_file_augmentations.append(midi_augmented)
                current_file_names.append(f"{i}_{filename}")
            # --- CHANGE 3: Add the list of results to the main list ---
            # Now midi_data will be a list of lists, as you wanted.
            midi_data.append(current_file_augmentations)
            file_names.append(current_file_names)

        except Exception as e:
            print(f"Could not process file {filename}. Reason: {e}")

print("\nProcessing complete.")
print(f"Processed {len(midi_data)} MIDI files, each with 12 augmentations.")


Processing directory: ./raw/maestro-v3.0.0/selected_2018/

Processing complete.
Processed 58 MIDI files, each with 12 augmentations.


In [6]:

# --- PART 3: Save the augmented MIDI files ---
print("\nSaving augmented MIDI files...")
# It's better to save to a new directory to avoid cluttering the original one.
output_dir = "./raw/maestro-v3.0.0/selected_2018_augmented/"
os.makedirs(output_dir, exist_ok=True) # Create the output directory if it doesn't exist

# Iterate through the lists to save each file.
# len(midi_data) is the number of original files processed.
for i in range(len(midi_data)):
    # len(midi_data[i]) is the number of augmentations (12).
    for j in range(len(midi_data[i])):
        # Get the augmented midi object and its new filename
        augmented_midi = midi_data[i][j]
        new_filename = file_names[i][j]
        
        # Construct the full output path
        output_path = os.path.join(output_dir, new_filename)
        
        # Write the MIDI file
        augmented_midi.write(output_path)

print(f"\nSuccessfully saved all {len(midi_data) * 12} augmented files to '{output_dir}'")


Saving augmented MIDI files...

Successfully saved all 696 augmented files to './raw/maestro-v3.0.0/selected_2018_augmented/'
