In [3]:
import os

import cv2
import random
import numpy as np
import scipy.io as scio
from PIL import Image
from matplotlib import pyplot as plt

import torch
import torchvision.transforms as transforms

In [None]:
class SynthText(Dataset):
    def __init__(self, target_size, data_path=None):
        assert data_path is not None

        self.target_size = target_size
        self.data_path = data_path
        # type of transforms is list of functions

        gt = scio.loadmat(os.path.join(self.data_path, 'gt.mat'))
        # type: dict, keys: '__header__', '__version__', '__globals__', 'charBB', 'wordBB', 'imnames', 'txt'
        # Their types are all < numpy.ndarray > and sizes are all (1, 858750) before slicing
        self.name = gt['imnames'][0]
        self.char = gt['charBB'][0]
        self.word = gt['wordBB'][0]
        # After slicing, their sizes are all 858750

    def __len__(self):
        return len(self.name)

    def get_image(self, path):
        image = cv2.imread(path, cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = np.array(image)
        return image

    # transpose bboxes for pre-processing
    def transpose_bboxes(self, bboxes):
        src = bboxes.copy()
        if len(np.shape(src)) == 2:
            src = src[:, :, np.newaxis]
        res = src.transpose((2, 1, 0))
        return res

    ###### Cropping one word in each image (Augmentation) ######
    def crop_one_word(self, img, word_bboxes, char_bboxes):
        src = word_bboxes.copy()
        len = src.size(0)
        ind = random.randint(0,len)

        img_h, img_w = img.shape[0], img.shape[1]

        x_1,y_1 = src(ind,0,:)
        
        h = src(ind,3,1) - src(ind,0,1)
        w = src(ind,1,0) - src(ind,0,0)

        crop_image = img[x_1: x_1 + w,
                      y_1: y_1 + h, :]
        resize_image = cv2.resize(crop_image, dsize=(
             img_w/w, img_h/h), interpolation=cv2.INTER_NEAREST)
        
        word_gt = torch.tensor([[[0,0],[w-1,0],[w-1,h-1],[0,h-1]],np.newaxis])
        
        char_bboxes(ind,:,0) = (char_bboxes(ind,:,0) * w /img_w) - w
        char_bboxes(ind,:,1) = (char_bboxes(ind,:,1) * h /img_h) - h

        return resize_image, word_gt, char_bboxes

    def clipping_bboxes(self, bbox, scale):
        # bbox: K x 2
        # ICDAR2013: K = 2, ICDAR2015: K = 4 , etc.
        src = bbox
        # when bbox: (2, K)
        if np.shape(src)[1] != 2:
            src = bbox.transpose((1, 0))

        x = [point[0] for point in src]
        y = [point[1] for point in src]

        center_x, center_y = np.mean(np.array(x)), np.mean(np.array(y))
        # shapely module
        src = Polygon(src)
        center = geometry.Point(center_x, center_y)
        distance_from_center_to_poly = src.exterior.distance(center)
        shrink_distance = distance_from_center_to_poly*scale
        # clip adaptively to the input
        temp = Polygon(src.buffer(-shrink_distance))
        temp = list(temp.exterior.coords)

        res = []
        for xy in temp:
            if xy not in res:
                res.append(xy)
        res = np.array(res).astype(np.int32)
        return res

    def char_mask(self, img_size, bboxes):
        # bboxes : N x 4 x 2
        mask = np.zeros(img_size, dtype=np.uint8)
        temp = bboxes.copy()

        shrink = []
        for i in range(np.shape(temp)[0]):
            point = temp[i, :, :]
            point = self.clipping_bboxes(point, scale=0.25)  # default = 0.2

            if len(point) > 0:
                shrink.append(point.astype(np.int32))
        # shrink: list of reduced char boxes
        if len(shrink) >= 1:
            res = cv2.fillPoly(mask, shrink, 1)

        else:
            res = mask

        res = res.astype(np.float32)
        return res

    def word_mask(self, img_size, bboxes):
        # bboxes : N x 4 x 2

        # default:
        scales = [0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
        # scales = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
        temp = bboxes.copy()
        res = np.zeros(img_size, dtype=np.float32)
        for scale in scales:
            shrink = []
            mask = np.zeros(img_size, dtype=np.uint8)

            for i in range(temp.shape[0]):
                # point: a box coordinate 4 x 2
                point = temp[i, :, :]
                point_copy = self.clipping_bboxes(point, scale=scale)
                # 비어 있는 point 제거
                if len(point_copy) > 0:
                    shrink.append(point_copy.astype(np.int32))
                    # shrink: list of reduced char boxes
            res_slice = cv2.fillPoly(mask, shrink, 1)
            res += res_slice.astype(np.float32).copy()

        res /= len(scales)
        return res

    def __getitem__(self, index):
        image_path = os.path.join(self.data_path, self.name[index][0])
        image = self.get_image(image_path)
        # image: uint8

        chars = self.char[index]
        words = self.word[index]

        h, w = image.shape[0], image.shape[1]

        chars = self.transpose_bboxes(chars)
        words = self.transpose_bboxes(words)

        "추가"
        #####
        image, words, chars = crop_one_word(image, words, chars)
        #####
        
        image, words, chars = RandomScalewithCoords()(image, words, chars)
        h_re, w_re = image.shape[0], image.shape[1]

        words = self.word_mask(img_size=(h_re, w_re), bboxes=words)
        chars = self.char_mask(img_size=(h_re, w_re), bboxes=chars)

        GT = {}
        GT['image'] = image

        # augmentation
        GT['chars'] = chars
        GT['words'] = words
        GT = RandomCrop(crop_size=self.target_size)(GT)
        GT = RandomHFlip(threshold=0.5)(GT)
        GT = RandomRotate(max=10)(GT)
        GT['image'] = color_jitter(GT['image'])

        if len(GT['chars'].shape) == 2:
            GT['chars'] = GT['chars'][np.newaxis, :, :]

        if len(GT['words'].shape) == 2:
            GT['words'] = GT['words'][np.newaxis, :, :]

        GT['chars'] = torch.FloatTensor(GT['chars'])
        GT['words'] = torch.FloatTensor(GT['words'])

        GT['image'] = GT['image'].astype(np.float32)/255.
        # Only mean/std normalization is applied on image
        GT['image'] = normalizeMeanVariance(GT['image'])
        GT['image'] = torch.FloatTensor(GT['image'].transpose((2, 0, 1)))
        return GT