In [7]:
#config.py
common_config = {
    'data_dir': 'data/mnt/ramdisk/max/90kDICT32px/',
    'img_width': 100,
    'img_height': 32,
    'map_to_seq_hidden': 64,
    'rnn_hidden': 256,
    'leaky_relu': False,
}

train_config = {
    'epochs': 10000,
    'train_batch_size': 32,
    'eval_batch_size': 512,
    'lr': 0.0005,
    'show_interval': 10,
    'valid_interval': 500,
    'save_interval': 2000,
    'cpu_workers': 4,
    'reload_checkpoint': None,
    'valid_max_iter': 100,
    'decode_method': 'greedy',
    'beam_size': 10,
    'checkpoints_dir': 'checkpoints/'
}
train_config.update(common_config)

evaluate_config = {
    'eval_batch_size': 512,
    'cpu_workers': 4,
    'reload_checkpoint': 'checkpoints/crnn_synth90k.pt',
    'decode_method': 'beam_search',
    'beam_size': 10,
}
evaluate_config.update(common_config)

In [1]:
#model.py
import torch.nn as nn

In [12]:
class CRNN(nn.Module):
    
    def __init__(self, channels, height, width, num_class,
                 map_to_seq_hidden=64, rnn_hidden=256, use_leaky_relu=False):
        
        super(CRNN, self).__init__()
        
        self.cnn, (output_channels, output_height, output_width) = \
            self.cnn_backbone(channels, height, width, use_leaky_relu)
        
        self.map_to_sequential = nn.Linear(output_channel *  output_height, map_to_seq_hidden)
        
        self.rnn1 = nn.LSTM(cnn_to_rnn_hidden, rnn_hidden, bidirectional=True)
        self.rnn2 = nn.LSTM(2 * rnn_hidden, rnn_hidden, bidirectional=True)
        
        self.dense = nn.Linear(2 * rnn_hidden, num_class)
    
    def cnn_backbone(self, channels, height, width, use_leaky_relu):
        channels = [channels, 64, 128, 256, 256, 512, 512, 512]
        kernels = [3, 3, 3, 3, 3, 3, 2]
        strides = [1, 1, 1, 1, 1, 1, 1]
        paddings = [1, 1, 1, 1, 1, 1, 0]
        
        cnn = nn.Sequential()
        
        def convolution_relu(i, batch_norm=False):
            # input shape: (batch size, input_channels, height, width)
            input_channels = channels[i]
            output_channels = channels[i + 1]
            
            cnn.add_module('conv-{}'.format(i), nn.Conv2d(input_channels, output_channels, kernels[i], strides[i], paddings[i]))
            
            if batch_norm:
                cnn.add_module('batchnorm-{}'.format(i), nn.BatchNorm2d(output_channel))
            
            if use_leaky_relu:
                relu = nn.LeakyReLU(0.2, inplace = True)
            else:
                relu = nn.ReLU(inplace = True)
                
            cnn.add_module('relu-{}'.format(i))
            
        
        # size of image: (channels, height, width)
        
        convolution_relu(0)
        cnn.add_module('maxpool-0', nn.MaxPool2d(kernel_size = 2, stride = 2))
        # (64, height // 2, width // 2)
        
        convolution_relu(1)
        cnn.add_module('maxpool-1', nn.MaxPool2d(kernel_size = 2, stride = 2))   
        # (128, height // 4, width // 4)
        
        convolution_relu(2)
        convolution_relu(3)
        cnn.add_module('maxpool-2', nn.MaxPool2d(kernel_size = (2,1)))
        # (256, height // 8, width // 4)
        
        convolution.relu(4, batch_norm=True)
        convolution.relu(5, batch_norm=True)
        cnn.add_module('maxpool-3', nn.MaxPool2d(kernel_size = (2,1)))
        # (512, height // 16, width // 4)
        
        convolution_relu(6)
        # (512, height // 16 - 1, width // 4 - 1)
        
        output_channels, output_height, output_width = channels[-1], height // 16 - 1, width // 4 - 1
        return cnn, (output_channels, output_height, output_width)
    
    def forward(self, images):
        # shape of images: (batch_size, channels, height, width)
        
        convolution = self.cnn(images)
        batch_size, channels, height, width = convolution.size()
        
        convolution = convolution.view(batch, channel * height, width)
        convolution = convolution.permute(2, 0, 1) # (width, batch_size, features)
        
        sequential = self.map_to_sequential(convolution)
        
        recurrent, _ = self.rnn1(sequential)
        recurrent, _ = self.rnn2(recurrent)
        
        output = self.dense(recurrent)
        return output # shape: (sequential_length, batch_size, num_class)
        
        

