In [1]:
import os
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data

import lightning.pytorch as pl
import torchaudio

import tqdm
import pretty_midi as pm
import mir_eval
import transformers
from transformers import T5Config, T5ForConditionalGeneration

import matplotlib.pyplot as plt
import mir_eval.sonify as sonify
import mir_eval.display as display



SR = 16000
AUDIO_SEGMENT_SEC = 2.0
SEGMENT_N_FRAMES = 200
FRAME_STEP_SIZE_SEC = 0.01
FRAME_PER_SEC = 100 


pl.seed_everything(1234)

print(f'PyTorch: {torch.__version__}')
print(f'Torchaudio: {torchaudio.__version__}')
print(f'PyTorch Lightning: {pl.__version__}')
print(f'Transformers: {transformers.__version__}')

  from .autonotebook import tqdm as notebook_tqdm
Global seed set to 1234


PyTorch: 1.13.1+cu116
Torchaudio: 0.13.1+cu116
PyTorch Lightning: 1.9.4
Transformers: 4.29.2


# MIDI-like token

In [2]:
# MIDI-like token definitions

N_NOTE = 128
N_TIME = 205
N_SPECIAL = 3

voc_single_track = {
    "pad": 0,
    "eos": 1,
    "endtie": 2,
    "note": N_SPECIAL,
    "onset": N_SPECIAL+N_NOTE,
    "time": N_SPECIAL+N_NOTE+2,
    "n_voc": N_SPECIAL+N_NOTE+2+N_TIME+3,
    "keylist": ["pad", "eos", "endtie", "note", "onset", "time"]}

## Tokenizer
Given a MIDI file, the notes are first extended according to its pedal signal (`pm_apply_pedal()`). Then the notes are splitted into segments and transformed into token lists (`get_segment_tokens()`).

In [3]:
from operator import attrgetter
import dataclasses

@dataclasses.dataclass
class Event:
    prog: int
    onset: bool
    pitch: int

class MIDITokenExtractor:
    def __init__(self, midi_path, voc_dict, apply_pedal=True):
        self.pm = pm.PrettyMIDI(midi_path)
        if apply_pedal:
            self.pm_apply_pedal(self.pm)
        self.voc_dict = voc_dict
        self.multi_track = "instrument" in voc_dict

    def pm_apply_pedal(self, pm: pm.PrettyMIDI, program=0):
        """
        Apply sustain pedal by stretching the notes in the pm object.
        """
        # 1: Record the onset positions of each notes as a dictionary
        onset_dict = dict()     
        for note in pm.instruments[program].notes:
            if note.pitch in onset_dict:
                onset_dict[note.pitch].append(note.start)
            else:
                onset_dict[note.pitch] = [note.start]
        for k in onset_dict.keys():
            onset_dict[k] = np.sort(onset_dict[k])
            
        # 2: Record the pedal on/off state of each time frame
        arr_pedal = np.zeros(
            round(pm.get_end_time()*FRAME_PER_SEC)+100, dtype=bool)
        pedal_on_time = -1
        list_pedaloff_time = []
        for cc in pm.instruments[program].control_changes:
            if cc.number == 64:
                if (cc.value > 0) and (pedal_on_time < 0):
                    pedal_on_time = round(cc.time*FRAME_PER_SEC)
                elif (cc.value == 0) and (pedal_on_time >= 0):
                    pedal_off_time = round(cc.time*FRAME_PER_SEC)
                    arr_pedal[pedal_on_time:pedal_off_time] = True
                    list_pedaloff_time.append(cc.time)
                    pedal_on_time = -1
        list_pedaloff_time = np.sort(list_pedaloff_time)
        
        # 3: Stretch the notes (modify note.end)
        for note in pm.instruments[program].notes:
            # 3-1: Determine whether sustain pedal is on at note.end. If not, do nothing.
            # 3-2: Find the next note onset time and next pedal off time after note.end.
            # 3-3: Extend note.end till the minimum of next_onset and next_pedaloff.
            note_off_frame = round(note.end*FRAME_PER_SEC)
            pitch = note.pitch
            if arr_pedal[note_off_frame]:
                next_onset = np.argwhere(onset_dict[pitch] > note.end)
                next_onset = np.inf if len(
                    next_onset) == 0 else onset_dict[pitch][next_onset[0, 0]]
                next_pedaloff = np.argwhere(list_pedaloff_time > note.end)
                next_pedaloff = np.inf if len(
                    next_pedaloff) == 0 else list_pedaloff_time[next_pedaloff[0, 0]]
                new_noteoff_time = max(note.end, min(next_onset, next_pedaloff))
                new_noteoff_time = min(new_noteoff_time, pm.get_end_time())
                note.end = new_noteoff_time
    
    def get_segment_tokens(self, start, end):
        """
        Transform a segment of the MIDI file into a sequence of tokens.
        """
        dict_event = dict() # a dictionary that maps time to a list of events.

        def append_to_dict_event(time, item):
            if time in dict_event:
                dict_event[time].append(item)
            else:
                dict_event[time] = [item]

        list_events = []        # events section
        list_tie_section = []   # tie section

        for instrument in self.pm.instruments:
            prog = instrument.program
            for note in instrument.notes:
                note_end = round(note.end * FRAME_PER_SEC)
                note_start = round(note.start * FRAME_PER_SEC)
                if (note_end < start) or (note_start >= end):
                    continue
                if (note_start < start) and (note_end >= start):
                    # If the note starts before the segment, but ends in the segment
                    # it is added to the tie section.
                    list_tie_section.append(self.voc_dict["note"] + note.pitch)
                    if note_end < end:
                        append_to_dict_event(
                            note_end - start, Event(prog, False, note.pitch)
                        )
                    continue
                assert note_start >= start
                append_to_dict_event(note_start - start, Event(prog, True, note.pitch))
                if note_end < end:
                    append_to_dict_event(
                        note_end - start, Event(prog, False, note.pitch)
                    )

        cur_onset = None
        cur_prog = -1
        for time in sorted(dict_event.keys()):
            list_events.append(self.voc_dict["time"] + time)
            for event in sorted(dict_event[time], key=attrgetter("pitch", "onset")):
                if cur_onset != event.onset:
                    cur_onset = event.onset
                    list_events.append(self.voc_dict["onset"] + int(event.onset))
                list_events.append(self.voc_dict["note"] + event.pitch)
                
        # Concatenate tie section, endtie token, and event section
        list_tie_section.append(self.voc_dict["endtie"])
        list_events.append(self.voc_dict["eos"])
        tokens = np.concatenate((list_tie_section, list_events)).astype(int)
        return tokens

