In [1]:
import hashlib
import hydra
from pathlib import Path
import json
import os
import tqdm
import torchaudio as ta
import musdb

#library for class Wavset
from collections import OrderedDict
import math
import torch as th
import julius
from torch.nn import functional as F

#library for loader()
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader, Subset


  from .autonotebook import tqdm as notebook_tqdm


In [20]:
# musdb_p = '/workspace/MusicDataset/musdb18hq'
musdb_samplerate= 44100
# wav=  # path to custom wav dataset
segment= 11
shift= 1
# train_valid= False
# full_cv= True
samplerate= 44100
channels= 2
normalize= True
metadata= './metadata'
sources= ['drums', 'bass', 'other', 'vocals']
EXT = ".wav"
MIXTURE = "mixture"
download = False

batch_size= 6
num_workers = 10
world_size = 1
root= '/workspace/MusicDataset/musdb18hq'
# root= '/workspace/helen/demucs/'

In [3]:
def _get_musdb_valid():
    # Return musdb valid set.
    import yaml
    setup_path = Path(musdb.__path__[0]) / 'configs' / 'mus.yaml'
    setup = yaml.safe_load(open(setup_path, 'r'))
    return setup['validation_tracks']

In [4]:
def _track_metadata(track, sources, normalize=True, ext=EXT):
    track_length = None
    track_samplerate = None
    mean = 0
    std = 1
    for source in sources + [MIXTURE]:
        file = track / f"{source}{ext}"
        try:
            info = ta.info(str(file))
        except RuntimeError:
            print(file)
            raise
        length = info.num_frames
        if track_length is None:
            track_length = length
            track_samplerate = info.sample_rate
        elif track_length != length:
            raise ValueError(
                f"Invalid length for file {file}: "
                f"expecting {track_length} but got {length}.")
        elif info.sample_rate != track_samplerate:
            raise ValueError(
                f"Invalid sample rate for file {file}: "
                f"expecting {track_samplerate} but got {info.sample_rate}.")
        if source == MIXTURE and normalize:
            try:
                wav, _ = ta.load(str(file))
            except RuntimeError:
                print(file)
                raise
            wav = wav.mean(0)
            mean = wav.mean().item()
            std = wav.std().item()

    return {"length": length, "mean": mean, "std": std, "samplerate": track_samplerate}


In [5]:
def build_metadata(path, sources, normalize=True, ext=EXT):
    """
    Build the metadata for `Wavset`.
    Args:
        path (str or Path): path to dataset.
        sources (list[str]): list of sources to look for.
        normalize (bool): if True, loads full track and store normalization
            values based on the mixture file.
        ext (str): extension of audio files (default is .wav).
    """

    meta = {}
    path = Path(path)
    pendings = []
    from concurrent.futures import ThreadPoolExecutor
    with ThreadPoolExecutor(8) as pool:
        for root, folders, files in os.walk(path, followlinks=True):
            root = Path(root)
            if root.name.startswith('.') or folders or root == path:
                continue
            name = str(root.relative_to(path))
            pendings.append((name, pool.submit(_track_metadata, root, sources, normalize, ext)))
            # meta[name] = _track_metadata(root, sources, normalize, ext)
        for name, pending in tqdm.tqdm(pendings, ncols=120):
            meta[name] = pending.result()
    return meta

In [6]:
def convert_audio_channels(wav, channels=2):
    """Convert audio to the given number of channels."""
    *shape, src_channels, length = wav.shape
    if src_channels == channels:
        pass
    elif channels == 1:
        # Case 1:
        # The caller asked 1-channel audio, but the stream have multiple
        # channels, downmix all channels.
        wav = wav.mean(dim=-2, keepdim=True)
    elif src_channels == 1:
        # Case 2:
        # The caller asked for multiple channels, but the input file have
        # one single channel, replicate the audio over all channels.
        wav = wav.expand(*shape, channels, length)
    elif src_channels >= channels:
        # Case 3:
        # The caller asked for multiple channels, and the input file have
        # more channels than requested. In that case return the first channels.
        wav = wav[..., :channels, :]
    else:
        # Case 4: What is a reasonable choice here?
        raise ValueError('The audio file has less channels than requested but is not mono.')
    return wav


