In [109]:
from collections import defaultdict
from inspect import isclass
import os
from pathlib import Path
import sys
from typing import Callable, Union, Optional
from tqdm import tqdm

import muspy
import numpy as np
import sklearn.model_selection

In [2]:
class Groove2GrooveDataset(muspy.RemoteFolderDataset):
    """Groove2Groove Dataset."""

    _sources = {
        "nes": {
            "filename": "groove2groove-data-v1.0.0.tar.gz",
            "url": "https://zenodo.org/record/3958000/files/groove2groove-data-v1.0.0.tar.gz?download=1",
            "archive": True,
            "size": 236114569,
            "md5": "c407de7b3676267660c88dc6ee351c79",
            "sha256": "b4ef60b8d0cf5c510868c1b10d560432a5412f809d47f51a092582158cb41c09",
        }
    }
    _extension = "mid"
    
    def __init__( 
        self,
        root: Union[str, Path],
        download_and_extract: bool = False,
        cleanup: bool = False,
        convert: bool = False,
        kind: str = "json",
        n_jobs: int = 1,
        ignore_exceptions: bool = True,
        use_converted: Optional[bool] = None,
        part: str = "train"
    ):
        self.part = part
        muspy.RemoteFolderDataset.__init__(
            self, root=root, download_and_extract=download_and_extract,
            cleanup=cleanup, convert=convert, kind=kind, n_jobs=n_jobs,
            ignore_exceptions=ignore_exceptions, use_converted=use_converted)

        path = self.root / 'groove2groove-data-v1.0.0' / 'midi' / part / 'fixed'
        self.raw_filenames = sorted(
            (
                filename
                for filename in path.rglob("*." + self._extension)
            )
        )
        self._filenames = self.raw_filenames

    @property
    def converted_dir(self):
        """Path to the root directory of the converted dataset."""
        return self.root / "_converted_{}".format(self.part)

    def read(self, filename: Union[str, Path]) -> muspy.Music:
        """Read a file into a Music object."""
        music = muspy.inputs.read_midi(self.root / filename)
        
        # Merge tracks of the same name
        name_to_tracks = defaultdict(list)
        for track in music.tracks:
            name_to_tracks[track.name].append(track)
        for name, tracks in name_to_tracks.items():
            for track in tracks[1:]:
                tracks[0].merge(track, override=False, remove_duplicate=False)
                music.tracks.remove(track)

        return music

In [3]:
def _obj_filter(self, func: Callable[[muspy.Base], bool], attr: str, recursive: bool):
    if attr in self._list_attributes:
        setattr(self, attr, list(filter(func, getattr(self, attr))))
    
        attr_type = self._attributes[attr]
        if recursive and isclass(attr_type) and issubclass(attr_type, muspy.Base):
            if attr in self._list_attributes:
                for item in getattr(self, attr):
                    obj_filter(item, func, recursive=recursive)
            elif getattr(self, attr) is not None:
                obj_filter(getattr(self, attr), func, recursive=recursive)
    

def obj_filter(self, func: Callable[[muspy.Base], bool], attr: Optional[str] = None, recursive: bool = True):
    if attr is None:
        for attribute in self._attributes:
            _obj_filter(self, func, attribute, recursive)
    else:
        _obj_filter(self, func, attr, recursive)
    return self

In [31]:
TRACK_NAMES = ['BB Bass', 'BB Drums', 'BB Guitar', 'BB Piano', 'BB Strings']
OUT_DIR = Path('.')