## Detokenizer
Transforms a list of MIDI-like token sequences into a MIDI file.

In [4]:

def parse_id(voc_dict: dict, id: int):
    keys = voc_dict["keylist"]
    # anchors = [voc_dict[k] for k in keys]
    token_name = keys[0]
    for k in keys:
        if id < voc_dict[k]:
            break
        token_name = k
    token_id = id - voc_dict[token_name]
    return token_name, token_id


def to_second(n):
    return n * FRAME_STEP_SIZE_SEC


def find_note(list, n):
    li_elem = [a for a, _ in list]
    try:
        idx = li_elem.index(n)
    except ValueError:
        return -1
    return idx


def token_seg_list_to_midi(token_seg_list: list):
    """
    Transform a list of token sequences into a MIDI file.
    """
    midi_data = pm.PrettyMIDI()
    piano_program = pm.instrument_name_to_program("Acoustic Grand Piano")
    piano = pm.Instrument(program=piano_program)
    list_onset = []
    cur_time = 0
    for token_seg in token_seg_list:
        list_tie = []
        cur_relative_time = -1
        cur_onset = -1
        tie_end = False
        for token in token_seg:
            token_name, token_id = parse_id(voc_single_track, token)
            if token_name == "note":
                if not tie_end:
                    list_tie.append(token_id)
                elif cur_onset == 1:
                    list_onset.append((token_id, cur_time + cur_relative_time))
                elif cur_onset == 0:
                    i = find_note(list_onset, token_id)
                    if i >= 0:
                        start = list_onset[i][1]
                        end = cur_time + cur_relative_time
                        if start < end:
                            new_note = pm.Note(100, token_id, start, end)
                            piano.notes.append(new_note)
                        list_onset.pop(i)

            elif token_name == "onset":
                if tie_end:
                    if token_id == 1:
                        cur_onset = 1
                    elif token_id == 0:
                        cur_onset = 0
            elif token_name == "time":
                if tie_end:
                    cur_relative_time = to_second(token_id)
            elif token_name == "endtie":
                tie_end = True
                for note, start in list_onset:
                    if note not in list_tie:
                        if start < cur_time:
                            new_note = pm.Note(100, note, start, cur_time)
                            piano.notes.append(new_note)
                        list_onset.remove((note, start))
        cur_time += AUDIO_SEGMENT_SEC

    midi_data.instruments.append(piano)
    return midi_data

# Dataset
We use MAESTRO v3.0.0 dataset.

