# Universal Music Translation
*Last modified: April 21, 2021*

This notebook sets up the UMT architecture and attempts to train it on a single batch of data.

**References** 
1. [Script to download raw MusicNet data (Github)](https://github.com/jthickstun/pytorch_musicnet)
2. [UMT's MusicNet (Github)](https://github.com/facebookresearch/music-translation)
3. [MusicNet Documentation](https://homes.cs.washington.edu/~thickstn/musicnet.html)

In [1]:
import torch
import torch.optim as optim
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import clip_grad_value_

torch.backends.cudnn.benchmark = True
torch.multiprocessing.set_start_method('spawn', force=True)

import os
import sys
import random
import subprocess
from subprocess import call
import shutil
from shutil import copy, move
import errno
from itertools import chain
import numpy as np
import pandas as pd
from pathlib import Path
import tqdm
import re
import csv

from collections import deque

import time
import logging
from datetime import timedelta

# UMT: data.py
from tempfile import NamedTemporaryFile
import h5py
import librosa
import torch.utils.data as data
from scipy.io import wavfile

import matplotlib

### Downloading and Preprocessing MusicNet Data
1. Download and unzip raw MusicNet files
2. Parse files by domain and composer
3. Split files into train, test, val folders
4. Add pitch augmentation 

In [2]:
def _check_exists(root):
    return os.path.exists(os.path.join(root, "train_data")) and \
        os.path.exists(os.path.join(root, "test_data")) and \
        os.path.exists(os.path.join(root, "train_labels")) and \
        os.path.exists(os.path.join(root, "test_labels"))

In [3]:
def download_data(root):
    """Download MusicNet data at root.
    Adapted from https://github.com/jthickstun/pytorch_musicnet

    Parameters
    ----------
    root : str, Path
        Directory to download MusicNet data. Will create train_data, train_labels,
        test_data, test_labels, and raw subdirectories.
    """
    from six.moves import urllib

    if not _check_exists(root):
        try:
            os.makedirs(os.path.join(root, "raw"))
        except OSError as e:
            if e.errno == errno.EEXIST:
                pass
            else:
                raise
        
        # Download musicnet.tar.gz
        url = "https://homes.cs.washington.edu/~thickstn/media/musicnet.tar.gz"
        filename = url.rpartition('/')[2]
        file_path = os.path.join(root, "raw", filename)
        if not os.path.exists(file_path):
            print(f"Downloading {url}")
            data = urllib.request.urlopen(url)
            with open(file_path, 'wb') as f:
                # stream the download to disk (it might not fit in memory!)
                while True:
                    chunk = data.read(16*1024)
                    if not chunk:
                        break
                    f.write(chunk)

        # Unpack musicnet.tar.gz
        extracted_folders = ["train_data", "train_labels", "test_data", "test_labels"]
        if not all(map(lambda f: os.path.exists(os.path.join(root, f)), extracted_folders)):
            print('Extracting ' + filename)
            if call(["tar", "-xf", file_path, '-C', root, '--strip', '1']) != 0:
                raise OSError("Failed tarball extraction")

    # Download musicnet_metadata.csv
    url = "https://homes.cs.washington.edu/~thickstn/media/musicnet_metadata.csv"
    metadata = urllib.request.urlopen(url)
    with open(os.path.join(root, 'musicnet_metadata.csv'), 'wb') as f:
        while True:
            chunk = metadata.read(16*1024)
            if not chunk:
                break
            f.write(chunk)

    print('Download Complete')

In [34]:
download = True
if download:
    root = '/content/musicnet'
    download_data(root)
else:
    print("Using data from Google Drive")

Download Complete


In [4]:
def parse_data(src, dst, domains):
    """
    Extract the desired domains from the raw MusicNet files

    Parameters
    ----------
    src: str
        Path to input data (e.g. /content/musicnet)
        
    """

    dst.mkdir(exist_ok=True, parents=True)
    
    db = pd.read_csv( src / 'musicnet_metadata.csv')
    traindir = src / 'train_data'
    testdir = src /'test_data'

    for (ensemble, composer) in domains:
        fid_list = db[(db["composer"] == composer) & (db["ensemble"] == ensemble)].id.tolist()
        total_time = sum(db[(db["composer"] == composer) & (db["ensemble"] == ensemble)].seconds.tolist())
        print(f"Total time for {composer} with {ensemble} is: {total_time} seconds")


        domaindir = dst / f"{composer}_{ensemble.replace(' ', '_')}"
        if not os.path.exists(domaindir):
            os.mkdir(domaindir)

        for fid in fid_list:
            fname = traindir / f'{fid}.wav'
            if not fname.exists():
                fname = testdir / f'{fid}.wav'

            copy(str(fname), str(domaindir))

In [5]:
domains = [
        ['Accompanied Violin', 'Beethoven'],
        ['Solo Cello', 'Bach'],
        ['Solo Piano', 'Bach'],
        ['Solo Piano', 'Beethoven'],
        ['String Quartet', 'Beethoven'],
        ['Wind Quintet', 'Cambini'],
    ]
if download:
    src_path = Path('/content/musicnet')
else:
    src_path = Path('/content/gdrive/MyDrive/College/Spring 2021/DL Final Project/musicnet')

dst_path = Path('/content/musicnet/parsed')

parse_data(src_path, dst_path, domains)

NameError: ignored

In [6]:
# UMT: utils.py
class timeit:
    def __init__(self, name, logger=None):
        self.name = name
        self.logger = logger

    def __enter__(self):
        self.start = time.time()

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self.logger is None:
            print(f'{self.name} took {(time.time() - self.start) * 1000} ms')
        else:
            self.logger.debug('%s took %s ms', self.name, (time.time() - self.start) * 1000)


def mu_law(x, mu=255):
    x = np.clip(x, -1, 1)
    x_mu = np.sign(x) * np.log(1 + mu*np.abs(x))/np.log(1 + mu)
    return ((x_mu + 1)/2 * mu).astype('int16')


def inv_mu_law(x, mu=255.0):
    x = np.array(x).astype(np.float32)
    y = 2. * (x - (mu+1.)/2.) / (mu+1.)
    return np.sign(y) * (1./mu) * ((1. + mu)**np.abs(y) - 1.)


class LossMeter(object):
    def __init__(self, name):
        self.name = name
        self.losses = []

    def reset(self):
        self.losses = []

    def add(self, val):
        self.losses.append(val)

    def summarize_epoch(self):
        if self.losses:
            return np.mean(self.losses)
        else:
            return 0

    def sum(self):
        return sum(self.losses)


class LogFormatter:
    def __init__(self):
        self.start_time = time.time()

    def format(self, record):
        elapsed_seconds = round(record.created - self.start_time)

        prefix = "%s - %s - %s" % (
            record.levelname,
            time.strftime('%x %X'),
            timedelta(seconds=elapsed_seconds)
        )
        message = record.getMessage()
        message = message.replace('\n', '\n' + ' ' * (len(prefix) + 3))
        return "%s - %s" % (prefix, message)


def create_output_dir(opt, path: Path):
    if hasattr(opt, 'rank'):
        filepath = path / f'main_{opt.rank}.log'
    else:
        filepath = path / 'main.log'

    print(path)
    if not path.exists():
        path.mkdir(parents=True, exist_ok=True)

    if hasattr(opt, 'rank') and opt.rank != 0:
        sys.stdout = open(path / f'stdout_{opt.rank}.log', 'w')
        sys.stderr = open(path / f'stderr_{opt.rank}.log', 'w')

    # Safety check
    if filepath.exists() and not opt.checkpoint:
        logging.warning("Experiment already exists!")

    # Create log formatter
    log_formatter = LogFormatter()

    # Create logger and set level to debug
    logger = logging.getLogger()
    logger.handlers = []
    logger.setLevel(logging.DEBUG)
    logger.propagate = False

    # create file handler and set level to debug
    file_handler = logging.FileHandler(filepath, "a")
    file_handler.setLevel(logging.DEBUG)
    file_handler.setFormatter(log_formatter)
    logger.addHandler(file_handler)

    # create console handler and set level to info
    if hasattr(opt, 'rank') and opt.rank == 0:
        console_handler = logging.StreamHandler()
        console_handler.setLevel(logging.INFO)
        console_handler.setFormatter(log_formatter)
        logger.addHandler(console_handler)

    # reset logger elapsed time
    def reset_time():
        log_formatter.start_time = time.time()
    logger.reset_time = reset_time

    logger.info(opt)
    return logger


def setup_logger(logger_name, filename):
    logger = logging.getLogger(logger_name)
    logger.handlers = []
    logger.setLevel(logging.DEBUG)
    logger.propagate = False

    stderr_handler = logging.StreamHandler(sys.stderr)
    file_handler = logging.FileHandler(filename)
    file_handler.setLevel(logging.DEBUG)
    if "RANK" in os.environ and os.environ["RANK"] != "0":
        stderr_handler.setLevel(logging.WARNING)
    else:
        stderr_handler.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    stderr_handler.setFormatter(formatter)
    file_handler.setFormatter(formatter)
    logger.addHandler(stderr_handler)
    logger.addHandler(file_handler)
    return logger


def wrap(data, **kwargs):
    if torch.is_tensor(data):
        var = data.cuda(non_blocking=True)
        return var
    else:
        return tuple([wrap(x, **kwargs) for x in data])


def save_audio(x, path, rate):
    path.parent.mkdir(parents=True, exist_ok=True)
    wavfile.write(path, rate, x)


def save_wav_image(wav, path):
    path.parent.mkdir(parents=True, exist_ok=True)
    plt.figure(figsize=(15, 5))
    plt.plot(wav)
    plt.savefig(path)

In [7]:
# UMT: data.py
logger = setup_logger(__name__, 'data.log')


def random_of_length(seq, length):
    limit = seq.size(0) - length
    if length < 1:
        # logging.warning("%d %s" % (length, path))
        return None

    start = random.randint(0, limit)
    end = start + length
    return seq[start: end]


class EncodedFilesDataset(data.Dataset):
    """
    Uses ffmpeg to read a random short segment from the middle of an encoded file
    """
    FILE_TYPES = ['mp3', 'ape', 'm4a', 'flac', 'mkv', 'wav']
    WAV_FREQ = 16000
    INPUT_FREQ = 44100
    FFT_SZ = 2048
    WINLEN = FFT_SZ - 1
    HOP_SZ = 80

    def __init__(self, top, seq_len=None, file_type=None, epoch_len=10000):
        self.path = Path(top)
        self.seq_len = seq_len
        self.file_types = [file_type] if file_type else self.FILE_TYPES
        self.file_paths = self.filter_paths(self.path.glob('**/*'), self.file_types)
        self.epoch_len = epoch_len

    @staticmethod
    def filter_paths(haystack, file_types):
        return [f for f in haystack
                if (f.is_file()
                    and any(f.name.endswith(suffix) for suffix in file_types)
                    and '__MACOSX' not in f.parts)]

    def _random_file(self):
        # return np.random.choice(self.file_paths, p=self.probs)
        return random.choice(self.file_paths)

    @staticmethod
    def _file_length(file_path):
        output = subprocess.run(['ffprobe',
                                 '-show_entries', 'format=duration',
                                 '-v', 'quiet',
                                 '-print_format', 'compact=print_section=0:nokey=1:escape=csv',
                                 str(file_path)],
                                stdout=subprocess.PIPE,
                                stderr=subprocess.PIPE).stdout
        duration = float(output)

        return duration

    def _file_slice(self, file_path, start_time):
        length_sec = self.seq_len / self.WAV_FREQ
        length_sec += .01  # just in case
        with NamedTemporaryFile() as output_file:
            output = subprocess.run(['ffmpeg',
                                     '-v', 'quiet',
                                     '-y',  # overwrite
                                     '-ss', str(start_time),
                                     '-i', str(file_path),
                                     '-t', str(length_sec),
                                     '-f', 'wav',
                                     # '-af', 'dynaudnorm',
                                     '-ar', str(self.WAV_FREQ),  # audio rate
                                     '-ac', '1',  # audio channels
                                     output_file.name
                                     ],
                                    stdout=subprocess.PIPE,
                                    stderr=subprocess.PIPE).stdout
            rate, wav_data = wavfile.read(output_file)
            assert wav_data.dtype == np.int16
            wav = wav_data[:self.seq_len].astype('float')

            return wav

    def __len__(self):
        return self.epoch_len

    def __getitem__(self, _):
        wav = self.random_file_slice()
        return torch.FloatTensor(wav)

    def random_file_slice(self):
        wav_data = None

        while wav_data is None or len(wav_data) != self.seq_len:
            try:
                file, file_length_sec, start_time, wav_data = self.try_random_file_slice()
            except Exception as e:
                logger.exception('Exception %s in random_file_slice.', e)

        # logger.debug('Sample: File: %s, File length: %s, Start time: %s',
        #              file, file_length_sec, start_time)

        return wav_data

    def try_random_file_slice(self):
        file = self._random_file()
        file_length_sec = self._file_length(file)
        segment_length_sec = self.seq_len / self.WAV_FREQ
        if file_length_sec < segment_length_sec:
            logger.warn('File "%s" has length %s, segment length is %s',
                        file, file_length_sec, segment_length_sec)

        start_time = random.random() * (file_length_sec - segment_length_sec * 2)  # just in case
        try:
            wav_data = self._file_slice(file, start_time)
        except Exception as e:
            logger.info(f'Exception in file slice: {e}. '
                        f'File: {file}, '
                        f'File length: {file_length_sec}, '
                        f'Start time: {start_time}')
            raise

        if len(wav_data) != self.seq_len:
            logger.warn('File "%s" has length %s, segment length is %s, wav data length: %s',
                        file, file_length_sec, segment_length_sec, len(wav_data))

        return file, file_length_sec, start_time, wav_data

    def dump_to_folder(self, output: Path, norm_db=False):
        for file_path in tqdm.tqdm(self.file_paths):
            output_file_path = output / file_path.relative_to(self.path).with_suffix('.h5')
            output_file_path.parent.mkdir(parents=True, exist_ok=True)
            with NamedTemporaryFile(suffix='.wav') as output_wav_file, \
                    NamedTemporaryFile(suffix='.wav') as norm_file_path, \
                    NamedTemporaryFile(suffix='.wav') as wav_convert_file:
                if norm_db:
                    logger.debug(f'Converting {file_path} to {wav_convert_file.name}')
                    subprocess.run(['ffmpeg',
                                    '-y',
                                    '-i', file_path,
                                    wav_convert_file.name],
                                   stdout=subprocess.PIPE,
                                   stderr=subprocess.PIPE)

                    logger.debug(f'Companding {wav_convert_file.name} to {norm_file_path.name}')
                    subprocess.run(['sox',
                                    '-G',
                                    wav_convert_file.name,
                                    norm_file_path.name,
                                    'compand',
                                    '0.3,1',
                                    '6:-70,-60,-20',
                                    '-5',
                                    '-90',
                                    '0.2'],
                                   stdout=subprocess.PIPE,
                                   stderr=subprocess.PIPE)
                    input_file_path = norm_file_path.name
                else:
                    input_file_path = file_path

                logger.debug(f'Converting {input_file_path} to {output_wav_file.name}')
                subprocess.run(['ffmpeg',
                                '-v', 'quiet',
                                '-y',  # overwrite
                                '-i', input_file_path,
                                # '-af', 'dynaudnorm',
                                '-f', 'wav',
                                '-ar', str(self.WAV_FREQ),  # audio rate
                                '-ac', '1',  # audio channels,
                                output_wav_file.name
                                ],
                               stdout=subprocess.PIPE,
                               stderr=subprocess.PIPE)
                try:
                    rate, wav_data = wavfile.read(output_wav_file.name)
                except ValueError:
                    logger.info(f'Cannot read {file_path} wav conversion')
                    raise
                    # raise
                assert wav_data.dtype == np.int16
                wav = wav_data.astype('float')

                with h5py.File(output_file_path, 'w') as output_file:
                    chunk_shape = (min(10000, len(wav)),)
                    wav_dset = output_file.create_dataset('wav', wav.shape, dtype=wav.dtype,
                                                          chunks=chunk_shape)
                    wav_dset[...] = wav

                logger.debug(f'Saved input {file_path} to {output_file_path}. '
                             f'Wav length: {wav.shape}')


class H5Dataset(data.Dataset):
    def __init__(self, top, seq_len, dataset_name, epoch_len=10000, augmentation=None, short=False,
                 whole_samples=False, cache=False):
        self.path = Path(top)
        self.seq_len = seq_len
        self.epoch_len = epoch_len
        self.short = short
        self.whole_samples = whole_samples
        self.augmentation = augmentation
        self.dataset_name = dataset_name

        self.file_paths = list(self.path.glob('**/*.h5'))
        if self.short:
            self.file_paths = [self.file_paths[0]]

        self.data_cache = {}
        if cache:
            for file_path in tqdm.tqdm(self.file_paths,
                                       desc=f'Reading dataset {top.parent.name}/{top.name}'):
                dataset = self.read_h5_file(file_path)
                self.data_cache[file_path] = dataset[:]

        if not self.file_paths:
            logger.error(f'No files found in {self.path}')

        logger.info(f'Dataset created. {len(self.file_paths)} files, '
                    f'augmentation: {self.augmentation is not None}. '
                    f'Path: {self.path}')

    def __getitem__(self, _):
        ret = None
        while ret is None:
            try:
                ret = self.try_random_slice()
                if self.augmentation:
                    ret = [ret, self.augmentation(ret)]
                else:
                    ret = [ret, ret]

                if self.dataset_name == 'wav':
                    ret = [mu_law(x / 2 ** 15) for x in ret]
            except Exception as e:
                logger.info('Exception %s in dataset __getitem__, path %s', e, self.path)
                logger.debug('Exception in H5Dataset', exc_info=True)

        return torch.tensor(ret[0]), torch.tensor(ret[1])

    def try_random_slice(self):
        h5file_path = random.choice(self.file_paths)
        if h5file_path in self.data_cache:
            dataset = self.data_cache[h5file_path]
        else:
            dataset = self.read_h5_file(h5file_path)
        return self.read_wav_data(dataset, h5file_path)

    def read_h5_file(self, h5file_path):
        try:
            f = h5py.File(h5file_path, 'r')
        except Exception as e:
            logger.exception('Failed opening %s', h5file_path)
            raise

        try:
            dataset = f[self.dataset_name]
        except Exception:
            logger.exception(f'No dataset named {self.dataset_name} in {file_path}. '
                             f'Available datasets are: {list(f.keys())}.')

        return dataset

    def read_wav_data(self, dataset, path):
        if self.whole_samples:
            data = dataset[:]
        else:
            length = dataset.shape[0]

            if length <= self.seq_len:
                logger.debug('Length of %s is %s', path, length)

            start_time = random.randint(0, length - self.seq_len)
            data = dataset[start_time: start_time + self.seq_len]
            assert data.shape[0] == self.seq_len

        return data.T

    def __len__(self):
        return self.epoch_len


class WavFrequencyAugmentation:
    def __init__(self, wav_freq, magnitude=0.5):
        self.magnitude = magnitude
        self.wav_freq = wav_freq

    def __call__(self, wav):
        length = wav.shape[0]
        perturb_length = random.randint(length // 4, length // 2)
        perturb_start = random.randint(0, length // 2)
        perturb_end = perturb_start + perturb_length
        pitch_perturb = (np.random.rand() - 0.5) * 2 * self.magnitude

        ret = np.concatenate([wav[:perturb_start],
                              librosa.effects.pitch_shift(wav[perturb_start:perturb_end],
                                                          self.wav_freq, pitch_perturb),
                              wav[perturb_end:]])

        return ret


class DatasetSet:
    def __init__(self, dir: Path, seq_len, args):
        if args.data_aug:
            augmentation = WavFrequencyAugmentation(EncodedFilesDataset.WAV_FREQ, args.magnitude)
        else:
            augmentation = None

        # Original epoch_len = 10000000000
        self.train_dataset = H5Dataset(dir / 'train', seq_len, epoch_len=1000000,
                                       dataset_name=args.h5_dataset_name, augmentation=augmentation,
                                       short=args.short, cache=False)
        self.train_loader = data.DataLoader(self.train_dataset,
                                            batch_size=args.batch_size,
                                            num_workers=args.num_workers,
                                            pin_memory=True)

        self.train_iter = iter(self.train_loader)

        # Original epoch_len = 1000000000
        # num_workers=args.num_workers // 10 + 1
        self.valid_dataset = H5Dataset(dir / 'val', seq_len, epoch_len=100000,
                                       dataset_name=args.h5_dataset_name, augmentation=augmentation,
                                       short=args.short)
        self.valid_loader = data.DataLoader(self.valid_dataset,
                                            batch_size=args.batch_size,
                                            num_workers=args.num_workers,
                                            pin_memory=True)

        self.valid_iter = iter(self.valid_loader)

In [8]:
# UMT: split_dir.py
def copy_files(files, from_path, to_path: Path):
    for f in files:
        out_file_path = to_path / f.relative_to(from_path)
        out_file_path.parent.mkdir(parents=True, exist_ok=True)
        shutil.copy(f, out_file_path)


def split(input_path: Path, output_path: Path, train_ratio, val_ratio, filetype):
    if filetype:
        filetypes = [filetype]
    else:
        filetypes = EncodedFilesDataset.FILE_TYPES

    input_files = EncodedFilesDataset.filter_paths(input_path.glob('**/*'), filetypes)
    random.shuffle(input_files)

    logger.info(f'Found {len(input_files)} files')

    n_train = int(len(input_files) * train_ratio)
    n_val = int(len(input_files) * val_ratio)
    if n_val == 0:
        n_val = 1
    n_test = len(input_files) - n_train - n_val

    logger.info('Split as follows: Train - %s, Validation - %s, Test - %s', n_train, n_val, n_test)
    assert n_test > 0

    copy_files(input_files[:n_train], input_path, output_path / 'train')
    copy_files(input_files[n_train:n_train + n_val], input_path, output_path / 'val')
    copy_files(input_files[n_train + n_val:], input_path, output_path / 'test')

def split_domains(root='/content/musicnet/parsed', 
              splitdir='/content/musicnet/split', 
              train_ratio=0.8, val_ratio=0.1):

    for subdir in os.scandir(root):
        if subdir.is_dir():
            input = Path(subdir.path)
            output = Path(os.path.join(splitdir, os.path.basename(input)))
            random.seed(0)
            split(input, output, train_ratio, val_ratio, None)

In [None]:
split_domains()

2021-05-04 23:31:31,591 - INFO - Found 12 files
2021-05-04 23:31:31,601 - INFO - Split as follows: Train - 9, Validation - 1, Test - 2
2021-05-04 23:31:38,831 - INFO - Found 28 files
2021-05-04 23:31:38,832 - INFO - Split as follows: Train - 22, Validation - 2, Test - 4
2021-05-04 23:31:59,663 - INFO - Found 22 files
2021-05-04 23:31:59,680 - INFO - Split as follows: Train - 17, Validation - 2, Test - 3
2021-05-04 23:32:21,487 - INFO - Found 9 files
2021-05-04 23:32:21,494 - INFO - Split as follows: Train - 7, Validation - 1, Test - 1
2021-05-04 23:32:28,708 - INFO - Found 39 files
2021-05-04 23:32:28,709 - INFO - Split as follows: Train - 31, Validation - 3, Test - 5
2021-05-04 23:32:41,826 - INFO - Found 93 files
2021-05-04 23:32:41,828 - INFO - Split as follows: Train - 74, Validation - 9, Test - 10


In [None]:
# UMT: preprocess.py
dataset = EncodedFilesDataset('/content/musicnet/split')
preprocdir = '/content/musicnet/preprocessed'
dataset.dump_to_folder(preprocdir)

100%|██████████| 203/203 [03:31<00:00,  1.04s/it]


### WaveNet Autoencoder Architecture

In [9]:
class CausalConv1d(nn.Conv1d):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size=2,
                 dilation=1,
                 **kwargs):
        super(CausalConv1d, self).__init__(
            in_channels,
            out_channels,
            kernel_size,
            padding=dilation * (kernel_size - 1),
            dilation=dilation,
            **kwargs)

    def forward(self, input):
        out = super(CausalConv1d, self).forward(input)
        return out[:, :, :-self.padding[0]]


class WavenetLayer(nn.Module):
    def __init__(self, residual_channels, skip_channels, cond_channels,
                 kernel_size=2, dilation=1):
        super(WavenetLayer, self).__init__()

        self.causal = CausalConv1d(residual_channels, 2 * residual_channels,
                                   kernel_size, dilation=dilation, bias=True)
        self.condition = nn.Conv1d(cond_channels, 2 * residual_channels,
                                   kernel_size=1, bias=True)
        self.residual = nn.Conv1d(residual_channels, residual_channels,
                                  kernel_size=1, bias=True)
        self.skip = nn.Conv1d(residual_channels, skip_channels,
                              kernel_size=1, bias=True)

    def _condition(self, x, c, f):
        c = f(c)
        x = x + c
        return x

    def forward(self, x, c=None):
        x = self.causal(x)
        if c is not None:
            x = self._condition(x, c, self.condition)

        assert x.size(1) % 2 == 0
        gate, output = x.chunk(2, 1)
        gate = torch.sigmoid(gate)
        output = torch.tanh(output)
        x = gate * output

        residual = self.residual(x)
        skip = self.skip(x)

        return residual, skip


class WaveNet(nn.Module):
    def __init__(self, args, create_layers=True, shift_input=True):
        super().__init__()

        self.blocks = args.blocks
        self.layer_num = args.layers
        self.kernel_size = args.kernel_size
        self.skip_channels = args.skip_channels
        self.residual_channels = args.residual_channels
        self.cond_channels = args.latent_d
        self.classes = 256
        self.shift_input = shift_input

        if create_layers:
            layers = []
            for _ in range(self.blocks):
                for i in range(self.layer_num):
                    dilation = 2 ** i
                    layers.append(WavenetLayer(self.residual_channels, self.skip_channels, self.cond_channels,
                                               self.kernel_size, dilation))
            self.layers = nn.ModuleList(layers)

        self.first_conv = CausalConv1d(1, self.residual_channels, kernel_size=self.kernel_size)
        self.skip_conv = nn.Conv1d(self.residual_channels, self.skip_channels, kernel_size=1)
        self.condition = nn.Conv1d(self.cond_channels, self.skip_channels, kernel_size=1)
        self.fc = nn.Conv1d(self.skip_channels, self.skip_channels, kernel_size=1)
        self.logits = nn.Conv1d(self.skip_channels, self.classes, kernel_size=1)

    def _condition(self, x, c, f):
        c = f(c)
        x = x + c
        return x

    @staticmethod
    def _upsample_cond(x, c):
        bsz, channels, length = x.size()
        cond_bsz, cond_channels, cond_length = c.size()
        assert bsz == cond_bsz

        if c.size(2) != 1:
            # c = c.unsqueeze(3).repeat(1, 1, 1, length // cond_length)
            # c = c.view(bsz, cond_channels, length)
            upsample = nn.Upsample(size=length)
            c = upsample(c)

        return c

    @staticmethod
    def shift_right(x):
        x = F.pad(x, (1, 0))
        return x[:, :, :-1]

    def forward(self, x, c=None):
        if x.dim() < 3:
            x = x.unsqueeze(1)
        if (not 'Half' in x.type()) and (not 'Float' in x.type()):
            x = x.float()

        x = x / 255 - 0.5

        if self.shift_input:
            x = self.shift_right(x)

        if c is not None:
            c = self._upsample_cond(x, c)

        residual = self.first_conv(x)
        skip = self.skip_conv(residual)

        for layer in self.layers:
            r, s = layer(residual, c)
            residual = residual + r
            skip = skip + s

        skip = F.relu(skip)
        skip = self.fc(skip)
        if c is not None:
            skip = self._condition(skip, c, self.condition)
        skip = F.relu(skip)
        skip = self.logits(skip)

        return skip

    ### Weights ###
    def export_layer_weights(self):
        Wdilated, Bdilated = [], []
        Wres, Bres = [], []
        Wskip, Bskip = [], []

        for l in self.layers:
            Wdilated.append(l.causal.weight)
            Bdilated.append(l.causal.bias)

            Wres.append(l.residual.weight)
            Bres.append(l.residual.bias)

            Wskip.append(l.skip.weight)
            Bskip.append(l.skip.bias)

        return Wdilated, Bdilated, Wres, Bres, Wskip, Bskip

    def export_embed_weights(self):
        inp = torch.range(0, 255) / 255 - 0.5
        prev = self.first_conv.weight[:, :, 0].cpu().contiguous()
        prev = inp.unsqueeze(1) @ prev.transpose(0, 1)
        prev = prev + self.first_conv.bias.cpu() / 2

        curr = self.first_conv.weight[:, :, 1].cpu().contiguous()
        curr = inp.unsqueeze(1) @ curr.transpose(0, 1)
        curr = curr + self.first_conv.bias.cpu() / 2

        return prev, curr

    def export_final_weights(self):
        Wzi = self.skip_conv.weight
        Bzi = self.skip_conv.bias
        Wzs = self.fc.weight
        Bzs = self.fc.bias
        Wza = self.logits.weight
        Bza = self.logits.bias

        return Wzi, Bzi, Wzs, Bzs, Wza, Bza

In [10]:
class QueuedConv1d(nn.Module):
    def __init__(self, conv, data):
        super().__init__()
        if isinstance(conv, nn.Conv1d):
            self.inner_conv = nn.Conv1d(conv.in_channels,
                                        conv.out_channels,
                                        conv.kernel_size)
            self.init_len = conv.padding[0]
            self.inner_conv.weight.data.copy_(conv.weight.data)
            self.inner_conv.bias.data.copy_(conv.bias.data)

        elif isinstance(conv, QueuedConv1d):
            self.inner_conv = nn.Conv1d(conv.inner_conv.in_channels,
                                        conv.inner_conv.out_channels,
                                        conv.inner_conv.kernel_size)
            self.init_len = conv.init_len
            self.inner_conv.weight.data.copy_(conv.inner_conv.weight.data)
            self.inner_conv.bias.data.copy_(conv.inner_conv.bias.data)

        self.init_queue(data)

    def init_queue(self, data):
        self.queue = deque([data[:, :, 0:1]]*self.init_len,
                           maxlen=self.init_len)

    def forward(self, x):
        y = x
        x = torch.cat([self.queue[0], x], dim=2)
        # import pdb; pdb.set_trace()
        self.queue.append(y)

        return self.inner_conv(x)


class WavenetGenerator(nn.Module):
    Q_ZERO = 128

    def __init__(self, wavenet: WaveNet, batch_size=1, cond_repeat=800, wav_freq=16000):
        super().__init__()
        self.wavenet = wavenet
        self.wavenet.shift_input = False
        self.cond_repeat = cond_repeat
        self.wav_freq = wav_freq
        self.batch_size = batch_size
        self.was_cuda = next(self.wavenet.parameters()).is_cuda

        x = torch.zeros(self.batch_size, 1, 1)
        x = x.cuda() if self.was_cuda else x
        self.wavenet.first_conv = QueuedConv1d(self.wavenet.first_conv, x)

        x = torch.zeros(self.batch_size, self.wavenet.residual_channels, 1)
        x = x.cuda() if self.was_cuda else x
        for layer in self.wavenet.layers:
            layer.causal = QueuedConv1d(layer.causal, x)

        if self.was_cuda:
            self.wavenet.cuda()
        self.wavenet.eval()

    def forward(self, x, c=None):
        return self.wavenet(x, c)

    def reset(self):
        return self.init()

    def init(self, batch_size=None):
        if batch_size is not None:
            self.batch_size = batch_size

        x = torch.zeros(self.batch_size, 1, 1)
        x = x.cuda() if self.was_cuda else x
        self.wavenet.first_conv.init_queue(x)

        x = torch.zeros(self.batch_size, self.wavenet.residual_channels, 1)
        x = x.cuda() if self.was_cuda else x
        for layer in self.wavenet.layers:
            layer.causal.init_queue(x)

        if self.was_cuda:
            self.wavenet.cuda()

    @staticmethod
    def softmax_and_sample(prediction, method='sample'):
        if method == 'sample':
            probabilities = F.softmax(prediction)
            samples = torch.multinomial(probabilities, 1)
        elif method == 'max':
            _, samples = torch.max(F.softmax(prediction), dim=1)
        else:
            assert False, "Method not supported."

        return samples

    def generate(self, encodings, init=True, method='sample'):
        if init:
            self.init(encodings.size(0))

        samples = torch.zeros(encodings.size(0), 1, encodings.size(2)*self.cond_repeat + 1)
        samples.fill_(self.Q_ZERO)
        samples = samples.long()
        samples = samples.cuda() if encodings.is_cuda else samples

        with torch.no_grad():
            t0 = time.time()
            for t1 in tqdm.tqdm(range(encodings.size(2)), desc='Generating'):
                for t2 in range(self.cond_repeat):
                    t = t1 * self.cond_repeat + t2
                    x = samples[:, :, t:t + 1].clone()
                    c = encodings[:, :, t1:t1+1]

                    prediction = self(x, c)[:, :, 0]

                    argmax = self.softmax_and_sample(prediction, method)

                    samples[:, :, t+1] = argmax

            logging.info(f'{encodings.size(0)} samples of {encodings.size(2)*self.cond_repeat/self.wav_freq} seconds length '
                         f'generated in {time.time() - t0} seconds.')

        return samples[:, :, 1:]


In [11]:
class DilatedResConv(nn.Module):
    def __init__(self, channels, dilation=1, activation='relu', padding=1, kernel_size=3, left_pad=0):
        super().__init__()
        in_channels = channels

        if activation == 'relu':
            self.activation = lambda *args, **kwargs: F.relu(*args, **kwargs, inplace=True)
        elif activation == 'tanh':
            self.activation = F.tanh
        elif activation == 'glu':
            self.activation = F.glu
            in_channels = channels // 2

        self.left_pad = left_pad
        self.dilated_conv = nn.Conv1d(in_channels, channels, kernel_size=kernel_size, stride=1,
                                      padding=dilation * padding, dilation=dilation, bias=True)
        self.conv_1x1 = nn.Conv1d(in_channels, channels,
                                  kernel_size=1, bias=True)

    def forward(self, input):
        x = input

        if self.left_pad > 0:
            x = F.pad(x, (self.left_pad, 0))
        x = self.dilated_conv(x)
        x = self.activation(x)
        x = self.conv_1x1(x)

        return input + x


class Encoder(nn.Module):
    def __init__(self, args):
        super().__init__()

        self.n_blocks = args.encoder_blocks
        self.n_layers = args.encoder_layers
        self.channels = args.encoder_channels
        self.latent_channels = args.latent_d
        self.activation = args.encoder_func

        try:
            self.encoder_pool = args.encoder_pool
        except AttributeError:
            self.encoder_pool = 800

        layers = []
        for _ in range(self.n_blocks):
            for i in range(self.n_layers):
                dilation = 2 ** i
                layers.append(DilatedResConv(self.channels, dilation, self.activation))
        self.dilated_convs = nn.Sequential(*layers)

        self.start = nn.Conv1d(1, self.channels, kernel_size=3, stride=1,
                               padding=1)
        self.conv_1x1 = nn.Conv1d(self.channels, self.latent_channels, 1)
        self.pool = nn.AvgPool1d(self.encoder_pool)

    def forward(self, x):
        x = x / 255 - .5
        if x.dim() < 3:
            x = x.unsqueeze(1)

        x = self.start(x)
        x = self.dilated_convs(x)
        x = self.conv_1x1(x)
        x = self.pool(x)

        return x


class ZDiscriminator(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.n_classes = args.n_datasets

        convs = []
        for i in range(args.d_layers):
            in_channels = args.latent_d if i == 0 else args.d_channels
            convs.append(nn.Conv1d(in_channels, args.d_channels, 1))
            convs.append(nn.ELU())
        convs.append(nn.Conv1d(args.d_channels, self.n_classes, 1))

        self.convs = nn.Sequential(*convs)
        self.dropout = nn.Dropout(p=args.p_dropout_discriminator)

    def forward(self, z):
        z = self.dropout(z)
        logits = self.convs(z)  # (N, n_classes, L)

        mean = logits.mean(2)
        return mean


def cross_entropy_loss(input, target):
    # input:  (batch, 256, len)
    # target: (batch, len)

    batch, channel, seq = input.size()

    input = input.transpose(1, 2).contiguous()
    input = input.view(-1, 256)  # (batch * seq, 256)
    target = target.view(-1).long()  # (batch * seq)

    cross_entropy = F.cross_entropy(input, target, reduction='none')  # (batch * seq)
    return cross_entropy.reshape(batch, seq).mean(dim=1)  # (batch)


### Training

In [12]:
class Trainer:
    def __init__(self, args):
        self.args = args
        self.args.n_datasets = len(self.args.data)
        self.expPath = Path('/content/checkpoints') / args.expName

        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)

        self.logger = create_output_dir(args, self.expPath)
        self.data = [DatasetSet(d, args.seq_len, args) for d in args.data]
        assert not args.distributed or len(self.data) == int(
            os.environ['WORLD_SIZE']), "Number of datasets must match number of nodes"

        self.losses_recon = [LossMeter(f'recon {i}') for i in range(self.args.n_datasets)]
        self.loss_d_right = LossMeter('d')
        self.loss_total = LossMeter('total')

        self.evals_recon = [LossMeter(f'recon {i}') for i in range(self.args.n_datasets)]
        self.eval_d_right = LossMeter('eval d')
        self.eval_total = LossMeter('eval total')

        self.encoder = Encoder(args)
        self.decoder = WaveNet(args)
        self.discriminator = ZDiscriminator(args)

        if args.checkpoint:
            checkpoint_args_path = os.path.dirname(args.checkpoint) + '/args.pth'
            checkpoint_args = torch.load(checkpoint_args_path)

            self.start_epoch = checkpoint_args[-1] + 1
            states = torch.load(args.checkpoint)

            self.encoder.load_state_dict(states['encoder_state'])
            self.decoder.load_state_dict(states['decoder_state'])
            self.discriminator.load_state_dict(states['discriminator_state'])

            self.logger.info('Loaded checkpoint parameters')
        else:
            self.start_epoch = 0

        if args.distributed:
            self.encoder.cuda()
            self.encoder = torch.nn.parallel.DistributedDataParallel(self.encoder)
            self.discriminator.cuda()
            self.discriminator = torch.nn.parallel.DistributedDataParallel(self.discriminator)
            self.logger.info('Created DistributedDataParallel')
        else:
            self.encoder = torch.nn.DataParallel(self.encoder).cuda()
            self.discriminator = torch.nn.DataParallel(self.discriminator).cuda()
        self.decoder = torch.nn.DataParallel(self.decoder).cuda()

        self.model_optimizer = optim.Adam(chain(self.encoder.parameters(),
                                                self.decoder.parameters()),
                                          lr=args.lr)
        self.d_optimizer = optim.Adam(self.discriminator.parameters(),
                                      lr=args.lr)

        if args.checkpoint and args.load_optimizer:
            self.model_optimizer.load_state_dict(states['model_optimizer_state'])
            self.d_optimizer.load_state_dict(states['d_optimizer_state'])

        self.lr_manager = torch.optim.lr_scheduler.ExponentialLR(self.model_optimizer, args.lr_decay)
        self.lr_manager.last_epoch = self.start_epoch
        #self.lr_manager.step()

    def eval_batch(self, x, x_aug, dset_num):
        x, x_aug = x.float(), x_aug.float()

        z = self.encoder(x)
        y = self.decoder(x, z)
        z_logits = self.discriminator(z)

        z_classification = torch.max(z_logits, dim=1)[1]

        z_accuracy = (z_classification == dset_num).float().mean()

        self.eval_d_right.add(z_accuracy.data.item())

        # discriminator_right = F.cross_entropy(z_logits, dset_num).mean()
        discriminator_right = F.cross_entropy(z_logits, torch.tensor([dset_num] * x.size(0)).long().cuda()).mean()
        recon_loss = cross_entropy_loss(y, x)

        self.evals_recon[dset_num].add(recon_loss.data.cpu().numpy().mean())

        total_loss = discriminator_right.data.item() * self.args.d_lambda + \
                     recon_loss.mean().data.item()

        self.eval_total.add(total_loss)

        return total_loss

    def train_batch(self, x, x_aug, dset_num):
        x, x_aug = x.float(), x_aug.float()

        # Optimize D - discriminator right
        z = self.encoder(x)
        z_logits = self.discriminator(z)
        discriminator_right = F.cross_entropy(z_logits, torch.tensor([dset_num] * x.size(0)).long().cuda()).mean()
        loss = discriminator_right * self.args.d_lambda
        self.d_optimizer.zero_grad()
        loss.backward()
        if self.args.grad_clip is not None:
            clip_grad_value_(self.discriminator.parameters(), self.args.grad_clip)

        self.d_optimizer.step()

        # optimize G - reconstructs well, discriminator wrong
        z = self.encoder(x_aug)
        y = self.decoder(x, z)
        z_logits = self.discriminator(z)
        discriminator_wrong = - F.cross_entropy(z_logits, torch.tensor([dset_num] * x.size(0)).long().cuda()).mean()

        if not (-100 < discriminator_right.data.item() < 100):
            self.logger.debug(f'z_logits: {z_logits.detach().cpu().numpy()}')
            self.logger.debug(f'dset_num: {dset_num}')

        recon_loss = cross_entropy_loss(y, x)
        self.losses_recon[dset_num].add(recon_loss.data.cpu().numpy().mean())

        loss = (recon_loss.mean() + self.args.d_lambda * discriminator_wrong)

        self.model_optimizer.zero_grad()
        loss.backward()
        if self.args.grad_clip is not None:
            clip_grad_value_(self.encoder.parameters(), self.args.grad_clip)
            clip_grad_value_(self.decoder.parameters(), self.args.grad_clip)
        self.model_optimizer.step()

        self.loss_total.add(loss.data.item())

        return loss.data.item()

    def train_epoch(self, epoch):
        for meter in self.losses_recon:
            meter.reset()
        self.loss_d_right.reset()
        self.loss_total.reset()

        self.encoder.train()
        self.decoder.train()
        self.discriminator.train()

        n_batches = self.args.epoch_len

        with tqdm.tqdm(total=n_batches, desc='Train epoch %d' % epoch) as train_enum:
            for batch_num in range(n_batches):
                if self.args.short and batch_num == 3:
                    break

                if self.args.distributed:
                    assert self.args.rank < self.args.n_datasets, "No. of workers must be equal to #dataset"
                    # dset_num = (batch_num + self.args.rank) % self.args.n_datasets
                    dset_num = self.args.rank
                else:
                    dset_num = batch_num % self.args.n_datasets

                x, x_aug = next(self.data[dset_num].train_iter)

                x = wrap(x)
                x_aug = wrap(x_aug)
                batch_loss = self.train_batch(x, x_aug, dset_num)

                train_enum.set_description(f'Train (loss: {batch_loss:.2f}) epoch {epoch}')
                train_enum.update()

    def evaluate_epoch(self, epoch):
        for meter in self.evals_recon:
            meter.reset()
        self.eval_d_right.reset()
        self.eval_total.reset()

        self.encoder.eval()
        self.decoder.eval()
        self.discriminator.eval()

        n_batches = int(np.ceil(self.args.epoch_len / 10))

        with tqdm.tqdm(total=n_batches) as valid_enum, \
                torch.no_grad():
            for batch_num in range(n_batches):
                if self.args.short and batch_num == 10:
                    break

                if self.args.distributed:
                    assert self.args.rank < self.args.n_datasets, "No. of workers must be equal to #dataset"
                    dset_num = self.args.rank
                else:
                    dset_num = batch_num % self.args.n_datasets

                x, x_aug = next(self.data[dset_num].valid_iter)

                x = wrap(x)
                x_aug = wrap(x_aug)
                batch_loss = self.eval_batch(x, x_aug, dset_num)

                valid_enum.set_description(f'Test (loss: {batch_loss:.2f}) epoch {epoch}')
                valid_enum.update()

    @staticmethod
    def format_losses(meters):
        losses = [meter.summarize_epoch() for meter in meters]
        return ', '.join('{:.4f}'.format(x) for x in losses)

    def train_losses(self):
        meters = [*self.losses_recon, self.loss_d_right]
        return self.format_losses(meters)

    def eval_losses(self):
        meters = [*self.evals_recon, self.eval_d_right]
        return self.format_losses(meters)

    def train(self):
        best_eval = float('inf')

        # Begin!
        for epoch in range(self.start_epoch, self.start_epoch + self.args.epochs):
            self.logger.info(f'Starting epoch, Rank {self.args.rank}, Dataset: {self.args.data[self.args.rank]}')
            self.train_epoch(epoch)
            self.evaluate_epoch(epoch)

            self.logger.info(f'Epoch %s Rank {self.args.rank} - Train loss: (%s), Test loss (%s)',
                             epoch, self.train_losses(), self.eval_losses())
            self.lr_manager.step()
            val_loss = self.eval_total.summarize_epoch()

            # TODO: args.rank is always 0, need to change to prevent overwriting
            if val_loss < best_eval:
                self.save_model(f'bestmodel_{self.args.rank}.pth')
                best_eval = val_loss

            if not self.args.per_epoch:
                self.save_model(f'lastmodel_{self.args.rank}.pth')
            else:
                self.save_model(f'lastmodel_{epoch}_rank_{self.args.rank}.pth')

            if self.args.is_master:
                torch.save([self.args,
                            epoch],
                           '%s/args.pth' % self.expPath)

            self.logger.debug('Ended epoch')

    def save_model(self, filename):
        print("save model called")
        print(self.expPath)
        print(filename)
        save_path = self.expPath / filename

        # TODO: verify whether we should be using args.rank here
        torch.save({'encoder_state': self.encoder.module.state_dict(),
                    'decoder_state': self.decoder.module.state_dict(),
                    'discriminator_state': self.discriminator.module.state_dict(),
                    'model_optimizer_state': self.model_optimizer.state_dict(),
                    'dataset': self.args.rank,
                    'd_optimizer_state': self.d_optimizer.state_dict()
                    },
                   save_path)

        self.logger.debug(f'Saved model to {save_path}')

In [13]:
class TrainerArgs:
    def __init__(self, data, epochs=10000, seed=1, expName='musicnet', checkpoint='', load_optimizer=None, per_epoch=False,
                 dist_url='env://', dist_backend='nccl', local_rank=0,
                 seq_len=16000, epoch_len=10000, batch_size=32, num_workers=10, data_aug=True, magnitude=0.5, lr=1e-4, lr_decay=0.98, short=False, h5_dataset_name='wav',
                 latent_d=128, repeat_num=6, encoder_channels=128, encoder_blocks=3, encoder_pool=800, encoder_final_kernel_size=1, encoder_layers=10, encoder_func='relu',
                 blocks=4, layers=10, kernel_size=2, residual_channels=128, skip_channels=128,
                 d_layers=3, d_channels=100, d_cond=1024, d_lambda=1e-2, p_dropout_discriminator=0.0, grad_clip=None, timestep=1):
        
        # Env options:
        self.epochs = epochs
        self.seed = seed
        self.expName = expName
        self.data = data
        self.checkpoint = checkpoint
        self.load_optimizer = load_optimizer
        self.per_epoch = per_epoch

        # Distributed
        self.dist_url = dist_url
        self.dist_backend = dist_backend
        self.local_rank = local_rank

        # Data options
        self.seq_len = seq_len
        self.epoch_len = epoch_len
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.data_aug = data_aug
        self.magnitude = magnitude
        self.lr = lr
        self.lr_decay = lr_decay
        self.short = short
        self.h5_dataset_name = h5_dataset_name

        # Encoder options
        self.latent_d = latent_d
        self.repeat_num = repeat_num
        self.encoder_channels = encoder_channels
        self.encoder_blocks = encoder_blocks
        self.encoder_pool = encoder_pool
        self.encoder_final_kernel_size = encoder_final_kernel_size
        self.encoder_layers = encoder_layers
        self.encoder_func = encoder_func

        # Decoder options
        self.blocks = blocks
        self.layers = layers
        self.kernel_size = kernel_size
        self.residual_channels = residual_channels
        self.skip_channels = skip_channels

        # Z discriminator options
        self.d_layers = d_layers
        self.d_channels = d_channels
        self.d_cond = d_cond
        self.d_lambda = d_lambda
        self.p_dropout_discriminator = p_dropout_discriminator
        self.grad_clip = grad_clip

        # CPC options
        self.timestep = timestep
        
        self.distributed = False
        if 'WORLD_SIZE' in os.environ:
            self.distributed = int(os.environ['WORLD_SIZE']) > 1

        if self.distributed:
            if int(os.environ['RANK']) == 0:
                self.is_master = True
            else:
                self.is_master = False
            self.rank = int(os.environ['RANK'])

            print('Before init_process_group')
            dist.init_process_group(backend=self.dist_backend,
                                    init_method=self.dist_url)
        else:
            self.rank = 0
            self.is_master = True

In [None]:
# num_workers must be set to 0, otherwise DataLoader exits with an unexpected error
# CUDA runs out of memory when batch_size=9

data_paths = [Path('musicnet/preprocessed/Bach_Solo_Cello'), 
        Path('musicnet/preprocessed/Beethoven_Solo_Piano'),
        Path('musicnet/preprocessed/Bach_Solo_Piano')]

# data_paths = [Path('musicnet/preprocessed/Solo_Cello'), 
#         Path('musicnet/preprocessed/Solo_Piano'),
#         Path('musicnet/preprocessed/Solo_Flute')]
args = TrainerArgs(data_paths, epochs=1, batch_size=8, lr_decay=0.995, epoch_len=1,
                  num_workers=0, lr=1e-3, seq_len=12000, d_lambda=1e-2, expName='musicnet',
                  latent_d=64, layers=14, blocks=4, data_aug=True, grad_clip=1)

torch.backends.cudnn.benchmark = True
torch.multiprocessing.set_start_method('spawn', force=True)

trainer = Trainer(args)
trainer.train()

In [None]:
!ls -l checkpoints/musicnet

total 201996
-rw-r--r-- 1 root root      1519 Apr 24 14:01 args.pth
-rw-r--r-- 1 root root 103055365 Apr 24 14:01 bestmodel_0.pth
-rw-r--r-- 1 root root 103055365 Apr 24 14:01 lastmodel_0.pth
-rw-r--r-- 1 root root    718972 Apr 24 14:01 main_0.log


### Inference

In [14]:
class QueuedConv1d(nn.Module):
    def __init__(self, conv, data):
        super().__init__()
        if isinstance(conv, nn.Conv1d):
            self.inner_conv = nn.Conv1d(conv.in_channels,
                                        conv.out_channels,
                                        conv.kernel_size)
            self.init_len = conv.padding[0]
            self.inner_conv.weight.data.copy_(conv.weight.data)
            self.inner_conv.bias.data.copy_(conv.bias.data)

        elif isinstance(conv, QueuedConv1d):
            self.inner_conv = nn.Conv1d(conv.inner_conv.in_channels,
                                        conv.inner_conv.out_channels,
                                        conv.inner_conv.kernel_size)
            self.init_len = conv.init_len
            self.inner_conv.weight.data.copy_(conv.inner_conv.weight.data)
            self.inner_conv.bias.data.copy_(conv.inner_conv.bias.data)

        self.init_queue(data)

    def init_queue(self, data):
        self.queue = deque([data[:, :, 0:1]]*self.init_len,
                           maxlen=self.init_len)

    def forward(self, x):
        y = x
        x = torch.cat([self.queue[0], x], dim=2)
        # import pdb; pdb.set_trace()
        self.queue.append(y)

        return self.inner_conv(x)


class WavenetGenerator(nn.Module):
    Q_ZERO = 128

    def __init__(self, wavenet: WaveNet, batch_size=1, cond_repeat=800, wav_freq=16000):
        super().__init__()
        self.wavenet = wavenet
        self.wavenet.shift_input = False
        self.cond_repeat = cond_repeat
        self.wav_freq = wav_freq
        self.batch_size = batch_size
        self.was_cuda = next(self.wavenet.parameters()).is_cuda

        x = torch.zeros(self.batch_size, 1, 1)
        x = x.cuda() if self.was_cuda else x
        self.wavenet.first_conv = QueuedConv1d(self.wavenet.first_conv, x)

        x = torch.zeros(self.batch_size, self.wavenet.residual_channels, 1)
        x = x.cuda() if self.was_cuda else x
        for layer in self.wavenet.layers:
            layer.causal = QueuedConv1d(layer.causal, x)

        if self.was_cuda:
            self.wavenet.cuda()
        self.wavenet.eval()

    def forward(self, x, c=None):
        return self.wavenet(x, c)

    def reset(self):
        return self.init()

    def init(self, batch_size=None):
        if batch_size is not None:
            self.batch_size = batch_size

        x = torch.zeros(self.batch_size, 1, 1)
        x = x.cuda() if self.was_cuda else x
        self.wavenet.first_conv.init_queue(x)

        x = torch.zeros(self.batch_size, self.wavenet.residual_channels, 1)
        x = x.cuda() if self.was_cuda else x
        for layer in self.wavenet.layers:
            layer.causal.init_queue(x)

        if self.was_cuda:
            self.wavenet.cuda()

    @staticmethod
    def softmax_and_sample(prediction, method='sample'):
        if method == 'sample':
            probabilities = F.softmax(prediction)
            samples = torch.multinomial(probabilities, 1)
        elif method == 'max':
            _, samples = torch.max(F.softmax(prediction), dim=1)
        else:
            assert False, "Method not supported."

        return samples

    def generate(self, encodings, init=True, method='sample'):
        if init:
            self.init(encodings.size(0))

        samples = torch.zeros(encodings.size(0), 1, encodings.size(2)*self.cond_repeat + 1)
        samples.fill_(self.Q_ZERO)
        samples = samples.long()
        samples = samples.cuda() if encodings.is_cuda else samples

        with torch.no_grad():
            t0 = time.time()
            for t1 in tqdm.tqdm(range(encodings.size(2)), desc='Generating'):
                for t2 in range(self.cond_repeat):
                    t = t1 * self.cond_repeat + t2
                    x = samples[:, :, t:t + 1].clone()
                    c = encodings[:, :, t1:t1+1]

                    prediction = self(x, c)[:, :, 0]

                    argmax = self.softmax_and_sample(prediction, method)

                    samples[:, :, t+1] = argmax

            logging.info(f'{encodings.size(0)} samples of {encodings.size(2)*self.cond_repeat/self.wav_freq} seconds length '
                         f'generated in {time.time() - t0} seconds.')

        return samples[:, :, 1:]

In [15]:
def extract_id(path):
    decoder_id = str(path)[:-4].split('_')[-1]
    return int(decoder_id)

In [16]:
def data_samples(data_path, data_from_args, output, n, seq_len=80000):
    if data_path:
        dataset_paths = data_path
    elif data_from_args:
        input_args, _ = torch.load(data_from_args)
        dataset_paths = input_args.data
    else:
        print('Please supply either --data or --data-from-args')
        return

    if dataset_paths[0].is_file():
        datasets = [H5Dataset(dataset_paths[0], seq_len, 'wav')]
    else:
        datasets = [H5Dataset(p / 'test', seq_len, 'wav')
                    for p in dataset_paths]

    for dataset_id, dataset in enumerate(datasets):
        for i in tqdm.trange(n):
            wav_data, _ = dataset[0]
            wav_data = inv_mu_law(wav_data.numpy())
            save_audio(wav_data, output / f'{dataset_id}/{i}.wav', rate=EncodedFilesDataset.WAV_FREQ)

In [17]:
def run_on_files(files, output, checkpoint, decoders=[], rate=16000, batch_size=6, sample_len=None, split_size=20, output_next_to_orig=False,
                 skip_filter=False, py=True):

    print('Starting')
    matplotlib.use('agg')

    checkpoints = checkpoint.parent.glob(checkpoint.name + '_*.pth')
    checkpoints = [c for c in checkpoints if extract_id(c) in decoders]
    assert len(checkpoints) >= 1, "No checkpoints found."

    model_args = torch.load(checkpoint.parent / 'args.pth')[0]
    encoder = Encoder(model_args)
    encoder.load_state_dict(torch.load(checkpoints[0])['encoder_state'])
    encoder.eval()
    encoder = encoder.cuda()

    decoders = []
    decoder_ids = []
    for checkpoint in checkpoints:
        decoder = WaveNet(model_args)
        decoder.load_state_dict(torch.load(checkpoint)['decoder_state'])
        decoder.eval()
        decoder = decoder.cuda()
        if py:
            decoder = WavenetGenerator(decoder, batch_size, wav_freq=rate)
        else:
            decoder = NVWavenetGenerator(decoder, rate * (split_size // 20), batch_size, 3)

        decoders += [decoder]
        decoder_ids += [extract_id(checkpoint)]

    xs = []
    assert output_next_to_orig ^ (output is not None)

    if len(files) == 1 and files[0].is_dir():
        top = files[0]
        file_paths = list(top.glob('**/*.wav')) + list(top.glob('**/*.h5'))
    else:
        file_paths = files

    if not skip_filter:
        file_paths = [f for f in file_paths if not '_' in str(f.name)]

    for file_path in file_paths:
        if file_path.suffix == '.wav':
            data, rate = librosa.load(file_path, sr=16000)
            assert rate == 16000
            data = mu_law(data)
        elif file_path.suffix == '.h5':
            data = mu_law(h5py.File(file_path, 'r')['wav'][:] / (2 ** 15))
            if data.shape[-1] % rate != 0:
                data = data[:-(data.shape[-1] % rate)]
            assert data.shape[-1] % rate == 0
            print(data.shape)
        else:
            raise Exception(f'Unsupported filetype {file_path}')

        if sample_len:
            data = data[:sample_len]
        else:
            sample_len = len(data)
        xs.append(torch.tensor(data).unsqueeze(0).float().cuda())

    xs = torch.stack(xs).contiguous()
    print(f'xs size: {xs.size()}')

    def save(x, decoder_ix, filepath):
        wav = inv_mu_law(x.cpu().numpy())
        print(f'X size: {x.shape}')
        print(f'X min: {x.min()}, max: {x.max()}')

        if output_next_to_orig:
            save_audio(wav.squeeze(), filepath.parent / f'{filepath.stem}_{decoder_ix}.wav', rate=rate)
        else:
            save_audio(wav.squeeze(), output / str(decoder_ix) / filepath.with_suffix('.wav').name, rate=rate)

    yy = {}
    with torch.no_grad():
        zz = []
        for xs_batch in torch.split(xs, batch_size):
            zz += [encoder(xs_batch)]
        zz = torch.cat(zz, dim=0)

        with timeit("Generation timer"):
            for i, decoder_id in enumerate(decoder_ids):
                yy[decoder_id] = []
                decoder = decoders[i]
                for zz_batch in torch.split(zz, batch_size):
                    print(zz_batch.shape)
                    splits = torch.split(zz_batch, split_size, -1)
                    audio_data = []
                    decoder.reset()
                    for cond in tqdm.tqdm(splits):
                        audio_data += [decoder.generate(cond).cpu()]
                    audio_data = torch.cat(audio_data, -1)
                    yy[decoder_id] += [audio_data]
                yy[decoder_id] = torch.cat(yy[decoder_id], dim=0)
                del decoder

    for decoder_ix, decoder_result in yy.items():
        for sample_result, filepath in zip(decoder_result, file_paths):
            save(sample_result, decoder_ix, filepath)

In [None]:
checkpoint_args = Path('/content/checkpoints/musicnet/args.pth')
output = Path('/content/results')
checkpoint_lastmodel = Path('/content/checkpoints/musicnet/lastmodel')

# Extract data samples to use as input for translation
data_samples(None, checkpoint_args, output, 4, 80000)

files = [Path('/content/results')]
run_on_files(files, None, checkpoint_lastmodel, decoders=[0, 1, 2, 3, 4, 5], output_next_to_orig=True)

# UMT-CPC

In [18]:
from __future__ import print_function
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import math

## PyTorch implementation of CDCK2 speaker classifier models
# CDCK2: base model from the paper 'Representation Learning with Contrastive Predictive Coding'
# SpkClassifier: a simple NN for speaker classification

class CPC(nn.Module):
    def __init__(self, args):

        timestep = args.timestep  
        seq_len = args.seq_len
        batch_size = args.batch_size

        super().__init__()

        # TODO: is it better to change #input channels for WaveNet Encoder to 512? 
        # TODO: Wavenet Encoder data gets too small after pooling. How do we adjust the downsampling factor here?
        self.batch_size = batch_size
        self.seq_len = seq_len
        self.timestep = timestep
        self.encoder = nn.Sequential( # downsampling factor = 160
            nn.Conv1d(1, 512, kernel_size=10, stride=5, padding=3, bias=False),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Conv1d(512, 512, kernel_size=8, stride=4, padding=2, bias=False),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Conv1d(512, 512, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Conv1d(512, 512, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Conv1d(512, 512, kernel_size=4, stride=2, padding=1, bias=False), 
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True)
        )
        self.gru = nn.GRU(512, 64, num_layers=1, bidirectional=False, batch_first=True)
        self.Wk  = nn.ModuleList([nn.Linear(64, 512) for i in range(timestep)])
        self.softmax  = nn.Softmax()
        self.lsoftmax = nn.LogSoftmax()

        def _weights_init(m):
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # initialize gru
        for layer_p in self.gru._all_weights:
            for p in layer_p:
                if 'weight' in p:
                    nn.init.kaiming_normal_(self.gru.__getattr__(p), mode='fan_out', nonlinearity='relu')

        self.apply(_weights_init)

    def forward(self, x, hidden):
        batch = x.size()[0]

        # Encoder downsamples x by 160
        #t_samples is a random index into the encoded sequence z
        t_samples = torch.randint(self.seq_len // 160 - self.timestep, size=(1,)).long() # randomly pick a time stamp

        # input sequence is N*C*L, e.g. 8*1*20480
        x = x / 255 - .5
        if x.dim() < 3:
            x = x.unsqueeze(1)
        z = self.encoder(x)
        # encoded sequence is N*C*L, e.g. 8*512*128
        # reshape to N*L*C for GRU, e.g. 8*128*512
        z = z.transpose(1,2)
        encode_samples = torch.empty((self.timestep, batch, 512)).float() # e.g. size 12*8*512
        for k in np.arange(1, self.timestep+1):
            encode_samples[k-1] = z[:,t_samples+k,:].view(batch, 512) # z_t+k e.g. size 8*512
        forward_seq = z[:,:t_samples+1,:] # e.g. size 8*100*512

        output, hidden = self.gru(forward_seq, hidden) # output size e.g. 8*100*256
        c_t = output[:,t_samples,:].view(batch, args.latent_d) # c_t e.g. size 8*256
        pred = torch.empty((self.timestep, batch, 512)).float() # e.g. size 12*8*512
        for i in np.arange(self.timestep):
            linear = self.Wk[i]
            pred[i] = linear(c_t) # Wk*c_t e.g. size 8*512

        output = output.transpose(1,2)

        return output, hidden, encode_samples, pred 

def InfoNCELoss(encode_samples, pred, timestep):
    batch = encode_samples.shape[1]

    # TODO: figure out which dimension to take softmax over
    nce = 0 # average over timestep and batch
    correct = 0
    for i in np.arange(timestep):
        total = torch.mm(encode_samples[i], torch.transpose(pred[i], 0, 1)) # e.g. size 8*8
        correct += torch.sum(torch.eq(torch.argmax(F.softmax(total, dim=0), dim=0), torch.arange(0, batch))) # correct is a tensor
        nce += torch.sum(torch.diag(F.log_softmax(total, dim=0))) # nce is a tensor
    nce /= -1.*batch*timestep
    accuracy = 1.*correct.item()/(batch * timestep)

    return nce

In [19]:
class UMTCPCTrainer:
    def __init__(self, args):
        self.args = args
        self.args.n_datasets = len(self.args.data)
        self.expPath = Path('/content/checkpoints') / args.expName

        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)

        self.logger = create_output_dir(args, self.expPath)
        self.data = [DatasetSet(d, args.seq_len, args) for d in args.data]
        assert not args.distributed or len(self.data) == int(
            os.environ['WORLD_SIZE']), "Number of datasets must match number of nodes"

        self.losses_recon = [LossMeter(f'recon {i}') for i in range(self.args.n_datasets)]
        self.losses_nce = [LossMeter(f'nce {i}') for i in range(self.args.n_datasets)]
        self.loss_d_right = LossMeter('d')
        self.loss_total = LossMeter('total')

        self.evals_recon = [LossMeter(f'recon {i}') for i in range(self.args.n_datasets)]
        self.evals_nce = [LossMeter(f'nce {i}') for i in range(self.args.n_datasets)]
        self.eval_d_right = LossMeter('eval d')
        self.eval_total = LossMeter('eval total')

        self.cpc_encoder = CPC(args)
        self.encoder = Encoder(args)
        self.decoder = WaveNet(args)
        self.discriminator = ZDiscriminator(args)

        if args.distributed:
            self.decoder = WaveNet(args)
        else:
            self.decoders = torch.nn.ModuleList([WaveNet(args) for _ in range(self.args.n_datasets)])

        if args.checkpoint:
            checkpoint_args_path = os.path.dirname(args.checkpoint) + '/args.pth'
            checkpoint_args = torch.load(checkpoint_args_path)

            self.start_epoch = checkpoint_args[-1] + 1
            if args.distributed:
                states = torch.load(args.checkpoint)
            else:
                states = [torch.load(args.checkpoint + f'_{i}.pth')
                          for i in range(self.args.n_datasets)]
            if args.distributed:
                #self.encoder.load_state_dict(states['encoder_state'])
                self.cpc_encoder.load_state_dict(states['cpc_encoder_state'])
                self.decoder.load_state_dict(states['decoder_state'])
                self.discriminator.load_state_dict(states['discriminator_state'])
            else:
                #self.encoder.load_state_dict(states[0]['encoder_state'])
                self.cpc_encoder.load_state_dict(states['cpc_encoder_state'])
                for i in range(self.args.n_datasets):
                    self.decoders[i].load_state_dict(states[i]['decoder_state'])
                self.discriminator.load_state_dict(states[0]['discriminator_state'])

            self.logger.info('Loaded checkpoint parameters')
        else:
            self.start_epoch = 0

        ## BUGFIX Data loading ##
        if args.distributed:
            self.cpc_encoder = torch.nn.parallel.DistributedDataParallel(self.cpc_encoder)
            self.cpc_encoder.cuda()
            #self.encoder.cuda()
            # self.encoder = torch.nn.parallel.DistributedDataParallel(self.encoder)
            self.discriminator.cuda()
            self.discriminator = torch.nn.parallel.DistributedDataParallel(self.discriminator)
            self.decoder = torch.nn.DataParallel(self.decoder).cuda()
            self.logger.info('Created DistributedDataParallel')
            self.model_optimizer = optim.Adam(chain(self.cpc_encoder.parameters(),
                                                    self.decoder.parameters()),
                                              lr=args.lr)
        else:
            self.cpc_encoder = torch.nn.DataParallel(self.cpc_encoder).cuda()
            #self.encoder = torch.nn.DataParallel(self.encoder).cuda()
            self.discriminator = torch.nn.DataParallel(self.discriminator).cuda()
            ## BUGFIX -- IMPLEMENTED Separate optim / decoder ##
            self.model_optimizers = []
            for i, decoder in enumerate(self.decoders):
                self.decoders[i] = torch.nn.DataParallel(decoder).cuda()
            self.model_optimizers = [optim.Adam(chain(self.cpc_encoder.parameters(),
                                                      decoder.parameters()),
                                                lr=args.lr)
                                     for decoder in self.decoders]

        self.d_optimizer = optim.Adam(self.discriminator.parameters(),
                                      lr=args.lr)

        ## BUGFIX Data loading ##
        if args.checkpoint and args.load_optimizer:
            if args.distributed:
                self.model_optimizer.load_state_dict(states['model_optimizer_state'])
                self.d_optimizer.load_state_dict(states['d_optimizer_state'])
            else:
                for i in range(self.args.n_datasets):
                    self.model_optimizers[i].load_state_dict(states[i]['model_optimizer_state'])
                self.d_optimizer.load_state_dict(states[0]['d_optimizer_state'])

        if args.distributed:
            self.lr_manager = torch.optim.lr_scheduler.ExponentialLR(self.model_optimizer, args.lr_decay)
            self.lr_manager.last_epoch = self.start_epoch
            self.lr_manager.step()
        else:
            self.lr_managers = []
            for i in range(self.args.n_datasets):
                self.lr_managers.append(torch.optim.lr_scheduler.ExponentialLR(self.model_optimizers[i], args.lr_decay))
                self.lr_managers[i].last_epoch = self.start_epoch
                self.lr_managers[i].step()

    def init_hidden(self, use_gpu=True):
        if use_gpu: return torch.zeros(1, self.args.batch_size, self.args.latent_d).cuda()
        else: return torch.zeros(1, self.args.batch_size, self.args.latent_d)

    def eval_batch(self, x, x_aug, dset_num):
        x, x_aug = x.float(), x_aug.float()

        hidden = self.init_hidden()

        c, hidden, encode_samples, pred = self.cpc_encoder(x, hidden)
        ## BUGFIX decoder ##
        if self.args.distributed:
            y = self.decoder(x, c)
        else:
            y = self.decoders[dset_num](x, c)

        c_logits = self.discriminator(c)

        c_classification = torch.max(c_logits, dim=1)[1]

        c_accuracy = (c_classification == dset_num).float().mean()

        self.eval_d_right.add(c_accuracy.data.item())

        # discriminator_right = F.cross_entropy(c_logits, dset_num).mean()
        discriminator_right = F.cross_entropy(c_logits, torch.tensor([dset_num] * x.size(0)).long().cuda()).mean()
        recon_loss = cross_entropy_loss(y, x)
        self.evals_recon[dset_num].add(recon_loss.data.cpu().numpy().mean())

        nce_loss = InfoNCELoss(encode_samples, pred, self.args.timestep)
        self.evals_nce[dset_num].add(nce_loss.data.cpu().numpy())

        total_loss = discriminator_right.data.item() * self.args.d_lambda + \
                     recon_loss.mean().data.item() + nce_loss.data.item()

        self.eval_total.add(total_loss)

        return total_loss

    def train_batch(self, x, x_aug, dset_num):
        x, x_aug = x.float(), x_aug.float()

        hidden = self.init_hidden()

        # Optimize D - discriminator right
        c, _, encoder_samples, pred = self.cpc_encoder(x, hidden)
        c_logits = self.discriminator(c)
        discriminator_right = F.cross_entropy(c_logits, torch.tensor([dset_num] * x.size(0)).long().cuda()).mean()
        self.loss_d_right.add(discriminator_right.data.cpu())

        # Get c_t for computing InfoNCE Loss
        nce_loss = InfoNCELoss(encoder_samples, pred, self.args.timestep)
        loss = discriminator_right * self.args.d_lambda + nce_loss
        self.d_optimizer.zero_grad()
        loss.backward()
        if self.args.grad_clip is not None:
            clip_grad_value_(self.discriminator.parameters(), self.args.grad_clip)

        self.d_optimizer.step()

        # optimize G - reconstructs well, discriminator wrong
        c, _, encoder_samples, pred = self.cpc_encoder(x_aug, hidden)
        if self.args.distributed:
            y = self.decoder(x, c)
        else:
            y = self.decoders[dset_num](x, c)
        c_logits = self.discriminator(c)
        discriminator_wrong = - F.cross_entropy(c_logits, torch.tensor([dset_num] * x.size(0)).long().cuda()).mean()

        if not (-100 < discriminator_right.data.item() < 100):
            self.logger.debug(f'c_logits: {c_logits.detach().cpu().numpy()}')
            self.logger.debug(f'dset_num: {dset_num}')

        nce_loss = InfoNCELoss(encoder_samples, pred, self.args.timestep)
        self.losses_nce[dset_num].add(nce_loss.data.cpu().numpy())

        recon_loss = cross_entropy_loss(y, x)
        self.losses_recon[dset_num].add(recon_loss.data.cpu().numpy().mean())

        loss = (recon_loss.mean() + self.args.d_lambda * discriminator_wrong) + nce_loss

        if self.args.distributed:
            self.model_optimizer.zero_grad()
        else:
            self.model_optimizers[dset_num].zero_grad()
        loss.backward()
        if self.args.grad_clip is not None:
            clip_grad_value_(self.encoder.parameters(), self.args.grad_clip)
            if self.args.distributed:
                clip_grad_value_(self.decoder.parameters(), self.args.grad_clip)
            else:
                for decoder in self.decoders:
                    clip_grad_value_(decoder.parameters(), self.args.grad_clip)
        ## BUGFIX model optimizer ##
        if self.args.distributed:
            self.model_optimizer.step()
        else:
            self.model_optimizers[dset_num].step()

        self.loss_total.add(loss.data.item())

        return loss.data.item()

    def train_epoch(self, epoch):
        for meter in self.losses_recon:
            meter.reset()
        self.loss_d_right.reset()
        self.loss_total.reset()

        self.cpc_encoder.train()
        if self.args.distributed:
            self.decoder.train()
        else:
            for decoder in self.decoders:
                decoder.train()
        self.discriminator.train()

        n_batches = self.args.epoch_len

        with tqdm.tqdm(total=n_batches, desc='Train epoch %d' % epoch) as train_enum:
            for batch_num in range(n_batches):
                if self.args.short and batch_num == 3:
                    break

                if self.args.distributed:
                    assert self.args.rank < self.args.n_datasets, "No. of workers must be equal to #dataset"
                    # dset_num = (batch_num + self.args.rank) % self.args.n_datasets
                    dset_num = self.args.rank
                else:
                    dset_num = batch_num % self.args.n_datasets

                x, x_aug = next(self.data[dset_num].train_iter)

                x = wrap(x)
                x_aug = wrap(x_aug)
                batch_loss = self.train_batch(x, x_aug, dset_num)

                train_enum.set_description(f'Train (loss: {batch_loss:.2f}) epoch {epoch}')
                train_enum.update()

    def evaluate_epoch(self, epoch):
        for meter in self.evals_recon:
            meter.reset()
        self.eval_d_right.reset()
        self.eval_total.reset()

        self.cpc_encoder.eval()
        if self.args.distributed:
            self.decoder.eval()
        else:
            for decoder in self.decoders:
                decoder.eval()
        self.discriminator.eval()

        n_batches = int(np.ceil(self.args.epoch_len / 10))

        with tqdm.tqdm(total=n_batches) as valid_enum, \
                torch.no_grad():
            for batch_num in range(n_batches):
                if self.args.short and batch_num == 10:
                    break

                if self.args.distributed:
                    assert self.args.rank < self.args.n_datasets, "No. of workers must be equal to #dataset"
                    dset_num = self.args.rank
                else:
                    dset_num = batch_num % self.args.n_datasets

                x, x_aug = next(self.data[dset_num].valid_iter)

                x = wrap(x)
                x_aug = wrap(x_aug)
                batch_loss = self.eval_batch(x, x_aug, dset_num)

                valid_enum.set_description(f'Test (loss: {batch_loss:.2f}) epoch {epoch}')
                valid_enum.update()

    @staticmethod
    def format_losses(meters):
        losses = [meter.summarize_epoch() for meter in meters]
        return ', '.join('{:.4f}'.format(x) for x in losses)

    def train_losses(self):
        meters = [*self.losses_recon, *self.losses_nce, self.loss_d_right]
        return self.format_losses(meters)

    def eval_losses(self):
        meters = [*self.evals_recon, *self.evals_nce, self.eval_d_right]
        return self.format_losses(meters)

    def train(self):
        best_eval = float('inf')

        # Begin!
        for epoch in range(self.start_epoch, self.start_epoch + self.args.epochs):
            self.logger.info(f'Starting epoch, Rank {self.args.rank}, Dataset: {self.args.data[self.args.rank]}')
            self.train_epoch(epoch)
            self.evaluate_epoch(epoch)

            self.logger.info(f'Epoch %s Rank {self.args.rank} - Train loss: (%s), Test loss (%s)',
                             epoch, self.train_losses(), self.eval_losses())
            if self.args.distributed:
                self.lr_manager.step()
            else:
                for i in range(self.args.n_datasets):
                    self.lr_managers[i].step()
            val_loss = self.eval_total.summarize_epoch()

            if val_loss < best_eval:
                self.save_model(f'bestmodel_{self.args.rank}.pth')
                best_eval = val_loss

            if not self.args.per_epoch:
                self.save_model(f'lastmodel_{self.args.rank}.pth')
            else:
                self.save_model(f'lastmodel_{epoch}_rank_{self.args.rank}.pth')

            if self.args.is_master:
                torch.save([self.args,
                            epoch],
                           '%s/args.pth' % self.expPath)

            self.logger.debug('Ended epoch')

    def save_model(self, filename):
        if self.args.distributed:
            save_path = self.expPath / filename
            torch.save({'encoder_state': self.cpc_encoder.module.state_dict(),
                        'decoder_state': self.decoder.module.state_dict(),
                        'discriminator_state': self.discriminator.module.state_dict(),
                        'model_optimizer_state': self.model_optimizer.state_dict(),
                        'dataset': self.args.rank,
                        'd_optimizer_state': self.d_optimizer.state_dict()
                        },
                    save_path)
            self.logger.debug(f'Saved model to {save_path}')
        else:
            filename = re.sub('_\d.pth$', '', filename)
            for i in range(self.args.n_datasets):
                save_path = self.expPath / f'{filename}_{i}.pth'
                torch.save({'encoder_state': self.cpc_encoder.module.state_dict(),
                            'decoder_state': self.decoders[i].module.state_dict(),
                            'discriminator_state': self.discriminator.module.state_dict(),
                            'model_optimizer_state': self.model_optimizers[i].state_dict(),
                            'dataset': self.args.rank,
                            'd_optimizer_state': self.d_optimizer.state_dict()
                            },
                        save_path)
                self.logger.debug(f'Saved model to {save_path}')

In [None]:
data_paths = [Path('musicnet/preprocessed/Bach_Solo_Cello'),
        Path('musicnet/preprocessed/Bach_Solo_Piano')]
args = TrainerArgs(data_paths, epochs=5, batch_size=8, lr_decay=0.995, epoch_len=20,
                  num_workers=0, lr=1e-3, seq_len=12000, d_lambda=1e-2, expName='musicnet_umtcpc',
                  latent_d=64, layers=14, blocks=4, data_aug=True, grad_clip=1, encoder_pool=100)


In [None]:
model = UMTCPCTrainer(args)
model.train()

INFO - 05/03/21 20:30:32 - 0:00:00 - <__main__.TrainerArgs object at 0x7fbf3292f7d0>
2021-05-03 20:30:32,792 - INFO - Dataset created. 9 files, augmentation: True. Path: musicnet/preprocessed/Bach_Solo_Cello/train
2021-05-03 20:30:32,795 - INFO - Dataset created. 1 files, augmentation: True. Path: musicnet/preprocessed/Bach_Solo_Cello/val
2021-05-03 20:30:32,801 - INFO - Dataset created. 31 files, augmentation: True. Path: musicnet/preprocessed/Bach_Solo_Piano/train
2021-05-03 20:30:32,804 - INFO - Dataset created. 3 files, augmentation: True. Path: musicnet/preprocessed/Bach_Solo_Piano/val


/content/checkpoints/musicnet_umtcpc


INFO - 05/03/21 20:30:33 - 0:00:00 - Starting epoch, Rank 0, Dataset: musicnet/preprocessed/Bach_Solo_Cello
Train (loss: 9.30) epoch 0: 100%|██████████| 20/20 [00:29<00:00,  1.47s/it]
Test (loss: 7.51) epoch 0: 100%|██████████| 2/2 [00:01<00:00,  1.69it/s]
INFO - 05/03/21 20:31:03 - 0:00:31 - Epoch 0 Rank 0 - Train loss: (5.4124, 5.3460, 4.7616, 6.4367, 0.6993), Test loss (5.0875, 5.1031, 2.4324, 2.3999, 0.5000)
INFO - 05/03/21 20:31:05 - 0:00:33 - Starting epoch, Rank 0, Dataset: musicnet/preprocessed/Bach_Solo_Cello
Train (loss: 6.55) epoch 1: 100%|██████████| 20/20 [00:29<00:00,  1.50s/it]
Test (loss: 6.72) epoch 1: 100%|██████████| 2/2 [00:01<00:00,  1.71it/s]
INFO - 05/03/21 20:31:37 - 0:01:04 - Epoch 1 Rank 0 - Train loss: (4.8585, 4.7416, 3.7682, 4.6251, 0.6964), Test loss (4.5277, 4.5389, 2.4487, 2.2881, 0.6875)
INFO - 05/03/21 20:31:39 - 0:01:06 - Starting epoch, Rank 0, Dataset: musicnet/preprocessed/Bach_Solo_Cello
Train (loss: 4.93) epoch 2: 100%|██████████| 20/20 [00:31<00

In [None]:
def run_on_files_umtcpc(files, output, checkpoint, decoders=[], rate=16000, batch_size=6, sample_len=None, split_size=20, output_next_to_orig=False,
                 skip_filter=False, py=True):

    print('Starting')
    matplotlib.use('agg')

    checkpoints = checkpoint.parent.glob(checkpoint.name + '_*.pth')
    checkpoints = [c for c in checkpoints if extract_id(c) in decoders]
    assert len(checkpoints) >= 1, "No checkpoints found."

    model_args = torch.load(checkpoint.parent / 'args.pth')[0]
    cpc_encoder = CPC(model_args)
    cpc_encoder.load_state_dict(torch.load(checkpoints[0])['encoder_state'])
    cpc_encoder.eval()
    cpc_encoder = cpc_encoder.cuda()

    decoders = []
    decoder_ids = []
    for checkpoint in checkpoints:
        decoder = WaveNet(model_args)
        decoder.load_state_dict(torch.load(checkpoint)['decoder_state'])
        decoder.eval()
        decoder = decoder.cuda()
        if py:
            decoder = WavenetGenerator(decoder, batch_size, wav_freq=rate)
        else:
            decoder = NVWavenetGenerator(decoder, rate * (split_size // 20), batch_size, 3)

        decoders += [decoder]
        decoder_ids += [extract_id(checkpoint)]

    xs = []
    assert output_next_to_orig ^ (output is not None)

    if len(files) == 1 and files[0].is_dir():
        top = files[0]
        file_paths = list(top.glob('**/*.wav')) + list(top.glob('**/*.h5'))
    else:
        file_paths = files

    if not skip_filter:
        file_paths = [f for f in file_paths if not '_' in str(f.name)]

    for file_path in file_paths:
        if file_path.suffix == '.wav':
            data, rate = librosa.load(file_path, sr=16000)
            assert rate == 16000
            data = mu_law(data)
        elif file_path.suffix == '.h5':
            data = mu_law(h5py.File(file_path, 'r')['wav'][:] / (2 ** 15))
            if data.shape[-1] % rate != 0:
                data = data[:-(data.shape[-1] % rate)]
            assert data.shape[-1] % rate == 0
            print(data.shape)
        else:
            raise Exception(f'Unsupported filetype {file_path}')

        if sample_len:
            data = data[:sample_len]
        else:
            sample_len = len(data)
        xs.append(torch.tensor(data).unsqueeze(0).float().cuda())

    xs = torch.stack(xs).contiguous()
    print(f'xs size: {xs.size()}')

    def save(x, decoder_ix, filepath):
        wav = inv_mu_law(x.cpu().numpy())
        print(f'X size: {x.shape}')
        print(f'X min: {x.min()}, max: {x.max()}')

        if output_next_to_orig:
            save_audio(wav.squeeze(), filepath.parent / f'{filepath.stem}_{decoder_ix}.wav', rate=rate)
        else:
            save_audio(wav.squeeze(), output / str(decoder_ix) / filepath.with_suffix('.wav').name, rate=rate)

    def init_hidden(size, use_gpu=True):
        if use_gpu: return torch.zeros(1, size, model_args.latent_d).cuda()
        else: return torch.zeros(1, size, model_args.latent_d)
    
    yy = {}
    with torch.no_grad():
        zz = []
        for xs_batch in torch.split(xs, batch_size):
            hidden = init_hidden(len(xs_batch))
            output, _, _, _ = cpc_encoder(xs_batch, hidden)
            print(output.shape)
            zz += [output]
        #zz = torch.cat(zz, dim=0)

        xx = torch.split(xs, batch_size)

        with timeit("Generation timer"):
            for i, decoder_id in enumerate(decoder_ids):
                yy[decoder_id] = []
                decoder = decoders[i]
                #for zz_batch in torch.split(zz, batch_size):
                for zz_batch, xx_batch in zip(zz, xx):
                    print(f"zz_batch.shape: {zz_batch.shape}")
                    print(f"xx_batch.shape: {xx_batch.shape}")
                    audio_data = decoder(xx_batch, zz_batch).cpu()
                    print(f"audio_data.shape: {audio_data.shape}")
                    # splits = torch.split(zz_batch, split_size, -1)
                    # xx_splits = torch.split(xx_batch, split_size, -1)
                    # audio_data = []
                    # decoder.reset()
                    # for x, cond in tqdm.tqdm(zip(xx_splits, splits)):
                    #     #audio_data += [decoder.generate(cond).cpu()]
                    #     print(f"cond.shape: {cond.shape}")
                    #     print(f"x.shape: {x.shape}")
                    #     audio_data += [decoder(x, cond).cpu()]
                    # audio_data = torch.cat(audio_data, -1)
                    yy[decoder_id] += [audio_data]
                yy[decoder_id] = torch.cat(yy[decoder_id], dim=0)
                del decoder

    # for decoder_ix, decoder_result in enumerate(yy):
    #     for sample_result, filepath in zip(decoder_result, file_paths):
    #         save(sample_result, decoder_ix, filepath)

In [None]:
checkpoint_args = Path('/content/checkpoints/musicnet_umtcpc/args.pth')
output = Path('/content/results')
checkpoint_lastmodel = Path('/content/checkpoints/musicnet_umtcpc/lastmodel')

# Extract data samples to use as input for translation
data_samples(None, checkpoint_args, output, 4, 80000)

files = [Path('/content/results')]
run_on_files_umtcpc(files, None, checkpoint_lastmodel, decoders=[0, 1, 2], output_next_to_orig=True)

2021-05-03 17:49:19,437 - INFO - Dataset created. 2 files, augmentation: False. Path: musicnet/preprocessed/Bach_Solo_Cello/test
2021-05-03 17:49:19,442 - INFO - Dataset created. 5 files, augmentation: False. Path: musicnet/preprocessed/Bach_Solo_Piano/test


100%|██████████| 4/4 [00:00<00:00, 68.74it/s]


100%|██████████| 4/4 [00:00<00:00, 86.79it/s]


Starting
xs size: torch.Size([8, 1, 80000])
torch.Size([6, 64, 50])
torch.Size([2, 64, 68])
zz_batch.shape: torch.Size([6, 64, 50])
xx_batch.shape: torch.Size([6, 1, 80000])
Generation timer took 2167.018413543701 ms


RuntimeError: ignored

# UMT-CPC New

In [20]:
class CPC(nn.Module):
    """
    Creates a contrastive predictive coding model with a strided convolutional 
    encoder and GRU RNN autoregressor as described by [1] and implemented in [2].

    References
    ----------
    [1] van der Oord et al., "Representation Learning with Contrastive 
        Predictive Coding", arXiv, 2019.
        https://arxiv.org/abs/1807.03748
    [2] Lai, "Contrastive-Predictive-Coding-PyTorch", GitHub.
        https://github.com/jefflai108/Contrastive-Predictive-Coding-PyTorch
    """
    def __init__(self, args):
        super().__init__()
        self.encoder = nn.Sequential( # downsampling factor = 160
            nn.Conv1d(1, 512, kernel_size=10, stride=5, padding=3, bias=False),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Conv1d(512, 512, kernel_size=8, stride=4, padding=2, bias=False),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Conv1d(512, 512, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Conv1d(512, 512, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Conv1d(512, 512, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Conv1d(512, 1, kernel_size=1, stride=1, bias=False),
        )
        #self.ar = nn.GRU(512, args.latent_d, num_layers=1, bidirectional=False, batch_first=True)
        self.ar = Encoder(args)

    def forward(self, x):
        """
        Parameters
        ----------
            x : B x 1 x L torch.Tensor
                Input batch of audio sequence with B samples and length L.

        Returns
        -------
            z : B x (L // 160) x 512 torch.Tensor
                Encoded representation of audio sequence with 512 channels.
            c : B x (L // 160) x 256 torch.Tensor
                Context-encoded representation of audio sequence with 256 channels.
        """
        x = x / 255 - .5
        if x.dim() < 3:
            x = x.unsqueeze(1)

        # Use encoder to get sequence of latent representations z_t
        z = self.encoder(x)
        #z = z.transpose(1,2)

        # Use autoregressive model to compute context latent representation c_t
        #c, _ = self.ar(z)
        c = self.ar(z)
        #c = c.transpose(1, 2)

        z = z.transpose(1, 2)
        return z, c

In [21]:
class InfoNCELoss(nn.Module):
    """
    Creates a criterion that computes the InfoNCELoss as described in [1].
    
    Parameters
    ----------
        prediction_step : int
            Number of steps to predict into the future using context vector c

    References
    ----------
    [1] van der Oord et al., "Representation Learning with Contrastive 
        Predictive Coding", arXiv, 2019.
        https://arxiv.org/abs/1807.03748
    """
    def __init__(self, args):
        super().__init__()
        self.prediction_step = args.timestep
        self.Wk = nn.ModuleList(
            nn.Linear(args.latent_d, 1) for _ in range(self.prediction_step)
        )

    def get_neg_z(self, z, k, t, n_replicates):
        """
        Parameters
        ----------
            z : B x L x 512 torch.Tensor
                Encoded representation of audio sequence.
            k : int
                Number of time steps in the future for prediction
            t : B torch.Tensor
                Current time step for each sample in the batch
            n_replicates : int
                Number of repetitions of the negative sampling procedure

        Returns
        -------
            neg_samples : B x L-1 x N_rep x 512 torch.Tensor
                Batch-wise average InfoNCE loss
        """
        cur_device = z.get_device() if z.get_device() != -1 else "cpu"

        neg_idx = torch.vstack([torch.cat([
            torch.arange(0, t_i + k),             # indices before t+k
            torch.arange(t_i + k + 1, z.size(1))  # indices after t+k
        ]) for t_i in t])

        neg_samples = torch.vstack([z[i, neg_idx[i]].unsqueeze(0) for i in range(len(t))])
        neg_samples = torch.stack(
            [
                torch.index_select(neg_samples, 1, torch.randint(neg_samples.size(1), 
                                                                 (neg_samples.size(1), )).to(cur_device))
                for i in range(n_replicates)
            ],
            2,
        )
        return neg_samples
        
    def forward(self, z, c, n_replicates):
        """
        Parameters
        ----------
            z : B x L x 512 torch.Tensor
                Encoded representation of audio sequence.
            c : B x L x 256 torch.Tensor
                Context-encoded representation of audio sequence.
            n_replicates : int
                Number of times to make a set of negative samples.
        
        Returns
        -------
            loss : float Tensor
                Batch-wise average InfoNCE loss
        """
        loss = 0

        n_batches = z.size(0)

        # Sample random t for each batch
        cur_device = z.get_device() if z.get_device() != -1 else "cpu"
        t = torch.randint(z.size(1) - self.prediction_step - 1, (n_batches,)).to(cur_device)

        # Get context vector c_t
        c = c.transpose(1, 2)
        c_t = c[torch.arange(n_batches), t] # B x 256

        self.Wk.to(cur_device)
        for k in range(1, self.prediction_step + 1):
            # Perform negative sampling
            neg_samples = self.get_neg_z(z, k, t, n_replicates)  # B x L-1 x N_rep x C

            # Compute W_k * c_t
            linear = self.Wk[k - 1]  # 256 x C
            pred = linear(c_t) # B x C

            # Get positive z_t+k sample
            pos_sample = z[torch.arange(n_batches), t+k]

            # Positive sample: compute f_k(x_t+k, c_t)
            # Only take diagonal elements to get product between matched batches
            fk_pos = torch.diag(torch.matmul(pos_sample, pred.T)) # B (1-D tensor)
            fk_pos_rep = fk_pos.repeat(n_replicates).view(1, 1, n_replicates, fk_pos.size(0)) # 1 x 1 x N_rep x B

            # Negative samples: compute f_k(x_j, c_t)
            # Only take diagonal elements to get products between matched batches
            fk_neg = torch.matmul(neg_samples, pred.T) # B x L-1 x N_rep x B
            fk_neg = torch.diagonal(fk_neg, dim1=0, dim2=-1).unsqueeze(0) # 1 x L-1 x N_rep x B

            # Concatenate fk for positive and negative samples
            fk = torch.hstack([fk_pos_rep, fk_neg]) # 1 x L x N_rep x B

            # Compute log softmax over all fk 
            log_sm_fk = torch.nn.LogSoftmax(dim=1)(fk)  # 1 x L x N_rep x B

            # Compute expected value of log softmaxes over replicates
            exp_log_sm_fk = torch.mean(log_sm_fk, dim=2)  # 1 x L x B    

            # Update loss with log softmax element corresponding to positive sample
            loss -= exp_log_sm_fk[:, 0] # 1 x B
            
        # Divide by number of predicted steps
        loss /= self.prediction_step

        # Average over batches
        loss = loss.sum() / n_batches

        return loss

In [22]:
class UMTCPCNewTrainer:
    def __init__(self, args):
        self.args = args
        self.args.n_datasets = len(self.args.data)
        self.expPath = Path('/content/checkpoints') / args.expName

        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)

        self.logger = create_output_dir(args, self.expPath)
        self.data = [DatasetSet(d, args.seq_len, args) for d in args.data]
        assert not args.distributed or len(self.data) == int(
            os.environ['WORLD_SIZE']), "Number of datasets must match number of nodes"

        self.losses_recon = [LossMeter(f'recon {i}') for i in range(self.args.n_datasets)]
        self.losses_nce = [LossMeter(f'nce {i}') for i in range(self.args.n_datasets)]
        self.loss_d_right = LossMeter('d')
        self.loss_total = LossMeter('total')

        self.evals_recon = [LossMeter(f'recon {i}') for i in range(self.args.n_datasets)]
        self.evals_nce = [LossMeter(f'nce {i}') for i in range(self.args.n_datasets)]
        self.eval_d_right = LossMeter('eval d')
        self.eval_total = LossMeter('eval total')

        self.encoder = CPC(args)
        self.decoder = WaveNet(args)
        self.discriminator = ZDiscriminator(args)

        if args.distributed:
            self.decoder = WaveNet(args)
        else:
            self.decoders = torch.nn.ModuleList([WaveNet(args) for _ in range(self.args.n_datasets)])

        if args.checkpoint:
            checkpoint_args_path = os.path.dirname(args.checkpoint) + '/args.pth'
            checkpoint_args = torch.load(checkpoint_args_path)

            self.start_epoch = checkpoint_args[-1] + 1
            if args.distributed:
                states = torch.load(args.checkpoint)
            else:
                states = [torch.load(args.checkpoint + f'_{i}.pth')
                          for i in range(self.args.n_datasets)]
            if args.distributed:
                self.encoder.load_state_dict(states['encoder_state'])
                self.decoder.load_state_dict(states['decoder_state'])
                self.discriminator.load_state_dict(states['discriminator_state'])
            else:
                self.encoder.load_state_dict(states['encoder_state'])
                for i in range(self.args.n_datasets):
                    self.decoders[i].load_state_dict(states[i]['decoder_state'])
                self.discriminator.load_state_dict(states[0]['discriminator_state'])

            self.logger.info('Loaded checkpoint parameters')
        else:
            self.start_epoch = 0

        ## BUGFIX Data loading ##
        if args.distributed:
            self.encoder.cuda()
            self.encoder = torch.nn.parallel.DistributedDataParallel(self.encoder)
            self.discriminator.cuda()
            self.discriminator = torch.nn.parallel.DistributedDataParallel(self.discriminator)
            self.decoder = torch.nn.DataParallel(self.decoder).cuda()
            self.logger.info('Created DistributedDataParallel')
            self.model_optimizer = optim.Adam(chain(self.encoder.parameters(),
                                                    self.decoder.parameters()),
                                              lr=args.lr)
        else:
            self.encoder = torch.nn.DataParallel(self.encoder).cuda()
            self.discriminator = torch.nn.DataParallel(self.discriminator).cuda()
            ## BUGFIX -- IMPLEMENTED Separate optim / decoder ##
            self.model_optimizers = []
            for i, decoder in enumerate(self.decoders):
                self.decoders[i] = torch.nn.DataParallel(decoder).cuda()
            self.model_optimizers = [optim.Adam(chain(self.encoder.parameters(),
                                                      decoder.parameters()),
                                                lr=args.lr)
                                     for decoder in self.decoders]
        self.d_optimizer = optim.Adam(self.discriminator.parameters(),
                                      lr=args.lr)

        ## BUGFIX Data loading ##
        if args.checkpoint and args.load_optimizer:
            if args.distributed:
                self.model_optimizer.load_state_dict(states['model_optimizer_state'])
                self.d_optimizer.load_state_dict(states['d_optimizer_state'])
            else:
                for i in range(self.args.n_datasets):
                    self.model_optimizers[i].load_state_dict(states[i]['model_optimizer_state'])
                self.d_optimizer.load_state_dict(states[0]['d_optimizer_state'])

        if args.distributed:
            self.lr_manager = torch.optim.lr_scheduler.ExponentialLR(self.model_optimizer, args.lr_decay)
            self.lr_manager.last_epoch = self.start_epoch
            self.lr_manager.step()
        else:
            self.lr_managers = []
            for i in range(self.args.n_datasets):
                self.lr_managers.append(torch.optim.lr_scheduler.ExponentialLR(self.model_optimizers[i], args.lr_decay))
                self.lr_managers[i].last_epoch = self.start_epoch
                self.lr_managers[i].step()

    def eval_batch(self, x, x_aug, dset_num):
        x, x_aug = x.float(), x_aug.float()

        z, c = self.encoder(x)
        ## BUGFIX decoder ##
        if self.args.distributed:
            y = self.decoder(x, c)
        else:
            y = self.decoders[dset_num](x, c)

        c_logits = self.discriminator(c)

        c_classification = torch.max(c_logits, dim=1)[1]

        c_accuracy = (c_classification == dset_num).float().mean()

        self.eval_d_right.add(c_accuracy.data.item())

        # discriminator_right = F.cross_entropy(c_logits, dset_num).mean()
        discriminator_right = F.cross_entropy(c_logits, torch.tensor([dset_num] * x.size(0)).long().cuda()).mean()
        recon_loss = cross_entropy_loss(y, x)
        self.evals_recon[dset_num].add(recon_loss.data.cpu().numpy().mean())

        nce_loss = InfoNCELoss(args)
        nce_loss_val = nce_loss(z, c, n_replicates=5)
        self.evals_nce[dset_num].add(nce_loss_val.data.cpu().numpy())

        total_loss = discriminator_right.data.item() * self.args.d_lambda + \
                     recon_loss.mean().data.item() + nce_loss_val.data.item()

        self.eval_total.add(total_loss)

        return total_loss

    def train_batch(self, x, x_aug, dset_num):
        x, x_aug = x.float(), x_aug.float()

        # Optimize D - discriminator right
        z, c = self.encoder(x)
        c_logits = self.discriminator(c)
        discriminator_right = F.cross_entropy(c_logits, torch.tensor([dset_num] * x.size(0)).long().cuda()).mean()
        self.loss_d_right.add(discriminator_right.data.cpu())

        # Get c_t for computing InfoNCE Loss
        nce_loss = InfoNCELoss(args)
        nce_loss_val = nce_loss(z, c, n_replicates=5)
        loss = discriminator_right * self.args.d_lambda + nce_loss_val
        self.d_optimizer.zero_grad()
        loss.backward()
        if self.args.grad_clip is not None:
            clip_grad_value_(self.discriminator.parameters(), self.args.grad_clip)

        self.d_optimizer.step()

        # optimize G - reconstructs well, discriminator wrong
        z, c = self.encoder(x_aug)
        if self.args.distributed:
            y = self.decoder(x, c)
        else:
            y = self.decoders[dset_num](x, c)
        c_logits = self.discriminator(c)
        discriminator_wrong = - F.cross_entropy(c_logits, torch.tensor([dset_num] * x.size(0)).long().cuda()).mean()

        if not (-100 < discriminator_right.data.item() < 100):
            self.logger.debug(f'c_logits: {c_logits.detach().cpu().numpy()}')
            self.logger.debug(f'dset_num: {dset_num}')

        nce_loss = InfoNCELoss(args)
        nce_loss_val = nce_loss(z, c, n_replicates=5)
        self.losses_nce[dset_num].add(nce_loss_val.data.cpu().numpy())

        recon_loss = cross_entropy_loss(y, x)
        self.losses_recon[dset_num].add(recon_loss.data.cpu().numpy().mean())

        loss = (recon_loss.mean() + self.args.d_lambda * discriminator_wrong) + nce_loss_val

        if self.args.distributed:
            self.model_optimizer.zero_grad()
        else:
            self.model_optimizers[dset_num].zero_grad()
        loss.backward()
        if self.args.grad_clip is not None:
            clip_grad_value_(self.encoder.parameters(), self.args.grad_clip)
            if self.args.distributed:
                clip_grad_value_(self.decoder.parameters(), self.args.grad_clip)
            else:
                for decoder in self.decoders:
                    clip_grad_value_(decoder.parameters(), self.args.grad_clip)
        ## BUGFIX model optimizer ##
        if self.args.distributed:
            self.model_optimizer.step()
        else:
            self.model_optimizers[dset_num].step()

        self.loss_total.add(loss.data.item())

        return loss.data.item()

    def train_epoch(self, epoch):
        for meter in self.losses_recon:
            meter.reset()
        self.loss_d_right.reset()
        self.loss_total.reset()

        self.encoder.train()
        if self.args.distributed:
            self.decoder.train()
        else:
            for decoder in self.decoders:
                decoder.train()
        self.discriminator.train()

        n_batches = self.args.epoch_len

        with tqdm.tqdm(total=n_batches, desc='Train epoch %d' % epoch) as train_enum:
            for batch_num in range(n_batches):
                if self.args.short and batch_num == 3:
                    break

                if self.args.distributed:
                    assert self.args.rank < self.args.n_datasets, "No. of workers must be equal to #dataset"
                    # dset_num = (batch_num + self.args.rank) % self.args.n_datasets
                    dset_num = self.args.rank
                else:
                    dset_num = batch_num % self.args.n_datasets

                x, x_aug = next(self.data[dset_num].train_iter)

                x = wrap(x)
                x_aug = wrap(x_aug)
                batch_loss = self.train_batch(x, x_aug, dset_num)

                train_enum.set_description(f'Train (loss: {batch_loss:.2f}) epoch {epoch}')
                train_enum.update()

    def evaluate_epoch(self, epoch):
        for meter in self.evals_recon:
            meter.reset()
        self.eval_d_right.reset()
        self.eval_total.reset()

        self.encoder.eval()
        if self.args.distributed:
            self.decoder.eval()
        else:
            for decoder in self.decoders:
                decoder.eval()
        self.discriminator.eval()

        n_batches = int(np.ceil(self.args.epoch_len / 10))

        with tqdm.tqdm(total=n_batches) as valid_enum, \
                torch.no_grad():
            for batch_num in range(n_batches):
                if self.args.short and batch_num == 10:
                    break

                if self.args.distributed:
                    assert self.args.rank < self.args.n_datasets, "No. of workers must be equal to #dataset"
                    dset_num = self.args.rank
                else:
                    dset_num = batch_num % self.args.n_datasets

                x, x_aug = next(self.data[dset_num].valid_iter)

                x = wrap(x)
                x_aug = wrap(x_aug)
                batch_loss = self.eval_batch(x, x_aug, dset_num)

                valid_enum.set_description(f'Test (loss: {batch_loss:.2f}) epoch {epoch}')
                valid_enum.update()

    @staticmethod
    def format_losses(meters):
        losses = [meter.summarize_epoch() for meter in meters]
        return ', '.join('{:.4f}'.format(x) for x in losses)

    def train_losses(self):
        meters = [*self.losses_recon, *self.losses_nce, self.loss_d_right]
        return self.format_losses(meters)

    def eval_losses(self):
        meters = [*self.evals_recon, *self.evals_nce, self.eval_d_right]
        return self.format_losses(meters)

    def train(self):
        best_eval = float('inf')

        # Begin!
        for epoch in range(self.start_epoch, self.start_epoch + self.args.epochs):
            self.logger.info(f'Starting epoch, Rank {self.args.rank}, Dataset: {self.args.data[self.args.rank]}')
            self.train_epoch(epoch)
            self.evaluate_epoch(epoch)

            self.logger.info(f'Epoch %s Rank {self.args.rank} - Train loss: (%s), Test loss (%s)',
                             epoch, self.train_losses(), self.eval_losses())
            if self.args.distributed:
                self.lr_manager.step()
            else:
                for i in range(self.args.n_datasets):
                    self.lr_managers[i].step()
            val_loss = self.eval_total.summarize_epoch()

            if val_loss < best_eval:
                self.save_model(f'bestmodel_{self.args.rank}.pth')
                best_eval = val_loss

            if not self.args.per_epoch:
                self.save_model(f'lastmodel_{self.args.rank}.pth')
            else:
                self.save_model(f'lastmodel_{epoch}_rank_{self.args.rank}.pth')

            if self.args.is_master:
                torch.save([self.args,
                            epoch],
                           '%s/args.pth' % self.expPath)

            self.logger.debug('Ended epoch')

    def save_model(self, filename):
        if self.args.distributed:
            save_path = self.expPath / filename
            torch.save({'encoder_state': self.encoder.module.state_dict(),
                        'decoder_state': self.decoder.module.state_dict(),
                        'discriminator_state': self.discriminator.module.state_dict(),
                        'model_optimizer_state': self.model_optimizer.state_dict(),
                        'dataset': self.args.rank,
                        'd_optimizer_state': self.d_optimizer.state_dict()
                        },
                    save_path)
            self.logger.debug(f'Saved model to {save_path}')
        else:
            filename = re.sub('_\d.pth$', '', filename)
            for i in range(self.args.n_datasets):
                save_path = self.expPath / f'{filename}_{i}.pth'
                torch.save({'encoder_state': self.encoder.module.state_dict(),
                            'decoder_state': self.decoders[i].module.state_dict(),
                            'discriminator_state': self.discriminator.module.state_dict(),
                            'model_optimizer_state': self.model_optimizers[i].state_dict(),
                            'dataset': self.args.rank,
                            'd_optimizer_state': self.d_optimizer.state_dict()
                            },
                        save_path)
                self.logger.debug(f'Saved model to {save_path}')

In [23]:
data_paths = [
              Path('musicnet/preprocessed/Bach_Solo_Cello'),
        Path('musicnet/preprocessed/Bach_Solo_Piano')]
#pool = 100
args = TrainerArgs(data_paths, epochs=20, batch_size=8, lr_decay=0.995, epoch_len=20,
                  num_workers=0, lr=1e-3, seq_len=10000, d_lambda=1e-2, expName='musicnet_umtcpc',
                  latent_d=64, layers=14, blocks=4, data_aug=True, grad_clip=1, encoder_pool=1, timestep=5)


In [24]:
model = UMTCPCNewTrainer(args)
model.train()

INFO - 05/05/21 03:24:20 - 0:00:00 - <__main__.TrainerArgs object at 0x7f636a0fdd10>
2021-05-05 03:24:20,394 - INFO - Dataset created. 9 files, augmentation: True. Path: musicnet/preprocessed/Bach_Solo_Cello/train
2021-05-05 03:24:20,433 - INFO - Dataset created. 1 files, augmentation: True. Path: musicnet/preprocessed/Bach_Solo_Cello/val
2021-05-05 03:24:20,437 - INFO - Dataset created. 31 files, augmentation: True. Path: musicnet/preprocessed/Bach_Solo_Piano/train
2021-05-05 03:24:20,439 - INFO - Dataset created. 3 files, augmentation: True. Path: musicnet/preprocessed/Bach_Solo_Piano/val


/content/checkpoints/musicnet_umtcpc


INFO - 05/05/21 03:24:24 - 0:00:04 - Starting epoch, Rank 0, Dataset: musicnet/preprocessed/Bach_Solo_Cello
Train (loss: 9.41) epoch 0: 100%|██████████| 20/20 [00:16<00:00,  1.20it/s]
Test (loss: 9.37) epoch 0: 100%|██████████| 2/2 [00:00<00:00,  3.16it/s]
INFO - 05/05/21 03:24:41 - 0:00:21 - Epoch 0 Rank 0 - Train loss: (5.3855, 5.4266, 4.2339, 4.3137, 0.6909), Test loss (5.4017, 5.1848, 4.2284, 4.1787, 0.5000)
INFO - 05/05/21 03:24:43 - 0:00:23 - Starting epoch, Rank 0, Dataset: musicnet/preprocessed/Bach_Solo_Cello
Train (loss: 8.97) epoch 1: 100%|██████████| 20/20 [00:15<00:00,  1.31it/s]
Test (loss: 8.84) epoch 1: 100%|██████████| 2/2 [00:00<00:00,  3.20it/s]
INFO - 05/05/21 03:24:59 - 0:00:39 - Epoch 1 Rank 0 - Train loss: (4.9224, 4.9729, 4.1866, 4.2243, 0.6846), Test loss (4.7551, 4.6926, 4.1859, 4.1594, 0.5000)
INFO - 05/05/21 03:25:02 - 0:00:42 - Starting epoch, Rank 0, Dataset: musicnet/preprocessed/Bach_Solo_Cello
Train (loss: 8.27) epoch 2: 100%|██████████| 20/20 [00:15<00

In [28]:
def run_on_files_umtcpc(files, output, checkpoint, decoders=[], rate=16000, batch_size=6, sample_len=None, split_size=20, output_next_to_orig=False,
                 skip_filter=False, py=True):

    print('Starting')
    matplotlib.use('agg')

    checkpoints = checkpoint.parent.glob(checkpoint.name + '_*.pth')
    checkpoints = [c for c in checkpoints if extract_id(c) in decoders]
    assert len(checkpoints) >= 1, "No checkpoints found."

    model_args = torch.load(checkpoint.parent / 'args.pth')[0]
    encoder = CPC(model_args)
    encoder.load_state_dict(torch.load(checkpoints[0])['encoder_state'])
    encoder.eval()
    encoder = encoder.cuda()

    decoders = []
    decoder_ids = []
    for checkpoint in checkpoints:
        decoder = WaveNet(model_args)
        decoder.load_state_dict(torch.load(checkpoint)['decoder_state'])
        decoder.eval()
        decoder = decoder.cuda()
        if py:
            decoder = WavenetGenerator(decoder, batch_size, wav_freq=rate)
        else:
            decoder = NVWavenetGenerator(decoder, rate * (split_size // 20), batch_size, 3)

        decoders += [decoder]
        decoder_ids += [extract_id(checkpoint)]

    xs = []
    assert output_next_to_orig ^ (output is not None)

    if len(files) == 1 and files[0].is_dir():
        top = files[0]
        file_paths = list(top.glob('**/*.wav')) + list(top.glob('**/*.h5'))
    else:
        file_paths = files

    if not skip_filter:
        file_paths = [f for f in file_paths if not '_' in str(f.name)]

    for file_path in file_paths:
        if file_path.suffix == '.wav':
            data, rate = librosa.load(file_path, sr=16000)
            assert rate == 16000
            data = mu_law(data)
        elif file_path.suffix == '.h5':
            data = mu_law(h5py.File(file_path, 'r')['wav'][:] / (2 ** 15))
            if data.shape[-1] % rate != 0:
                data = data[:-(data.shape[-1] % rate)]
            assert data.shape[-1] % rate == 0
            print(data.shape)
        else:
            raise Exception(f'Unsupported filetype {file_path}')

        if sample_len:
            data = data[:sample_len]
        else:
            sample_len = len(data)
        xs.append(torch.tensor(data).unsqueeze(0).float().cuda())

    xs = torch.stack(xs).contiguous()
    print(f'xs size: {xs.size()}')

    def save(x, decoder_ix, filepath):
        wav = inv_mu_law(x.cpu().numpy())
        print(f'X size: {x.shape}')
        print(f'X min: {x.min()}, max: {x.max()}')

        if output_next_to_orig:
            save_audio(wav.squeeze(), filepath.parent / f'{filepath.stem}_{decoder_ix}.wav', rate=rate)
        else:
            save_audio(wav.squeeze(), output / str(decoder_ix) / filepath.with_suffix('.wav').name, rate=rate)

    yy = {}
    with torch.no_grad():
        cc = []
        for xs_batch in torch.split(xs, batch_size):
            z, c = encoder(xs_batch)
            cc += [c]
        cc = torch.cat(cc, dim=0)

        with timeit("Generation timer"):
            for i, decoder_id in enumerate(decoder_ids):
                yy[decoder_id] = []
                decoder = decoders[i]
                for cc_batch in torch.split(cc, batch_size):
                    print(cc_batch.shape)
                    splits = torch.split(cc_batch, split_size, -1)
                    audio_data = []
                    decoder.reset()
                    for cond in tqdm.tqdm(splits):
                        audio_data += [decoder.generate(cond).cpu()]
                    audio_data = torch.cat(audio_data, -1)
                    yy[decoder_id] += [audio_data]
                yy[decoder_id] = torch.cat(yy[decoder_id], dim=0)
                del decoder

    for decoder_ix, decoder_result in yy.items():
        for sample_result, filepath in zip(decoder_result, file_paths):
            save(sample_result, decoder_ix, filepath)

In [32]:
checkpoint_args = Path('/content/checkpoints/musicnet_umtcpc/args.pth')
output = Path('/content/results')
checkpoint_lastmodel = Path('/content/checkpoints/musicnet_umtcpc/lastmodel')

# Extract data samples to use as input for translation
data_samples(None, checkpoint_args, output, 1, 5000)

files = [Path('/content/results')]
run_on_files_umtcpc(files, None, checkpoint_lastmodel, decoders=[0, 1, 2], batch_size=2, output_next_to_orig=True)

2021-05-05 04:11:48,586 - INFO - Dataset created. 2 files, augmentation: False. Path: musicnet/preprocessed/Bach_Solo_Cello/test
2021-05-05 04:11:48,587 - INFO - Dataset created. 5 files, augmentation: False. Path: musicnet/preprocessed/Bach_Solo_Piano/test






100%|██████████| 1/1 [00:00<00:00, 92.51it/s]






100%|██████████| 1/1 [00:00<00:00, 73.81it/s]


Starting








  0%|          | 0/2 [00:00<?, ?it/s][A[A[A[A[A[A






Generating:   0%|          | 0/20 [00:00<?, ?it/s][A[A[A[A[A[A[A

xs size: torch.Size([2, 1, 5000])
torch.Size([2, 64, 31])









Generating:   5%|▌         | 1/20 [00:17<05:31, 17.44s/it][A[A[A[A[A[A[A






Generating:  10%|█         | 2/20 [00:34<05:10, 17.26s/it][A[A[A[A[A[A[A






Generating:  15%|█▌        | 3/20 [00:50<04:46, 16.83s/it][A[A[A[A[A[A[A






Generating:  20%|██        | 4/20 [01:05<04:24, 16.52s/it][A[A[A[A[A[A[A






Generating:  25%|██▌       | 5/20 [01:21<04:04, 16.28s/it][A[A[A[A[A[A[A






Generating:  30%|███       | 6/20 [01:37<03:46, 16.17s/it][A[A[A[A[A[A[A






Generating:  35%|███▌      | 7/20 [01:53<03:29, 16.10s/it][A[A[A[A[A[A[A






Generating:  40%|████      | 8/20 [02:09<03:12, 16.07s/it][A[A[A[A[A[A[A






Generating:  45%|████▌     | 9/20 [02:25<02:56, 16.07s/it][A[A[A[A[A[A[A






Generating:  50%|█████     | 10/20 [02:41<02:40, 16.02s/it][A[A[A[A[A[A[A






Generating:  55%|█████▌    | 11/20 [02:57<02:25, 16.11s/it][A[A[A[A[A[A[A






Generating:  60%|██████    | 12/20 [03:13<02:

torch.Size([2, 64, 31])









Generating:   5%|▌         | 1/20 [00:15<04:58, 15.70s/it][A[A[A[A[A[A[A






Generating:  10%|█         | 2/20 [00:31<04:41, 15.65s/it][A[A[A[A[A[A[A






Generating:  15%|█▌        | 3/20 [00:46<04:26, 15.67s/it][A[A[A[A[A[A[A






Generating:  20%|██        | 4/20 [01:02<04:10, 15.66s/it][A[A[A[A[A[A[A






Generating:  25%|██▌       | 5/20 [01:18<03:54, 15.63s/it][A[A[A[A[A[A[A






Generating:  30%|███       | 6/20 [01:33<03:37, 15.57s/it][A[A[A[A[A[A[A






Generating:  35%|███▌      | 7/20 [01:49<03:21, 15.53s/it][A[A[A[A[A[A[A






Generating:  40%|████      | 8/20 [02:04<03:07, 15.60s/it][A[A[A[A[A[A[A






Generating:  45%|████▌     | 9/20 [02:20<02:51, 15.60s/it][A[A[A[A[A[A[A






Generating:  50%|█████     | 10/20 [02:36<02:37, 15.71s/it][A[A[A[A[A[A[A






Generating:  55%|█████▌    | 11/20 [02:52<02:21, 15.70s/it][A[A[A[A[A[A[A






Generating:  60%|██████    | 12/20 [03:07<02:

Generation timer took 978345.251083374 ms
X size: torch.Size([1, 24800])
X min: 8, max: 244
X size: torch.Size([1, 24800])
X min: 1, max: 254
X size: torch.Size([1, 24800])
X min: 26, max: 227
X size: torch.Size([1, 24800])
X min: 23, max: 226
