In [None]:
import numpy as np
import os
import pathlib
import shutil
from tqdm import tqdm

## Creating a Test-Train split

In [None]:
def _copy_files_over(src_files, path_to_subset, copy, rm_existing: bool = False,):
    if rm_existing and path_to_subset.exists():
        shutil.rmtree(str(path_to_subset))
    path_to_subset.mkdir(exist_ok=False, parents=True)

    for path_to_src in tqdm(src_files):
        src = path_to_src.absolute()
        dest = path_to_subset.absolute() / (src.name)

        if not src.exists():
            raise FileNotFoundError(str(src))

        if copy:
            shutil.copy(str(src), str(dest))
        else:
            os.symlink(str(src), str(dest))

In [None]:
def create_train_val_splits(
    path_to_data: pathlib.Path,
    path_to_train_subset: pathlib.Path,
    path_to_val_subset: pathlib.Path,
    n_train: int, n_val: int,
    copy: bool = False,
    rm_existing: bool = False,
) -> None:
    if not path_to_data.exists():
        raise ValueError

    files = list(path_to_data.glob('**/*.wav'))
    print(len(files))

    # Take a random subset
    subset = np.random.choice(files, size=n_train + n_val, replace=False)
    
    files_train = list(subset[:n_train])
    files_val = list(subset[n_train:])
    
    assert len(np.unique(files_train)) == n_train
    assert len(np.unique(files_val)) == n_val
    
    exts = ['.txt']
    files_train_annots = get_annot_files_list(files_train, exts)
    files_val_annots = get_annot_files_list(files_val, exts)
    
    _copy_files_over(src_files=files_train + files_train_annots, path_to_subset=path_to_train_subset, copy=copy, rm_existing=rm_existing)
    _copy_files_over(src_files=files_val + files_val_annots, path_to_subset=path_to_val_subset, copy=copy, rm_existing=rm_existing)

In [None]:
def get_annot_files_list(files_list, exts):
    files_list_annots = []
    for ext in exts:
        for file in files_list:
            files_list_annots.append(file.with_suffix(ext).absolute())
    return files_list_annots

In [None]:
src_files = list(pathlib.Path('/home/anuj/data/GuitarSet/audio/audio_hex-pickup_original/').glob('*.wav')) +\
            list(pathlib.Path('/home/anuj/data/GuitarSet/annotation/').glob('*.jams'))

In [None]:
_copy_files_over(src_files, pathlib.Path('/home/anuj/data/GuitarSet/originalhex'), copy=False, rm_existing=True)

In [None]:
create_train_val_splits(
    path_to_data = pathlib.Path('/home/anuj/data/GuitarSet/originalhex'),
    path_to_train_subset = pathlib.Path('/home/anuj/data/GuitarSet/originalhex-train'),
    path_to_val_subset = pathlib.Path('/home/anuj/data/GuitarSet/originalhex-val'),
    n_train=330,
    n_val=30,
    copy=False,
    rm_existing=True,
)