In [1]:
pip install pretty_midi numpy matplotlib torch


Note: you may need to restart the kernel to use updated packages.


In [2]:
import pretty_midi
import numpy as np 
import torch
import matplotlib.pyplot as plt

In [3]:
# input and output file locations. 
input_midi_path = "../../vickram_exploration/music_training_database/Pop/ABBA_-_Dancing_Queen.mid"
output_midi_path = "../../vickram_exploration/output_files"

In [4]:
input_midi_data = pretty_midi.PrettyMIDI(input_midi_path)

In [5]:
# Inspecting Instruments used in the midi files. 
print("Instruments in MIDI file:", len(input_midi_data.instruments))
for i, instrument in enumerate(input_midi_data.instruments):
    print(f"Instrument {i}: {pretty_midi.program_to_instrument_name(instrument.program)}")

Instruments in MIDI file: 10
Instrument 0: Acoustic Grand Piano
Instrument 1: Bright Acoustic Piano
Instrument 2: Electric Bass (finger)
Instrument 3: Pad 7 (halo)
Instrument 4: Choir Aahs
Instrument 5: Pad 3 (polysynth)
Instrument 6: Bright Acoustic Piano
Instrument 7: Pad 7 (halo)
Instrument 8: Banjo
Instrument 9: String Ensemble 1


In [6]:
# Extracting Notes, Velocities, and Timing 

In [7]:
# Extracting Notes 
notes = []
for instrument in input_midi_data.instruments:
    for note in instrument.notes:
        # Notes: [start_time, end_time, pitch, velocity]
        notes.append([note.start, note.end, note.pitch, note.velocity])
        
# Convert the list to a numpy array 
note = np.array(notes)

# Inspect the first few notes
print("First few notes:", notes[:5])


First few notes: [[2.39044, 2.4091153125, 42, 75], [2.98805, 3.0067253124999995, 42, 75], [3.58566, 3.6043353124999995, 42, 75], [4.183269999999999, 4.2019453124999995, 42, 75], [4.78088, 4.7995553125, 42, 75]]


In [8]:
# Normalize velocities to [0, 1]

max_velocity = 127 # Standard MIDI velocity range
note[:, 3] = note[:, 3] / max_velocity # Max velocity in MIDI is 127

In [9]:
#Verification verlocity normalization
print(note[:5]) 

[[ 2.39044     2.40911531 42.          0.59055118]
 [ 2.98805     3.00672531 42.          0.59055118]
 [ 3.58566     3.60433531 42.          0.59055118]
 [ 4.18327     4.20194531 42.          0.59055118]
 [ 4.78088     4.79955531 42.          0.59055118]]


In [10]:
# Convert to Event Sequence (Note-On, Note-Off) 
events = [] 

for note in notes: 
    # Note-On event (start time, 'Note-On', pitch, velocity)
    note_on = [note[0], 'Note-On', int(note[2]), note[3]]
    events.append(note_on)
    # Note-Off event (end time, 'Note-Off', pitch)
    note_off =  [note[1], 'Note-Off', int(note[2]),(0.0)]
    # Add a placeholder for velocity (if necessary) 
    events.append(note_off)

# Chekcing for inconsistent element length
for i, event in enumerate(events): 
    if len(event) != 4:
        print(f"Warrning: Inconsistent element found: {event}")
        events[i] = [event[0], event[2], 0.0]
# Sort events by time
events.sort(key=lambda x: x[0])

# Convert to numpy array for easier handling 
events = np.array(events)

# Verify the frist frew events 
print("First few events:", events[5])

First few events: ['3.6043353124999995' 'Note-Off' '42' '0.0']


In [12]:
# Visualizing the data using matplotlib

def plot_piano_roll(events, start_time=0, end_time=None, pitch_range=(21, 108)):
    # Create a time range for visualization (using normalized start/end times)
    if end_time is None:
        end_time = max(events[:, 0]) # Use the last note's end time as the endpoint

    time_steps = np.linspace(start_time, end_time, 500)
    piano_roll = np.zeros((len(time_steps), pitch_range[1] - pitch_range[0]))

    for event in events:
        event_time = event[0]
        if event[1] == 'Note-On': # Note-On event 
            pitch_idx = int(event[2]) - pitch_range[0]
            start_idx = np.argmax(time_steps >= event_time)
            piano_roll[start_idx:, pitch_idx] = event[3] # Mark note's velocity from start to end
        elif event[1] == 'Note-Off': # Note-Off event
            pitch_idx = np.argmax(time_steps >= event_time)
            piano_roll[:end_idx, pitch_idx] = 0 # End note 

    # Plot the piano roll
    plt.figure(figsize=(12, 8))
    plt.imshow(piano_roll.T, origin='lower', aspect='auto', cmap='Greys', extent=[start_time, end_time, pitch_range[0], pitch_range[1]])
    plt.ylable('MIDI Pitch')
    plt.xlable('Time (s)')
    plt.title("Piano Roll Visualization")
    plt.show()

# Plot the first 10 seconds of the MIDI events
plot_piano_roll(events, start_time=0, end_time=10)



UFuncTypeError: ufunc 'greater_equal' did not contain a loop with signature matching types (<class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.StrDType'>) -> None