![](../assets/architecture.png)

In this notebook we will be putting all the pieces together to form the full wavenet architecture. Some considerations are:
- Resizing the global and local conditioning inputs
- How many casual convolution
- What blocks are
- Any changes we have to make the model trainable

In [1]:
from model import CasualConv1D, ResidualLayer, Head

In [2]:
import torch
import torchaudio.datasets as datasets

In [3]:
dataset = datasets.LJSPEECH("../data")

In [4]:
waveform, sample_rate, transcript, normalized_transcript = dataset[0]

waveform, sample_rate, normalized_transcript

(tensor([[-7.3242e-04, -7.6294e-04, -6.4087e-04,  ...,  7.3242e-04,
           2.1362e-04,  6.1035e-05]]),
 22050,
 'Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition')

In [5]:
from preprocess import quantize_and_onehot_waveform

In [6]:
quantize_categories = 256 # From the paper
residual_channels = 32
gate_channels = 32
skip_channels = 512

In [7]:
waveform = waveform.view(1,1,-1)

In [8]:
processed_waveform = quantize_and_onehot_waveform(waveform)
processed_waveform.shape

torch.Size([1, 256, 212893])

In [9]:
cas_out = CasualConv1D(quantize_categories, residual_channels)(processed_waveform.to(torch.float))
cas_out.shape

torch.Size([1, 32, 212892])

In [10]:
layers = 5
skip_connections = []
x = cas_out
for i in range(layers):
    x, skip_out = ResidualLayer(2**i, residual_channels, gate_channels, skip_channels)(x)
    skip_connections.append(skip_out)
    
print(len(skip_connections))

5


In [11]:
for skip in skip_connections:
    print(skip.shape)

torch.Size([1, 512, 212890])
torch.Size([1, 512, 212886])
torch.Size([1, 512, 212878])
torch.Size([1, 512, 212862])
torch.Size([1, 512, 212830])


In [12]:
skip_connections = list(map(lambda skip: skip[:,:,-skip_connections[-1].size(2)], skip_connections))
skip_connections = torch.stack(skip_connections)

In [13]:
skip_connections.shape

torch.Size([5, 1, 512])

In [14]:
head_out = Head(skip_channels, quantize_categories)(skip_connections)
head_out.shape

[W NNPACK.cpp:53] Could not initialize NNPACK! Reason: Unsupported hardware.


torch.Size([1, 256, 1])

In [21]:
import torch.nn as nn
from preprocess import quantize_and_onehot_waveform

class WaveNet(nn.Module):
    def __init__(self, num_residual_layers, num_blocks, num_casual_layers, residual_channels=32, 
                 gate_channels=32, skip_channels=512, quantize_channels=256, local_channels=0, 
                 global_channels=0, device=None):
        super(WaveNet, self).__init__()
        
        self.casual_layers = [CasualConv1D(quantize_channels, residual_channels, device=device)]
        for i in range(num_casual_layers-1): 
            self.casual_layers.append(CasualConv1D(residual_channels, residual_channels, device=device))
        self.casual_layers = nn.Sequential(*self.casual_layers)
        
        residual_layers = [
            ResidualLayer(2**i, residual_channels, gate_channels, skip_channels,
                          local_channels, global_channels, device=device) for i in range(num_residual_layers)
        ]
        self.residual_blocks = nn.ModuleList(residual_layers * num_blocks)
        
        self.head = Head(skip_channels, quantize_channels)
        
    
    def forward(self, inputs, local_inputs=None, global_inputs=None):
        batch_size, _, time_steps = inputs.size()
        
        processed_inputs = quantize_and_onehot_waveform(inputs)
        casual_out = self.casual_layers(processed_inputs)
        
        residual_out = casual_out
        skip_connections = []
        for residual_layer in self.residual_blocks:
            if local_inputs != None:
                # Resize local inputs
                pass
            
            if global_inputs != None:
                # Resize global inputs
                pass
            
            residual_out, skip_out = residual_layer(residual_out, local_inputs, global_inputs)
            skip_connections.append(skip_out)
        
        skip_connections = list(map(lambda skip: skip[:,:,-skip_connections[-1].size(2)], skip_connections))
        skip_connections = torch.stack(skip_connections)
    
        head_out = self.head(skip_connections)
        return head_out

In [43]:
probs = WaveNet(2, 2, 2)(waveform)

In [44]:
max_prob = torch.argmax(probs.view(-1))
max_prob

tensor(74)

In [45]:
from preprocess import decodeMuLaw

In [46]:
decoded_output = decodeMuLaw(max_prob)
decoded_output

tensor(-0.0371)