<font face = 'Arial'  size = '6'> 1. Import Libraries </font>

In [1]:
import ultralytics
ultralytics.checks()

Ultralytics YOLOv8.2.92  Python-3.12.4 torch-2.3.0 CPU (AMD Ryzen 5 5500U with Radeon Graphics)
Setup complete  (12 CPUs, 19.4 GB RAM, 26.9/100.1 GB disk)


In [None]:
import os
import numpy as np
import timm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torchvision import models
from torchvision import transforms
from PIL import Image

<font face = 'Arial' size = '6'> 2. Load Model </font>

<font face = 'Arial' size = '4'> 2.1 Loade YOLO </face>

In [None]:
from ultralytics import YOLO

text_det_model_path = ' '
yolo = YOLO(text_det_model_path)

<font face = 'Arial' size = '4'> 2.2 Load CRNN </font>

In [None]:
chars = '0123456789abcdefghijklmnopqrstuvwxyz-'
vocab_size = len(chars)
char_to_idx = {char : idx +1 for idx, char in  enumerate(sorted(chars))}
idx_to_char = {idx : char for char, idx in char_to_idx.items()}

In [2]:

class CRNN(nn.Module):
    def __init__(
            self,
            vocab_size,
            hidden_size,
            n_layers,
            dropout = 0.2,
            unfreeze_layers = 3
    ):
        super(CRNN, self).__init__()

        # initialize pretrained model resnet101
        backbone = timm.create_model(
            'resnet101',
            in_chans = 1,
            pretrained = True
        )

        # remove the original pretrained classification class
        modules = list(backbone.children())[:2]
        # add the adaptiveAvgPool2d class
        modules.append(nn.AdaptiveAvgPool2d((1, None)))
        self.backbone = nn.Sequential(*modules)

        # unfreeze some final layer of the pretrained model
        for parameter in self.backbone[-unfreeze_layers : ].parameters():
            parameter.requires_grad = True
        
        # layer map from CNN features maps to LSTM
        self.mapSeq = nn.Sequential(
            nn.Linear(2048, 1024),
            nn.ReLU(),
            nn.Dropout(dropout)
        )

        self.lstm = nn.LSTM(
            1024, hidden_size,
            n_layers, bidirectional = True, batch_first = True,
            droppout = dropout if n_layers > 1 else 0
        )

        self.layer_norm = nn.LayerNorm(hidden_size * 2 )
        self.out = nn.Sequential(
            nn.Linear(hidden_size * 2 , vocab_size),
            nn.LogSoftmax(dim = 2)
        )
    
    def forward(self, x):
        x = self.backbone(x) # shape (bs , channels, height, width)
        x = x.permute(0, 3, 1, 2) # shape ( bs, w , c, h)
        x = x.view(x.size(0), x.size(1), -1)
        x = self.mapSeq(x)
        x,_ = self.lstm(x)
        x = self.layer_norm(x)
        x = self.out(x)
        x = x.permute(1, 0, 2)

        return x
    
hidden_size = 256
n_layers = 2
dropout_prob = 0.3
unfreeze_layers=3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_path = 'models/ocr_crnn_resnet_best.pt'

crnn_model = CRNN(
    vocab_size=vocab_size,
    hidden_size=hidden_size,
    n_layers=n_layers,
    dropout=dropout_prob,
    unfreeze_layers=unfreeze_layers
).to(device)
crnn_model.load_state_dict(torch.load(model_path))


SyntaxError: expected ':' (4270271084.py, line 2)

<font face = 'Arial' size = '6'> 3. Inference</font>

In [None]:
def decode(encoded_sequences, idx_to_char, blank_char = '-'):
    """
    Decode encoded_sequences to string
    parameters:
       encoded_sequences (list) : The lists tensor label
       idx_to_char (dict) : mapping ID -> classname
       blank_char(str) : '-'

    Return:
          decoded_sequences (list) : The List of decoded_labels.
    
    """

    # Declare empty list to contains result decoded
    decoded_sequences = []

    for seq in encoded_sequences:
        
        decoded_label = []

        for idx, token in enumerate(seq) :
            if token != 0: 
                char = idx_to_char[token.item()]
                if char != blank_char:
                    decoded_label.append(char)
        
        decoded_sequences.append(''.join(decoded_label))
    
    return decoded_sequences

