![](../assets/architecture.png)

In this notebook we will be implementing the part where the skip connection go to or the so called head of the
model there are some consideration I have to make when building this part which are:
+ How will the skip connections be summed up?
    + I will clip the skip connections to be the smallest of the skip connection size

In [1]:
import torch

In [2]:
from model import ResidualLayer

In [3]:
residual_channels = 32
skip_channels = 512
dialation = 2
categories = 256
num_layers = 5

In [4]:
inputs = torch.arange(1, 501, dtype=torch.float).view(1,1,-1)

In [5]:
res_out, skip_out = ResidualLayer(dialation, 1, residual_channels, skip_channels)(inputs)
res_out.shape, skip_out.shape

[W NNPACK.cpp:53] Could not initialize NNPACK! Reason: Unsupported hardware.


(torch.Size([1, 32, 496]), torch.Size([1, 512, 496]))

In [6]:
dilations_per_layer = [2**i for i in range(1, num_layers)]
dilations_per_layer

[2, 4, 8, 16]

In [7]:
x, skip_out = ResidualLayer(dialation, 1, residual_channels, skip_channels)(inputs)
skip_connections = [skip_out]
for dilation in dilations_per_layer:
    x, skip_out = ResidualLayer(dialation, residual_channels, residual_channels, skip_channels)(x)
    skip_connections.append(skip_out)

len(skip_connections)

5

In [8]:
for skip in skip_connections:
    print(skip.shape)

torch.Size([1, 512, 496])
torch.Size([1, 512, 492])
torch.Size([1, 512, 488])
torch.Size([1, 512, 484])
torch.Size([1, 512, 480])


In [9]:
smallest_skip_size = skip_connections[-1].size(2)
smallest_skip_size

480

In [10]:
skip_connections = list(map(lambda skip: skip[:,:,-smallest_skip_size], skip_connections))

In [13]:
stacked_skip_connections = torch.stack(skip_connections)

In [14]:
summed_skip_connections = stacked_skip_connections.sum(dim=0)

In [None]:
import torch.nn as nn
from model import Conv1d1x1

In [25]:
layers = [
    nn.ReLU(),
    Conv1d1x1(skip_channels, skip_channels),
    nn.ReLU(),
    Conv1d1x1(skip_channels, categories),
    nn.Softmax(dim=0)
]

In [36]:
x = summed_skip_connections.view(-1,1)
for layer in layers:
    print(layer)
    print(x.shape)
    x = layer(x)
print(x.shape)

ReLU()
torch.Size([512, 1])
Conv1d1x1(512, 512, kernel_size=(1,), stride=(1,), bias=False)
torch.Size([512, 1])
ReLU()
torch.Size([512, 1])
Conv1d1x1(512, 256, kernel_size=(1,), stride=(1,), bias=False)
torch.Size([512, 1])
Softmax(dim=0)
torch.Size([256, 1])
torch.Size([256, 1])


In [55]:
import torch.nn as nn
from model import Conv1d1x1

class Head(nn.Sequential):
    def __init__(self, in_channels, out_channels):
        super(Head, self).__init__(
            nn.ReLU(),
            Conv1d1x1(in_channels, in_channels),
            nn.ReLU(),
            Conv1d1x1(in_channels, out_channels),
            nn.Softmax(dim=1)
        )
    def forward(self, inputs):
        '''
        Inputs: a torch tensor of stacked skip connections
        '''
        summed_inputs = inputs.sum(dim=0).view(1,-1,1)
        return super(Head, self).forward(summed_inputs)

In [56]:
stacked_skip_connections.shape

torch.Size([5, 1, 512])

In [62]:
Head(skip_channels, categories)(stacked_skip_connections).view(-1).shape

torch.Size([256])