In [7]:
import torch.nn as nn
import torch 
import torch.nn.functional as F

In [12]:
class CRNN(nn.Module):
    def __init__(self, img_channel, img_height, img_width, num_classes,
                 leaky_relu = True, map_to_sequence = 64, lstm_hidden = 256):
        super(CRNN, self).__init__()
        self.cnn, dimension = self._create_cnn(img_channel, img_height, img_width, leaky_relu)
        output_channel, output_height, output_width = dimension
        
        self.map_2_sequence = nn.Linear(output_channel * output_height, map_to_sequence)
        self.lstm1 = nn.LSTM(map_to_sequence, lstm_hidden, 
                             bidirectional = True, batch_first = True)
        self.lstm2 = nn.LSTM(2 * lstm_hidden, lstm_hidden,
                             bidirectional = True, batch_first = True)

        self.dense = nn.Linear(2 * lstm_hidden, num_classes)
        
    def _create_cnn(self, img_channel, img_height, img_width, leaky_relu):
        assert img_height % 16 == 0
        assert img_width % 4 == 0

        # m means the mode: 0 is convolution, 1 is max pooling
        # k means kernel_size
        # s means stride
        # p means padding
        cfgs = [
            #m   #k,     #s    #p    #c    #bn
            [0,  (3, 3),  1,   1,    64,  False],
            [1,  (2, 2),  2,   None, None, False],
            [0,  (3, 3),  1,   1,    128, False],
            [1,  (2, 2),  2,   None, None, False],
            [0,  (3, 3),  1,   1,    256,  False],
            [0,  (3, 3),  1,   1,    256,  False],
            [1,  (2, 1),  2,   None,  None, False],
            [0,  (3, 3),  1,   1,    512,   True],
            [0,  (3, 3),  1,   1,    512,   True],
            [1,  (2, 1),  2,   None,  None, False],
            [0,  (2, 2),  1,   1,    512,   True],
        ]

        cnn = []
        input_channels = img_channel
        output_channels = None
        
        for m, k, s, p, c, bn in cfgs:
            if m == 0: # Convolution 
                output_channels = c
                
                cnn.append(nn.Conv2d(input_channels, output_channels, 
                                     kernel_size = k, stride = s, 
                                     padding = p))
                relu = nn.LeakyReLU(0.2, inplace = True) if leaky_relu == True \
                            else nn.ReLU(inplace = True)
                cnn.append(relu)
                
                if bn == True:
                    cnn.append(nn.BatchNorm2d(output_channels))
                           
                input_channels = output_channels
                
            elif m == 1:
                cnn.append(nn.MaxPool2d(kernel_size = k, stride = s))

        cnn_module = nn.Sequential(*cnn)
        # The output height and width of an image after passing through CNN
        output_height = img_height // 16 - 1 
        output_width = img_width // 4 - 1
        
        return cnn_module, (output_channels, output_height, output_width)

    def forward(self, image):
        conv = self.conv(image)
        batch, channel, height, width = conv.size()

        conv = conv.view(batch, channel * height, width)
        conv = conv.permute(0, 2, 1)  # (batch, width, channel * height)
        seq = self.map_2_sequence(conv)

        recurrent, _ = self.lstm1(seq)
        recurrent, _ = self.lstm2(recurrent)

        logits = self.dense(recurrent)

        return F.log_softmax(logits, dim = 2)
        

In [13]:
crnn = CRNN(1, 32, 100, 100)
crnn 

CRNN(
  (cnn): Sequential(
    (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): LeakyReLU(negative_slope=0.2, inplace=True)
    (10): MaxPool2d(kernel_size=(2, 1), stride=2, padding=0, dilation=1, ceil_mode=False)
    (11): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): LeakyReLU(negative_slope=0.2, inplace=True)
    (13): BatchNorm2d(512, eps=1e-05, momentum=0.1, affin

# CTC Decoder