# Inference for CRNN

Inference is working on prepared image cutouts, not library images as a whole.

# Preprocessing

run once to load methods

In [1]:
import cv2
import numpy as np
from typing import Tuple, Union
import math

# get grayscale image
def get_grayscale(image):
    return cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

# noise removal - blur
def blur(image):
    return cv2.medianBlur(image,5)
 
# thresholding
def thresholding(image):
    return cv2.adaptiveThreshold(image, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 21, 15)

# dilation
def dilate(image):
    kernel = np.ones((5,5),np.uint8)
    return cv2.dilate(image, kernel, iterations = 1)
    
# erosion
def erode(image):
    kernel = np.ones((5,5),np.uint8)
    return cv2.erode(image, kernel, iterations = 1)

# opening - erosion followed by dilation
def opening(image):
    kernel = np.ones((5,5),np.uint8)
    return cv2.morphologyEx(image, cv2.MORPH_OPEN, kernel)

# closing - dilation followed by erosion
def closing(image):
    kernel = np.ones((1,1),np.uint8)
    return cv2.morphologyEx(image, cv2.MORPH_CLOSE, kernel)

def rotate(image: np.ndarray, angle: float, background: Union[int, Tuple[int, int, int]]) -> np.ndarray:
    old_width, old_height = image.shape[:2]
    angle_radian = math.radians(angle)
    width = abs(np.sin(angle_radian) * old_height) + abs(np.cos(angle_radian) * old_width)
    height = abs(np.sin(angle_radian) * old_width) + abs(np.cos(angle_radian) * old_height)

    image_center = tuple(np.array(image.shape[1::-1]) / 2)
    rot_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0)
    rot_mat[1, 2] += (width - old_width) / 2
    rot_mat[0, 2] += (height - old_height) / 2
    return cv2.warpAffine(image, rot_mat, (int(round(height)), int(round(width))), borderValue=background, borderMode=cv2.BORDER_CONSTANT)

# CRNN

## Setup

In [None]:
!git clone https://github.com/GitYCC/crnn-pytorch
!mv crnn-pytorch crnnpytorch
!pip install deskew
%cd crnnpytorch/

In [None]:
import torch

from src.config import common_config as config
from src.model import CRNN
from src.ctc_decoder import ctc_decode
from PIL import Image
import numpy as np
import os
import logging
from google.colab.patches import cv2_imshow
from deskew import determine_skew


img_height = config['img_height']
img_width = config['img_width']
decode_method = 'greedy'  # (greedy, beam_search or prefix_beam_search) [default: beam_search]
beam_size = 10  # default

CHARS = '0123456789abcdefghijklmnopqrstuvwxyz'
CHAR2LABEL = {char: i + 1 for i, char in enumerate(CHARS)}
LABEL2CHAR = {label: char for char, label in CHAR2LABEL.items()}


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'device: {device}')

num_class = len(LABEL2CHAR) + 1
crnn = CRNN(1, img_height, img_width, num_class,
            map_to_seq_hidden=config['map_to_seq_hidden'],
            rnn_hidden=config['rnn_hidden'],
            leaky_relu=config['leaky_relu'])
crnn.load_state_dict(torch.load('checkpoints/crnn_synth90k.pt', map_location=device)) # weight file, included in the repository
crnn.to(device)
crnn.eval()

## Inference on folder

### without deskewing

In [None]:
directory = 'path-to-inference-folder'

borderSize = 0

all_preds = []

