In [1]:
import sys
sys.path.append('..')

import torch
import torch.nn as nn

In [2]:
from model import BatchNormConv1d, CBHG, Prenet, Highway, Encoder

In [3]:
layer = BatchNormConv1d(256, 128, kernel_size=3, stride=1, padding=[1, 1], activation=nn.ReLU())

In [4]:
layer

BatchNormConv1d(
  (padder): ConstantPad1d(padding=[1, 1], value=0)
  (conv1d): Conv1d(256, 128, kernel_size=(3,), stride=(1,), bias=False)
  (bn): BatchNorm1d(128, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)
  (activation): ReLU()
)

In [5]:
input = torch.randn(32, 256, 64)  # (batch, input_size, seq_len)
output = layer(input)

In [6]:
output.shape

torch.Size([32, 128, 64])

In [7]:
x = layer.padder(input)
x.shape

torch.Size([32, 256, 66])

In [8]:
x = layer.conv1d(x)
x.shape

torch.Size([32, 128, 64])

In [9]:
x = layer.bn(x)
x.shape

torch.Size([32, 128, 64])

In [10]:
for k in range(1, 17):
    print(k, [(k - 1) // 2, k // 2])

1 [0, 0]
2 [0, 1]
3 [1, 1]
4 [1, 2]
5 [2, 2]
6 [2, 3]
7 [3, 3]
8 [3, 4]
9 [4, 4]
10 [4, 5]
11 [5, 5]
12 [5, 6]
13 [6, 6]
14 [6, 7]
15 [7, 7]
16 [7, 8]


In [11]:
pre = Prenet(128, out_features=[256, 256])
pre

Prenet(
  (layers): ModuleList(
    (0): Linear(
      (linear_layer): Linear(in_features=128, out_features=256, bias=True)
    )
    (1): Linear(
      (linear_layer): Linear(in_features=256, out_features=256, bias=True)
    )
  )
)

In [12]:
cbhg = CBHG(256)
cbhg

CBHG(
  (relu): ReLU()
  (conv1d_banks): ModuleList(
    (0): BatchNormConv1d(
      (padder): ConstantPad1d(padding=[0, 0], value=0)
      (conv1d): Conv1d(256, 128, kernel_size=(1,), stride=(1,), bias=False)
      (bn): BatchNorm1d(128, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)
      (activation): ReLU()
    )
    (1): BatchNormConv1d(
      (padder): ConstantPad1d(padding=[0, 1], value=0)
      (conv1d): Conv1d(256, 128, kernel_size=(2,), stride=(1,), bias=False)
      (bn): BatchNorm1d(128, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)
      (activation): ReLU()
    )
    (2): BatchNormConv1d(
      (padder): ConstantPad1d(padding=[1, 1], value=0)
      (conv1d): Conv1d(256, 128, kernel_size=(3,), stride=(1,), bias=False)
      (bn): BatchNorm1d(128, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)
      (activation): ReLU()
    )
    (3): BatchNormConv1d(
      (padder): ConstantPad1d(padding=[1, 2], value=0)
      (conv1d)

In [13]:
highway = Highway(128, 128)
highway

Highway(
  (H): Linear(in_features=128, out_features=128, bias=True)
  (T): Linear(in_features=128, out_features=128, bias=True)
  (relu): ReLU()
  (sigmoid): Sigmoid()
)

In [14]:
encoder = Encoder(256)
encoder

Encoder(
  (prenet): Prenet(
    (layers): ModuleList(
      (0): Linear(
        (linear_layer): Linear(in_features=256, out_features=256, bias=True)
      )
      (1): Linear(
        (linear_layer): Linear(in_features=256, out_features=128, bias=True)
      )
    )
  )
  (cbhg): CBHG(
    (relu): ReLU()
    (conv1d_banks): ModuleList(
      (0): BatchNormConv1d(
        (padder): ConstantPad1d(padding=[0, 0], value=0)
        (conv1d): Conv1d(128, 128, kernel_size=(1,), stride=(1,), bias=False)
        (bn): BatchNorm1d(128, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)
        (activation): ReLU()
      )
      (1): BatchNormConv1d(
        (padder): ConstantPad1d(padding=[0, 1], value=0)
        (conv1d): Conv1d(128, 128, kernel_size=(2,), stride=(1,), bias=False)
        (bn): BatchNorm1d(128, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)
        (activation): ReLU()
      )
      (2): BatchNormConv1d(
        (padder): ConstantPad1d(padding=

In [15]:
inputs = torch.rand(32, 71, 256)
inputs.shape

torch.Size([32, 71, 256])

In [17]:
encoder_outputs = encoder(inputs)
encoder_outputs.shape

torch.Size([32, 71, 256])