In [None]:
import mmcls
from mmcls.apis import inference_model, init_model, show_result_pyplot

import os
import torch 
import io

import ipywidgets as widgets
from PIL import Image

import mmcv
import matplotlib.pyplot as plt

from mmdet.apis import init_detector

from mmocr.apis.inference import model_inference
from mmocr.core.visualize import det_recog_show_result
from mmocr.datasets.pipelines.crop import crop_img

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')
    

In [None]:
dtcfg = 'configs/configs_ocr/textdet/drrg/drrg_r50_fpn_unet_1200e_ctw1500.py'
rccfg = 'configs/configs_ocr/textrecog/robust_scanner/robustscanner_r31_academic.py'
dtckpt = 'https://download.openmmlab.com/mmocr/textdet/drrg/drrg_r50_fpn_unet_1200e_ctw1500-1abf4f67.pth'
rcckpt = 'https://download.openmmlab.com/mmocr/textrecog/robustscanner/robustscanner_r31_academic-5f05874f.pth'
out = 'ocr_out.jpg'
checkpoint = 'pretrains/latest.pth'
img = 'uploaded.png'

In [None]:
def det_and_recog_inference(img, det_model, recog_model):
    image_path = img
    end2end_res = {'filename': image_path}
    end2end_res['result'] = []

    image = mmcv.imread(image_path)
    det_result = model_inference(det_model, image)
    bboxes = det_result['boundary_result']

    box_imgs = []
    for bbox in bboxes:
        box_res = {}
        box_res['box'] = [round(x) for x in bbox[:-1]]
        box_res['box_score'] = float(bbox[-1])
        box = bbox[:8]
        if len(bbox) > 9:
            min_x = min(bbox[0:-1:2])
            min_y = min(bbox[1:-1:2])
            max_x = max(bbox[0:-1:2])
            max_y = max(bbox[1:-1:2])
            box = [min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y]
        box_img = crop_img(image, box)
        recog_result = model_inference(recog_model, box_img)
        text = recog_result['text']
        text_score = recog_result['score']
        if isinstance(text_score, list):
            text_score = sum(text_score) / max(1, len(text))
        box_res['text'] = text
        box_res['text_score'] = text_score

        end2end_res['result'].append(box_res)

    return end2end_res

#!python demo/ocr_image_demo.py /content/Invoice.png demo/output.jpg

def detectLogoText( img, out_file, 
                 det_config ='./configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2015.py',
                 det_ckpt = 'https://download.openmmlab.com/mmocr/textdet/psenet/psenet_r50_fpnf_600e_icdar2015_pretrain-eefd8fe6.pth',
                 recog_config = './configs/textrecog/sar/sar_r31_parallel_decoder_academic.py',
                 recog_ckpt = 'https://download.openmmlab.com/mmocr/textrecog/sar/sar_r31_parallel_decoder_academic-dba3a4a3.pth',
                 device = 'cuda:0'
                ):
    # build detect model
    detect_model = init_detector(det_config, det_ckpt, device=device)
    if hasattr(detect_model, 'module'):
        detect_model = detect_model.module
    if detect_model.cfg.data.test['type'] == 'ConcatDataset':
        detect_model.cfg.data.test.pipeline = \
            detect_model.cfg.data.test['datasets'][0].pipeline

    # build recog model
    recog_model = init_detector(recog_config, recog_ckpt, device=device)
    if hasattr(recog_model, 'module'):
        recog_model = recog_model.module
    if recog_model.cfg.data.test['type'] == 'ConcatDataset':
        recog_model.cfg.data.test.pipeline = \
            recog_model.cfg.data.test['datasets'][0].pipeline

    det_recog_result = det_and_recog_inference(img, detect_model, recog_model)
    mmcv.dump(
        det_recog_result,
        out_file + '.json',
        ensure_ascii=False,
        indent=4)

    img = det_recog_show_result(img, det_recog_result)
    mmcv.imwrite(img, out_file)
    plt.figure(figsize = (256,256))
    plt.imshow(img)

In [None]:
def ocr_and_classify(checkpoint_file = 'pretrains/latest.pth', img='examples/exp.jpg'):
    config_file = 'configs/custom_config/myconfig.py'
    device = 'cuda:0'
    model = init_model(config_file, checkpoint_file, device=device)
    
    result = inference_model(model=model, img=img)
    print(result)
    show_result_pyplot(model, img, result)
    
    detectLogoText(img, out, det_config=dtcfg, recog_config=rccfg, det_ckpt=dtckpt, recog_ckpt=rcckpt)

In [None]:
uploader = widgets.FileUpload( multiple=False )
display(uploader)

In [None]:
for name, file_info in uploader.value.items():
    image = Image.open(io.BytesIO(file_info['content']))
    image.save(img)

In [None]:
ocr_and_classify(checkpoint_file = checkpoint, img = img)