for filename in sorted(os.listdir(directory)):
    inputImage = cv2.imread(os.path.join(directory,filename))
    inputImage = cv2.resize(inputImage, (200, 400))

    grayImage = get_grayscale(inputImage)
    gaussianFilter = blur(grayImage)
    binarizedImage = thresholding(gaussianFilter)

    binarizedImage[binarizedImage == 0] = 1
    binarizedImage[binarizedImage == 255] = 0

    horizontal_projection = np.sum(binarizedImage, axis=1);

    height, width = binarizedImage.shape

    blankImage = np.zeros((height, width, 3), np.uint8)

    blockCount = 0
    blockFlag = False
    lineBlocks = []

    for row in range(height):
        cv2.line(blankImage, (0,row), (int(horizontal_projection[row]*width/height),row), (255,255,255), 1)
        if not blockFlag and int(horizontal_projection[row]*width/height) > 9:
            # add start line of block
            lineBlocks.append(row)
            blockFlag = True
        elif blockFlag and int(horizontal_projection[row]*width/height) < 10:
            # add end line of block
            lineBlocks.append(row)
            blockFlag = False
        elif blockFlag and row == height-1:
            # if char reaches to end of line
            lineBlocks.append(row)
            blockFlag = False
        
    output = ''
    lines = []
    blackLines = []
    numLines = int(len(lineBlocks)/2)
    for line in range(numLines):
        with torch.no_grad():
            image = grayImage[lineBlocks[line*2]-borderSize if lineBlocks[line*2]-borderSize > 0 else 0:lineBlocks[line*2+1]+borderSize, 0:width]
            image = cv2.resize(image, (img_width, img_height))
            image = np.array(image)
            image = image.reshape((1, img_height, img_width))
            image = (image / 127.5) - 1.0
            image = torch.from_numpy(image).unsqueeze(0).float().to(device)
            logits = crnn(image)
            log_probs = torch.nn.functional.log_softmax(logits, dim=2)

            preds = ctc_decode(log_probs, method=decode_method, beam_size=beam_size,
                              label2char=LABEL2CHAR)
            preds = ''.join(str(e) for e in preds)
            output += preds
            output += ' '
    
    all_preds.append(output)


### with deskewing

In [None]:
directory = 'path-to-inference-folder'

borderSize = 0

all_preds = []

for filename in sorted(os.listdir(directory)):
    inputImage = cv2.imread(os.path.join(directory,filename))
    inputImage = cv2.resize(inputImage, (200, 400))

    grayImage = get_grayscale(inputImage)
    
    angle = determine_skew(grayImage)
    if angle is not None:
      if abs(angle)>45:
        angle = angle + 90
      grayImage = rotate(grayImage, angle, (255, 255, 255))
    
    gaussianFilter = blur(grayImage)
    binarizedImage = thresholding(gaussianFilter)

    binarizedImage[binarizedImage == 0] = 1
    binarizedImage[binarizedImage == 255] = 0

    horizontal_projection = np.sum(binarizedImage, axis=1);

    height, width = binarizedImage.shape

    blankImage = np.zeros((height, width, 3), np.uint8)

    blockCount = 0
    blockFlag = False
    lineBlocks = []

    for row in range(height):
        cv2.line(blankImage, (0,row), (int(horizontal_projection[row]*width/height),row), (255,255,255), 1)
        if not blockFlag and int(horizontal_projection[row]*width/height) > 9:
            # add start line of block
            lineBlocks.append(row)
            blockFlag = True
        elif blockFlag and int(horizontal_projection[row]*width/height) < 10:
            # add end line of block
            lineBlocks.append(row)
            blockFlag = False
        elif blockFlag and row == height-1:
            # if char reaches to end of line
            lineBlocks.append(row)
            blockFlag = False
        
    output = ''
    lines = []
    blackLines = []
    numLines = int(len(lineBlocks)/2)
    for line in range(numLines):
        with torch.no_grad():
            image = grayImage[lineBlocks[line*2]-borderSize if lineBlocks[line*2]-borderSize > 0 else 0:lineBlocks[line*2+1]+borderSize, 0:width]
            image = cv2.resize(image, (img_width, img_height))
            image = np.array(image)
            image = image.reshape((1, img_height, img_width))
            image = (image / 127.5) - 1.0
            image = torch.from_numpy(image).unsqueeze(0).float().to(device)
            logits = crnn(image)
            log_probs = torch.nn.functional.log_softmax(logits, dim=2)

            preds = ctc_decode(log_probs, method=decode_method, beam_size=beam_size,
                              label2char=LABEL2CHAR)
            preds = ''.join(str(e) for e in preds)
            output += preds
            output += ' '
    
    all_preds.append(output)