In [1]:
#config.py
common_config = {
    'data_dir': 'E:/indirilenler/mjsynth/mnt/ramdisk/max/90kDICT32px/',
    'img_width': 100,
    'img_height': 32,
    'map_to_seq_hidden': 64,
    'rnn_hidden': 256,
    'use_leaky_relu': False,
}

train_config = {
    'epochs': 1000,
    'train_batch_size': 32,
    'eval_batch_size': 512,
    'lr': 0.001,
    'show_interval': 10,
    'valid_interval': 10,
    'save_interval': 2000,
    'cpu_workers': 4,
    'reload_checkpoint': None,
    'valid_max_iter': 100,
    'decode_method': 'greedy',
    'beam_size': 10,
    'checkpoints_dir': 'checkpoints/'
}
train_config.update(common_config)

eval_config = {
    'eval_batch_size': 512,
    'cpu_workers': 4,
    'reload_checkpoint': 'checkpoints/crnn_ocr.pt',
    'decode_method': 'beam_search',
    'beam_size': 10,
}
eval_config.update(common_config)

In [2]:
#model.py
import torch.nn as nn

class CRNN(nn.Module):
    
    def __init__(self, channels, height, width, num_class,
                 map_to_seq_hidden=64, rnn_hidden=256, use_leaky_relu=False):
        
        super(CRNN, self).__init__()
        
        self.cnn, (output_channels, output_height, output_width) = \
            self.cnn_backbone(channels, height, width, use_leaky_relu)
        
        self.map_to_sequential = nn.Linear(output_channels *  output_height, map_to_seq_hidden)
        
        self.rnn1 = nn.LSTM(map_to_seq_hidden, rnn_hidden, bidirectional=True)
        self.rnn2 = nn.LSTM(2 * rnn_hidden, rnn_hidden, bidirectional=True)
        
        self.dense = nn.Linear(2 * rnn_hidden, num_class)
    
    def cnn_backbone(self, channels, height, width, use_leaky_relu):
        channels = [channels, 64, 128, 256, 256, 512, 512, 512]
        kernels = [3, 3, 3, 3, 3, 3, 2]
        strides = [1, 1, 1, 1, 1, 1, 1]
        paddings = [1, 1, 1, 1, 1, 1, 0]
        
        cnn = nn.Sequential()
        
        def convolution_relu(i, batch_norm=False):
            # input shape: (batch size, input_channels, height, width)
            input_channels = channels[i]
            output_channels = channels[i + 1]
            
            cnn.add_module('conv-{}'.format(i), nn.Conv2d(input_channels, output_channels, kernels[i], strides[i], paddings[i]))
            
            if batch_norm:
                cnn.add_module('batchnorm-{}'.format(i), nn.BatchNorm2d(output_channels))
            
            if use_leaky_relu:
                relu = nn.LeakyReLU(0.2, inplace = True)
            else:
                relu = nn.ReLU(inplace = True)
                
            cnn.add_module('relu-{}'.format(i), relu)
            
        
        # size of image: (channels, height, width)
        
        convolution_relu(0)
        cnn.add_module('maxpool-0', nn.MaxPool2d(kernel_size = 2, stride = 2))
        # (64, height // 2, width // 2)
        
        convolution_relu(1)
        cnn.add_module('maxpool-1', nn.MaxPool2d(kernel_size = 2, stride = 2))   
        # (128, height // 4, width // 4)
        
        convolution_relu(2)
        convolution_relu(3)
        cnn.add_module('maxpool-2', nn.MaxPool2d(kernel_size = (2,1)))
        # (256, height // 8, width // 4)
        
        convolution_relu(4, batch_norm=True)
        convolution_relu(5, batch_norm=True)
        cnn.add_module('maxpool-3', nn.MaxPool2d(kernel_size = (2,1)))
        # (512, height // 16, width // 4)
        
        convolution_relu(6)
        # (512, height // 16 - 1, width // 4 - 1)
        
        output_channels, output_height, output_width = channels[-1], height // 16 - 1, width // 4 - 1
        return cnn, (output_channels, output_height, output_width)
    
    def forward(self, images):
        # shape of images: (batch_size, channels, height, width)
        
        convolution = self.cnn(images)
        batch_size, channels, height, width = convolution.size()
        
        convolution = convolution.view(batch_size, channels * height, width)
        convolution = convolution.permute(2, 0, 1) # (width, batch_size, features)
        
        sequential = self.map_to_sequential(convolution)
        
        recurrent, _ = self.rnn1(sequential)
        recurrent, _ = self.rnn2(recurrent)
        
        output = self.dense(recurrent)
        return output # shape: (sequential_length, batch_size, num_class)
        
        

