# Dataset and CPC Sandbox

*April 16, 2021*

This notebook achieves two tasks:
1. Dataset curation -- we want to organize the MusicNet dataset by instrument type
2. Preliminary CPC training/testing -- we want to get an implementation of the CPC model working on the MusicNet dataset

**References** 
1. [Script to download raw MusicNet data (Github)](https://github.com/jthickstun/pytorch_musicnet)
2. [UMT's MusicNet data loader (Github)](https://github.com/facebookresearch/music-translation)
3. [MusicNet Documentation](https://homes.cs.washington.edu/~thickstn/musicnet.html)
4. [CPC Implementation (Github)](https://github.com/jefflai108/Contrastive-Predictive-Coding-PyTorch)

## 1. Dataset Curation

In [None]:
from __future__ import print_function
from subprocess import call
import torch.utils.data as data
import os, mmap
import os.path
import pickle
import errno
import csv
import numpy as np
import pandas as pd
import torch
import random
import subprocess
import h5py

from shutil import copy, move
from intervaltree import IntervalTree
from scipy.io import wavfile
from tqdm import tqdm
from pathlib import Path
from tempfile import NamedTemporaryFile


### Downloading MusicNet data
We follow a similar setup process to [Universal Music Translation Network](https://github.com/facebookresearch/music-translation)
1. Download raw data from [https://homes.cs.washington.edu/~thickstn/media/](https://homes.cs.washington.edu/~thickstn/media/)
2. Extract files into `train_data`, `train_labels`, `test_data`, `test_labels` subdirectories
3. Parse the raw data and organize them by either `ensemble` or `composer`
4. Split into train/test/val sets
5. (Optional) Perform preprocessing on the audio for training

In [None]:
def _check_exists(root):
    return os.path.exists(os.path.join(root, "train_data")) and \
        os.path.exists(os.path.join(root, "test_data")) and \
        os.path.exists(os.path.join(root, "train_labels")) and \
        os.path.exists(os.path.join(root, "test_labels"))

In [None]:
def download_data(root):
    """Download MusicNet data at root.
    Adapted from https://github.com/jthickstun/pytorch_musicnet

    Parameters
    ----------
    root : str, Path
        Directory to download MusicNet data. Will create train_data, train_labels,
        test_data, test_labels, and raw subdirectories.
    """
    from six.moves import urllib

    if _check_exists(root):
        return

    try:
        os.makedirs(os.path.join(root, "raw"))
    except OSError as e:
        if e.errno == errno.EEXIST:
            pass
        else:
            raise
    
    # Download musicnet.tar.gz
    url = "https://homes.cs.washington.edu/~thickstn/media/musicnet.tar.gz"
    filename = url.rpartition('/')[2]
    file_path = os.path.join(root, "raw", filename)
    if not os.path.exists(file_path):
        print(f"Downloading {url}")
        data = urllib.request.urlopen(url)
        with open(file_path, 'wb') as f:
            # stream the download to disk (it might not fit in memory!)
            while True:
                chunk = data.read(16*1024)
                if not chunk:
                    break
                f.write(chunk)

    # Unpack musicnet.tar.gz
    extracted_folders = ["train_data", "train_labels", "test_data", "test_labels"]
    if not all(map(lambda f: os.path.exists(os.path.join(root, f)), extracted_folders)):
        print('Extracting ' + filename)
        if call(["tar", "-xf", file_path, '-C', root, '--strip', '1']) != 0:
            raise OSError("Failed tarball extraction")

    # Download musicnet_metadata.csv
    url = "https://homes.cs.washington.edu/~thickstn/media/musicnet_metadata.csv"
    metadata = urllib.request.urlopen(url)
    with open(os.path.join(root, 'musicnet_metadata.csv'), 'wb') as f:
        while True:
            chunk = metadata.read(16*1024)
            if not chunk:
                break
            f.write(chunk)

    print('Download Complete')

In [None]:
root = "/content/musicnet"
download_data(root)

Download Complete


In [None]:
metadata = pd.read_csv('/content/musicnet/musicnet_metadata.csv')
metadata.head()

Unnamed: 0,id,composer,composition,movement,ensemble,source,transcriber,catalog_name,seconds
0,1727,Schubert,Piano Quintet in A major,2. Andante,Piano Quintet,European Archive,http://tirolmusic.blogspot.com/,OP114,447
1,1728,Schubert,Piano Quintet in A major,3. Scherzo: Presto,Piano Quintet,European Archive,http://tirolmusic.blogspot.com/,OP114,251
2,1729,Schubert,Piano Quintet in A major,4. Andantino - Allegretto,Piano Quintet,European Archive,http://tirolmusic.blogspot.com/,OP114,444
3,1730,Schubert,Piano Quintet in A major,5. Allegro giusto,Piano Quintet,European Archive,http://tirolmusic.blogspot.com/,OP114,368
4,1733,Schubert,Piano Sonata in A major,2. Andantino,Solo Piano,Museopen,Segundo G. Yogore,D959,546


In [None]:
def process_labels(root, path):
    """Parse label CSVs for MusicNet and store in a dictionary
    containing IntervalTrees 
    
    Parameters
    ----------
    root : str, Path
        Absolute path to root of data directory
    
    path : str, Path
        Subdirectory in root to parse labels from

    Returns
    -------
    trees : dict
        Dictionary of IntervalTrees for each CSV found in the specified
        subdirectory path.
    """
    trees = dict()
    for item in os.listdir(os.path.join(root,path)):
        if not item.endswith('.csv'): continue
        uid = int(item[:-4])
        tree = IntervalTree()
        with open(os.path.join(root, path, item), 'r') as f:
            reader = csv.DictReader(f, delimiter=',')
            for label in reader:
                start_time = int(label['start_time'])
                end_time = int(label['end_time'])
                instrument = int(label['instrument'])
                note = int(label['note'])
                start_beat = float(label['start_beat'])
                end_beat = float(label['end_beat'])
                note_value = label['note_value']
                tree[start_time:end_time] = (instrument,note,start_beat,end_beat,note_value)
        trees[uid] = tree
    return trees

In [None]:
train_labels = process_labels(root, "train_labels")
test_labels = process_labels(root, "test_labels")

In [None]:
def curate_data(root, destination, metadata, groupby='composer', disable_progress_bar=True):
    """Organize original dataset structure into 
    
    """
    if not hasattr(metadata, "columns"):
        raise AttributeError('metadata must have a columns attribute')

    if groupby not in metadata.columns:
        raise ValueError(f'{groupby} column is not in metadata')

    root = Path(root)
    destination = Path(destination)

    if not os.path.isabs(root):
        root = Path(os.path.abspath(root))

    if not os.path.isabs(destination):
        destination = Path(os.path.abspath(destination))

    if not os.path.exists(destination):
        os.mkdir(destination)

    # Loop and move files from MusicNet into a train folder grouped by "groupby"
    train_dir = root / "train_data"
    test_dir = root / "test_data"
    for group_name, group_df in tqdm(metadata.groupby(groupby), disable=disable_progress_bar):
        group_ids = group_df.id.tolist()

        out_dir = destination / f"{group_name.replace(' ', '_')}"
        if not os.path.exists(out_dir):
            os.mkdir(out_dir)
        
        for fid in group_ids:
            
            fname = train_dir / f"{fid}.wav"
            if not fname.exists():
                fname = test_dir / f"{fid}.wav"
            
            copy(str(fname), str(out_dir))

    print(f"Curated data at {destination}")

In [None]:
# parsed_dir = Path("musicnet/parsed")
# curate_data(root, parsed_dir, metadata, 'composer')

Curated data to /content/musicnet/parsed


In [None]:
def parse_data(src, dst, domains):
    """
    Extract the desired domains from the raw MusicNet files

    Parameters
    ----------
    src: str
        Path to input data (e.g. /content/musicnet)
        
    """

    dst.mkdir(exist_ok=True, parents=True)
    
    db = pd.read_csv( src / 'musicnet_metadata.csv')
    traindir = src / 'train_data'
    testdir = src /'test_data'

    for (ensemble, composer) in domains:
        fid_list = db[(db["composer"] == composer) & (db["ensemble"] == ensemble)].id.tolist()
        total_time = sum(db[(db["composer"] == composer) & (db["ensemble"] == ensemble)].seconds.tolist())
        print(f"Total time for {composer} with {ensemble} is: {total_time} seconds")


        domaindir = dst / f"{composer}_{ensemble.replace(' ', '_')}"
        if not os.path.exists(domaindir):
            os.mkdir(domaindir)

        for fid in fid_list:
            fname = traindir / f'{fid}.wav'
            if not fname.exists():
                fname = testdir / f'{fid}.wav'

            copy(str(fname), str(domaindir))

In [None]:
domains = [
        ['Accompanied Violin', 'Beethoven'],
        ['Solo Cello', 'Bach'],
        ['Solo Piano', 'Bach'],
        ['Solo Piano', 'Beethoven'],
        ['String Quartet', 'Beethoven'],
        ['Wind Quintet', 'Cambini'],
    ]
src_path = Path('/content/musicnet')
dst_path = Path('/content/musicnet/parsed')
parse_data(src_path, dst_path, domains)

### DataLoader Class

`TODO`: Adapt data wrapper class for easy I/O with WaveNet data
- `TODO 1`: Get this working for raw WaveNet dataset
- `TODO 2`: Get this working for composer subdirectory structure

Silu's Notes:
- Removed augmentation, other unnecessary params
- Changed to encoded dataset
- Left TODOs below


In [None]:
# UMT: utils.py
class timeit:
    def __init__(self, name, logger=None):
        self.name = name
        self.logger = logger

    def __enter__(self):
        self.start = time.time()

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self.logger is None:
            print(f'{self.name} took {(time.time() - self.start) * 1000} ms')
        else:
            self.logger.debug('%s took %s ms', self.name, (time.time() - self.start) * 1000)


def mu_law(x, mu=255):
    x = numpy.clip(x, -1, 1)
    x_mu = numpy.sign(x) * numpy.log(1 + mu*numpy.abs(x))/numpy.log(1 + mu)
    return ((x_mu + 1)/2 * mu).astype('int16')


def inv_mu_law(x, mu=255.0):
    x = numpy.array(x).astype(numpy.float32)
    y = 2. * (x - (mu+1.)/2.) / (mu+1.)
    return numpy.sign(y) * (1./mu) * ((1. + mu)**numpy.abs(y) - 1.)


class LossMeter(object):
    def __init__(self, name):
        self.name = name
        self.losses = []

    def reset(self):
        self.losses = []

    def add(self, val):
        self.losses.append(val)

    def summarize_epoch(self):
        if self.losses:
            return np.mean(self.losses)
        else:
            return 0

    def sum(self):
        return sum(self.losses)


class LogFormatter:
    def __init__(self):
        self.start_time = time.time()

    def format(self, record):
        elapsed_seconds = round(record.created - self.start_time)

        prefix = "%s - %s - %s" % (
            record.levelname,
            time.strftime('%x %X'),
            timedelta(seconds=elapsed_seconds)
        )
        message = record.getMessage()
        message = message.replace('\n', '\n' + ' ' * (len(prefix) + 3))
        return "%s - %s" % (prefix, message)


def create_output_dir(opt, path: Path):
    if hasattr(opt, 'rank'):
        filepath = path / f'main_{opt.rank}.log'
    else:
        filepath = path / 'main.log'

    if not path.exists():
        path.mkdir(parents=True, exist_ok=True)

    if hasattr(opt, 'rank') and opt.rank != 0:
        sys.stdout = open(path / f'stdout_{opt.rank}.log', 'w')
        sys.stderr = open(path / f'stderr_{opt.rank}.log', 'w')

    # Safety check
    if filepath.exists() and not opt.checkpoint:
        logging.warning("Experiment already exists!")

    # Create log formatter
    log_formatter = LogFormatter()

    # Create logger and set level to debug
    logger = logging.getLogger()
    logger.handlers = []
    logger.setLevel(logging.DEBUG)
    logger.propagate = False

    # create file handler and set level to debug
    file_handler = logging.FileHandler(filepath, "a")
    file_handler.setLevel(logging.DEBUG)
    file_handler.setFormatter(log_formatter)
    logger.addHandler(file_handler)

    # create console handler and set level to info
    if hasattr(opt, 'rank') and opt.rank == 0:
        console_handler = logging.StreamHandler()
        console_handler.setLevel(logging.INFO)
        console_handler.setFormatter(log_formatter)
        logger.addHandler(console_handler)

    # reset logger elapsed time
    def reset_time():
        log_formatter.start_time = time.time()
    logger.reset_time = reset_time

    logger.info(opt)
    return logger


def setup_logger(logger_name, filename):
    logger = logging.getLogger(logger_name)
    logger.handlers = []
    logger.setLevel(logging.DEBUG)
    logger.propagate = False

    stderr_handler = logging.StreamHandler(sys.stderr)
    file_handler = logging.FileHandler(filename)
    file_handler.setLevel(logging.DEBUG)
    if "RANK" in os.environ and os.environ["RANK"] != "0":
        stderr_handler.setLevel(logging.WARNING)
    else:
        stderr_handler.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    stderr_handler.setFormatter(formatter)
    file_handler.setFormatter(formatter)
    logger.addHandler(stderr_handler)
    logger.addHandler(file_handler)
    return logger


def wrap(data, **kwargs):
    if torch.is_tensor(data):
        var = data.cuda(non_blocking=True)
        return var
    else:
        return tuple([wrap(x, **kwargs) for x in data])


def save_audio(x, path, rate):
    path.parent.mkdir(parents=True, exist_ok=True)
    wavfile.write(path, rate, x)


def save_wav_image(wav, path):
    path.parent.mkdir(parents=True, exist_ok=True)
    plt.figure(figsize=(15, 5))
    plt.plot(wav)
    plt.savefig(path)

In [None]:
# UMT: data.py
logger = setup_logger(__name__, 'data.log')


def random_of_length(seq, length):
    limit = seq.size(0) - length
    if length < 1:
        # logging.warning("%d %s" % (length, path))
        return None

    start = random.randint(0, limit)
    end = start + length
    return seq[start: end]


class EncodedFilesDataset(data.Dataset):
    """
    Uses ffmpeg to read a random short segment from the middle of an encoded file
    """
    FILE_TYPES = ['mp3', 'ape', 'm4a', 'flac', 'mkv', 'wav']
    WAV_FREQ = 16000
    INPUT_FREQ = 44100
    FFT_SZ = 2048
    WINLEN = FFT_SZ - 1
    HOP_SZ = 80

    def __init__(self, top, seq_len=None, file_type=None, epoch_len=10000):
        self.path = Path(top)
        self.seq_len = seq_len
        self.file_types = [file_type] if file_type else self.FILE_TYPES
        self.file_paths = self.filter_paths(self.path.glob('**/*'), self.file_types)
        self.epoch_len = epoch_len

    @staticmethod
    def filter_paths(haystack, file_types):
        return [f for f in haystack
                if (f.is_file()
                    and any(f.name.endswith(suffix) for suffix in file_types)
                    and '__MACOSX' not in f.parts)]

    def _random_file(self):
        # return np.random.choice(self.file_paths, p=self.probs)
        return random.choice(self.file_paths)

    @staticmethod
    def _file_length(file_path):
        output = subprocess.run(['ffprobe',
                                 '-show_entries', 'format=duration',
                                 '-v', 'quiet',
                                 '-print_format', 'compact=print_section=0:nokey=1:escape=csv',
                                 str(file_path)],
                                stdout=subprocess.PIPE,
                                stderr=subprocess.PIPE).stdout
        duration = float(output)

        return duration

    def _file_slice(self, file_path, start_time):
        length_sec = self.seq_len / self.WAV_FREQ
        length_sec += .01  # just in case
        with NamedTemporaryFile() as output_file:
            output = subprocess.run(['ffmpeg',
                                     '-v', 'quiet',
                                     '-y',  # overwrite
                                     '-ss', str(start_time),
                                     '-i', str(file_path),
                                     '-t', str(length_sec),
                                     '-f', 'wav',
                                     # '-af', 'dynaudnorm',
                                     '-ar', str(self.WAV_FREQ),  # audio rate
                                     '-ac', '1',  # audio channels
                                     output_file.name
                                     ],
                                    stdout=subprocess.PIPE,
                                    stderr=subprocess.PIPE).stdout
            rate, wav_data = wavfile.read(output_file)
            assert wav_data.dtype == np.int16
            wav = wav_data[:self.seq_len].astype('float')

            return wav

    def __len__(self):
        return self.epoch_len

    def __getitem__(self, _):
        wav = self.random_file_slice()
        return torch.FloatTensor(wav)

    def random_file_slice(self):
        wav_data = None

        while wav_data is None or len(wav_data) != self.seq_len:
            try:
                file, file_length_sec, start_time, wav_data = self.try_random_file_slice()
            except Exception as e:
                logger.exception('Exception %s in random_file_slice.', e)

        # logger.debug('Sample: File: %s, File length: %s, Start time: %s',
        #              file, file_length_sec, start_time)

        return wav_data

    def try_random_file_slice(self):
        file = self._random_file()
        file_length_sec = self._file_length(file)
        segment_length_sec = self.seq_len / self.WAV_FREQ
        if file_length_sec < segment_length_sec:
            logger.warn('File "%s" has length %s, segment length is %s',
                        file, file_length_sec, segment_length_sec)

        start_time = random.random() * (file_length_sec - segment_length_sec * 2)  # just in case
        try:
            wav_data = self._file_slice(file, start_time)
        except Exception as e:
            logger.info(f'Exception in file slice: {e}. '
                        f'File: {file}, '
                        f'File length: {file_length_sec}, '
                        f'Start time: {start_time}')
            raise

        if len(wav_data) != self.seq_len:
            logger.warn('File "%s" has length %s, segment length is %s, wav data length: %s',
                        file, file_length_sec, segment_length_sec, len(wav_data))

        return file, file_length_sec, start_time, wav_data

    def dump_to_folder(self, output: Path, norm_db=False):
        for file_path in tqdm(self.file_paths):
            output_file_path = output / file_path.relative_to(self.path).with_suffix('.h5')
            output_file_path.parent.mkdir(parents=True, exist_ok=True)
            with NamedTemporaryFile(suffix='.wav') as output_wav_file, \
                    NamedTemporaryFile(suffix='.wav') as norm_file_path, \
                    NamedTemporaryFile(suffix='.wav') as wav_convert_file:
                if norm_db:
                    logger.debug(f'Converting {file_path} to {wav_convert_file.name}')
                    subprocess.run(['ffmpeg',
                                    '-y',
                                    '-i', file_path,
                                    wav_convert_file.name],
                                   stdout=subprocess.PIPE,
                                   stderr=subprocess.PIPE)

                    logger.debug(f'Companding {wav_convert_file.name} to {norm_file_path.name}')
                    subprocess.run(['sox',
                                    '-G',
                                    wav_convert_file.name,
                                    norm_file_path.name,
                                    'compand',
                                    '0.3,1',
                                    '6:-70,-60,-20',
                                    '-5',
                                    '-90',
                                    '0.2'],
                                   stdout=subprocess.PIPE,
                                   stderr=subprocess.PIPE)
                    input_file_path = norm_file_path.name
                else:
                    input_file_path = file_path

                logger.debug(f'Converting {input_file_path} to {output_wav_file.name}')
                subprocess.run(['ffmpeg',
                                '-v', 'quiet',
                                '-y',  # overwrite
                                '-i', input_file_path,
                                # '-af', 'dynaudnorm',
                                '-f', 'wav',
                                '-ar', str(self.WAV_FREQ),  # audio rate
                                '-ac', '1',  # audio channels,
                                output_wav_file.name
                                ],
                               stdout=subprocess.PIPE,
                               stderr=subprocess.PIPE)
                try:
                    rate, wav_data = wavfile.read(output_wav_file.name)
                except ValueError:
                    logger.info(f'Cannot read {file_path} wav conversion')
                    raise
                    # raise
                assert wav_data.dtype == np.int16
                wav = wav_data.astype('float')

                with h5py.File(output_file_path, 'w') as output_file:
                    chunk_shape = (min(10000, len(wav)),)
                    wav_dset = output_file.create_dataset('wav', wav.shape, dtype=wav.dtype,
                                                          chunks=chunk_shape)
                    wav_dset[...] = wav

                logger.debug(f'Saved input {file_path} to {output_file_path}. '
                             f'Wav length: {wav.shape}')


class H5Dataset(data.Dataset):
    def __init__(self, top, seq_len, dataset_name, epoch_len=10000, augmentation=None, short=False,
                 whole_samples=False, cache=False):
        self.path = Path(top)
        self.seq_len = seq_len
        self.epoch_len = epoch_len
        self.short = short
        self.whole_samples = whole_samples
        self.augmentation = augmentation
        self.dataset_name = dataset_name

        self.file_paths = list(self.path.glob('**/*.h5'))
        if self.short:
            self.file_paths = [self.file_paths[0]]

        self.data_cache = {}
        if cache:
            for file_path in tqdm(self.file_paths,
                                       desc=f'Reading dataset {top.parent.name}/{top.name}'):
                dataset = self.read_h5_file(file_path)
                self.data_cache[file_path] = dataset[:]

        if not self.file_paths:
            logger.error(f'No files found in {self.path}')

        logger.info(f'Dataset created. {len(self.file_paths)} files, '
                    f'augmentation: {self.augmentation is not None}. '
                    f'Path: {self.path}')

    def __getitem__(self, _):
        ret = None
        while ret is None:
            try:
                ret = self.try_random_slice()
                if self.augmentation:
                    ret = [ret, self.augmentation(ret)]
                else:
                    ret = [ret, ret]

                if self.dataset_name == 'wav':
                    ret = [mu_law(x / 2 ** 15) for x in ret]
            except Exception as e:
                logger.info('Exception %s in dataset __getitem__, path %s', e, self.path)
                logger.debug('Exception in H5Dataset', exc_info=True)

        return torch.tensor(ret[0]), torch.tensor(ret[1])

    def try_random_slice(self):
        h5file_path = random.choice(self.file_paths)
        if h5file_path in self.data_cache:
            dataset = self.data_cache[h5file_path]
        else:
            dataset = self.read_h5_file(h5file_path)
        return self.read_wav_data(dataset, h5file_path)

    def read_h5_file(self, h5file_path):
        try:
            f = h5py.File(h5file_path, 'r')
        except Exception as e:
            logger.exception('Failed opening %s', h5file_path)
            raise

        try:
            dataset = f[self.dataset_name]
        except Exception:
            logger.exception(f'No dataset named {self.dataset_name} in {file_path}. '
                             f'Available datasets are: {list(f.keys())}.')

        return dataset

    def read_wav_data(self, dataset, path):
        if self.whole_samples:
            data = dataset[:]
        else:
            length = dataset.shape[0]

            if length <= self.seq_len:
                logger.debug('Length of %s is %s', path, length)

            start_time = random.randint(0, length - self.seq_len)
            data = dataset[start_time: start_time + self.seq_len]
            assert data.shape[0] == self.seq_len

        return data.T

    def __len__(self):
        return self.epoch_len

In [None]:
def copy_files(files, from_path, to_path: Path):
    for f in files:
        out_file_path = to_path / f.relative_to(from_path)
        out_file_path.parent.mkdir(parents=True, exist_ok=True)
        copy(f, out_file_path)
    
def move_files(files, from_path, to_path: Path):
    for f in files:
        out_file_path = to_path / f.relative_to(from_path)
        out_file_path.parent.mkdir(parents=True, exist_ok=True)
        move(f, out_file_path)

def split(input_path, output_path, train_ratio=0.8, val_ratio=0.1, copy=False, filetype=None):
    if filetype:
        filetypes = [filetype]
    else:
        filetypes = EncodedFilesDataset.FILE_TYPES
    
    input_files = EncodedFilesDataset.filter_paths(input_path.glob('**/*'), filetypes)
    random.shuffle(input_files)

    logger.info(f"Found {len(input_files)} files")

    n_train = int(len(input_files) * train_ratio)
    n_val = int(len(input_files) * val_ratio)
    if n_val == 0:
        n_val = 1
    n_test = len(input_files) - n_train - n_val

    logger.info(f'Split as follows: Train - {n_train}, Validation - {n_val}, Test - {n_test}')
    assert n_test > 0

    if copy:
        copy_files(input_files[:n_train], input_path, output_path / 'train')
        copy_files(input_files[n_train:n_train + n_val], input_path, output_path / 'val')
        copy_files(input_files[n_train + n_val:], input_path, output_path / 'test')
    else:
        move_files(input_files[:n_train], input_path, output_path / 'train')
        move_files(input_files[n_train:n_train + n_val], input_path, output_path / 'val')
        move_files(input_files[n_train + n_val:], input_path, output_path / 'test')

In [None]:
random.seed(1234)
splitdir = Path('/content/musicnet/split')

for input_path in dst_path.glob("*/"):
    basename = os.path.basename(input_path)
    output_path = Path(splitdir / basename)
    split(input_path, output_path, filetype='wav', copy=True)

Found 0 files
Split as follows: Train - 0, Validation - 1, Test - -1
Found 0 files
Split as follows: Train - 0, Validation - 1, Test - -1
Found 0 files
Split as follows: Train - 0, Validation - 1, Test - -1
Found 0 files
Split as follows: Train - 0, Validation - 1, Test - -1
Found 0 files
Split as follows: Train - 0, Validation - 1, Test - -1
Found 0 files
Split as follows: Train - 0, Validation - 1, Test - -1
Found 0 files
Split as follows: Train - 0, Validation - 1, Test - -1
Found 0 files
Split as follows: Train - 0, Validation - 1, Test - -1
Found 0 files
Split as follows: Train - 0, Validation - 1, Test - -1
Found 0 files
Split as follows: Train - 0, Validation - 1, Test - -1


In [None]:
def preprocess(input_path, output_path, norm_db=False):
    dataset = EncodedFilesDataset(input_path)
    dataset.dump_to_folder(output_path, norm_db=norm_db)
    print('Preprocessing complete')

preprocessed_dir = Path('/content/musicnet/preprocessed')
preprocess(splitdir, preprocessed_dir)

100%|██████████| 330/330 [08:16<00:00,  1.51s/it]

Preprocessing complete





In [None]:
print(*preprocessed_dir.glob('*/'), sep='\n')

/content/musicnet/preprocessed/Beethoven
/content/musicnet/preprocessed/Brahms
/content/musicnet/preprocessed/Cambini
/content/musicnet/preprocessed/Mozart
/content/musicnet/preprocessed/Haydn
/content/musicnet/preprocessed/Ravel
/content/musicnet/preprocessed/Faure
/content/musicnet/preprocessed/Dvorak
/content/musicnet/preprocessed/Schubert
/content/musicnet/preprocessed/Bach


In [None]:
class WavFrequencyAugmentation:
    def __init__(self, wav_freq, magnitude=0.5):
        self.magnitude = magnitude
        self.wav_freq = wav_freq

    def __call__(self, wav):
        length = wav.shape[0]
        perturb_length = random.randint(length // 4, length // 2)
        perturb_start = random.randint(0, length // 2)
        perturb_end = perturb_start + perturb_length
        pitch_perturb = (np.random.rand() - 0.5) * 2 * self.magnitude

        ret = np.concatenate([wav[:perturb_start],
                              librosa.effects.pitch_shift(wav[perturb_start:perturb_end],
                                                          self.wav_freq, pitch_perturb),
                              wav[perturb_end:]])

        return ret


class DatasetSet:
    # Default values are from UMT train.py argparser
    def __init__(self, dir: Path, seq_len, batch_size=32, num_workers=10, data_aug='store_true', magnitude=0.5, h5_dataset_name='wav', short='store_true'):
        if data_aug:
            # augmentation = WavFrequencyAugmentation(EncodedFilesDataset.WAV_FREQ, args.magnitude)
            augmentation = WavFrequencyAugmentation(EncodedFilesDataset.WAV_FREQ, magnitude)
        else:
            augmentation = None

        # Original epoch_len = 10000000000
        # self.train_dataset = H5Dataset(dir / 'train', seq_len, epoch_len=10000000000,
        #                                dataset_name=args.h5_dataset_name, augmentation=augmentation,
        #                                short=args.short, cache=False)
        self.train_dataset = H5Dataset(dir / 'train', seq_len, epoch_len=1000000,
                                       dataset_name=h5_dataset_name, augmentation=augmentation,
                                       short=short, cache=False)
        # self.train_loader = data.DataLoader(self.train_dataset,
        #                                     batch_size=args.batch_size,
        #                                     num_workers=args.num_workers,
        #                                     pin_memory=True)
        self.train_loader = data.DataLoader(self.train_dataset,
                                            batch_size=batch_size,
                                            num_workers=num_workers,
                                            pin_memory=True)
        

        self.train_iter = iter(self.train_loader)

        # Original epoch_len = 1000000000
        # self.valid_dataset = H5Dataset(dir / 'val', seq_len, epoch_len=100000,
        #                                dataset_name=args.h5_dataset_name, augmentation=augmentation,
        #                                short=args.short)
        self.valid_dataset = H5Dataset(dir / 'val', seq_len, epoch_len=100000,
                                       dataset_name=h5_dataset_name, augmentation=augmentation,
                                       short=short)
        # self.valid_loader = data.DataLoader(self.valid_dataset,
        #                                     batch_size=args.batch_size,
        #                                     num_workers=arg.snum_workers // 10 + 1,
        #                                     pin_memory=True)
        self.valid_loader = data.DataLoader(self.valid_dataset,
                                            batch_size=batch_size,
                                            num_workers=num_workers // 10 + 1,
                                            pin_memory=True)

        self.valid_iter = iter(self.valid_loader)

In [None]:
# UMT: train.sh arguments
kwargs = dict(seq_len=12000, batch_size=32, num_workers=2)
dataset = [DatasetSet(d, **kwargs) for d in preprocessed_dir.glob("*/")]

## 2. CPC Sandbox

In [None]:
from __future__ import print_function
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import math

## PyTorch implementation of CDCK2 speaker classifier models
# CDCK2: base model from the paper 'Representation Learning with Contrastive Predictive Coding'
# SpkClassifier: a simple NN for speaker classification

class CDCK2(nn.Module):
    def __init__(self, timestep, batch_size, seq_len):

        super(CDCK2, self).__init__()

        self.batch_size = batch_size
        self.seq_len = seq_len
        self.timestep = timestep
        self.encoder = nn.Sequential( # downsampling factor = 160
            nn.Conv1d(1, 512, kernel_size=10, stride=5, padding=3, bias=False),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Conv1d(512, 512, kernel_size=8, stride=4, padding=2, bias=False),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Conv1d(512, 512, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Conv1d(512, 512, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Conv1d(512, 512, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True)
        )
        self.gru = nn.GRU(512, 256, num_layers=1, bidirectional=False, batch_first=True)
        self.Wk  = nn.ModuleList([nn.Linear(256, 512) for i in range(timestep)])
        self.softmax  = nn.Softmax()
        self.lsoftmax = nn.LogSoftmax()

        def _weights_init(m):
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # initialize gru
        for layer_p in self.gru._all_weights:
            for p in layer_p:
                if 'weight' in p:
                    nn.init.kaiming_normal_(self.gru.__getattr__(p), mode='fan_out', nonlinearity='relu')

        self.apply(_weights_init)

    def init_hidden(self, batch_size, use_gpu=True):
        if use_gpu: return torch.zeros(1, batch_size, 256).cuda()
        else: return torch.zeros(1, batch_size, 256)

    def forward(self, x, hidden):
        batch = x.size()[0]

        # Encoder downsamples x by 160
        # t_samples is a random index into the encoded sequence z
        t_samples = torch.randint(self.seq_len // 160 - self.timestep, size=(1,)).long() # randomly pick a time stamp

        # input sequence is N*C*L, e.g. 8*1*20480
        z = self.encoder(x)
        # encoded sequence is N*C*L, e.g. 8*512*128
        # reshape to N*L*C for GRU, e.g. 8*128*512
        z = z.transpose(1,2)
        encode_samples = torch.empty((self.timestep, batch, 512)).float() # e.g. size 12*8*512
        for k in np.arange(1, self.timestep+1):
            encode_samples[k-1] = z[:,t_samples+k,:].view(batch, 512) # z_t+k e.g. size 8*512
        forward_seq = z[:,:t_samples+1,:] # e.g. size 8*100*512
        output, hidden = self.gru(forward_seq, hidden) # output size e.g. 8*100*256
        c_t = output[:,t_samples,:].view(batch, 256) # c_t e.g. size 8*256
        pred = torch.empty((self.timestep, batch, 512)).float() # e.g. size 12*8*512
        for i in np.arange(self.timestep):
            linear = self.Wk[i]
            pred[i] = linear(c_t) # Wk*c_t e.g. size 8*512
        
        # InfoNCELoss -- we will likely want to separate this out
        nce = 0 # average over timestep and batch
        correct = 0
        for i in np.arange(self.timestep):
            total = torch.mm(encode_samples[i], torch.transpose(pred[i], 0, 1)) # e.g. size 8*8
            correct += torch.sum(torch.eq(torch.argmax(self.softmax(total), dim=0), torch.arange(0, batch))) # correct is a tensor
            nce += torch.sum(torch.diag(self.lsoftmax(total))) # nce is a tensor
        nce /= -1.*batch*self.timestep
        accuracy = 1.*correct.item()/(batch * self.timestep)

        return accuracy, nce, hidden

    def predict(self, x, hidden):
        batch = x.size()[0]
        # input sequence is N*C*L, e.g. 8*1*20480
        z = self.encoder(x)
        # encoded sequence is N*C*L, e.g. 8*512*128
        # reshape to N*L*C for GRU, e.g. 8*128*512
        z = z.transpose(1, 2)
        output, hidden = self.gru(z, hidden) # output size e.g. 8*128*256

        return output, hidden # return every frame
        #return output[:,-1,:], hidden # only return the last frame per utt

In [None]:
class InfoNCELoss(nn.Module):
    """Separate NCE loss class from CPC model
    """
    def __init__(self, *args):
        super(InfoNCELoss, self).__init__()
    
    def forward(self, x):
        pass

In [None]:
class SpkClassifier(nn.Module):
    ''' linear classifier '''
    def __init__(self, spk_num):

        super(SpkClassifier, self).__init__()

        self.classifier = nn.Sequential(
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, spk_num)
            #nn.Linear(256, spk_num)
        )

        def _weights_init(m):
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        self.apply(_weights_init)

    def forward(self, x):
        x = self.classifier(x)

        return F.log_softmax(x, dim=-1)

### Running CPC on one sample

In [None]:
file_id = 2186

samplerate, sample = wavfile.read(f'/content/musicnet/train/Bach/data/{file_id}.wav')

batch_size = 1
seq_len = 20480

sample = sample[:seq_len].copy()
sample = torch.from_numpy(sample)
sample = sample.view(1, 1, *sample.shape)

In [None]:
labels = pd.read_csv(f'/content/musicnet/train/Bach/labels/{file_id}.csv')
labels.head()

Unnamed: 0,start_time,end_time,instrument,note,start_beat,end_beat,note_value
0,45534,55262,41,88,0.5,0.239583,Sixteenth
1,55774,60893,41,87,0.75,0.239583,Sixteenth
2,61406,71646,41,88,1.0,0.489583,Eighth
3,71646,83422,41,83,1.5,0.489583,Eighth
4,83934,96222,41,80,2.0,0.489583,Eighth


In [None]:
def train(model, sample, optimizer):
    model.train()
    #sample = sample.float().unsqueeze(1).cuda() # add channel dimension
    sample = sample.float()

    optimizer.zero_grad()
    hidden = model.init_hidden(len(sample), use_gpu=False)
    acc, loss, hidden = model(sample, hidden)

    print(acc, loss)

    loss.backward()
    optimizer.step()

In [None]:
model = CDCK2(12, 1, seq_len)
optimizer = torch.optim.Adam(model.parameters())

train(model, sample, optimizer)



1.0 tensor(-0., grad_fn=<DivBackward0>)
