# Text Detection 

In [7]:
import os
import time
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
from PIL import Image
import cv2
from skimage import io
import numpy as np
from detection_utils import craft_utils
from detection_utils import imgproc
from detection_utils import file_utils
import json
import zipfile
from detection_utils.craft import CRAFT
from collections import OrderedDict

In [9]:
config = {
    "trained_model": "ai_models/craft_mlt_25k.pth",
    "text_threshold": 0.7,
    "low_text": 0.4,
    "link_threshold": 0.4,
    "cuda": False,
    "canvas_size": 1280,
    "mag_ratio": 1.5,
    "poly": False,
    "show_time": False,
    "test_folder": "./input_frames/",
    "refine": False,
    "refiner_model": "weights/craft_refiner_CTW1500.pth"
}


def copyStateDict(state_dict):
    if list(state_dict.keys())[0].startswith("module"):
        start_idx = 1
    else:
        start_idx = 0
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = ".".join(k.split(".")[start_idx:])
        new_state_dict[name] = v
    return new_state_dict

##---------------------- inference Function ----------------------

def test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly, refine_net=None):
    t0 = time.time()
    img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(
        image, config["canvas_size"], interpolation=cv2.INTER_LINEAR, mag_ratio=config["mag_ratio"]
    )
    ratio_h = ratio_w = 1 / target_ratio

    x = imgproc.normalizeMeanVariance(img_resized)
    x = torch.from_numpy(x).permute(2, 0, 1)
    x = Variable(x.unsqueeze(0))
    if cuda:
        x = x.cuda()

    with torch.no_grad():
        y, feature = net(x)

    score_text = y[0, :, :, 0].cpu().data.numpy()
    score_link = y[0, :, :, 1].cpu().data.numpy()

    if refine_net is not None:
        with torch.no_grad():
            y_refiner = refine_net(y, feature)
        score_link = y_refiner[0, :, :, 0].cpu().data.numpy()

    t0 = time.time() - t0
    t1 = time.time()

    boxes, polys = craft_utils.getDetBoxes(score_text, score_link, text_threshold, link_threshold, low_text, poly)
    boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
    polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h)
    for k in range(len(polys)):
        if polys[k] is None:
            polys[k] = boxes[k]

    t1 = time.time() - t1

    render_img = np.hstack((score_text.copy(), score_link))
    ret_score_text = imgproc.cvt2HeatmapImg(render_img)

    if config["show_time"]:
        print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1))

    return boxes, polys, ret_score_text

if __name__ == '__main__':
    result_folder = './result/'
    os.makedirs(result_folder, exist_ok=True)
    image_list, _, _ = file_utils.get_files(config["test_folder"])

    net = CRAFT()
    print('Loading weights from ' + config["trained_model"] )
    if config["cuda"]:
        net.load_state_dict(copyStateDict(torch.load(config["trained_model"])))
    else:
        net.load_state_dict(copyStateDict(torch.load(config["trained_model"], map_location='cpu')))

    if config["cuda"]:
        net = net.cuda()
        net = nn.DataParallel(net)
        cudnn.benchmark = False

    net.eval()
    refine_net = None
    t = time.time()
    for k, image_path in enumerate(image_list):
        
        image = imgproc.loadImage(image_path) #-- load image
        ##--- run detection
        boxes, polys, score_text = test_net(net, image, config["text_threshold"], config["link_threshold"], config["low_text"], config["cuda"], config["poly"], refine_net)

        filename, file_ext = os.path.splitext(os.path.basename(image_path))

        ##---- Save cropped regions 
        crop_folder = os.path.join(result_folder, "crops")
        os.makedirs(crop_folder, exist_ok=True) ##--- result/crops

        for i, box in enumerate(boxes):
            rect = np.array(box).astype(np.int32)
            x, y, w, h = cv2.boundingRect(rect)
            crop = image[y:y+h, x:x+w]

            crop_file = os.path.join(crop_folder, f"{filename}_box_{i+1}.jpg")
            cv2.imwrite(crop_file, crop)

        

    print("\nText detection completed in : {:.2f}s".format(time.time() - t))

Loading weights from ai_models/craft_mlt_25k.pth

Text detection completed in : 2.63s


# Text recognition

