# WaveNet - Generate a Sample

In [1]:
import sys
#sys.path.append('../../src/')
sys.path.append('../../network/')

In [31]:
import torch
import librosa
import librosa.output
import datetime
import numpy as np
from types import SimpleNamespace
torch.cuda.empty_cache()

In [15]:
import matplotlib.pyplot as plt
plt.style.use('seaborn')

In [3]:
from models.wavenet.model import WaveNet
import models.wavenet.utils.data as utils

In [8]:
params = SimpleNamespace(
    layer_size=10,
    stack_size=5,
    in_channels=256,
    res_channels=512,
    sample_size=10_000,
    sample_rate=22_050,
    length=15
)

In [9]:
class Generator:
    def __init__(self, args):
        self.args = args

        self.wavenet = WaveNet(args.layer_size, args.stack_size,
                               args.in_channels, args.res_channels)

        self.wavenet.load(args.model_dir, args.model_name, args.step)

    @staticmethod
    def _variable(data):
        tensor = torch.from_numpy(data).float()

        if torch.cuda.is_available():
            return torch.autograd.Variable(tensor.cuda())
        else:
            return torch.autograd.Variable(tensor)

    def _make_seed(self, audio):
        audio = np.pad([audio], [[0, 0], [self.wavenet.receptive_fields, 0], [0, 0]], 'constant')

        if self.args.sample_size:
            seed = audio[:, :self.args.sample_size, :]
        else:
            seed = audio[:, :self.wavenet.receptive_fields*2, :]

        return seed

    def _get_seed_from_audio(self, filepath):
        audio = utils.load_audio(filepath, self.args.sample_rate)
        audio_length = len(audio)

        audio = utils.mu_law_encode(audio, self.args.in_channels)
        audio = utils.one_hot_encode(audio, self.args.in_channels)

        seed = self._make_seed(audio)

        return self._variable(seed), audio_length

    def _save_to_audio_file(self, data):
        data = data[0].cpu().data.numpy()
        data = utils.one_hot_decode(data, axis=1)
        audio = utils.mu_law_decode(data, self.args.in_channels)

        librosa.output.write_wav(self.args.out, np.array(audio, dtype="float32", self.args.sample_rate)
        print('Saved wav file at {}'.format(self.args.out))

        return audio#librosa.get_duration(y=audio, sr=self.args.sample_rate)

    def generate(self):
        
        with torch.no_grad():
            outputs = []
            inputs, audio_length = self._get_seed_from_audio(self.args.seed)

            while True:
                new = self.wavenet.generate(inputs)

                outputs = torch.cat((outputs, new), dim=1) if len(outputs) else new

                print('{0}/{1} samples are generated.'.format(len(outputs[0]), self.args.length*self.args.sample_rate))

                if len(outputs[0]) >= self.args.length*self.args.sample_rate:
                    break

                inputs = torch.cat((inputs[:, :-len(new[0]), :], new), dim=1)

            outputs = outputs[:, :audio_length, :]

        return self._save_to_audio_file(outputs)

In [10]:
params.model_dir = '../../network/weights/wavenet/'
params.model_name = 'wavenet-tapping-glass-tiny-jar'
params.step = 0
params.seed = '../../data/processed/tapping/tapping-glass/PLhDdb5CgZ4-tiny-jar.wav'
params.out = '../../network/outputs/wavenet/wavenet-out-tapping-glass-tiny-jar.wav'

In [11]:
generator = Generator(params)
x = generator.generate()

2 GPUs are detected.
Loading model from ../../network/weights/wavenet/
../../data/processed/tapping/tapping-glass/PLhDdb5CgZ4-tiny-jar.wav
4885/752000 samples are generated.
9770/752000 samples are generated.
14655/752000 samples are generated.
19540/752000 samples are generated.
24425/752000 samples are generated.
29310/752000 samples are generated.
34195/752000 samples are generated.
39080/752000 samples are generated.
43965/752000 samples are generated.
48850/752000 samples are generated.
53735/752000 samples are generated.
58620/752000 samples are generated.
63505/752000 samples are generated.
68390/752000 samples are generated.
73275/752000 samples are generated.
78160/752000 samples are generated.
83045/752000 samples are generated.
87930/752000 samples are generated.
92815/752000 samples are generated.
97700/752000 samples are generated.
102585/752000 samples are generated.
107470/752000 samples are generated.
112355/752000 samples are generated.
117240/752000 samples are genera

In [24]:
seed, sr = librosa.load(params.seed)

(1036350,)

In [40]:
audio = seed.copy()
audio[:len(x)] = x
audio = audio[:len(x)]

In [42]:
audio.__class__

numpy.ndarray

In [46]:
librosa.output.write_wav(params.out, audio, params.sample_rate)