In [None]:
def text_detection(img_path, text_det_model):
    """
    Locate (bbox) the text in the image

    Parameters: 
         img_path (str) : path to file image
         text_det_model (YOLO) : Model YOLO text detection
    
    Returns:
         tuple: Includes identified components(bboxes, classes, names, confs)
    
    """
    # Perform detection according to YOLO
    text_det_results = text_det_model(img_path, verbose = False)[0]

    bboxes = text_det_results.boxes.xyxy.tolist()
    # get classes, confidence scores
    classes = text_det_results.boxes.cls.tolist()
    names = text_det_results.names
    confs = text_det_results.boxes.conf.tolist()

    return bboxes, classes, names, confs

In [None]:
def text_recognition(img, data_transforms, text_reg_model, idx_to_char, device):
    """
    Recognition text in image
    Parameters:
        img(PIL.image) : image Object
        data_transforms (transforms.Compose) : Preprocessing.
        text_reg_model (CRNN) : Model CRNN text recognition.
        idx_to_char (dict) :  mapping ID -> classname
    
        Returns :
            text(str) : output text
        """
    
    transformed_image = data_transforms(img)
    transformed_image = transformed_image.unsqueeze(0).to(device)
    text_reg_model.eval()
    with torch.no_grad():
        logits = text_reg_model(transformed_image).detach().cpu()
    text = decode(logits.permute(1, 0, 2).argmax(2), idx_to_char)

    return text

In [None]:
def visualize_detections(img, detections):
    """
    Visualize result Scene Text Recognition (STR) 

    Parameters :
        img (PIL.Image) : image Object
        detections (list) : The lists contains result STR on the image 
    """
    plt.figure(figsize = (12, 8))
    plt.imshow(img)
    plt.axis('off')

    for bbox, detected_class, confidence, transcribed_text in detections : 
        x1, y1, x2, y2 = bbox
        plt.gca().add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill = False, edgecolor = 'red', linewidth = 2))
        plt.text(
            x1, y1-10, f"{detected_class} ({confidence:.2f}) : {transcribed_text}",
            fontsize = 9, bbox = dict(facecolor = 'red', alpha = 0.5)
        )
    
    plt.show()


In [None]:
data_transforms = {
    'train' : transforms.Compose([
        transforms.Resize((100, 420)),
        transforms.ColorJitter(brightness = 0.5, constrast = 0.5, saturation = 0.5),
        transforms.Graysscale(num_output_channels = 1 ),
        transforms.GaussianBlur(3),
        transforms.RandomAffine(degress = 1, sehar = 1),
        transforms.RandomPerspective(distortion_scale = 0.2, p = 0.3, interpolation = 3),
        transforms.RandomRotation(degrees = 2),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ]),
    'val'  : transforms.Compose([
        transforms.Resize((100,420)),
        transforms.Grayscale(num_outout_channels = 1),
        transforms.ToTnesor(),
        transforms.Normalize((0.5,), (0.5)),
    ])
}
def predict(img_path, data_transforms, text_det_model, text_reg_model,
            idx_to_char, device, visualize = True):
    """
    Scene Text Recognition with any image
    parameters:
       img_path (str) : path to image
       data_transfroms (trainsforms.compose) : function preprocessing image
       text_det_model ( YOLO) : model YOLO text detection
       text_reg_model (CRNN) : model CRNN text recognition
       idx_to_char ( dict) : mapping idx -> classname
       device(str) : 'cpu' or 'gpu'
       visualize (bool) : visualization result STR
    Returns : 
    predictions (list): The list results STR on image
    """
    bboxes, classes, names, confs = text_detection(img_path, text_det_model)

    img = Image.open(img_path)

    predictions = []

    for bbox, cls, conf in zip(bboxes, classes, confs):
        x1, y1, x2, y2 = bbox
        confidence = conf
        detected_class = cls
        name = name[int(cls)] 

        cropped_image = img.crop((x1, y1, x2, y2))

        transcribed_text = text_recognition(
            cropped_image,
            data_transforms,
            text_reg_model,
            idx_to_char,
            device
        )

        predictions.append(( bbox, name, confidence, transcribed_text))

    if visualize:
        visualize_detections(img, predictions)
    
    return predictions
        

In [None]:
img_dir = ''
inf_transforms = data_transforms['val']
for img_filename in os.listdir(img_dir):
    img_path = os.path.join(img_dir, img_filename)
    predictions = predict(
        img_path, 
        data_transforms=inf_transforms, 
        text_det_model=yolo, 
        text_reg_model=crnn_model, 
        idx_to_char=idx_to_char,
        device=device,
        visualize=True
    )