In [7]:
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
import numpy as np
from PIL import Image
import os

class OCR_Dataset(Dataset):
    CHARS = '0123456789abcdefghijklmnopqrstuvwxyz'
    CHAR2LABEL = {char: i + 1 for i, char in enumerate(CHARS)}
    LABEL2CHAR = {label: char for char, label in CHAR2LABEL.items()}
    
    def __init__ (self, mode = None, root_dir = None, img_height = 100, img_width = 100):
        
        mapping = {}
        
        with open(os.path.join(root_dir, 'lexicon.txt'), 'r') as fr:
            for i, line in enumerate(tqdm(fr.readlines())):
                mapping[i] = line.strip()
        
        if mode == 'train':
            path = 'annotation_train.txt'
        elif mode == 'val':
            path = 'annotation_val.txt'
        elif mode == 'test':
            path = 'annotation_test.txt'
        else:
            raise Exception("Incorrect argument for variable mode!")
        
        paths = []
        texts = []
        
        with open(os.path.join(root_dir, path), 'r') as fr:
            for line in tqdm(fr.readlines()):
                line_stripped = line.strip()
                
                cur_path, index = line_stripped.split(' ')
                
                cur_path = os.path.join(root_dir, cur_path[2:])
                index = int(index)
                
                paths.append(cur_path)
                texts.append(mapping[index])
                
        self.paths = paths
        self.texts = texts
        self.mode = mode
        self.img_height = img_height
        self.img_width = img_width
    
    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self, index):
        
        path = self.paths[index]
        
        try:
            image = Image.open(path).convert('L')  # grey-scale
        except IOError:
            print('Corrupted image for %d' % index)
            return self[index + 1]
        
        image = image.resize((self.img_width, self.img_height), resample=Image.BILINEAR)
        image = np.array(image)
        image = image.reshape((1, self.img_height, self.img_width))
        '''
        img_min = np.min(image)
        img_max = np.max(image)
        image = (image - img_min) / (img_max - img_min)
        '''
        image = (image / 127.5) - 1.0
        image = torch.FloatTensor(image)
        
        if self.texts:
            text = self.texts[index]
            target = [self.CHAR2LABEL[c] for c in text]
            target_length = [len(target)]

            target = torch.LongTensor(target)
            target_length = torch.LongTensor(target_length)
            
            return image, target, target_length
        else:
            return image

def ocr_dataset_collate_fn(batch):
    images, targets, target_lengths = zip(*batch)
    images = torch.stack(images, 0)
    targets = torch.cat(targets, 0)
    target_lengths = torch.cat(target_lengths, 0)
    
    return images, targets, target_lengths

In [8]:
#ctc_decoder.py
from collections import defaultdict

import torch
import numpy as np
from scipy.special import logsumexp  # log(p1 + p2) = logsumexp([log_p1, log_p2])

NINF = -1 * float('inf')
DEFAULT_EMISSION_THRESHOLD = 0.01


def _reconstruct(labels, blank=0):
    new_labels = []
    # merge same labels
    previous = None
    for l in labels:
        if l != previous:
            new_labels.append(l)
            previous = l
    # delete blank
    new_labels = [l for l in new_labels if l != blank]

    return new_labels


