In [1]:
import numpy as np
from matplotlib import pyplot as plt
import mxnet as mx
from mxnet import gluon, autograd, nd
from mxnet.gluon import nn,utils 
import mxnet.ndarray as F
from IPython.display import Audio
from scipy.io import wavfile
from tqdm import tqdm
import sys

#### WaveNet 구현을 위한 class를 정의
+ 우선 integer값을 one-hot으로 변환하기 위한 class를 정의

In [2]:
class One_Hot(nn.Block):
    def __init__(self, depth):
        super(One_Hot,self).__init__()
        self.depth = depth
        
    def forward(self, X_in):
        with X_in.context:
            X_in = X_in
            self.ones = nd.one_hot(nd.arange(self.depth), self.depth)
            return self.ones[X_in,:]
        
    def __repr(self):
        return self.__class__.__name__+"({})".format(self.depth)

In [3]:
class WaveNet(nn.Block):
    def __init__(self, mu=256,n_residue=32, n_skip= 512, dilation_depth=10, n_repeat=5):
        # mu: audio quantization size
        # n_residue: residue channels
        # n_skip: skip channels
        # dilation_depth & n_repeat: dilation layer setup
        super(WaveNet, self).__init__()
        self.dilation_depth = dilation_depth
        self.dilations = [2**i for i in range(dilation_depth)] * n_repeat      
        with self.name_scope():
            self.one_hot = One_Hot(mu)
            self.from_input = nn.Conv1D(in_channels=mu, channels=n_residue, kernel_size=1)
            self.conv_sigmoid = nn.Sequential()
            self.conv_tanh = nn.Sequential()
            self.skip_scale = nn.Sequential()
            self.residue_scale = nn.Sequential()
            for d in self.dilations:
                self.conv_sigmoid.add(nn.Conv1D(in_channels=n_residue, channels=n_residue, kernel_size=2, dilation=d))
                self.conv_tanh.add(nn.Conv1D(in_channels=n_residue, channels=n_residue, kernel_size=2, dilation=d))
                self.skip_scale.add(nn.Conv1D(in_channels=n_residue, channels=n_skip, kernel_size=1, dilation=d))
                self.residue_scale.add(nn.Conv1D(in_channels=n_residue, channels=n_residue, kernel_size=1, dilation=d))
            self.conv_post_1 = nn.Conv1D(in_channels=n_skip, channels=n_skip, kernel_size=1)
            self.conv_post_2 = nn.Conv1D(in_channels=n_skip, channels=mu, kernel_size=1)
        
    def forward(self,x):
        with x.context:
            output = self.preprocess(x)
            skip_connections = [] # save for generation purposes
            for s, t, skip_scale, residue_scale in zip(self.conv_sigmoid, self.conv_tanh, self.skip_scale, self.residue_scale):
                output, skip = self.residue_forward(output, s, t, skip_scale, residue_scale)
                skip_connections.append(skip)
            # sum up skip connections
            output = sum([s[:,:,-output.shape[2]:] for s in skip_connections])
            output = self.postprocess(output)
        return output
        
    def preprocess(self, x):
        output = F.transpose(self.one_hot(x).expand_dims(0),axes=(0,2,1))
        output = self.from_input(output)
        return output

    def postprocess(self, x):
        output = F.relu(x)
        output = self.conv_post_1(output)
        output = F.relu(output)
        output = self.conv_post_2(output)
        output = nd.reshape(output,(output.shape[1],output.shape[2]))
        output = F.transpose(output,axes=(1,0))
        return output
    
    def residue_forward(self, x, conv_sigmoid, conv_tanh, skip_scale, residue_scale):
        output = x
        output_sigmoid, output_tanh = conv_sigmoid(output), conv_tanh(output)
        output = F.sigmoid(output_sigmoid) * F.tanh(output_tanh)
        skip = skip_scale(output)
        output = residue_scale(output)
        output = output + x[:,:,-output.shape[2]:]
        return output, skip

+ stack of dilated casual convolution을 수행한 후에 residual_forward를 거친 후 최종적으로 Softmax를 거쳐서 우리가 원하는 class별 확률값을 산출한다.
+ 다음으로 Softmax Distribution 적용을 위한 Encoder/Decoder 함수를 정의. 함수 내용은 수식을 코드로 바꾼거

In [4]:
def encode_mu_law(x, mu=256):
    mu = mu-1
    fx = np.sign(x) * np.log(1+mu*np.abs(x))/np.log(1+mu)
    return np.floor((fx+1)/2*mu+0.5)/astype(np.long)

