In [1]:
import torch.nn as nn

In [None]:
class CRNN(nn.Module):
    def __init__(self, img_channel, img_height, img_width, leaky_relu = True):
        super(CRNN, self).__init__()
        self.cnn, dimension = self._create_cnn(img_channel, img_height, img_width, leaky_relu)
    
    def _create_cnn(self, img_channel, img_height, img_width, leaky_relu):
        assert img_height % 16 == 0
        assert img_width % 4 == 0

        # m means the mode: 0 is convolution, 1 is max pooling
        # k means kernel_size
        # s means stride
        # p means padding
        cfgs = [
            #m   #k,     #s    #p    #c    #bn
            [0,  (3, 3),  1,   1,    64,  False],
            [1,  (2, 2),  2,   None, None, False],
            [0,  (3, 3),  1,   1,    128, False],
            [1,  (2, 2),  2,   None, None, False],
            [0,  (3, 3),  1,   1,    256,  False],
            [0,  (3, 3),  1,   1,    256,  False],
            [1,  (2, 1),  2,   None,  None, False],
            [0,  (3, 3),  1,   1,    512,   True],
            [0,  (3, 3),  1,   1,    512,   True],
            [1,  (2, 1),  2,   None,  None, False],
            [0,  (2, 2),  1,   1,    512,   True],
        ]

        cnn = []
        input_channels = img_channel
        output_channels = None
        
        for m, k, s, p, c, bn in cfgs:
            if m == 0: # Convolution 
                output_channels = c
                
                cnn.append(nn.Conv2d(input_channels, output_channels, 
                                     kernel_size = k, stride = s, 
                                     padding = p)
                relu = nn.LeakyReLU(0.2, inplace = True) if leaky_relu == True \
                            else nn.ReLU(inplace = True)
                cnn.append(relu)
                
                if bn == True:
                cnn.append(nn.BatchNorm(output_channels)
                           
                input_channels = output_channels
                
            elif m == 1:
                cnn.append(nn.MaxPool2d(kernel_size = k, stride = s))

        cnn_module = nn.Sequential(*cnn)
        # The output height and width of an image after passing through CNN
        output_height = img_height // 16 - 1 
        output_width = img_width // 4 - 1
        
        return cnn_module, (output_channel, output_height, output_width)


        