def greedy_decode(emission_log_prob, blank=0, **kwargs):
    labels = np.argmax(emission_log_prob, axis=-1)
    labels = _reconstruct(labels, blank=blank)
    return labels


def beam_search_decode(emission_log_prob, blank=0, **kwargs):
    beam_size = kwargs['beam_size']
    emission_threshold = kwargs.get('emission_threshold', np.log(DEFAULT_EMISSION_THRESHOLD))

    length, class_count = emission_log_prob.shape

    beams = [([], 0)]  # (prefix, accumulated_log_prob)
    for t in range(length):
        new_beams = []
        for prefix, accumulated_log_prob in beams:
            for c in range(class_count):
                log_prob = emission_log_prob[t, c]
                if log_prob < emission_threshold:
                    continue
                new_prefix = prefix + [c]
                # log(p1 * p2) = log_p1 + log_p2
                new_accu_log_prob = accumulated_log_prob + log_prob
                new_beams.append((new_prefix, new_accu_log_prob))

        # sorted by accumulated_log_prob
        new_beams.sort(key=lambda x: x[1], reverse=True)
        beams = new_beams[:beam_size]

    # sum up beams to produce labels
    total_accu_log_prob = {}
    for prefix, accu_log_prob in beams:
        labels = tuple(_reconstruct(prefix))
        # log(p1 + p2) = logsumexp([log_p1, log_p2])
        total_accu_log_prob[labels] = \
            logsumexp([accu_log_prob, total_accu_log_prob.get(labels, NINF)])

    labels_beams = [(list(labels), accu_log_prob)
                    for labels, accu_log_prob in total_accu_log_prob.items()]
    labels_beams.sort(key=lambda x: x[1], reverse=True)
    labels = labels_beams[0][0]

    return labels


def prefix_beam_decode(emission_log_prob, blank=0, **kwargs):
    beam_size = kwargs['beam_size']
    emission_threshold = kwargs.get('emission_threshold', np.log(DEFAULT_EMISSION_THRESHOLD))

    length, class_count = emission_log_prob.shape

    beams = [(tuple(), (0, NINF))]  # (prefix, (blank_log_prob, non_blank_log_prob))
    # initial of beams: (empty_str, (log(1.0), log(0.0)))

    for t in range(length):
        new_beams_dict = defaultdict(lambda: (NINF, NINF))  # log(0.0) = NINF

        for prefix, (lp_b, lp_nb) in beams:
            for c in range(class_count):
                log_prob = emission_log_prob[t, c]
                if log_prob < emission_threshold:
                    continue

                end_t = prefix[-1] if prefix else None

                # if new_prefix == prefix
                new_lp_b, new_lp_nb = new_beams_dict[prefix]

                if c == blank:
                    new_beams_dict[prefix] = (
                        logsumexp([new_lp_b, lp_b + log_prob, lp_nb + log_prob]),
                        new_lp_nb
                    )
                    continue
                if c == end_t:
                    new_beams_dict[prefix] = (
                        new_lp_b,
                        logsumexp([new_lp_nb, lp_nb + log_prob])
                    )

                # if new_prefix == prefix + (c,)
                new_prefix = prefix + (c,)
                new_lp_b, new_lp_nb = new_beams_dict[new_prefix]

                if c != end_t:
                    new_beams_dict[new_prefix] = (
                        new_lp_b,
                        logsumexp([new_lp_nb, lp_b + log_prob, lp_nb + log_prob])
                    )
                else:
                    new_beams_dict[new_prefix] = (
                        new_lp_b,
                        logsumexp([new_lp_nb, lp_b + log_prob])
                    )

        # sorted by log(blank_prob + non_blank_prob)
        beams = sorted(new_beams_dict.items(), key=lambda x: logsumexp(x[1]), reverse=True)
        beams = beams[:beam_size]

    labels = list(beams[0][0])
    return labels