def decode_mu_law(y, mu=256):
    mu = mu-1
    fx = (y-0.5)/mu*2-1
    x = np.sign(fx)/mu*((1+mu)**np.abs(fx)-1)
    return x

+ 다음으로 네트워크 학습을 위해 네트워크 정의하고, parameter 및 Optimizer를 설정. 
+ Gluon에서 GPU를 사용하기 위해서는 별도의 context를 지정해주어야하기 떄문에 그 부분을 포함해서 코드를 작성

In [None]:
ctx = mx.gpu(1)
net = WaveNet(mu=128, n_residue=24, n_skip=128, dilation_depth=10, n_repeat=2)
net.collect_params().initialize(ctx=ctx)
#set optimizer
trainer = gluon.Trainer(net.collect_params(), optimizer='adam', optimizer_params={'learning_rate:0.01'})
g = data_generation(data, fs, mu=128, seq_size = 20000, ctx=ctx)
batch_size = 64
loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()

+ WAV데이터는 다음의 module를 통해 불러옴. 
+ 이 때 loading WAV가 stereo인 경우에는 2-dim형태이기 때문에 이를 mono 형태로 변환해야함

In [None]:
import os
from scipy.io import wavfile
fs, data = wavfile.read(os.getcwd()+'/parametric-2.wav')
##stereo convert to mono
data = data.sum(axis=1) / 2

+ 해당 module을 통해 입력한 데이터를 Softmax Distribution을 수행하기 전에 전처리 과정필요
+ 데이터는 -1 ~ 1 사이의 scale로 변환 시켜야함
+ 그 이후 학습을 위한 데이터를 random으로 생성하기 위해 while문 활용

In [None]:
def data_generation(data,framerate,seq_size = 6000, mu=256,ctx=ctx):
    div = max(data.max(),abs(data.min()))
    data = data/div
    while True:
        start = np.random.randint(0,data.shape[0]-seq_size)
        ys = data[start:start+seq_size]
        ys = encode_mu_law(ys,mu)
        yield nd.array(ys[:seq_size],ctx=ctx)

+ 그 다음 네트워크 학습
+ 1000epoch를 학습하는 형태로 작성

In [None]:
loss_save = []
max_epoch = 1000
best_loss = sys.maxsize
for epoch in range(max_epoch):
    loss = 0.0
    for _ in tqdm(range(batch_size)):
        batch = next(g)
        x = batch[:-1]
        with autograd.record():
            logits = net(x)
            sz = logits.shape[0]
            loss = loss + loss_fn(logits, batch[-sz:])
            #loss = loss/batch_size
        loss.backward()
        trainer.step(1,ignore_stale_grad=True)
    loss_save.append(nd.sum(loss).asscalar()/batch_size)

    #save the best model
    current_loss = nd.sum(loss).asscalar()/batch_size
    if best_loss > current_loss:
        print('epoch {}, loss {}'.format(epoch, nd.sum(loss).asscalar()/batch_size))
        print("====save best model====")
        filename = '/home/skinet/work/research/WaveNet/models/best_perf_epoch_'+str(epoch)+"_loss_"+str(current_loss)
        net.save_params(filename)
        best_loss = current_loss
    
    # monitor progress
    if epoch%100==0:        
        batch = next(g)
        logits = net(batch[:-1])
        #_, i = logits.max(dim=1)
        i = logits.argmax(1).asnumpy()
        plt.figure(figsize=[16,4])
        plt.plot(list(i))
        plt.plot(list(batch.asnumpy())[sum(net.dilations)+1:],'.',ms=1)
        plt.title('epoch {}'.format(epoch))
        plt.tight_layout()
        plt.show()

+ 네트워크 학습을 수행한 후에 학습 결과를 바탕으로 하여 generation을 수행
+ generation은 다음의 함수를 통해 수행. for-loop 형태로 되어 있어서 속도가 다소 느림

In [None]:
def generate_slow(x,models,dilation_depth,n_repeat,ctx, n=100):
    dilations = [2**i for i in range(dilation_depth)] * n_repeat 
    res = list(x.asnumpy())
    for _ in tqdm(range(n)):
        x = nd.array(res[-sum(dilations)-1:],ctx=ctx)
        y = models(x)
        res.append(y.argmax(1).asnumpy()[-1])
    return res