In [5]:
import torch
from torch.autograd import Variable
import numpy as np
import torch.nn as nn

# from ..wavenet_modules import constant_pad_1d

## Helpers

In [6]:
class ConstantPad1d(nn.Module):
    def __init__(self, target_size, dimension=0, value=0, pad_start=False):
        super(ConstantPad1d, self).__init__()
        self.target_size = target_size
        self.dimension = dimension
        self.value = value
        self.pad_start = pad_start

    def forward(self, input):
        self.num_pad = self.target_size - input.size(self.dimension)
        assert self.num_pad >= 0, 'target size has to be greater than input size'

        self.input_size = input.size()

        size = list(input.size())
        size[self.dimension] = self.target_size
        output = input.new(*tuple(size)).fill_(self.value)
        c_output = output

        # crop output
        if self.pad_start:
            c_output = c_output.narrow(self.dimension, self.num_pad, c_output.size(self.dimension) - self.num_pad)
        else:
            c_output = c_output.narrow(self.dimension, 0, c_output.size(self.dimension) - self.num_pad)

        c_output.copy_(input)
        return output

    def backward(self, grad_output):
        grad_input = grad_output.new(*self.input_size).zero_()
        cg_output = grad_output

        # crop grad_output
        if self.pad_start:
            cg_output = cg_output.narrow(self.dimension, self.num_pad, cg_output.size(self.dimension) - self.num_pad)
        else:
            cg_output = cg_output.narrow(self.dimension, 0, cg_output.size(self.dimension) - self.num_pad)

        grad_input.copy_(cg_output)
        return grad_input


def constant_pad_1d(input,
                    target_size,
                    dimension=0,
                    value=0,
                    pad_start=False):
    padding_cls = ConstantPad1d(target_size, dimension, value, pad_start)
    return padding_cls(input)

## Example

In [8]:
t = Variable(torch.linspace(0, 23, steps=24).view(1, 3, 8))
print(t)

tensor([[[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11., 12., 13., 14., 15.],
         [16., 17., 18., 19., 20., 21., 22., 23.]]])


In [11]:
t.size()

torch.Size([1, 3, 8])

In [10]:
def dilate(x, dilation):
    [n, c, l] = x.size()
    dilation_factor = dilation / n
    print(f'dilation factor {dilation_factor}')
    if dilation == n:
        return x

    # zero padding for reshaping
    new_l = int(np.ceil(l / dilation_factor) * dilation_factor)
    if new_l != l:
        l = new_l
        x = constant_pad_1d(x, new_l, dimension=2, pad_start=True)

    # reshape according to dilation
    x = x.permute(1, 2, 0).contiguous()
    print("first transpose: ", x)
    
    l = (l * n) // dilation
    n = dilation
    
    x = x.view(c, l, n)
    print("view change: ", x)
    
    x = x.permute(2, 0, 1)
    #x = x.transpose(1, 2).transpose(0, 2).contiguous()
    print("second transpose: ", x)

    return x

r = dilate(t, 2)
print(r)

dilation factor 2.0
first transpose:  tensor([[[ 0.],
         [ 1.],
         [ 2.],
         [ 3.],
         [ 4.],
         [ 5.],
         [ 6.],
         [ 7.]],

        [[ 8.],
         [ 9.],
         [10.],
         [11.],
         [12.],
         [13.],
         [14.],
         [15.]],

        [[16.],
         [17.],
         [18.],
         [19.],
         [20.],
         [21.],
         [22.],
         [23.]]])
view change:  tensor([[[ 0.,  1.],
         [ 2.,  3.],
         [ 4.,  5.],
         [ 6.,  7.]],

        [[ 8.,  9.],
         [10., 11.],
         [12., 13.],
         [14., 15.]],

        [[16., 17.],
         [18., 19.],
         [20., 21.],
         [22., 23.]]])
second transpose:  tensor([[[ 0.,  2.,  4.,  6.],
         [ 8., 10., 12., 14.],
         [16., 18., 20., 22.]],

        [[ 1.,  3.,  5.,  7.],
         [ 9., 11., 13., 15.],
         [17., 19., 21., 23.]]])
tensor([[[ 0.,  2.,  4.,  6.],
         [ 8., 10., 12., 14.],
         [16., 18., 20., 22.]

In [49]:
r2 = dilate(r, 4)
print(r2)

first transpose:  Variable containing:
(0 ,.,.) = 
   0   1
   2   3
   4   5
   6   7

(1 ,.,.) = 
   8   9
  10  11
  12  13
  14  15

(2 ,.,.) = 
  16  17
  18  19
  20  21
  22  23
[torch.FloatTensor of size 3x4x2]

view change:  Variable containing:
(0 ,.,.) = 
   0   1   2   3
   4   5   6   7

(1 ,.,.) = 
   8   9  10  11
  12  13  14  15

(2 ,.,.) = 
  16  17  18  19
  20  21  22  23
[torch.FloatTensor of size 3x2x4]

second transpose:  Variable containing:
(0 ,.,.) = 
   0   4
   8  12
  16  20

(1 ,.,.) = 
   1   5
   9  13
  17  21

(2 ,.,.) = 
   2   6
  10  14
  18  22

(3 ,.,.) = 
   3   7
  11  15
  19  23
[torch.FloatTensor of size 4x3x2]

Variable containing:
(0 ,.,.) = 
   0   4
   8  12
  16  20

(1 ,.,.) = 
   1   5
   9  13
  17  21

(2 ,.,.) = 
   2   6
  10  14
  18  22

(3 ,.,.) = 
   3   7
  11  15
  19  23
[torch.FloatTensor of size 4x3x2]