metadata_train: 86 songs\
metadata_valid: 14 songs 

In [9]:
# import wget
# wget.download('https://zenodo.org/record/3338373/files/musdb18hq.zip?download=1')

In [21]:
class MusdbHQ:
    def __init__(
            self,root, subset,segment=None, shift=None, normalize=True,
            samplerate=44100, channels=2, ext=EXT):
        """
        Waveset (or mp3 set for that matter). Can be used to train
        with arbitrary sources. Each track should be one folder inside of `path`.
        The folder should contain files named `{source}.{ext}`.
        Args:
            root (Path or str): root folder for the dataset.
            subset (str): training or validation
            download (bool): Whether to download the dataset if it is not found at root path. (default: ``False``).
            segment (None or float): segment length in seconds. If `None`, returns entire tracks.
            shift (None or float): stride in seconds bewteen samples.
            normalize (bool): normalizes input audio, **based on the metadata content**,
                i.e. the entire track is normalized, not individual extracts.
            samplerate (int): target sample rate. if the file sample rate
                is different, it will be resampled on the fly.
            channels (int): target nb of channels. if different, will be
                changed onthe fly.
            ext (str): extension for audio files (default is .wav).
        samplerate and channels are converted on the fly.
        """
#         url = 'https://zenodo.org/record/3338373/files/musdb18hq.zip?download=1'
        
#         download_path = root
#         self.download_path = download_path
        
          
#         if download:
#             if os.path.isdir(download_path) and os.path.isdir(os.path.join(download_path, 'data')):
#                 print(f'Dataset folder exists, skipping download...')
#                 decision = input(f"Do you want to extract {archive_name+ext_archive} again? "
#                                  f"To avoid this prompt, set `download=False`\n"
#                                  f"This action will overwrite exsiting files, do you still want to continue? [yes/no]") 
#                 if decision.lower()=='yes':
#                     print(f'extracting...')
#                     extract_archive(os.path.join(download_path, archive_name+ext_archive))                
#             elif os.path.isfile(os.path.join(download_path, 'timit.zip')):
#                 print(f'timit.zip exists, extracting...')
#                 check_md5(os.path.join(download_path, archive_name+ext_archive), checksum)
#                 extract_archive(os.path.join(download_path, archive_name+ext_archive))
#             else:
#                 decision='yes'       
#                 if not os.path.isdir(download_path):
#                     print(f'Creating download path = {root}')
#                     os.makedirs(os.path.join(download_path))
# #                 if os.path.isfile(download_path+ext_archive):
# #                     print(f'.tar.gz file exists, proceed to extraction...')
# #                 else:
#                 if os.path.isfile(os.path.join(download_path, archive_name+ext_archive)):
#                     print(f'{download_path+ext_archive} already exists, proceed to extraction...')
#                 else:
#                     print(f'downloading...')
#                     try:
#                         download_url(url, download_path, hash_value=checksum, hash_type='md5')
#                     except:
#                         raise Exception('Auto download fails. '+
#                                         'You may want to download it manually from:\n'+
#                                         url+ '\n' +
#                                         f'Then, put it inside {download_path}')
                      
        
        
# get_musdb_wav_datasets
        sig = hashlib.sha1(str(root).encode()).hexdigest()[:8]
        metadata_file = Path('./metadata') / ('musdb_' + sig + ".json")
        root = Path(root) / "train"
    #     if not metadata_file.is_file() and distrib.rank == 0:
        if not metadata_file.is_file():
            metadata_file.parent.mkdir(exist_ok=True, parents=True)
            metadata = build_metadata(root, sources)
            json.dump(metadata, open(metadata_file, "w"))
    #     if distrib.world_size > 1:
    #         distributed.barrier()
        metadata = json.load(open(metadata_file))

        valid_tracks = _get_musdb_valid()
        
        if subset == 'training':
            metadata = {name: meta for name, meta in metadata.items() if name not in valid_tracks}
            self.sources = sources
            
        elif subset=='validation':
            metadata = {name: meta for name, meta in metadata.items() if name in valid_tracks}
            self.sources = [MIXTURE] + list(sources)
