In [1]:
import torch
import torch.nn as nn
import numpy as np

## basic layers
1. Brelu which is relu with a cutoff. In paddlepaddle0.10, the defaut threshould is 24

In [2]:
class BReLU(nn.Hardtanh):
    r"""Applies the element-wise function:

    .. math::
        \text{ReLU6}(x) = \min(\max(0,x), cutoff)

    Args:
        inplace: can optionally do the operation in-place. Default: ``False``

    Shape:
        - Input: :math:`(N, *)` where `*` means, any number of additional
          dimensions
        - Output: :math:`(N, *)`, same shape as the input

    .. image:: scripts/activation_images/ReLU6.png

    Examples::

        >>> m = nn.ReLU6()
        >>> input = torch.randn(2)
        >>> output = m(input)
    """

    def __init__(self, cutoff=24., inplace=False):
        super(BReLU, self).__init__(0., cutoff, inplace)

    def extra_repr(self):
        inplace_str = 'inplace=True' if self.inplace else ''
        return inplace_str

2. mask

In [3]:
class Mask(nn.Module):
    def __init__(self):
        super(Mask, self).__init__()
        
    def forward(self,x, length):
        mask = torch.zeros_like(x, dtype=torch.float32)
        for index, length in enumerate(length):
            mask[index, :, :length, :] = 1
        return x * mask

In [4]:
a = torch.ones(5, 10)
b = torch.ones(5, 10)
c = torch.ones(4, 10)
d = torch.ones(4, 10)
sequences = nn.utils.rnn.pad_sequence([a, b, c, d], padding_value=0.4, batch_first=True)
sequences = sequences.unsqueeze(1) # change it into image format
sequence_lengths = torch.Tensor([5,5 ,4,4]).type(torch.LongTensor)
mask = Mask()
masked_seq = mask(sequences, sequence_lengths)

### Conv+bn+mask

In [5]:
class conv_bn_mask(nn.Module):
    def __init__(self, ichannel, ochannel, kernel_size, padding, stride, bias=False, track_running_stats=False):
        super(conv_bn_mask, self).__init__()
        self.conv = nn.Conv2d(ichannel, ochannel, kernel_size=kernel_size, padding=padding, stride=stride, bias=bias)
        self.bn = nn.BatchNorm2d(ochannel, track_running_stats=track_running_stats)
        self.activation = BReLU(cutoff=24)
        self.mask = Mask()
    def forward(self, x, length):
        x = self.conv(x)
        x = self.bn(x)
        x = self.activation(x)
        x = self.mask(x, length)
        return x

### image flatten and sequence

### flatten
image.view()

In [6]:
masked_seq.shape

torch.Size([4, 1, 5, 10])

In [7]:
# a = torch.ones(5, 10)
# b = torch.ones(4, 10)
# c = torch.ones(4, 10)
# sequences = nn.utils.rnn.pad_sequence([a, b, c], padding_value=0.4, batch_first=True)

flattened_seq = masked_seq.view(4, -1, 10)
flattened_seq.shape

torch.Size([4, 5, 10])

### sequence 
nn.utils.rnn.pack_padded_sequence

In [8]:
seqs = nn.utils.rnn.pack_padded_sequence(flattened_seq, sequence_lengths, batch_first=True)
seqs

PackedSequence(data=tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]), batch_sizes=tensor([4, 4, 4, 4, 2]), sorted_indices=None, unsorted_indices=Non

### bidirectional GRU

In [9]:
from custom_gru import BidirRNNLayer, GRUCell

In [10]:
bigru = BidirRNNLayer(GRUCell, input_size=10, hidden_size=20, gate_act="relu", state_act="tanh")

In [11]:
bigru_result, _ = bigru.forward(seqs, torch.zeros((2, 4, 20)))

In [12]:
bigru_result_reshaped = nn.utils.rnn.pad_packed_sequence(bigru_result, batch_first=True)

In [13]:
bigru_result.batch_sizes

tensor([4, 4, 4, 4, 2])

### last fc 
2048 * 29

In [14]:
bottleneck_data, batch_sizes, _, _ = bigru_result
bottleneck = nn.Linear(40, 28)
bottleneck_result = bottleneck(bottleneck_data)

### output layer
softmax

In [15]:
softmax = nn.LogSoftmax(dim=0)
prob = softmax(bottleneck_result)

In [24]:
output = nn.utils.rnn.PackedSequence(prob, batch_sizes)
print(output)
output = nn.utils.rnn.pad_packed_sequence(output)
print(output)

PackedSequence(data=tensor([[-4.1676e+09, -5.0515e+07, -4.4991e+10, -1.5702e+10, -6.9315e-01,
         -6.6085e+09, -7.7123e+07, -2.4533e+10, -5.3310e+10, -2.1041e+10,
         -1.4993e+08, -6.9315e-01, -1.2671e+08, -2.8169e+10, -6.9315e-01,
         -1.0366e+08, -6.9315e-01, -6.9315e-01, -6.9315e-01, -2.5552e+10,
         -4.9122e+07, -6.9315e-01, -1.3816e+10, -4.5174e+10, -6.9315e-01,
         -1.4428e+10, -6.9315e-01, -1.6474e+10],
        [-4.1676e+09, -5.0515e+07, -4.4991e+10, -1.5702e+10, -6.9315e-01,
         -6.6085e+09, -7.7123e+07, -2.4533e+10, -5.3310e+10, -2.1041e+10,
         -1.4993e+08, -6.9315e-01, -1.2671e+08, -2.8169e+10, -6.9315e-01,
         -1.0366e+08, -6.9315e-01, -6.9315e-01, -6.9315e-01, -2.5552e+10,
         -4.9122e+07, -6.9315e-01, -1.3816e+10, -4.5174e+10, -6.9315e-01,
         -1.4428e+10, -6.9315e-01, -1.6474e+10],
        [-4.0711e+09, -8.2979e-01, -4.5112e+10, -1.5653e+10, -6.2133e+07,
         -6.4808e+09, -1.7119e+00, -2.4576e+10, -5.2996e+10, -2.1001

In [17]:
log_likelyhood, seq_lengths = output

### CTC loss

In [18]:
input_lengths = seq_lengths
target_lengths = seq_lengths

In [19]:
target = torch.randint(low=1, high=log_likelyhood.shape[2], size=(4, 5), dtype=torch.long)

In [20]:
ctc_loss = nn.CTCLoss()

In [21]:
loss = ctc_loss(log_likelyhood, target, input_lengths, target_lengths)

In [22]:
loss.backward()

In [23]:
b

tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])