In [1]:
import torch
import torch.nn as nn

In [3]:
class DilatedCausalConv1d(nn.Module):
    def __init__(self,channels,dilation=1):
        super().__init__()
        
        self.conv = nn.Conv1d(channels,channels,kernel_size=2,stride=1,dilation=dilation,padding=0)
        
    def init_weights_for_test(self):
        for m in self.modules:
            if isinstance(m,nn.Conv1d):
                m.weight.data.fill_(1)
                
    def forward(self,x):
        output = self.conv(x)
        return output

In [5]:
class CausalConv1d(nn.Module):
    def __init__(self,in_channels,out_channels):
        super().__init__()
        
        self.conv = nn.Conv1d(in_channels,out_channels,kernel_size=2,stride=1,padding=1)
        
    def init_weights_for_test(self):
        for m in self.modules:
            if isinstance(m,nn.Conv1d):
                m.weight.data.fill_(1)
                    
    def forward(self,x):
        output = self.conv(x)
        return output[:,:,:,-1]

In [7]:
class ResidualBlock(nn.Module):
    def __init__(self,res_channels,skip_channels,dilation):
        super().__init__()
        
        self.dilated = DilatedCausalConv1d(res_channels,dilation=dilation)
        self.conv_res = nn.Conv1d(res_channels,res_channels,1)
        self.conv_skip = nn.Conv1d(res_channels,skip_channels,1)
        
        self.gate_tanh = nn.Tanh()
        self.gate_sigmoid = nn.Sigmoid()
        
    def forward(self,x,skip_size):
        output = self.dilated(x)
        
        gated_tanh = self.gate_tanh(output)
        gated_sigmoid = self.gate_sigmoid(output)
        gated = gated_tanh*gated_sigmoid
        
        output = self.conv(gated)
        input_cut = x[:, :, -output.size(2):]
        output += input_cut
        
        skip = self.conv_skip(gated)
        skip = skip[:, :, -skip_size:]

        return output, skip

In [8]:
class ResidualStack(torch.nn.Module):
    def __init__(self, layer_size, stack_size, res_channels, skip_channels):
        """
        Stack residual blocks by layer and stack size
        :param layer_size: integer, 10 = layer[dilation=1, dilation=2, 4, 8, 16, 32, 64, 128, 256, 512]
        :param stack_size: integer, 5 = stack[layer1, layer2, layer3, layer4, layer5]
        :param res_channels: number of residual channel for input, output
        :param skip_channels: number of skip channel for output
        :return:
        """
        super(ResidualStack, self).__init__()

        self.layer_size = layer_size
        self.stack_size = stack_size

        self.res_blocks = self.stack_res_block(res_channels, skip_channels)

    def _residual_block(res_channels, skip_channels, dilation):
        block = ResidualBlock(res_channels, skip_channels, dilation)

        if torch.cuda.device_count() > 1:
            block = torch.nn.DataParallel(block)

        if torch.cuda.is_available():
            block.cuda()

        return block
    
    def build_dilations(self):
        dilations = []

        # 5 = stack[layer1, layer2, layer3, layer4, layer5]
        for s in range(0, self.stack_size):
            # 10 = layer[dilation=1, dilation=2, 4, 8, 16, 32, 64, 128, 256, 512]
            for l in range(0, self.layer_size):
                dilations.append(2 ** l)

        return dilations
    
    def stack_res_block(self, res_channels, skip_channels):
        """
        Prepare dilated convolution blocks by layer and stack size
        :return:
        """
        res_blocks = []
        dilations = self.build_dilations()

        for dilation in dilations:
            block = self._residual_block(res_channels, skip_channels, dilation)
            res_blocks.append(block)

        return res_blocks

In [9]:
class DensNet(nn.Module):
    def __init__(self, channels):
        """
        The last network of WaveNet
        :param channels: number of channels for input and output
        :return:
        """
        super().__init__()

        self.conv1 = nn.Conv1d(channels, channels, 1)
        self.conv2 = nn.Conv1d(channels, channels, 1)

        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        output = self.relu(x)
        output = self.conv1(output)
        output = self.relu(output)
        output = self.conv2(output)

        output = self.softmax(output)

        return output

In [10]:
class WaveNet(nn.Module):
    def __init__(self, in_depth=256, res_channels=32, skip_channels=512, dilation_depth=10, n_repeat=5):
        super(WaveNet, self).__init__()
        self.dilations = [2**i for i in range(dilation_depth)] * n_repeat
        self.main = nn.ModuleList([ResidualBlock(res_channels,skip_channels,dilation) for dilation in self.dilations])
        self.pre = nn.Embedding(in_depth, res_channels)
        #self.pre_conv = CausalConv1d(in_channels=res_channels, out_channels=res_channels)
        self.post = nn.Sequential(nn.ReLU(),
                                  nn.Conv1d(skip_channels,skip_channels,1),
                                  nn.ReLU(),
                                  nn.Conv1d(skip_channels,in_depth,1))
        
    def forward(self,inputs):
        """
        The size of timestep(3rd dimention) has to be bigger than receptive fields
        :param x: Tensor[batch, timestep, channels]
        :return: Tensor[batch, timestep, channels]
        """
        outputs = self.preprocess(inputs)
        skip_connections = []
        
        for layer in self.main:
            outputs,skip = layer(outputs)
            skip_connections.append(skip)
            
        outputs = sum([s[:,:,-outputs.size(2):] for s in skip_connections])
        outputs = self.post(outputs)
        
        return outputs
    
    def preprocess(self,inputs):
        out = self.pre(inputs).transpose(1,2)
        #out = self.pre_conv(out)
        return out