In [1]:
import torch.nn as nn
import torch 
import torch.nn.functional as F

In [2]:
class CRNN(nn.Module):
    def __init__(self, img_channel, img_height, img_width, num_classes,
                 leaky_relu = True, map_to_sequence = 64, lstm_hidden = 256):
        super(CRNN, self).__init__()
        self.cnn, dimension = self._create_cnn(img_channel, img_height, img_width, leaky_relu)
        output_channel, output_height, output_width = dimension
        
        self.map_2_sequence = nn.Linear(output_channel * output_height, map_to_sequence)
        self.lstm1 = nn.LSTM(map_to_sequence, lstm_hidden, 
                             bidirectional = True, batch_first = True)
        self.lstm2 = nn.LSTM(2 * lstm_hidden, lstm_hidden,
                             bidirectional = True, batch_first = True)

        self.dense = nn.Linear(2 * lstm_hidden, num_classes)
        
    def _create_cnn(self, img_channel, img_height, img_width, leaky_relu):
        assert img_height % 16 == 0
        assert img_width % 4 == 0

        # m means the mode: 0 is convolution, 1 is max pooling
        # k means kernel_size
        # s means stride
        # p means padding
        cfgs = [
            #m   #k,     #s    #p    #c    #bn
            [0,  (3, 3),  1,   1,    64,  False],
            [1,  (2, 2),  2,   None, None, False],
            [0,  (3, 3),  1,   1,    128, False],
            [1,  (2, 2),  2,   None, None, False],
            [0,  (3, 3),  1,   1,    256,  False],
            [0,  (3, 3),  1,   1,    256,  False],
            [1,  (2, 1),  2,   None,  None, False],
            [0,  (3, 3),  1,   1,    512,   True],
            [0,  (3, 3),  1,   1,    512,   True],
            [1,  (2, 1),  2,   None,  None, False],
            [0,  (2, 2),  1,   1,    512,   True],
        ]

        cnn = []
        input_channels = img_channel
        output_channels = None
        
        for m, k, s, p, c, bn in cfgs:
            if m == 0: # Convolution 
                output_channels = c
                
                cnn.append(nn.Conv2d(input_channels, output_channels, 
                                     kernel_size = k, stride = s, 
                                     padding = p))
                relu = nn.LeakyReLU(0.2, inplace = True) if leaky_relu == True \
                            else nn.ReLU(inplace = True)
                cnn.append(relu)
                
                if bn == True:
                    cnn.append(nn.BatchNorm2d(output_channels))
                           
                input_channels = output_channels
                
            elif m == 1:
                cnn.append(nn.MaxPool2d(kernel_size = k, stride = s))

        cnn_module = nn.Sequential(*cnn)
        # The output height and width of an image after passing through CNN
        output_height = img_height // 16 - 1 
        output_width = img_width // 4 - 1
        
        return cnn_module, (output_channels, output_height, output_width)

    def forward(self, image):
        conv = self.conv(image)
        batch, channel, height, width = conv.size()

        conv = conv.view(batch, channel * height, width)
        conv = conv.permute(0, 2, 1)  # (batch, width, channel * height)
        seq = self.map_2_sequence(conv)

        recurrent, _ = self.lstm1(seq)
        recurrent, _ = self.lstm2(recurrent)

        logits = self.dense(recurrent)

        return F.log_softmax(logits, dim = 2)
        

In [3]:
crnn = CRNN(1, 32, 100, 100)
crnn 