In [3]:
import torch
import torch.backends.cudnn as cudnn
import torch.utils.data
import torch.nn.functional as F
import string
from recognition_utils.utils import CTCLabelConverter, AttnLabelConverter
from recognition_utils.dataset import RawDataset, AlignCollate
from recognition_utils.model import Model

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
class Opt:
    
    ##------- input_frams/model settings
    image_folder = 'result/crops'
    saved_model = 'ai_models/best_accuracy.pth'
    
    ##------ Model architecture
    Transformation = 'TPS'
    FeatureExtraction = 'ResNet'
    SequenceModeling = 'BiLSTM'
    Prediction = 'CTC'#'Attn'

    ##-----Input/output settings
    imgH = 32
    imgW = 100
    rgb = False
    character = '0123456789abcdefghijklmnopqrstuvwxyz'
    sensitive = False  
    PAD = False

    ##----- Training/inference settings
    batch_max_length = 25
    num_fiducial = 20
    input_channel = 1
    output_channel = 512
    hidden_size = 256
    batch_size = 192
    workers = 0

    ##---- Device settings
    num_gpu = torch.cuda.device_count()

opt = Opt()

cudnn.benchmark = True
cudnn.deterministic = True

##------------------- recognition block -------------------

def demo(opt):
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3

    model = Model(opt)
    model = torch.nn.DataParallel(model).to(device)

    print('Loading model from %s' % opt.saved_model)
    model.load_state_dict(torch.load(opt.saved_model, map_location=device))

    AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
    demo_data = RawDataset(root=opt.image_folder, opt=opt)
    demo_loader = torch.utils.data.DataLoader(
        demo_data, batch_size=opt.batch_size,
        shuffle=False,
        num_workers=opt.workers,
        collate_fn=AlignCollate_demo, pin_memory=True)

    model.eval()
    with torch.no_grad():
        for image_tensors, image_path_list in demo_loader:
            batch_size = image_tensors.size(0)
            image = image_tensors.to(device)
            length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device)
            text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device)

            if 'CTC' in opt.Prediction:
                preds = model(image, text_for_pred)
                preds_size = torch.IntTensor([preds.size(1)] * batch_size)
                _, preds_index = preds.max(2)
                preds_str = converter.decode(preds_index, preds_size)
            else:
                preds = model(image, text_for_pred, is_train=False)
                _, preds_index = preds.max(2)
                preds_str = converter.decode(preds_index, length_for_pred)

            log = open('./result/recognition_result.txt', 'a')
            dashed_line = '-' * 80
            head = f'{"image_path":25s}\t{"predicted_labels":25s}\tconfidence score'
            print(f'{dashed_line}\n{head}\n{dashed_line}')
            log.write(f'{dashed_line}\n{head}\n{dashed_line}\n')

            preds_prob = F.softmax(preds, dim=2)
            preds_max_prob, _ = preds_prob.max(dim=2)
            for img_name, pred, pred_max_prob in zip(image_path_list, preds_str, preds_max_prob):
                if 'Attn' in opt.Prediction:
                    pred_EOS = pred.find('[s]')
                    pred = pred[:pred_EOS]
                    pred_max_prob = pred_max_prob[:pred_EOS]
                confidence_score = pred_max_prob.cumprod(dim=0)[-1]
                print(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}')
                log.write(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}\n')
            log.close()


if __name__ == '__main__':
    demo(opt)


Loading model from ai_models/best_accuracy.pth
--------------------------------------------------------------------------------
image_path               	predicted_labels         	confidence score
--------------------------------------------------------------------------------
result/crops\image_validation_94_box_1.jpg	053660                   	0.9693
result/crops\image_validation_939_box_1.jpg	479643                   	0.4904
result/crops\image_validation_941_box_1.jpg	480149                   	0.4348
result/crops\image_validation_943_box_1.jpg	481096                   	0.9263
result/crops\image_validation_944_box_1.jpg	481867                   	0.8861
result/crops\image_validation_945_box_1.jpg	482198                   	0.9726
result/crops\image_validation_949_box_1.jpg	483540                   	0.6824
result/crops\image_validation_988_box_1.jpg	501703                   	0.9941
result/crops\image_validation_991_box_1.jpg	502942                   	0.9896
result/crops\image_validation_