def ctc_decode(log_probs, label2char=None, blank=0, method='beam_search', beam_size=10):
    emission_log_probs = np.transpose(log_probs.cpu().numpy(), (1, 0, 2))
    # size of emission_log_probs: (batch, length, class)

    decoders = {
        'greedy': greedy_decode,
        'beam_search': beam_search_decode,
        'prefix_beam_search': prefix_beam_decode,
    }
    decoder = decoders[method]

    decoded_list = []
    for emission_log_prob in emission_log_probs:
        decoded = decoder(emission_log_prob, blank=blank, beam_size=beam_size)
        if label2char:
            decoded = [label2char[l] for l in decoded]
        decoded_list.append(decoded)
    return decoded_list

In [9]:
#evaluate.py
import torch
from torch.utils.data import DataLoader
from torch.nn import CTCLoss
from tqdm import tqdm

#from dataset import OCR_Dataset, ocr_dataset_collate_fn
#from model import CRNN
#from ctc_decoder import ctc_decode
#from config import evaluate_config as config

torch.backends.cudnn.enabled = False


def evaluate(crnn, dataloader, criterion,
             max_iter=None, decode_method='beam_search', beam_size=10):
    crnn.eval()

    tot_count = 0
    tot_loss = 0
    tot_correct = 0
    wrong_cases = []

    pbar_total = max_iter if max_iter else len(dataloader)
    pbar = tqdm(total=pbar_total, desc="Evaluate")

    with torch.no_grad():
        for i, data in enumerate(dataloader):
            if max_iter and i >= max_iter:
                break
            device = 'cuda' if next(crnn.parameters()).is_cuda else 'cpu'

            images, targets, target_lengths = [d.to(device) for d in data]

            logits = crnn(images)
            log_probs = torch.nn.functional.log_softmax(logits, dim=2)

            batch_size = images.size(0)
            input_lengths = torch.LongTensor([logits.size(0)] * batch_size)

            loss = criterion(log_probs, targets, input_lengths, target_lengths)

            preds = ctc_decode(log_probs, method=decode_method, beam_size=beam_size)
            reals = targets.cpu().numpy().tolist()
            target_lengths = target_lengths.cpu().numpy().tolist()

            tot_count += batch_size
            tot_loss += loss.item()
            target_length_counter = 0
            for pred, target_length in zip(preds, target_lengths):
                real = reals[target_length_counter:target_length_counter + target_length]
                target_length_counter += target_length
                if pred == real:
                    tot_correct += 1
                else:
                    wrong_cases.append((real, pred))

            pbar.update(1)
        pbar.close()

    evaluation = {
        'loss': tot_loss / tot_count,
        'acc': tot_correct / tot_count,
        'wrong_cases': wrong_cases
    }
    return evaluation



In [10]:
def test():
    eval_batch_size = evaluate_config['eval_batch_size']
    cpu_workers = evaluate_config['cpu_workers']
    reload_checkpoint = evaluate_config['reload_checkpoint']

    img_height = common_config['img_height']
    img_width = common_config['img_width']

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'device: {device}')

    test_dataset = OCR_Dataset(root_dir=evaluate_config['data_dir'], mode='test',
                                   img_height=img_height, img_width=img_width)

    test_loader = DataLoader(
        dataset=test_dataset,
        batch_size=eval_batch_size,
        shuffle=False,
        num_workers=cpu_workers,
        collate_fn=ocr_dataset_collate_fn)

    num_class = len(OCR_Dataset.LABEL2CHAR) + 1
    crnn = CRNN(1, img_height, img_width, num_class,
                map_to_seq_hidden=evaluate_config['map_to_seq_hidden'],
                rnn_hidden=evaluate_config['rnn_hidden'],
                use_leaky_relu=evaluate_config['use_leaky_relu'])
    crnn.load_state_dict(torch.load(reload_checkpoint, map_location=device))
    crnn.to(device)

    criterion = CTCLoss(reduction='sum')
    criterion.to(device)

    evaluation = evaluate(crnn, test_loader, criterion,
                          decode_method=evaluate_config['decode_method'],
                          beam_size=evaluate_config['beam_size'])
    print('test_evaluation: loss={loss}, acc={acc}'.format(**evaluation))