CRNN(
  (cnn): Sequential(
    (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): LeakyReLU(negative_slope=0.2, inplace=True)
    (10): MaxPool2d(kernel_size=(2, 1), stride=2, padding=0, dilation=1, ceil_mode=False)
    (11): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): LeakyReLU(negative_slope=0.2, inplace=True)
    (13): BatchNorm2d(512, eps=1e-05, momentum=0.1, affin

# CTC Decoder

In [4]:
import math 
import numpy as np

class Hypothesis:
    def __init__(self, sequence, log_prob):
        self.sequence = sequence
        self.log_prob = log_prob

In [5]:
def beam_search_decode(probabilities, alphabet, beam_width):
    initial_hypothesis = Hypothesis(sequence = [], log_prob = 0.0)
    beam = [initial_hypothesis]

    for timestep in range(len(probabilities)):
        new_beam = []

        for hypothesis in beam:
            # for label, prob in enumerate(probabilities[timestep]):
            #     extended_sequence = hypothesis.sequence + [alphabet[label]]
            #     log_prob = hypothesis.log_prob + math.log(prob)

            #     new_hypothesis = Hypothesis(sequence = extended_sequence,
            #                                 log_prob = log_prob)
            #     new_beam.append(new_hypothesis)

            for c in range(probabilities.shape[1]):
                extended_sequence = hypothesis.sequence + [alphabet[c]]
                log_prob = hypothesis.log_prob + math.log(probabilities[timestep, c])
                new_hypothesis = Hypothesis(sequence = extended_sequence,
                                            log_prob = log_prob)
                new_beam.append(new_hypothesis)
        # Select top-k hypothesis
        beam = sorted(new_beam, key = lambda x: x.log_prob, reverse = True)[:beam_width]

    # Select the best hypothesis from the final beam
    best_sequence = max(beam, key = lambda x: x.log_prob)

    # Select the best hypothesis 
    return best_sequence
    

alphabet = ['a', 'b', 'c']
probabilities = np.array([
    [0.2, 0.7, 0.1],  # Timestep 1
    [0.3, 0.4, 0.3],  # Timestep 2
    [0.1, 0.2, 0.7]   # Timestep 3
])

decoded_sequence = beam_search_decode(probabilities, alphabet, beam_width = 2)
print("Decoded sequence: ", ' '.join(decoded_sequence.sequence))
print(decoded_sequence.log_prob)

Decoded sequence:  b b c
-1.62964061975162


In [6]:
math.log(0.7) + math.log(0.4) + math.log(0.7)

-1.62964061975162

In [7]:
np.argmax(probabilities, axis = -1)

array([1, 1, 2])

# CTC Loss

In [8]:
# Target are to be padded

T = 50  # Input Sequence Length
C = 20  # Number of classes (including blank)
N = 16  # Batch size
S = 30  # Target sequence Length

S_min = 10

# Input has size (50, 16, 20)
input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()

# Target has size (16, 30)
target = torch.randint(low = 1, high = C, size = (N, S), dtype = torch.long)

# Input lengths has size (16)
input_lengths = torch.full(size = (N,), fill_value = T, dtype = torch.long)

# Target lenths has size (16)
target_lengths = torch.randint(low = S_min, high = S, size = (N, ),
                                dtype = torch.long)

ctc_loss = nn.CTCLoss()
loss = ctc_loss(input, target, input_lengths, target_lengths)
loss.backward()

In [9]:
input.size()

torch.Size([50, 16, 20])

In [10]:
input_lengths.size()

torch.Size([16])

In [11]:
target.size()

torch.Size([16, 30])

In [12]:
loss

tensor(6.5691, grad_fn=<MeanBackward0>)

In [13]:
# Target are to be un-padded
T = 50
C = 20
N = 16

input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()
input_lengths = torch.full(size = (N, ), fill_value = T, dtype = torch.long)

target_lengths = torch.randint(low = 1, high = T, size = (N, ),
                                dtype = torch.long)
target = torch.randint(low = 1, high = C, size = (sum(target_lengths),), dtype = torch.long)

target_lengths

ctc_loss = nn.CTCLoss(reduction = 'sum')
loss = ctc_loss(input, target, input_lengths, target_lengths)
loss

tensor(2087.9353, grad_fn=<SumBackward0>)

In [14]:
input.shape

torch.Size([50, 16, 20])

In [15]:
target.size()

torch.Size([419])

In [16]:
target_lengths.size()

torch.Size([16])

In [17]:
sum(target_lengths)

tensor(419)

In [18]:
input_lengths

tensor([50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50])