In [13]:
#train.py
import os

import cv2
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.nn import CTCLoss

#from config import train_config as config
#from dataset import Synth90kDataset, synth90k_collate_fn
#from model import CRNN
#from evaluate import evaluate

In [14]:
def train_batch(crnn, data, optimizer, criterion, device):
    crnn.train()
    images, labels, label_lengths = [d.to(device) for d in data]
    
    logits = crnn(images)
    log_probs = torch.nn.functional.log_softmax(logits, dim=2)
    
    batch_size = images.size(0)
    input_lenghts = torch.LongTensor([logits.size(0)] * batch_size)
    label_lengths = torch.flatten(label_lenghts)
    
    loss = criterion(log_probs, targets, input_lengths, label_lengths)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss.item()

def main():
    epochs = config['epochs']
    train_batch_size = config['train_batch_size']
    eval_batch_size = config['eval_batch_size']
    lr = config['lr']
    show_interval = config['show_interval']
    valid_interval = config['valid_interval']
    save_interval = config['save_interval']
    cpu_workers = config['cpu_workers']
    reload_checkpoint = config['reload_checkpoint']
    valid_max_iter = config['valid_max_iter']

    img_width = config['img_width']
    img_height = config['img_height']
    data_dir = config['data_dir']
    
    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
        
    train_dataset = Synth90kDataset(root_dir=data_dir, mode='train', 
                                     img_height=img_height, img_width=img_width)
    
    valid_dataset = Synth90kDataset(root_dir=data_dir, mode='dev', 
                                     img_height=img_height, img_width=img_width)
    
    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=train_batch_size,
        shuffle=True,
        num_workers=cpu_workers,
        collate_fn=synth90k_collate_fn)
    
    valid_loader = DataLoader(
        dataset=valid_dataset,
        batch_size=eval_batch_size,
        shuffle=True,
        num_workers=cpu_workers,
        collate_fn=synth90k_collate_fn)

    num_class = len(Synth90kDataset.LABEL2CHAR) + 1
    crnn = CRNN(1, img_height, img_width, num_class,
                map_to_seq_hidden=config['map_to_seq_hidden'],
                rnn_hidden=config['rnn_hidden'],
                leaky_relu=config['leaky_relu'])
    if reload_checkpoint:
        crnn.load_state_dict(torch.load(reload_checkpoint, map_location=device))
    crnn.to(device)

    optimizer = optim.Adam(crnn.parameters(), lr=lr)
    criterion = CTCLoss(reduction='sum')
    criterion.to(device)
    
    i = 1
    for epoch in range(1, epochs + 1):
        print(f'epoch: {epoch}')
        tot_train_loss = 0.
        tot_train_count = 0
        for train_data in train_loader:
            loss = train_batch(crnn, train_data, optimizer, criterion, device)
            train_size = train_data[0].size(0)

            tot_train_loss += loss
            tot_train_count += train_size
            if i % show_interval == 0:
                print('train_batch_loss[', i, ']: ', loss / train_size)

            if i % valid_interval == 0:
                evaluation = evaluate(crnn, valid_loader, criterion,
                                      decode_method=config['decode_method'],
                                      beam_size=config['beam_size'])
                print('valid_evaluation: loss={loss}, acc={acc}'.format(**evaluation))

                if i % save_interval == 0:
                    prefix = 'crnn'
                    loss = evaluation['loss']
                    save_model_path = os.path.join(config['checkpoints_dir'],
                                                   f'{prefix}_{i:06}_loss{loss}.pt')
                    torch.save(crnn.state_dict(), save_model_path)
                    print('save model at ', save_model_path)

            i += 1

        print('train_loss: ', tot_train_loss / tot_train_count)


if __name__ == '__main__':
    main()

NameError: name 'config' is not defined