In [5]:
class AMTDatasetBase(data.Dataset):
    def __init__(
        self,
        flist_audio,
        flist_midi,
        sample_rate,
        voc_dict,
        apply_pedal=True,
        whole_song=False,
    ):
        super().__init__()
        self.midi_filelist = flist_midi
        self.audio_filelist = flist_audio
        self.audio_metalist = [torchaudio.info(f) for f in flist_audio]
        self.voc_dict = voc_dict
        self.midi_list = [
            MIDITokenExtractor(f, voc_dict, apply_pedal)
            for f in tqdm.tqdm(self.midi_filelist, desc="load dataset")
        ]
        self.sample_rate = sample_rate
        self.whole_song = whole_song

    def __len__(self):
        return len(self.audio_filelist)

    def __getitem__(self, index):
        """
        Return a pair of (audio, tokens) for the given index.
        On the training stage, return a random segment from the song.
        On the test stage, return the audio and MIDI of the whole song.
        """
        if not self.whole_song:
            return self.getitem_segment(index)
        else:
            return self.getitem_wholesong(index)

    def getitem_segment(self, index, start_pos=None):
        metadata = self.audio_metalist[index]
        num_frames = metadata.num_frames
        sample_rate = metadata.sample_rate
        duration_y = round(num_frames / float(sample_rate) * FRAME_PER_SEC)
        midi_item = self.midi_list[index]
        if start_pos is None:
            segment_start = np.random.randint(duration_y - SEGMENT_N_FRAMES)
        else:
            segment_start = start_pos
        segment_end = segment_start + SEGMENT_N_FRAMES
        segment_start_sample = round(
            segment_start * FRAME_STEP_SIZE_SEC * sample_rate
        )

        segment_tokens = midi_item.get_segment_tokens(segment_start, segment_end)
        segment_tokens = torch.from_numpy(segment_tokens).long()
        y_segment, _ = torchaudio.load(
            self.audio_filelist[index],
            frame_offset=segment_start_sample,
            num_frames=round(AUDIO_SEGMENT_SEC * sample_rate),
        )
        y_segment = y_segment.mean(0)
        if sample_rate != self.sample_rate:
            y_segment = torchaudio.functional.resample(
                y_segment,
                sample_rate,
                self.sample_rate,
                resampling_method="kaiser_window",
            )
        return y_segment, segment_tokens

    def getitem_wholesong(self, index):
        y, sr = torchaudio.load(self.audio_filelist[index])
        y = y.mean(0)
        if sr != self.sample_rate:
            y = torchaudio.functional.resample(
                y, sr, self.sample_rate, 
                resampling_method="kaiser_window"
            )
        midi = self.midi_list[index].pm
        return y, midi

    def collate_wholesong(self, batch):
        batch_audio = torch.stack([b[0] for b in batch], dim=0)
        midi = [b[1] for b in batch]
        return batch_audio, midi

    def collate_batch(self, batch):
        batch_audio = torch.stack([b[0] for b in batch], dim=0)
        batch_tokens = [b[1] for b in batch]
        batch_tokens_pad = torch.nn.utils.rnn.pad_sequence(
            batch_tokens, batch_first=True, padding_value=self.voc_dict["pad"]
        )
        return batch_audio, batch_tokens_pad
    
    
class Maestro(AMTDatasetBase):
    def __init__(
        self,
        n_files=-1,
        sample_rate=44100,
        split="test",
        apply_pedal=True,
        whole_song=False,
    ):
        data_path = "../../../Data2/maestro-v3.0.0/"
        df_metadata = pd.read_csv(os.path.join(data_path, "maestro-v3.0.0.csv"))
        flist_audio = []
        flist_midi = []
        list_title = []
        for row in range(len(df_metadata)):
            if df_metadata["split"][row] == split:
                f_audio = os.path.join(data_path, df_metadata["audio_filename"][row])
                f_midi = os.path.join(data_path, df_metadata["midi_filename"][row])
                assert os.path.exists(f_audio) and os.path.exists(f_midi)
                flist_audio.append(f_audio)
                flist_midi.append(f_midi)
                list_title.append(df_metadata["canonical_title"][row])
        if n_files > 0:
            flist_audio = flist_audio[:n_files]
            flist_midi = flist_midi[:n_files]
        super().__init__(
            flist_audio,
            flist_midi,
            sample_rate,
            voc_dict=voc_single_track,
            apply_pedal=apply_pedal,
            whole_song=whole_song,
        )
        self.list_title = list_title

# Evaluation function

