In [1]:
import torch
import torch.nn as nn

from torchparse import parse_cfg

In [2]:
# Easily define CRNN
cfg_model = """
[input]
    shape = [2, 200, 400]
    
[convs_module]
    REPEATx3
        [conv2d]
            out_channels = 32
            kernel_size = 3
            stride = 1
            padding = valid
        [batchnorm2d]
        [elu]
        [maxpool2d]
            kernel_size = 4
            stride = 4
        [dropout]
            p = 0.1
    END

[moddims]
    permute = [2,1,0]
    collapse = [1,2]

[recur_module]
    [lstm]
        hidden_size = 64
        num_layers = 2

[moddims]
    permute = [1]

[dense_module]
    [dropout]
        p = 0.3
    [batchnorm1d]
    [linear]
        out_features = 10
"""

In [3]:
# Get defined model
parse_cfg(cfg_model)

  return (spatial + p2 - k)//s + 1


ModuleDict(
  (convs): Sequential(
    (conv2d_0): Conv2d(2, 32, kernel_size=(3, 3), stride=(1, 1))
    (batchnorm2d_0): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (elu_0): ELU(alpha=1.0)
    (maxpool2d_0): MaxPool2d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
    (dropout_0): Dropout(p=0.1, inplace=False)
    (conv2d_1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
    (batchnorm2d_1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (elu_1): ELU(alpha=1.0)
    (maxpool2d_1): MaxPool2d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
    (dropout_1): Dropout(p=0.1, inplace=False)
    (conv2d_2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
    (batchnorm2d_2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (elu_2): ELU(alpha=1.0)
    (maxpool2d_2): MaxPool2d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
    (

In [4]:
#Use it in a custom nn.Module
class MyNet(nn.Module):
    def __init__(self, cfg_model):
        super(MyNet, self).__init__()
        self.model = parse_cfg(cfg_model)

    def forward(self, x):
        x = self.model['convs'](x)
        
        # transpose channel and time
        x = x.transpose(1,-1)
        
        # collapse height and channel
        x = x.reshape(*x.shape[:2], -1)
        x = self.model['recur'](x)[0]
        
        # many-to-one rnn
        x = x[:,-1]
        
        x = self.model['dense'](x)
        return x

In [7]:
m = MyNet(cfg_model)
m(torch.randn(16, 2, 200, 400)).shape

torch.Size([16, 10])