In [1]:
pip install librosa

Note: you may need to restart the kernel to use updated packages.


In [3]:
import librosa
import numpy as np

def wav_to_piano_roll(wav_file_path, fs=100, threshold=0.1):
    """
    Convert a WAV file to a Piano Roll array.

    Args:
        wav_file_path (str): Path to the wav file.
        fs (int): Sampling frequency of the columns, i.e. each column is an interval of 1./fs seconds.
        threshold (float): Volume threshold for note on/off.

    Returns:
        np.ndarray: Piano roll of shape (128, T), where T is the number of time steps.
    """
    # Load the audio file
    y, sr = librosa.load(wav_file_path)

    # Compute the short-time Fourier transform
    stft = librosa.stft(y)

    # Convert the STFT into amplitude
    amplitude = np.abs(stft)

    # Convert to dB scale
    db = librosa.amplitude_to_db(amplitude)

    # Normalize to the range [0, 1]
    db_normalized = (db - db.min()) / (db.max() - db.min())

    # Map frequencies to MIDI notes
    freqs = librosa.fft_frequencies(sr=sr)

    # Only consider frequencies that fall within the MIDI note range (approximately 8 Hz to 12543 Hz)
    lo, hi = librosa.note_to_hz(['A0', 'G#8'])
    bins = np.where((freqs >= lo) & (freqs <= hi))

    freqs = freqs[bins]
    db_normalized = db_normalized[bins]

    # Map frequencies to MIDI notes
    midi_notes = librosa.hz_to_midi(freqs)

    # Round MIDI notes to the nearest integer
    midi_notes_rounded = np.round(midi_notes).astype(int)

    # Create an empty piano roll
    piano_roll = np.zeros((128, db_normalized.shape[1]))

    # Transfer DB values to piano roll
    for note, column in enumerate(db_normalized):
        if 0 <= midi_notes_rounded[note] < 128:
            piano_roll[midi_notes_rounded[note]] += column

    # Apply threshold to piano roll
    piano_roll = np.where(piano_roll >= threshold, piano_roll, 0)

    return piano_roll

In [9]:
badguy_recording_pianoroll = wav_to_piano_roll("Recording.wav", fs=100, threshold=0.1)
print(badguy_recording_pianoroll.shape)

(128, 1148)


In [8]:
def print_random_row(piano_roll):
    random_index = np.random.randint(piano_roll.shape[0])
    print(f"Row {random_index}: {list(piano_roll[random_index])}")
    
# Usage:
piano_roll = badguy_recording_pianoroll
print_random_row(piano_roll)

Row 43: [0.8066056370735168, 0.8742600679397583, 0.7767319679260254, 0.7172520160675049, 0.6764041781425476, 0.7255877256393433, 0.5280548930168152, 0.6995893716812134, 0.8347554206848145, 0.928173840045929, 0.9373718500137329, 0.9278669357299805, 0.9256173372268677, 0.9227226376533508, 0.9228765368461609, 0.9234737157821655, 0.9207242727279663, 0.8907787203788757, 0.966522216796875, 0.9240518808364868, 0.908482551574707, 0.7794966697692871, 0.7202707529067993, 0.72565758228302, 0.718366801738739, 0.6067366600036621, 0.6638385057449341, 0.8335555791854858, 0.9232746362686157, 0.9355373382568359, 0.9247257113456726, 0.9176847338676453, 0.9148653149604797, 0.9157147407531738, 0.9176455736160278, 0.9195352792739868, 0.8718467950820923, 0.947836697101593, 0.9546799659729004, 0.9518899917602539, 0.8928335309028625, 0.7275940775871277, 0.7931822538375854, 0.7801276445388794, 0.7243264317512512, 0.7095390558242798, 0.774975061416626, 0.926537036895752, 0.9617176055908203, 0.9606547355651855, 

In [4]:
import os

def process_directory(directory_path):
    """
    Convert all .wav files in a directory to piano roll .txt files.

    Args:
        directory_path (str): Path to the directory.

    Returns:
        None
    """
    # List all files in the directory
    for filename in os.listdir(directory_path):
        # Check if the file is a .wav file
        if filename.endswith('.wav'):
            # Get the full path to the .wav file
            wav_file_path = os.path.join(directory_path, filename)
            
            # Convert the .wav file to a piano roll
            piano_roll = wav_to_piano_roll(wav_file_path)
            
            # Save the piano roll to a .txt file
            txt_file_path = os.path.join(directory_path, filename.replace('.wav', '.txt'))
            np.savetxt(txt_file_path, piano_roll)

    print(f"Processed all .wav files in {directory_path}")

process_directory('.wav_files')

Processed all .wav files in .wav_files
