Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add batch img inference support for ocr det with readtext_batched #458

Merged
merged 3 commits into from
Jun 24, 2021
Merged

Add batch img inference support for ocr det with readtext_batched #458

merged 3 commits into from
Jun 24, 2021

Conversation

SamSamhuns
Copy link
Contributor

@SamSamhuns SamSamhuns commented Jun 12, 2021

Batched image inference for text detection

reader = easyocr.Reader(['en'], cudnn_benchmark=True)
img_path = [
        "https://pytorch.org/tutorials/_static/img/thumbnails/cropped/profiler.png",
        "https://www.tensorflow.org/images/tf_logo_social.png",
        "https://storage.googleapis.com/gd-wagtail-prod-assets/original_images/evolving_google_identity_2x1.jpg"]
reader.readtext_batched(img_path, n_width=800, n_height=600)

Caveats:

  1. For batched inference, all input images must be of the same size. They can be resized before or the n_width and n_height parameters can be used in readtext_batched. readtext_batched can also take a single image as input but returns a result list with one element, i.e. a further result[0] access will be required.
  2. cudnn.benchmark mode set to True is better for batched inference hence I pass cudnn_benchmark=True in easyocr.Reader
  3. GPU batched inference needs some warmup for the same batch size to see better performance on batched mode hence I use dummy = np.zeros([batch_size, 600, 800, 3], dtype=np.uint8); reader.readtext_batched(dummy)before timing the inferences.
  4. Batched inference mode should be used when a large number of frames are needed to be processed i.e. detecting and recognizing text in a video, otherwise, sequential processing will be faster for processing one image per API call.
  5. When running on GPU mode, the user will have to take care of the batch size themselves to prevent cuda out of memory error

Edited files

These changes although major should have no backward compatibility issues, but I would greatly appreciate extensive testing @rkcosmos . I am open to any suggestions or changes

utils.py

  • Added a new functionreformat_input_batched to take a list of file paths, numpy ndarrays, or byte stream objects

detection.py

  • Changed the get_textbox function to process a list of lists of bboxes and polys
  • Changed the test_net functions to accumulate the input image and send all the inputs in a single tensor to the CRAFT torch model

easyocr.py

  • Added a new function readtext_batched to take a list of file paths, numpy ndarrays, or byte stream objects now to process them in batch.
  • Change the detect function to process a list of images

I have a test script here to verify the functions are working as intended and added results for both CPU and GPU

As expected GPU batched inference is almost twice as fast as sequential GPU inference.

GPU results

gpu

CPU results

cpu

test_batch_easyocr.py program to generate the outputs above

from __future__ import print_function

import easyocr
import numpy as np
import time
import cv2
import sys
import os

if sys.version_info[0] == 2:
    from six.moves.urllib.request import urlretrieve
else:
    from urllib.request import urlretrieve


