In [16]:
import json

import numpy as np
import fortepyan as ff
from tqdm import tqdm
from datasets import Dataset, load_dataset

from data_midi_bert import augment as A
from data_midi_bert.masking import Mask, AwesomeMasks

In [17]:
from omegaconf import DictConfig


def convert_dataset(
    dataset: Dataset,
    masks: list[Mask],
    dataset_cfg: DictConfig,
) -> list[dict]:
    # TODO This needs multiprocessing, consider dataset.map(convert, num_proc=32)
    records = []
    for record in tqdm(dataset):
        piece = ff.MidiPiece.from_huggingface(record)

        if piece.size < dataset_cfg.sequence_len:
            continue

        # This is an alternative way of calculating n_samples
        # Idea was: instead of constant step moving window, let's have
        # random sampling of fragment start
        records += process_piece(
            piece=piece,
            masks=masks,
            dataset_cfg=dataset_cfg,
        )

    return records


def process_piece(
    piece: ff.MidiPiece,
    masks: list[Mask],
    dataset_cfg: DictConfig,
) -> list[dict]:
    # How many samples we should produce from this piece
    n_samples = dataset_cfg.n_augments * (piece.size // dataset_cfg.sequence_step)

    # Every part of the piece we want to sample is defined by
    # first note index - here we select N of those randomly
    # +1 because random returns [low, high)
    high = piece.size + 1 - dataset_cfg.sequence_len
    starts = np.random.randint(0, high, size=n_samples)

    # Inject the dstart feature to the whole piece
    piece.df["next_start"] = piece.df.start.shift(-1)
    piece.df["dstart"] = piece.df.next_start - piece.df.start
    piece.df.dstart = piece.df.dstart.fillna(0)

    records = []
    for start in starts:
        # Required so it's json serializable
        start = int(start)
        finish = start + dataset_cfg.sequence_len

        part = piece[start:finish]

        # Is this a piano piece?
        if part.df.pitch.min() < 21 or part.df.pitch.max() > 108:
            continue

        # Random augments
        part.df, speedup_factor = A.change_speed(part.df)
        part.df, pitch_shift = A.pitch_shift(part.df)
        part.source |= {"pitch_shift": pitch_shift, "speedup_factor": speedup_factor}

        record = {
            "pitch": part.df.pitch.astype("int16").values.T,
            "start": part.df.start.values,
            "dstart": part.df.dstart.values,
            "end": part.df.end.values,
            "duration": part.df.duration.values,
            "velocity": part.df.velocity.values,
            "source": json.dumps(part.source),
        }
        # Masking
        masking_spaces = {}
        for mask in masks:
            masking_space = mask.masking_space(part.df).values
            masking_spaces[mask.token] = masking_space

        record["masking_space"] = masking_spaces
        records.append(record)

    return records

In [18]:
dataset = load_dataset("roszcz/maestro-v1-sustain")
test_dataset = dataset["test"]

Found cached dataset parquet (C:/Users/samue/.cache/huggingface/datasets/roszcz___parquet/roszcz--maestro-v1-sustain-5350ada51983a2ef/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7)


  0%|          | 0/3 [00:00<?, ?it/s]

In [19]:
dataset_cfg = DictConfig(
        {
            "sequence_len": 60,
            "sequence_step": 60,
            "n_augments": 5,
        }
    )
midi_masks = AwesomeMasks()

In [20]:
%%timeit -n 5 -r 1

records = convert_dataset(dataset=test_dataset,
            dataset_cfg=dataset_cfg,
            masks=midi_masks.masks,)

 15%|█▌        | 27/177 [00:23<02:05,  1.19it/s]