#if __name__ == '__main__':
#    main()

In [16]:
#train.py
import os

import cv2
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.nn import CTCLoss

#from config import train_config as config
#from dataset import OCR_Dataset, ocr_dataset_collate_fn
#from model import CRNN
#from evaluate import evaluate

def train_batch(crnn, data, optimizer, criterion, device):
    crnn.train()
    images, labels, label_lengths = [d.to(device) for d in data]
    
    logits = crnn(images)
    log_probs = torch.nn.functional.log_softmax(logits, dim=2)
    
    batch_size = images.size(0)
    input_lengths = torch.LongTensor([logits.size(0)] * batch_size)
    label_lengths = torch.flatten(label_lengths)
    
    loss = criterion(log_probs, labels, input_lengths, label_lengths)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss.item()

def main():
    print("Sa")
    epochs = train_config['epochs']
    train_batch_size = train_config['train_batch_size']
    eval_batch_size = train_config['eval_batch_size']
    lr = train_config['lr']
    show_interval = train_config['show_interval']
    valid_interval = train_config['valid_interval']
    save_interval = train_config['save_interval']
    cpu_workers = train_config['cpu_workers']
    reload_checkpoint = train_config['reload_checkpoint']
    valid_max_iter = train_config['valid_max_iter']

    img_width = common_config['img_width']
    img_height = common_config['img_height']
    data_dir = common_config['data_dir']
    
    if torch.cuda.is_available():
        device = torch.device('cuda')
        print("Running on GPU!")
    else:
        device = torch.device('cpu')
        print("Running on CPU!")
        
    train_dataset = OCR_Dataset(root_dir=data_dir, mode='train', 
                                     img_height=img_height, img_width=img_width)
    
    valid_dataset = OCR_Dataset(root_dir=data_dir, mode='val', 
                                     img_height=img_height, img_width=img_width)
    
    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=train_batch_size,
        shuffle=True,
        collate_fn=ocr_dataset_collate_fn)
    
    valid_loader = DataLoader(
        dataset=valid_dataset,
        batch_size=eval_batch_size,
        shuffle=True,
        collate_fn=ocr_dataset_collate_fn)

    num_class = len(OCR_Dataset.LABEL2CHAR) + 1
    crnn = CRNN(1, img_height, img_width, num_class,
                map_to_seq_hidden= common_config['map_to_seq_hidden'],
                rnn_hidden= common_config['rnn_hidden'],
                use_leaky_relu= common_config['use_leaky_relu'])
    if reload_checkpoint:
        crnn.load_state_dict(torch.load(reload_checkpoint, map_location=device))
    crnn.to(device)

    optimizer = optim.Adam(crnn.parameters(), lr=lr)
    criterion = CTCLoss(reduction='sum')
    criterion.to(device)
    
    i = 1
    for epoch in range(1, epochs + 1):
        print(f'epoch: {epoch}')
        tot_train_loss = 0.
        tot_train_count = 0
        for train_data in tqdm(train_loader):
            
            loss = train_batch(crnn, train_data, optimizer, criterion, device)
            train_size = train_data[0].size(0)

            tot_train_loss += loss
            tot_train_count += train_size
                
        save_model_path = os.path.join(train_config['checkpoints_dir'],
                                        f'{prefix}_{epoch:06}_loss{loss}.pt')
        torch.save(crnn.state_dict(), save_model_path)
        print('save model at ', save_model_path)
        
        evaluation = evaluate(crnn, valid_loader, criterion,
                            decode_method=eval_config['decode_method'],
                            beam_size=eval_config['beam_size'])
        
        print('valid_evaluation: loss={loss}, acc={acc}'.format(**evaluation))
        print('epoch ', str(epoch), ': train_loss: ', tot_train_loss / tot_train_count)