def test_single_and_batched_text_detection_and_prediction():
    reader = easyocr.Reader(['en'])
    # test with easy logos to ensure same results
    # test for single image with old api
    result = reader.readtext(
        "https://pytorch.org/tutorials/_static/img/thumbnails/cropped/profiler.png")
    assert len(result), 1
    assert result[0][1], 'PyTorch'
    print(result)
    print("Single image test with readtext successful")

    # test for single image with new api
    result = reader.readtext_batched(
        "https://pytorch.org/tutorials/_static/img/thumbnails/cropped/profiler.png")
    assert len(result), 1
    assert result[0][0][1], 'PyTorch'
    print(result)
    print("Single image test with readtext_batched successful")

    # test for a list of images in batch
    img_path = [
        "https://pytorch.org/tutorials/_static/img/thumbnails/cropped/profiler.png",
        "https://www.tensorflow.org/images/tf_logo_social.png",
        "https://storage.googleapis.com/gd-wagtail-prod-assets/original_images/evolving_google_identity_2x1.jpg"]

    """
    all images in image list must be of the same size for batched inference
        for eg, result = reader.readtext_batched(img_path) will fail here
        so either resize all images to the same size before passing to readtext_batched
        or call the func like so reader.readtext_batched(img_path, n_width=800, n_height=600)
    """
    # warning, for better results, it is recommended to maintain aspect while resizing
    result = reader.readtext_batched(img_path, n_width=800, n_height=600)
    assert len(result), 3
    assert result[0][0][1], 'PyTorch'
    assert result[1][0][1], 'TensorFlow'
    assert result[2][0][1], 'Google'
    print(result)
    print("Batched image test with readtext_batched successful")

    ############################################################################
    # inference time test between sequential and batch processing
    # batch processing will be faster when using GPU
    ############################################################################
    # pre-download, load and resize images for inference time test
    img_path = [
        "https://pytorch.org/tutorials/_static/img/thumbnails/cropped/profiler.png",
        "https://www.tensorflow.org/images/tf_logo_social.png",
        "https://storage.googleapis.com/gd-wagtail-prod-assets/original_images/evolving_google_identity_2x1.jpg"]

    cv2_images = []
    for i, path in enumerate(img_path):
        tmp, _ = urlretrieve(path)
        cv2_img = cv2.resize(cv2.imread(tmp), (800, 600))
        cv2_images.append(cv2_img)
        os.remove(tmp)

    img_repeat, num_loop = 5, 1
    cv2_images = np.array(cv2_images)
    # np repeat to get a batch of 15 images, getting arr 15,600,800,3
    cv2_images_repeat1 = np.repeat(cv2_images, repeats=img_repeat, axis=0)
    cv2_images_repeat2 = cv2_images_repeat1.copy()
    print(
        f"Running inference speed test with an image array of shape {cv2_images_repeat1.shape} for {num_loop} iterations")

    # sequential processing
    # run batch processing test
    reader = easyocr.Reader(['en'])
    itime = time.time()
    for i in range(num_loop):
        for img in cv2_images_repeat1:
            reader.readtext(img)
    print(
        "Single/Sequential image inference time per image: " +
        f"{(time.time()-itime)/(num_loop*cv2_images_repeat1.shape[0]):.3f}s")
    # batched processing
    reader = easyocr.Reader(['en'], cudnn_benchmark=True)

    # warmup for batched inference on GPU, using same batch size for all subsequent inference
    # cudnn benchmark should be set to True
    # see this issue https://discuss.pytorch.org/t/model-inference-very-slow-when-batch-size-changes-for-the-first-time/44911
    dummy = np.zeros([len(img_path) * img_repeat, 600, 800, 3], dtype=np.uint8)
    reader.readtext_batched(dummy)

    # run batch processing test
    itime = time.time()
    for i in range(num_loop):
        reader.readtext_batched(cv2_images_repeat2)
    print(
        "Batched image inference time per image: " +
        f"{(time.time()-itime)/(num_loop*cv2_images_repeat1.shape[0]):.3f}s")


test_single_and_batched_text_detection_and_prediction()

@SamSamhuns
Copy link
Contributor Author

@rkcosmos, let me know if you need to have some tests or pipeline for verification as well. If you think this PR is good, we can later discuss a PR for improving the general code formatting with PEP guidelines as well. thanks

@SaddamBInSyed
Copy link

@SamSamhuns
thanks for your PR.
can u advise about GPU model details which you used ?

thank you

@SamSamhuns
Copy link
Contributor Author

@SamSamhuns
thanks for your PR.
can u advise about GPU model details which you used ?

thank you

Tesla V100 DGX

@rkcosmos rkcosmos merged commit 89ec92f into JaidedAI:master Jun 24, 2021
@myxzlpltk
Copy link

why this method load all tensor into gpu memory? I got memory leaks about 26 gb while batch_size didn't work

@ash2703
Copy link

ash2703 commented Jun 23, 2023

When benchmarking on GPU
Did you clear cache after running sequential inference?
GPU inference tends to be much faster after warmup!

thuc-moreh pushed a commit to moreh-dev/EasyOCR that referenced this pull request Jul 5, 2023
Add batch img inference support for ocr det with readtext_batched
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants