## Detector

In [39]:
import torch
import cv2
import numpy as np
import os
from craft_text_detector import craft_utils, file_utils
from craft_text_detector import imgproc
from craft_text_detector.craft import CRAFT
from torch.autograd import Variable

In [36]:
from collections import OrderedDict

In [37]:
def copyStateDict(state_dict):
    if list(state_dict.keys())[0].startswith("module"):
        start_idx = 1
    else:
        start_idx = 0
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = ".".join(k.split(".")[start_idx:])
        new_state_dict[name] = v
    return new_state_dict

In [None]:
def test_net(net, image, text_threshold, link_threshold, low_text, poly, refine_net=None):

    # resize
    img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(image, canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=mag_ratio)
    ratio_h = ratio_w = 1 / target_ratio

    # preprocessing
    x = imgproc.normalizeMeanVariance(img_resized)
    x = torch.from_numpy(x).permute(2, 0, 1)    # [h, w, c] to [c, h, w]
    x = Variable(x.unsqueeze(0))                # [c, h, w] to [b, c, h, w]

    # forward pass
    with torch.no_grad():
        y, feature = net(x)

    # make score and link map
    score_text = y[0,:,:,0].cpu().data.numpy()
    score_link = y[0,:,:,1].cpu().data.numpy()

    # refine link
    if refine_net is not None:
        with torch.no_grad():
            y_refiner = refine_net(y, feature)
        score_link = y_refiner[0,:,:,0].cpu().data.numpy()

    # Post-processing
    boxes, polys = craft_utils.getDetBoxes(score_text, score_link, text_threshold, link_threshold, low_text, poly)

    # coordinate adjustment
    boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
    polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h)
    for k in range(len(polys)):
        if polys[k] is None: polys[k] = boxes[k]


    # render results (optional)
    render_img = score_text.copy()
    render_img = np.hstack((render_img, score_link))
    ret_score_text = imgproc.cvt2HeatmapImg(render_img)


    return boxes, polys, ret_score_text

In [None]:
image_path = './data/'
model_path = './models/' + 'craft_ic15_20k.pth'
text_threshold = 0.7
link_threshold=0.4
low_text =0.4
# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cpu




Loading CRAFT model from: ./models/craft_ic15_20k.pth
CRAFT model loaded successfully.


In [42]:
canvas_size = 1280
mag_ratio = 1.5
poly = False

In [43]:
image_list, _, _ = file_utils.get_files(image_path)
result_folder = './result/'
if not os.path.isdir(result_folder):
    os.mkdir(result_folder)

In [None]:
net = CRAFT()     # initialize

print('Loading weights from checkpoint (' + model_path + ')')
net.load_state_dict(copyStateDict(torch.load(model_path, map_location=device)))
net.eval()



# load data
for k, image_path in enumerate(image_list):
    print("Test image {:d}/{:d}: {:s}".format(k+1, len(image_list), image_path), end='\r')
    image = imgproc.loadImage(image_path)

    bboxes, polys, score_text = test_net(net, image, text_threshold, link_threshold, low_text, poly)

    # save score text
    filename, file_ext = os.path.splitext(os.path.basename(image_path))
    mask_file = result_folder + "/res_" + filename + '_mask.jpg'
    cv2.imwrite(mask_file, score_text)

    file_utils.saveResult(image_path, image[:,:,::-1], polys, dirname=result_folder)



Loading weights from checkpoint (./models/craft_ic15_20k.pth)
Test image 17/17: ./data/test18.jpgg