<a href="https://colab.research.google.com/github/alexandrumeterez/ai_notebooks/blob/master/wavenet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
import time
import math
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
%matplotlib inline

In [140]:
class CausalConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation=1, **kwargs):
        super().__init__()
        self.pad = (kernel_size - 1) * dilation
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=self.pad, dilation=dilation, **kwargs)
    
    def forward(self, x):
        x = self.conv(x)
        return x[:, :, :-self.conv.padding[0]]

class ResidualBlock(nn.Module):
    def __init__(self, residual_channels, skip_channels, dilation):
        super().__init__()
        self.conv = CausalConv1d(residual_channels, residual_channels, 2, dilation=dilation)
        self.conv_residual = nn.Conv1d(residual_channels, residual_channels, 1)
        self.conv_skip = nn.Conv1d(residual_channels, skip_channels, 1)

        self.gate_tanh = nn.Tanh()
        self.gate_sigmoid = nn.Sigmoid()

    def forward(self, x, skip_size):
        output = self.conv(x)

        gated_tanh = self.gate_tanh(output)
        gated_sigmoid = self.gate_sigmoid(output)
        gated = gated_tanh * gated_sigmoid

        output = self.conv_residual(gated)
        input_ = x[:, :, -output.size(2):]
        output += input_

        # Skip connection
        skip = self.conv_skip(gated)
        skip = skip[:, :, -skip_size:]

        return output, skip



class DensNet(torch.nn.Module):
    def __init__(self, channels):
        """
        The last network of WaveNet
        :param channels: number of channels for input and output
        :return:
        """
        super(DensNet, self).__init__()

        self.conv1 = torch.nn.Conv1d(channels, channels, 1)
        self.conv2 = torch.nn.Conv1d(channels, 1, 1)

        self.relu = torch.nn.ReLU()

    def forward(self, x):
        output = self.relu(x)
        output = self.conv1(output)
        output = self.relu(output)
        output = self.conv2(output)


        return output

class WaveNet(nn.Module):
    def __init__(self, n_dilations, n_residuals, in_channels, res_channels):
        super().__init__()
        self.causal = CausalConv1d(in_channels, res_channels, 2, 1)
        self.residual_blocks = []
        self.densenet = DensNet(9)
        for residual_id in range(n_residuals):
            for dilation_id in range(1, n_dilations):
                dilation = 2 ** dilation_id
                res_block = ResidualBlock(res_channels, in_channels, dilation)
                self.residual_blocks.append(res_block)

    def forward(self, x):
        output = self.causal(x)

        skip_connections = []

        for block in self.residual_blocks:
            output, skip = block(output, 0)
            skip_connections.append(skip)
        
        output = torch.sum(torch.stack(skip_connections), dim=0)
        return self.densenet(output)


In [141]:
model = WaveNet(5, 10, 9, 3)
x = torch.randn(33,9,17) #batchsize, channels, timestep
# b t c = 33 17 9 ----> 33 9 17

In [142]:
output = model(x)
output.shape

torch.Size([33, 1, 17])

In [95]:
sc.shape

torch.Size([40, 1, 9, 1])