To condition the wavenet on linguistic feature we will need to implment local conditioning to our current implmentation of the residual layer. There are three additions that are needed to achieve a locally conditioned wavenet:
- upsampling to make conditioned input the same resolution of waveform
- conditoned input
- 1x1 conovlution filter for conditioned input

To achieve this we use the following equation from the paper:

z = tanh (Wf,k ∗ x + Vf,k ∗ y) σ (Wg,k ∗ x + Vg,k ∗ y),

where Vf,k ∗y is now a 1×1 convolution and y = f (h) where f is an upsampling transformation

In [1]:
import torch

We will first look at how upsampling 1d audio samples work

In [2]:
import torch.nn as nn

In [16]:
sample = torch.arange(1, 21, dtype=torch.float).view(1,1,-1)
sample.shape

torch.Size([1, 1, 20])

In [10]:
nn.ConvTranspose1d(1, 1, 10, 10)(sample).shape

torch.Size([1, 1, 200])

Our previous implmentation of the residual layer

```
class ResidualLayer(nn.Module):
    def __init__(self, dilation, in_channels=1, residual_channels=32, skip_channels=512):
        super(ResidualLayer, self).__init__()
        self.filter_dialated_conv = DialatedConv1d(in_channels, residual_channels, dilation)
        self.gate_dialted_conv = DialatedConv1d(in_channels, residual_channels, dilation)
        self.residual_1x1 = Conv1d1x1(residual_channels, residual_channels)
        self.skip_1x1 = Conv1d1x1(residual_channels, skip_channels)
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, inputs):
        assert inputs.dim() == 3 # To clip the inputs
        filter_out, gate_out = self.filter_dialated_conv(inputs), self.gate_dialted_conv(inputs)
        z = self.tanh(filter_out) * self.sigmoid(gate_out)
        residual_out, skip_out  = self.residual_1x1(z), self.skip_1x1(z)
        clipped_inputs = inputs[:,:, -residual_out.size(2):]
        residual_out += clipped_inputs
        
        return residual_out, skip_out
```

In [99]:
from model import DialatedConv1d, Conv1d1x1

class ResidualLayer(nn.Module):
    def __init__(self, dilation, residual_channels=32, gate_channels=32, skip_channels=512,
                 local_channels=0, global_channels=0):
        super(ResidualLayer, self).__init__()
        self.dilated_conv = DialatedConv1d(residual_channels, gate_channels, dilation)
        
        self.local_1x1 = self.global_1x1 = None
        if local_channels > 0:
            self.local_1x1 = Conv1d1x1(local_channels, gate_channels, bias=False)
        
        if global_channels > 0:
            self.global_1x1 = Conv1d1x1(global_channels, gate_channels, bias=False)
        
        
        self.residual_1x1 = Conv1d1x1(gate_channels, residual_channels)
        self.skip_1x1 = Conv1d1x1(gate_channels, skip_channels)
        
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, inputs, global_inputs=None, local_inputs=None):
        assert inputs.dim() == 3 # To clip the inputs
        conv_out = self.dilated_conv(inputs)
        
        if global_inputs != None:
            print(global_inputs.size(2), inputs.size(2))
            assert self.global_1x1 != None and global_inputs.size(2) == conv_out.size(2)
            global_out = self.global_1x1(global_inputs)   
            conv_out += global_out
            
        if local_inputs != None:
            assert self.local_1x1 != None and local_inputs.size(2) == conv_out.size(2)
            local_out = self.local_1x1(local_inputs)
            conv_out += local_out
        
        z = self.tanh(conv_out) * self.sigmoid(conv_out)
        residual_out, skip_out  = self.residual_1x1(z), self.skip_1x1(z)
        
        clipped_inputs = inputs[:,:, -residual_out.size(2):]
        residual_out += clipped_inputs
        
        return residual_out, skip_out

In [100]:
sample_data = torch.randn((1,32,300), dtype=torch.float)
sample_data.shape

torch.Size([1, 32, 300])

In [101]:
residual_out, skip_out = ResidualLayer(2)(sample_data)
residual_out.shape, skip_out.shape

(torch.Size([1, 32, 296]), torch.Size([1, 512, 296]))

In [102]:
residual_out, skip_out = ResidualLayer(4)(residual_out)
residual_out.shape, skip_out.shape

(torch.Size([1, 32, 288]), torch.Size([1, 512, 288]))

In [103]:
global_input = torch.stack((torch.ones((1, 296)), torch.zeros((1,296))), dim=1)
global_input.shape

torch.Size([1, 2, 296])

In [105]:
residual_out, skip_out = ResidualLayer(2,global_channels=2)(sample_data, global_inputs=global_input)
residual_out.shape, skip_out.shape

296 300


(torch.Size([1, 32, 296]), torch.Size([1, 512, 296]))

To make the global conditioning input work we have to calculate the size of tensor before inputting it into the residual layers and same with the local conditioning

In [106]:
res_layer = ResidualLayer(2,global_channels=2, local_channels=2)
res_layer(sample_data, global_inputs=global_input, local_inputs=global_input)

296 300


(tensor([[[ 0.6716, -0.0872,  0.5723,  ..., -0.8595, -1.0565, -0.7606],
          [ 0.8214, -1.6721, -0.0699,  ..., -1.1350, -0.5650,  0.7269],
          [ 1.3577,  0.4841, -0.4716,  ...,  2.6977,  1.1704, -0.9785],
          ...,
          [-0.0090,  1.2861,  2.2006,  ..., -0.7829,  1.3235,  2.0160],
          [-1.0036, -1.4957,  2.0787,  ...,  1.0450,  1.5836,  0.7142],
          [-0.3482, -0.8100, -0.9892,  ...,  0.6066, -1.4210,  0.2603]]],
        grad_fn=<AddBackward0>),
 tensor([[[ 0.0880, -0.1583,  0.1551,  ..., -0.1778, -0.1762, -0.1299],
          [ 0.1606,  0.3654,  0.2770,  ...,  0.5431, -0.0458,  0.4018],
          [ 0.0241, -0.0170,  0.0065,  ..., -0.0360, -0.0750, -0.2279],
          ...,
          [ 0.2181,  0.1260,  0.1515,  ...,  0.2540,  0.1055,  0.2321],
          [-0.0189,  0.0228, -0.0284,  ...,  0.1304, -0.1437, -0.0261],
          [ 0.0742,  0.1859,  0.1422,  ...,  0.2148,  0.0565,  0.2827]]],
        grad_fn=<ConvolutionBackward0>))