for part in ['test', 'val', 'train']:
    data = Groove2GrooveDataset('/tmp/groove2groove-data', part=part, ignore_exceptions=False)
    
    (OUT_DIR / part).mkdir(parents=True, exist_ok=True)
    
    for music in tqdm(data, desc=part):
        # Add missing tracks
        for track_name in TRACK_NAMES:
            if not any(track.name == track_name for track in music.tracks):
                music.append(muspy.Track(name=track_name))

        music.tracks.sort(key=lambda track: track.name)
        assert [track.name for track in music.tracks] == TRACK_NAMES
        
        music.adjust_resolution(target=12)

        # Get rid of first 2 bars
        music.adjust_time(lambda t: t - 8 * music.resolution)
        obj_filter(music, lambda obj: not hasattr(obj, 'time') or obj.time >= 0)
        
        # Add metadata
        title, style, _ = music.metadata.source_filename.split('.')
        music.metadata.title = f'{title}.{style}'

        music.validate()

        music.save((OUT_DIR / part / music.metadata.source_filename).with_suffix('.json'))

test: 100%|██████████| 1200/1200 [02:09<00:00,  9.27it/s]
val: 100%|██████████| 1200/1200 [01:57<00:00, 10.24it/s]
train: 100%|██████████| 5733/5733 [2:15:44<00:00,  1.42s/it]  


In [78]:
styles = []

for path in (OUT_DIR / 'train').iterdir():
    _, substyle, _ = path.name.split('.')
    styles.append(substyle.rsplit('_', maxsplit=1)[0])

styles = sorted(set(styles))

np.random.seed(0)
styles_train, styles_test = sklearn.model_selection.train_test_split(styles, test_size=25)
styles_train, styles_val = sklearn.model_selection.train_test_split(styles_train, test_size=25)
_, styles_itest = sklearn.model_selection.train_test_split(styles_train, test_size=50)
styles_itest, styles_ival = sklearn.model_selection.train_test_split(styles_itest, test_size=25)

In [79]:
files_by_substyle = defaultdict(list)

for path in (OUT_DIR / 'train').iterdir():
    _, substyle, _ = path.name.split('.')
    files_by_substyle[substyle].append(path)

In [112]:
for part in ['train', 'train_lim', 'ival', 'itest', 'val', 'test']:
    (OUT_DIR / 'train_split' / part).mkdir(parents=True)

In [113]:
np.random.seed(0)
for substyle, paths in files_by_substyle.items():
    style, _ = substyle.rsplit('_', maxsplit=1)
    if len(paths) != 2:
        print(f'Style {substyle} has {len(paths)} files, skipping', file=sys.stderr)
        continue
    paths = sorted(paths)
    np.random.shuffle(paths)
    
    if style in styles_test:
        (OUT_DIR / 'train_split' / 'test' / paths[0].name).symlink_to(Path('..', '..') / paths[0])
    elif style in styles_val:
        (OUT_DIR / 'train_split' / 'val' / paths[0].name).symlink_to(Path('..', '..') / paths[0])
    else:
        for path in paths:
            (OUT_DIR / 'train_split' / 'train' / path.name).symlink_to(Path('..', '..') / path)

        if style in styles_itest:
            (OUT_DIR / 'train_split' / 'itest' / paths[0].name).symlink_to(Path('..', '..') / paths[0])
            paths = paths[1:]
        elif style in styles_ival:
            (OUT_DIR / 'train_split' / 'ival' / paths[0].name).symlink_to(Path('..', '..') / paths[0])
            paths = paths[1:]

        (OUT_DIR / 'train_split' / 'train_lim' / paths[-1].name).symlink_to(Path('..', '..') / paths[-1])

Style MACY1_b has 1 files, skipping
Style THEMBA_a has 1 files, skipping
Style C_DARRYL_a has 1 files, skipping
Style R_90DISC_b has 1 files, skipping
Style JonLuce1_a has 1 files, skipping
Style Melisa2_a has 1 files, skipping
Style FUNKWALK_b has 1 files, skipping
Style SWBALLAD_a has 1 files, skipping
Style C2_LONE_a has 1 files, skipping
Style BeachCar_b has 1 files, skipping
Style REVEALED_b has 1 files, skipping


In [117]:
!for d in train_split/{train,train_lim,test,val,itest,ival}; do echo $d | tr '\n' ' '; ls -1 $d | wc -l; done

train_split/train 5522
train_split/train_lim 2761
train_split/test 50
train_split/val 50
train_split/itest 50
train_split/ival 49
