In [1]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import torch.nn.init as init
from torch.autograd import Variable
from torch.autograd import Variable
import torchvision.models as models

import numpy as np
import matplotlib.pyplot as plt
import math
from graphviz import Digraph
import re

% matplotlib inline

CUDA = False

In [2]:
class CausalConv1d(nn.Conv1d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True):
        super(CausalConv1d, self).__init__(in_channels, out_channels, kernel_size, stride=stride, padding=0,
            dilation=dilation, groups=groups, bias=bias)

        self.left_padding = dilation * (kernel_size - 1)
        
    def forward(self, input):
        temp = torch.unsqueeze(input, 0)
        x = F.pad(temp, (self.left_padding, 0, 0, 0))
        x = torch.squeeze(x, 0)
        
        return super(CausalConv1d, self).forward(x)

In [67]:
class OneHot(nn.Module):
    def __init__(self, quant):
        super(OneHot, self).__init__()
        self.one = torch.sparse.torch.eye(quant)
        
    def forward(self, input):
        return Variable(self.one.index_select(0, input.data)).unsqueeze(0).transpose(1,2)

In [71]:
data = Variable(torch.LongTensor(np.random.randint(1,5,10)))

onehot = OneHot(5)
onehot(data), data

(Variable containing:
 (0 ,.,.) = 
    0   0   0   0   0   0   0   0   0   0
    0   0   0   0   1   1   0   0   0   0
    0   0   0   0   0   0   0   0   0   1
    0   1   1   0   0   0   1   0   0   0
    1   0   0   1   0   0   0   1   1   0
 [torch.FloatTensor of size 1x5x10], Variable containing:
  4
  3
  3
  4
  1
  1
  3
  4
  4
  2
 [torch.LongTensor of size 10])

In [76]:
class WaveNet(nn.Module):
    def __init__(self, quant = 256, res_size = 512, skip_size = 256, dilation_layers = 10, stacks = 3):
        super(WaveNet, self).__init__()
        self.dilation_layers = dilation_layers
        self.dilations = dilations = [(2**dilation) for dilation in range(dilation_layers)] * stacks
        self.one_hot = OneHot(quant)
        
        self.causal_conv = CausalConv1d(quant, res_size, 1)
        
        self.dial_tanh_conv = nn.ModuleList([CausalConv1d(skip_size, skip_size, 2, dilation = d) for d in dilations])
        self.dial_sigm_conv = nn.ModuleList([CausalConv1d(skip_size, skip_size, 2, dilation = d) for d in dilations])
        
        self.dial_skip_conv = nn.ModuleList([CausalConv1d(skip_size, skip_size, 1) for _ in dilations])
        self.dial_res_conv = nn.ModuleList([CausalConv1d(skip_size, res_size, 1) for _ in dilations])
        
        self.end_conv1 = nn.Conv1d(in_channels = skip_size, out_channels = skip_size, kernel_size = 1)
        self.end_conv2 = nn.Conv1d(in_channels = skip_size, out_channels = quant, kernel_size = 1)

    def forward(self, input):
        output = self.one_hot(input)
        output = self.causal_conv(output)
        
        skip_sum = []
        for s, t, skip_conv, res_conv in zip(self.dial_sigm_conv, self.dial_tanh_conv, self.dial_skip_conv, self.dial_res_conv):
            res_output = output
            
            gate_output = self.gated_unit(res_output, s, t)
            
            output = res_conv(gate_output)
            output = output + res_output[:,:,-output.size(2):]
            
            skip = skip_conv(gate_output)
            skip_sum.append(skip)
        
        output = sum([s[:,:,-output.size(2):] for s in skip_sum])
        
        output = self.postprocess(output)
        
        return
    
    def gated_unit(self, input, dial_sigm_conv, dial_tanh_conv):
        input_sigmoid = input[0][256:].unsqueeze(0)
        input_tanh = input[0][:256].unsqueeze(0)
        
        output_sigmoid = dial_sigm_conv(input_sigmoid)
        output_tanh = dial_tanh_conv(input_tanh)
        
        output = nn.functional.sigmoid(output_sigmoid) * nn.functional.tanh(output_tanh)
        
        return output
        
    def postprocess(self, input):
        output = nn.functional.elu(input)
        output = self.end_conv1(output)
        output = nn.functional.elu(output)
        output = self.end_conv2(output).squeeze(0).transpose(0,1)

        return output

In [77]:
net = WaveNet()
batch = Variable(torch.from_numpy(np.random.randint(0,256,10000).astype(np.long)))
net(batch)

In [56]:
data = Variable(torch.rand([1,512,5000])); data

Variable containing:
( 0  ,.,.) = 
  5.6435e-01  5.4275e-01  7.0707e-01  ...   4.9807e-01  7.1019e-01  2.9433e-01
  4.2612e-01  1.9716e-01  9.5337e-01  ...   3.3937e-01  8.2537e-01  3.3836e-02
  3.1062e-01  6.4660e-01  7.5954e-02  ...   6.8057e-01  6.8052e-01  1.0355e-01
                 ...                   ⋱                   ...                
  9.4133e-01  8.9128e-01  8.8054e-01  ...   8.1223e-01  1.2172e-01  6.8401e-01
  8.3522e-01  3.3828e-01  8.3925e-01  ...   1.5371e-01  2.7752e-01  7.9269e-01
  9.2236e-01  2.5876e-01  3.6675e-04  ...   2.6285e-01  5.5589e-01  2.4504e-02
[torch.FloatTensor of size 1x512x5000]

In [58]:
data1 = data[0][256:]; data1.unsqueeze(0)

Variable containing:
( 0  ,.,.) = 
  8.1559e-01  7.2862e-01  8.2748e-01  ...   4.8296e-01  3.3091e-01  5.1204e-01
  3.7281e-01  9.1877e-01  3.3286e-01  ...   4.6933e-01  1.8827e-01  1.0883e-01
  7.5145e-01  4.0418e-01  4.2649e-01  ...   4.6620e-01  8.4367e-01  1.5885e-01
                 ...                   ⋱                   ...                
  9.4133e-01  8.9128e-01  8.8054e-01  ...   8.1223e-01  1.2172e-01  6.8401e-01
  8.3522e-01  3.3828e-01  8.3925e-01  ...   1.5371e-01  2.7752e-01  7.9269e-01
  9.2236e-01  2.5876e-01  3.6675e-04  ...   2.6285e-01  5.5589e-01  2.4504e-02
[torch.FloatTensor of size 1x256x5000]

In [59]:
data2 = data[0][:256]; data2.unsqueeze(0)

Variable containing:
( 0  ,.,.) = 
  5.6435e-01  5.4275e-01  7.0707e-01  ...   4.9807e-01  7.1019e-01  2.9433e-01
  4.2612e-01  1.9716e-01  9.5337e-01  ...   3.3937e-01  8.2537e-01  3.3836e-02
  3.1062e-01  6.4660e-01  7.5954e-02  ...   6.8057e-01  6.8052e-01  1.0355e-01
                 ...                   ⋱                   ...                
  3.6317e-01  1.2844e-01  1.7922e-01  ...   6.7068e-01  6.4877e-01  9.3521e-02
  9.7565e-02  6.3691e-01  3.1665e-01  ...   8.2483e-01  2.6525e-01  2.0491e-01
  8.3155e-01  6.1810e-01  8.8408e-01  ...   5.6680e-02  2.1012e-01  1.1484e-01
[torch.FloatTensor of size 1x256x5000]