In [6]:
def extract_midi(midi: pm.PrettyMIDI, program=0):
    intervals = []
    pitches = []
    pm_notes = midi.instruments[program].notes
    for note in pm_notes:
        intervals.append((note.start, note.end))
        pitches.append(note.pitch)

    return np.array(intervals), np.array(pitches)


def evaluate_midi(est_midi: pm.PrettyMIDI, ref_midi: pm.PrettyMIDI, program=0):
    est_intervals, est_pitches = extract_midi(est_midi, program)
    ref_intervals, ref_pitches = extract_midi(ref_midi, program)

    dict_eval = mir_eval.transcription.evaluate(
        ref_intervals, ref_pitches, est_intervals, est_pitches)

    return dict_eval


# Sequence-to-sequence transcriber

The transformer model is built with the configuration of "T5v1.1-small" model released by Google.

In [7]:
def split_audio_into_segments(y: torch.Tensor, sr: int):
    audio_segment_samples = round(AUDIO_SEGMENT_SEC * sr)
    pad_size = audio_segment_samples - (y.shape[-1] % audio_segment_samples)
    y = F.pad(y, (0, pad_size))
    assert (y.shape[-1] % audio_segment_samples) == 0
    n_chunks = y.shape[-1] // audio_segment_samples
    y_segments = torch.chunk(y, chunks=n_chunks, dim=-1)
    return torch.stack(y_segments, dim=0)


def unpack_sequence(x: torch.Tensor, eos_id: int=1):
    seqs = []
    max_length = x.shape[-1]
    for seq in x:
        start_pos = 0
        pos = 0
        while (pos < max_length) and (seq[pos] != eos_id):
            pos += 1
        end_pos = pos+1
        seqs.append(seq[start_pos:end_pos])
    return seqs

class LogMelspec(nn.Module):
    def __init__(self, sample_rate, n_fft, n_mels, hop_length):
        super().__init__()
        self.melspec = torchaudio.transforms.MelSpectrogram(
            sample_rate=sample_rate,
            n_fft=n_fft,
            hop_length=hop_length,
            f_min=20.0,
            n_mels=n_mels,
            mel_scale="slaney",
            norm="slaney",
            power=1,
        )
        self.eps = 1e-5

    def forward(self, x):
        spec = self.melspec(x)
        safe_spec = torch.clamp(spec, min=self.eps)
        log_spec = torch.log(safe_spec)
        return log_spec

class Seq2SeqTranscriber(nn.Module):
    def __init__(
        self, n_mels: int, sample_rate: int, n_fft: int, hop_length: int, voc_dict: dict
    ):
        super().__init__()
        self.infer_max_len = 200
        self.voc_dict = voc_dict
        self.n_voc_token = voc_dict["n_voc"]
        self.t5config = T5Config.from_pretrained("google/t5-v1_1-small")
        custom_configs = {
            "vocab_size": self.n_voc_token,
            "pad_token_id": voc_dict["pad"],
            "d_model": n_mels,
        }

        for k, v in custom_configs.items():
            self.t5config.__setattr__(k, v)

        self.transformer = T5ForConditionalGeneration(self.t5config)
        self.melspec = LogMelspec(sample_rate, n_fft, n_mels, hop_length)
        self.sr = sample_rate

    def forward(self, wav, labels):
        spec = self.melspec(wav).transpose(-1, -2)
        outs = self.transformer.forward(
            inputs_embeds=spec, return_dict=True, labels=labels
        )
        return outs

    def infer(self, wav):
        """
        Infer the transcription of a single audio file.
        The input audio file is split into segments of 2 seconds
        before passing to the transformer.
        """
        wav_segs = split_audio_into_segments(wav, self.sr)
        spec = self.melspec(wav_segs).transpose(-1, -2)
        outs = self.transformer.generate(
            inputs_embeds=spec,
            max_length=self.infer_max_len,
            num_beams=1,
            do_sample=False,
            return_dict_in_generate=False,
        )
        return outs

# Trainer

In [8]:

