# Geting text in the wild in natural images

In [None]:
import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()

# import some common libraries
import numpy as np
import os, json, cv2, random
import matplotlib.pyplot as plt

# import some common detectron2 utilities
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog, DatasetCatalog
import coco_text
from PIL import Image

We are using the coco text annotations for the coco dataset.

In [None]:
ct = coco_text.COCO_Text('cocotext.v2.json')

## Preparing the dataset along with the respective image annotations of text

In [None]:
from detectron2.structures import BoxMode

def get_coco_dict(img_dir,d):
    dataset_dicts = []
    if d.find('train') !=-1:
        imgIds = ct.getImgIds(imgIds=ct.train, catIds=[('legibility','legible')])
    elif d.find('val') != -1:
        imgIds = ct.getImgIds(imgIds=ct.val, catIds=[('legibility','legible')])
    for ids in imgIds:
        record = {}
        img = ct.loadImgs(ids)[0]
        filename = os.path.join(img_dir, img["file_name"])        
        record["file_name"] = filename
        record["image_id"] = ids
        record["height"] = img['height']
        record["width"] = img['width']
      
        annoIds = ct.getAnnIds(imgIds = ids)
        annos = ct.loadAnns(annoIds)
        objs = []
        for anno in annos:
            obj = {
                "bbox": anno['bbox'],
                "bbox_mode": BoxMode.XYWH_ABS,
                "category_id": 0 if anno['language']=='english' else 1,
            }
            objs.append(obj)
        record["annotations"] = objs
        dataset_dicts.append(record)
    return dataset_dicts

for d in ["train2014", "val2014"]:
    DatasetCatalog.register("coco_text_" + d, lambda d=d: get_coco_dict("/coco/train2014/",d))
    MetadataCatalog.get("coco_text_" + d).set(thing_classes=["english","others",])
coco_text_metadata = MetadataCatalog.get("coco_text_train2014")

Using the Facebook detectron2 COCO Object detection model from the model zoo to retrain on text annotations

In [None]:
from detectron2.engine import DefaultTrainer

cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"))
cfg.DATASETS.TRAIN = ("coco_text_train2014",)
cfg.DATASETS.TEST = ("coco_text_val2014",)
cfg.DATALOADER.NUM_WORKERS = 8
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml")
cfg.SOLVER.IMS_PER_BATCH = 8
cfg.SOLVER.BASE_LR = 0.00025  
cfg.SOLVER.MAX_ITER = 30000
cfg.SOLVER.STEPS = []
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2

In [None]:
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = DefaultTrainer(cfg)
trainer.resume_or_load(resume=True)
trainer.train()

In [None]:
%load_ext tensorboard
%tensorboard --logdir output

In [None]:
cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")  # path to the model we just trained
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7   # set a custom testing threshold
predictor = DefaultPredictor(cfg)

In [None]:
from detectron2.utils.visualizer import ColorMode
import skimage.io as io
imgIds = ct.getImgIds(imgIds=ct.val, catIds=[('legibility','legible')])
plt.figure(figsize=(20,20))
for d in range(3):    
    img = ct.loadImgs(imgIds[np.random.randint(0,len(imgIds))])[0]
    im = cv2.imread('/coco/train2014/'+img['file_name'])
    outputs = predictor(im)  
    v = Visualizer(im[:, :, ::-1],
                   metadata=coco_text_metadata, 
                   scale=1,
                   instance_mode=ColorMode.IMAGE
    )
    out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
    plt.subplot(3,1,d+1)
    plt.imshow(out.get_image()[:, :, ::-1])
    plt.axis("off")
plt.show()

In [None]:
def crop_object(image, box):
    x_top_left = int(box[0])
    y_top_left = int(box[1])
    x_bottom_right = int(box[2])
    y_bottom_right = int(box[3])
    x_center = (x_top_left + x_bottom_right) / 2
    y_center = (y_top_left + y_bottom_right) / 2
    
    crop_img = image[y_top_left:y_bottom_right, x_top_left:x_bottom_right]
    return crop_img

imgIds = ct.getImgIds(imgIds=ct.val, catIds=[('legibility','legible')])
img = ct.loadImgs(imgIds[np.random.randint(0,len(imgIds))])[0]
image = cv2.imread('/coco/train2014/'+img['file_name'])
outputs = predictor(image)
boxes = outputs["instances"].to('cpu').pred_boxes
#box = list(boxes)[0].detach().cpu().numpy()
plt.imshow(image)
plt.axis("off")
plt.show()
for i,box in enumerate(boxes):
    crop_img = crop_object(image,box)
    plt.imshow(crop_img)
    plt.axis("off")
    plt.show()
    cv2.imwrite("/scratch/ac9025/test_images/"+img['file_name']+"_cropped_"+str(i)+".png", crop_img)