In [None]:
import sys
from pathlib import Path
sys.path.append(str(Path("src").resolve()))

from audio2piano import Audio2Piano, note_matrices_to_notes

import matplotlib.pyplot as plt
import torch
import numpy as np


In [None]:
transcriber = Audio2Piano("weights/model_weights.pth")

In [None]:
def show_spectrogram(mat, title, y_label):
    if mat.shape[0] < mat.shape[1]:
        mat = mat.T
    plt.figure(figsize=(14, 6))
    plt.imshow(mat.T, aspect="auto", origin="lower", cmap="magma")
    plt.xlabel("Time frames")
    plt.ylabel(y_label)
    plt.title(title)
    plt.colorbar()
    plt.show()

def show_piano_roll(mat, title="Piano Roll"):
    if mat.shape[1] != 88:
        mat = mat.T
    plt.figure(figsize=(14, 6))
    plt.imshow(mat.T, aspect="auto", origin="lower", cmap="gray_r")
    plt.xlabel("Time frames")
    plt.ylabel("MIDI pitch (21â€“108)")
    plt.title(title)
    plt.colorbar()
    plt.show()

def midi_to_sustain_roll(notes, total_steps, hop_sec=0.05):
    roll = np.zeros((88, total_steps), dtype=np.float32)

    for note in notes:
        pitch = note["pitch"] - 21
        if not (0 <= pitch < 88):
            continue

        start = int(note["start"] / hop_sec)
        end   = int(note["end"]   / hop_sec)

        if start >= total_steps:
            continue

        end = max(start + 1, end)
        end = min(end, total_steps)

        duration = end - start

        for i, t in enumerate(range(start, end)):
            alpha = i / duration

            value = 1.0 - 0.5 * alpha

            roll[pitch, t] = max(roll[pitch, t], value)

    return roll

In [None]:
WAV_FILE = "data/musics/wav(input)/example2.wav"
THRESHOLD = 0.6

samples, sr = transcriber.load_wav(WAV_FILE)
mels = transcriber.wav_to_mel(samples, sr)

show_spectrogram(mels, "Mel Spectrogram", "Mel bins")


x = mels.T.detach().unsqueeze(0).float().to(transcriber.device)
onset_logits, sustain_logits = transcriber.forward(x)
onset_probs = torch.sigmoid(onset_logits)[0].cpu().detach().numpy()
sustain_probs = torch.sigmoid(sustain_logits)[0].cpu().detach().numpy()

show_piano_roll(onset_probs, "Piano roll (Onset)")
show_piano_roll(sustain_probs, "Piano roll (Sustain)")


duration = len(samples) / sr
total_steps = int(duration / 0.05) + 1
notes = note_matrices_to_notes(onset_matrix=onset_probs, sustain_matrix=sustain_probs, onset_threshold=THRESHOLD)
midi_matrix = midi_to_sustain_roll(notes, total_steps=total_steps)

show_piano_roll(midi_matrix, "Piano roll (MIDI final)")

In [None]:
WAV_FILE = "data/musics/wav(input)/example0.wav"
OUTPUT_FILE = "data/musics/midi(output)/example0.mid"

midi = transcriber.wav_to_midi_file(WAV_FILE, OUTPUT_FILE)

In [None]:
WAV_FOLDER = "data/musics/wav(input)"
OUTPUT_FOLDER = "data/musics/midi(output)"

midi = transcriber.wav_to_midi_folder(WAV_FOLDER, OUTPUT_FOLDER)