class LitTranscriber(pl.LightningModule):
    def __init__(
        self,
        transcriber_args: dict,
        lr: float,
        lr_decay: float = 1.0,
        lr_decay_interval: int = 1,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.voc_dict = voc_single_track
        self.n_voc = self.voc_dict["n_voc"]
        self.transcriber = Seq2SeqTranscriber(
            **transcriber_args, voc_dict=self.voc_dict
        )
        self.lr = lr
        self.lr_decay = lr_decay
        self.lr_decay_interval = lr_decay_interval

    def forward(self, y: torch.Tensor):
        transcriber_infer = self.transcriber.infer(y)
        return transcriber_infer

    def training_step(self, batch, batch_idx):
        y, t = batch
        tf_out = self.transcriber(y, t)
        loss = tf_out.loss
        t = t.detach()
        mask = t != self.voc_dict["pad"]
        accr = (tf_out.logits.argmax(-1)[mask] == t[mask]).sum() / mask.sum()
        self.log("train/loss", loss)
        self.log("train/accr", accr)
        return loss

    def validation_step(self, batch, batch_idx):
        assert not self.transcriber.training
        y, t = batch
        tf_out = self.transcriber(y, t)
        loss = tf_out.loss
        t = t.detach()
        mask = t != self.voc_dict["pad"]
        accr = (tf_out.logits.argmax(-1)[mask] == t[mask]).sum() / mask.sum()
        self.log("vali/loss", loss)
        self.log("vali/accr", accr)
        return loss

    def test_step(self, batch, batch_idx):
        y, ref_midi = batch
        y = y[0]
        ref_midi = ref_midi[0]
        with torch.no_grad():
            est_tokens = self.forward(y)
            unpadded_tokens = unpack_sequence(est_tokens.cpu().numpy())
            unpadded_tokens = [t[1:] for t in unpadded_tokens]
            est_midi = token_seg_list_to_midi(unpadded_tokens)
        dict_eval = evaluate_midi(est_midi, ref_midi)
        dict_log = {}
        for key in dict_eval:
            dict_log["test/" + key] = dict_eval[key]
        self.log_dict(dict_log, batch_size=1)

    def train_dataloader(self):
        dset = Maestro(
            sample_rate=SR,
            split="train",
            n_files=20,
        )
        return data.DataLoader(
            dataset=dset,
            collate_fn=dset.collate_batch,
            batch_size=4,
            shuffle=True,
            pin_memory=True,
            num_workers=4,
        )
        
    def test_dataloader(self):
        dset = Maestro(
            sample_rate=SR,
            split="test",
            whole_song=True,
            n_files=10,
        )
        return data.DataLoader(
            dataset=dset,
            collate_fn=dset.collate_wholesong,
            batch_size=1,
            shuffle=False,
            pin_memory=True,
        )
    
    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.lr)

In [9]:
args = {
    "n_mels":512,
    "sample_rate":16000,
    "n_fft":1024,
    "hop_length":256,
}

lightning_module = LitTranscriber(
    transcriber_args=args,
    lr=1e-4,
    lr_decay=0.99
)

trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=False,
    accelerator="gpu",
    devices="1,",
    max_epochs=1000,
)

trainer.fit(lightning_module)
    

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2]

  | Name        | Type               | Params
---------------------------------------------------
0 | transcriber | Seq2SeqTranscriber | 44.4 M
---------------------------------------------------
44.4 M    Trainable params
0         Non-trainable params
44.4 M    Total params
177.645   Total estimated model params size (MB)
load dataset: 100%|██████████| 20/20 [01:52<00:00,  5.64s/it]


Epoch 999: 100%|██████████| 5/5 [00:01<00:00,  3.28it/s, loss=2.63]

`Trainer.fit` stopped: `max_epochs=1000` reached.


Epoch 999: 100%|██████████| 5/5 [00:01<00:00,  3.28it/s, loss=2.63]


In [10]:
# The scores are low since the model is trained on a small dataset.

trainer.test(lightning_module)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2]
load dataset: 100%|██████████| 10/10 [00:20<00:00,  2.04s/it]
  rank_zero_warn(


Testing DataLoader 0:  20%|██        | 2/10 [00:04<00:16,  2.07s/it]



Testing DataLoader 0: 100%|██████████| 10/10 [00:38<00:00,  3.84s/it]


[{'test/Precision': 0.00045526440953835845,
  'test/Recall': 0.0003843010345008224,
  'test/F-measure': 0.0004150962922722101,
  'test/Average_Overlap_Ratio': 0.44270918821881045,
  'test/Precision_no_offset': 0.00696998555213213,
  'test/Recall_no_offset': 0.004659607540816069,
  'test/F-measure_no_offset': 0.005259971134364605,
  'test/Average_Overlap_Ratio_no_offset': 0.23869087929446692,
  'test/Onset_Precision': 0.03701071813702583,
  'test/Onset_Recall': 0.02266177348792553,
  'test/Onset_F-measure': 0.02616998553276062,
  'test/Offset_Precision': 0.1667836457490921,
  'test/Offset_Recall': 0.10361303389072418,
  'test/Offset_F-measure': 0.11814618110656738}]