if __name__ == '__main__':
    main()

100%|██████████| 88172/88172 [00:00<00:00, 658019.59it/s]

Sa
Running on GPU!



100%|██████████| 1000/1000 [00:00<00:00, 32255.69it/s]
100%|██████████| 88172/88172 [00:00<00:00, 419855.15it/s]
100%|██████████| 1000/1000 [00:00<00:00, 90905.83it/s]
  0%|          | 0/32 [00:00<?, ?it/s]

epoch: 1


 28%|██▊       | 9/32 [00:03<00:08,  2.58it/s]
Evaluate:   0%|          | 0/2 [00:00<?, ?it/s][A

train_batch_loss[ 10 ]:  30.772218704223633



 28%|██▊       | 9/32 [00:36<01:34,  4.11s/it]3.89s/it][A


KeyboardInterrupt: 

In [None]:
#predict.py
from docopt import docopt
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader

from config import common_config as config
from dataset import OCR_Dataset, ocr_dataset_collate_fn
from model import CRNN
from ctc_decoder import ctc_decode


def predict(crnn, dataloader, label2char, decode_method, beam_size):
    crnn.eval()
    pbar = tqdm(total=len(dataloader), desc="Predict")

    all_preds = []
    with torch.no_grad():
        for data in dataloader:
            device = 'cuda' if next(crnn.parameters()).is_cuda else 'cpu'

            images = data.to(device)

            logits = crnn(images)
            log_probs = torch.nn.functional.log_softmax(logits, dim=2)

            preds = ctc_decode(log_probs, method=decode_method, beam_size=beam_size,
                               label2char=label2char)
            all_preds += preds

            pbar.update(1)
        pbar.close()

    return all_preds


def show_result(paths, preds):
    print('\n===== result =====')
    for path, pred in zip(paths, preds):
        text = ''.join(pred)
        print(f'{path} > {text}')


def main():
    arguments = docopt(__doc__)

    images = arguments['IMAGE']
    reload_checkpoint = arguments['-m']
    batch_size = int(arguments['-s'])
    decode_method = arguments['-d']
    beam_size = int(arguments['-b'])

    img_height = config['img_height']
    img_width = config['img_width']

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'device: {device}')
    
    #todo: jotform dataset
    predict_dataset = Synth90kDataset(paths=images,
                                      img_height=img_height, img_width=img_width)

    predict_loader = DataLoader(
        dataset=predict_dataset,
        batch_size=batch_size,
        shuffle=False)

    num_class = len(Synth90kDataset.LABEL2CHAR) + 1
    crnn = CRNN(1, img_height, img_width, num_class,
                map_to_seq_hidden=config['map_to_seq_hidden'],
                rnn_hidden=config['rnn_hidden'],
                leaky_relu=config['leaky_relu'])
    crnn.load_state_dict(torch.load(reload_checkpoint, map_location=device))
    crnn.to(device)

    preds = predict(crnn, predict_loader, Synth90kDataset.LABEL2CHAR,
                    decode_method=decode_method,
                    beam_size=beam_size)

    show_result(images, preds)


if __name__ == '__main__':
    main()

In [14]:
train_dataset = OCR_Dataset(root_dir=train_config['data_dir'], mode='train', img_height=32, img_width=100)   

train_loader = DataLoader( dataset=train_dataset, batch_size=10, shuffle=True, collate_fn=ocr_dataset_collate_fn)

100%|██████████| 88172/88172 [00:00<00:00, 722735.56it/s]
100%|██████████| 1000/1000 [00:00<00:00, 142857.77it/s]


In [None]:
i = 0
dataloader_iterator = iter(train_loader)

'''
for i in range(iterations):
    print(i)
    try:
        X, Y = next(dataloader_iterator)
    except:
        dataloader_iterator = iter(train_loader)
        X, Y = next(dataloader_iterator)
    do_backprop(X, Y)
'''

len(next(dataloader_iterator))



In [12]:
len(train_loader)

10