# metadata (dict): output from `build_metadata`.
# sources (list[str]): list of source names.
        
        self.root = Path(root)
        self.metadata = OrderedDict(metadata)
        self.segment = segment
        self.shift = shift or segment
        self.normalize = normalize
        self.channels = channels
        self.samplerate = samplerate
        self.ext = ext
        self.num_examples = []
        for name, meta in self.metadata.items():
            track_duration = meta['length'] / meta['samplerate']
            if segment is None or track_duration < segment:
                examples = 1
            else:
                examples = int(math.ceil((track_duration - self.segment) / self.shift) + 1)
            self.num_examples.append(examples)
# samplerate = number of sample per second
# length = number of sample
# track_duration is second, cut song into segment

    def __len__(self):
        return sum(self.num_examples)

    def get_file(self, name, source):
        return self.root / name / f"{source}{self.ext}"

    def __getitem__(self, index):
        print(len(self.metadata))
        for name, examples in zip(self.metadata, self.num_examples):           
            if index >= examples:
                index -= examples
                continue
            meta = self.metadata[name]
            num_frames = -1
            offset = 0
            if self.segment is not None:
                offset = int(meta['samplerate'] * self.shift * index)
                num_frames = int(math.ceil(meta['samplerate'] * self.segment))
            wavs = []
            for source in self.sources:
                file = self.get_file(name, source)
                wav, _ = ta.load(str(file), frame_offset=offset, num_frames=num_frames)
                wav = convert_audio_channels(wav, self.channels)
                wavs.append(wav)

            example = th.stack(wavs)
            example = julius.resample_frac(example, meta['samplerate'], self.samplerate)
            if self.normalize:
                example = (example - meta['mean']) / meta['std']
            if self.segment:
                length = int(self.segment * self.samplerate)
                example = example[..., :length]
                example = F.pad(example, (0, length - example.shape[-1]))
            return example

In [22]:
train_set = MusdbHQ(root,'training',samplerate=samplerate, channels=channels,normalize=normalize,
                    segment=segment, shift=shift)


100%|█████████████████████████████████████████████████████████████████████████████████| 100/100 [00:05<00:00, 18.07it/s]


In [23]:
valid_set = MusdbHQ(root,'validation', samplerate=samplerate, channels=channels,normalize=normalize)

In [24]:
train_set[1].shape

86


torch.Size([4, 2, 485100])

In [25]:
valid_set[3].shape

14


torch.Size([5, 2, 11301609])

In [None]:
def loader(dataset, batch_size , shuffle=False, klass=DataLoader, **kwargs):
    """
    Create a dataloader properly in case of distributed training.
    If a gradient is going to be computed you must set `shuffle=True`.
    """
    if world_size == 1:
        return klass(dataset, batch_size=batch_size, shuffle=shuffle, **kwargs)

    if shuffle:
        # train means we will compute backward, we use DistributedSampler
        sampler = DistributedSampler(dataset)
        # We ignore shuffle, DistributedSampler already shuffles
        return klass(dataset, batch_size=batch_size, **kwargs, sampler=sampler)
    else:
        # We make a manual shard, as DistributedSampler otherwise replicate some examples
        dataset = Subset(dataset, list(range(rank, len(dataset), world_size)))
        return klass(dataset, batch_size=batch_size, shuffle=shuffle, **kwargs)

In [None]:
train_loader = loader(
        train_set, batch_size=batch_size, shuffle=True,
        num_workers=num_workers, drop_last=True)

In [None]:
train_loader