In [None]:
import cv2
import numpy as np
import pandas as pd
import pytesseract # need to install tesseract and set system PATH first
from pytesseract import Output
from PIL import Image
import text_.classifier
import chart

INF = float('inf')

src = cv2.imread("example.png")

In [None]:
gray = cv2.cvtColor(src, cv2.COLOR_BGR2GRAY)
ret, binary_ = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
binary_ = 255 - binary_
# cv2.imshow("binary", binary_)
# cv2.waitKey(0)
num_labels, labels, stats, centers = cv2.connectedComponentsWithStats(binary_, connectivity=4, ltype=cv2.CV_32S)

In [None]:
## Filtering using aspect ratio and area
centers_x = stats[:, 0] + stats[:, 2] / 2
centers_y = stats[:, 1] + stats[:, 3] / 2
centers = np.array(list(zip(centers_x, centers_y)))
filtered_stats = stats[stats[:, -1] < 1000]
filtered_stats = filtered_stats[filtered_stats[:, -1] > 4]
filtered_centroids = centers[stats[:, -1] < 1000]
filtered_centroids = filtered_centroids[filtered_stats[:, -1] > 4]
aspect_ratio = filtered_stats[:, 2] / filtered_stats[:, 3]
filtered_stats = filtered_stats[aspect_ratio < 15]
filtered_stats = filtered_stats[aspect_ratio > (1 / 15)]
filtered_centroids = filtered_centroids[aspect_ratio < 15]
filtered_centroids = filtered_centroids[aspect_ratio > (1 / 15)]
print("BEFORE filtering: " + str(stats.shape), "AFTER filtering: " + str(filtered_stats.shape))

widths = filtered_stats[:, -2]
heights = filtered_stats[:, -3]
length_set = np.concatenate([widths, heights])
mode = np.argmax(np.bincount(length_set))
print("MODE: " + str(mode))

In [None]:
colors = []
for i in range(filtered_stats.shape[0]):
    b = np.random.randint(0, 256)
    g = np.random.randint(0, 256)
    r = np.random.randint(0, 256)
    colors.append((b, g, r))

colors[0] = (0, 0, 0)
image = np.copy(src)
for t in range(1, filtered_stats.shape[0], 1):
    x, y, w, h, area = filtered_stats[t]
    cx, cy = filtered_centroids[t]
    cv2.rectangle(image, (x, y), (x+w, y+h), colors[t], 1, 8, 0)
    cv2.putText(image, str(t), (x, y), cv2.FONT_HERSHEY_SIMPLEX, .35, (0, 0, 255), 1);

cv2.imshow("colored labels", image)
cv2.waitKey(0)
cv2.destroyAllWindows()
cv2.imwrite("colored labels.png", image)

In [None]:
def calculate_distance(center1, center2):
    return np.sqrt((center1[0] - center2[0]) ** 2 + (center1[1] - center2[1]) ** 2)

def calculate_closest_edge_distance(stats, id1, id2):
    l1, l2 = stats[id1, 0], stats[id2, 0]
    t1, t2 = stats[id1, 1], stats[id2, 1]
    w1, w2 = stats[id1, 2], stats[id2, 2]
    h1, h2 = stats[id1, 3], stats[id2, 3]
    r1, r2 = l1 + w1, l2 + w2
    b1, b2 = t1 + h1, t2 + h2
    alignment = check_alignment(stats, id1, id2)
    if alignment == 'vertical':
        if(l1 >= r2):
            return l1 - r2
        elif(l2 >= r1):
            return l2 - r1
        else:
            return min(l1 - r2, l2 - r1)
    elif alignment == 'horizontal':
        if(t1 >= b2):
            return t1 - b2
        elif(t2 >= b1):
            return t2 - b1
        else:
            return min(t1 - b2, t2 - b1)
    else:
        return INF

def check_alignment(stats, id1, id2):
    '''
    INPUTS: 
        stats: (left, top, width, height, area);
        id1: id of the first bounding box
        id2: id of the second bounding box
    OUTPUTS:
        mode: 'vertical' / 'horizontal' / 'neither'
        'vertical' means these two bounding boxes have overlapping height.
        'horizontal' means these two bounding boxes have overlapping width.
    '''
    l1, l2 = stats[id1, 0], stats[id2, 0]
    t1, t2 = stats[id1, 1], stats[id2, 1]
    w1, w2 = stats[id1, 2], stats[id2, 2]
    h1, h2 = stats[id1, 3], stats[id2, 3]
    r1, r2 = l1 + w1, l2 + w2
    b1, b2 = t1 + h1, t2 + h2
    if((b1 <= t2) or (b2 <= t1)):
        vertical_overlap = 0
    else:
        vertical_overlap = max(0, b1 - t2, b2 - t1)
    if((r1 <= l2) or (r2 <= l1)):
        horizontal_overlap = 0
    else:
        horizontal_overlap = max(0, r1 - l2, r2 - l1)
    # print(vertical_overlap, horizontal_overlap)
    if(horizontal_overlap > min(w1, w2) / 2):
        return 'horizontal'
    elif(vertical_overlap > min(h1, h2) / 2):
        return 'vertical'
    else:
        return 'neither'

def graph_initialization(stats, centers, mode):
    graph = [[] for _ in range(len(centers))]
    for i in range(0, len(centers)):
        for j in range(0, len(centers)):
            if(i == j):
                graph[i].append(0)
            else:
                distance = calculate_distance(centers[i], centers[j])
                graph[i].append(distance if distance < 2 * mode else INF)
    return graph

def edges_initialization(stats, centers, mode):
    edges = []
    for i in range(0, len(centers)):
        for j in range(0, len(centers)):
            if(i == j):
                continue
            else:
                distance = calculate_closest_edge_distance(stats, i, j)
                # distance = calculate_distance(centers[i], centers[j])
                if(distance < mode / 3):
                    if(check_alignment(stats, i, j) == 'neither'):
                        continue
                    else:
                        if((j, i, distance) not in edges):
                            edges.append((i, j, distance))
                else:
                    continue
    return edges

In [None]:
print(filtered_stats[9])
print(filtered_stats[19])
check_alignment(filtered_stats, 9, 19)

In [None]:
## Initialize adjacency lists
edges = edges_initialization(filtered_stats, filtered_centroids, mode)
print(len(edges))
# edges

In [None]:
## Kruskal algorithm
class Edge:
    def __init__(self, x, y, length):
        self.x = x
        self.y = y
        self.length = length

class UnionFindSet:
    def __init__(self, start, n):
        self.start = start  
        self.n = n
        self.pre = [0 for i in range(self.n - self.start + 2)]  
        self.rank = [0 for i in range(self.n - self.start + 2)] 

    def init(self):
        for i in range(self.start, self.n+1):
            self.pre[i] = i
            self.rank[i] = 1

    def find_pre(self, x):
        if self.pre[x] == x:
            return x
        else:
            self.pre[x] = self.find_pre(self.pre[x])
        return self.pre[x]

    def is_same(self, x, y):
        return self.find_pre(x) == self.find_pre(y)

    def unite(self, x, y):
        x = self.find_pre(x)
        y = self.find_pre(y)
        if x == y:
            return False
        if self.rank[x] > self.rank[y]:
            self.pre[y] = x
        else:
            if self.rank[x] == self.rank[y]:
                self.rank[y] += 1
            self.pre[x] = y
        return True

    def is_one(self):
        temp = self.find_pre(self.start)
        for i in range(self.start+1, self.n+1):
            if self.find_pre(i) != temp:
                return False
        return True


class Kruskal:
    def __init__(self, n, m, edges):
        self.n = n  
        self.m = m
        self.e = [] 
        self.s = [] 
        self.u = UnionFindSet(1, self.n) 

    def graphy(self):
        for i in range(self.m):
            x, y, length = list(map(int, edges[i]))
            self.e.append(Edge(x, y, length))
        self.e.sort(key=lambda e: e.length)
        self.u.init()

    def run(self):
        for i in range(self.m):
            if self.u.unite(self.e[i].x, self.e[i].y):
                self.s.append(self.e[i])
            if self.u.is_one():
                break

    def print(self):
        print(f'Edges: ')
        edge_sum = 0
        for i in range(len(self.s)):
            print(f'edge <{self.s[i].x},{self.s[i].y}> = {self.s[i].length}')
            edge_sum += self.s[i].length
        print(f'Weights: {edge_sum}')

In [None]:
n, m = list(map(int, (len(filtered_centroids), len(edges))))
kruskal = Kruskal(n, m, edges)
kruskal.graphy()
kruskal.run()
kruskal.print()

In [None]:
## Combine all connected components
trees = []
for edge in kruskal.s:
    fail = 0
    id1 = edge.x
    id2 = edge.y
    for i in range(len(trees)):
        if(id1 in trees[i]):
            trees[i].append(id2)
        elif(id2 in trees[i]):
            trees[i].append(id1)
        else:
            fail += 1
    if(fail == len(trees)):
        trees.append([id1, id2])

b = len(trees)
for i in range(b):
    for j in range(b):
        x = list(set(trees[i]+trees[j]))
        y = len(trees[j])+len(trees[i])
        if i == j or trees[i] == 0 or trees[j] == 0:
            break
        elif len(x) < y:
            trees[i] = x
            trees[j] = [0]
word_trees = [i for i in trees if i != [0]]
word_trees

In [None]:
## Draw integrated textboxes.
temp_stats = filtered_stats.copy()
for tree in word_trees:
    for i in range(0, len(tree) - 1):
        id1 = tree[i]
        id2 = tree[i + 1]
        l1, l2 = temp_stats[id1, 0], temp_stats[id2, 0]
        t1, t2 = temp_stats[id1, 1], temp_stats[id2, 1]
        w1, w2 = temp_stats[id1, 2], temp_stats[id2, 2]
        h1, h2 = temp_stats[id1, 3], temp_stats[id2, 3]
        r1, r2 = l1 + w1, l2 + w2
        b1, b2 = t1 + h1, t2 + h2
        new_top = min(t1, t2)
        new_bottom = max(b1, b2)
        new_height = new_bottom - new_top
        new_left = min(l1, l2)
        new_right = max(r1, r2)
        new_width = new_right - new_left
        new_area = new_width * new_height
        new_bounding_box = (new_left, new_top, new_width, new_height, new_area)
        temp_stats[id1] = [0, ] * 5
        temp_stats[id2] = new_bounding_box
indexes = np.where(~temp_stats.any(axis=1))
print("BEFORE filtering: ", temp_stats.shape)
print("Removed edges: ", indexes)
temp_stats = np.delete(temp_stats, indexes, axis=0)
print("AFTER filtering: ", temp_stats.shape)
colors = []
for i in range(num_labels):
    b = np.random.randint(0, 256)
    g = np.random.randint(0, 256)
    r = np.random.randint(0, 256)
    colors.append((b, g, r))

colors[0] = (0, 0, 0)
image = np.copy(src)
for t in range(0, len(temp_stats)):
    x, y, w, h, area = temp_stats[t]
    cv2.rectangle(image, (x, y), (x+w, y+h), colors[t], 1, 8, 0)
    cv2.putText(image, str(t), (x, y), cv2.FONT_HERSHEY_SIMPLEX, .35, (0, 0, 255), 1);
    # print("label index %d, area of the label : %d"%(t, area))

cv2.imshow("combined", image)
cv2.waitKey(0)
cv2.destroyAllWindows()
cv2.imwrite("bounded.png", image)

## Tesseract OCR

In [None]:
## Filter out all the textual type data, assign None to other non-detectable textboxes.
pytesseract.pytesseract.tesseract_cmd = r'E:\Tesseract-OCR\tesseract.exe'
normal_binary = 255 - binary_
text_list = []
conf_list = []
for i in range(temp_stats.shape[0]):
    best_conf = -INF
    best_word = None
    l, t, w, h = temp_stats[i, :4]
    text_box = normal_binary[max(0, min(normal_binary.shape[0], t-1)):t+h+1, max(0, l-1):min(normal_binary.shape[1], l+w+1)]
    # cv2.imshow("OpenCV",text_box)  
    # cv2.waitKey()  
    rotate_text_boxes = text_box, cv2.rotate(text_box, cv2.ROTATE_90_CLOCKWISE), cv2.rotate(text_box, cv2.ROTATE_180)
    for j in range(len(rotate_text_boxes)):
        rotate_text_box = cv2.resize(rotate_text_boxes[j],dsize=None,fx=3,fy=3,interpolation=cv2.INTER_LINEAR)
        textbox = Image.fromarray(cv2.cvtColor(rotate_text_box, cv2.COLOR_BGR2RGB))
        text = pytesseract.image_to_data(textbox, output_type=Output.DICT, config='--psm 6')
        while('' in text['text']):
            text['text'].remove('')
            text['conf'].remove(-1)
        if(len(text['text']) != 1):
            text['text'] = []
            text['conf'] = []
        if(len(text['text']) != 0):
            if(best_conf < float(text['conf'][0])):
                best_conf = float(text['conf'][0])
                best_word = text['text'][0]
    print(best_word)
    text_list.append((temp_stats[i, 0], temp_stats[i, 1], temp_stats[i, 2], temp_stats[i, 3], best_word))
    conf_list.append(best_conf)

In [None]:
## Filter numerical type data from the data that were marked as None.
config = r'-c tessedit_char_whitelist=0123456789 --psm 6'
for i in list(np.where(np.array(text_list) == None)[0]):
    best_conf = -INF
    best_word = None
    l, t, w, h = temp_stats[i, :4]
    text_box = normal_binary[max(0, min(normal_binary.shape[0], t-1)):t+h+1, max(0, l-1):min(normal_binary.shape[1], l+w+1)]
    # cv2.imshow("OpenCV",text_box)  
    # cv2.waitKey()  
    rotate_text_boxes = text_box, cv2.rotate(text_box, cv2.ROTATE_90_CLOCKWISE), cv2.rotate(text_box, cv2.ROTATE_90_COUNTERCLOCKWISE), cv2.rotate(text_box, cv2.ROTATE_180)
    for j in range(len(rotate_text_boxes)):
        rotate_text_box = cv2.resize(rotate_text_boxes[j],dsize=None,fx=3,fy=3,interpolation=cv2.INTER_LINEAR)
        textbox = Image.fromarray(cv2.cvtColor(rotate_text_box, cv2.COLOR_BGR2RGB))
        text = pytesseract.image_to_data(textbox, output_type=Output.DICT, config=config)
        while('' in text['text']):
            text['text'].remove('')
            text['conf'].remove(-1)
        if(len(text['text']) != 1):
            text['text'] = []
            text['conf'] = []
        if(len(text['text']) != 0):
            if(best_conf < float(text['conf'][0])):
                best_conf = float(text['conf'][0])
                best_word = text['text'][0]
    print(best_word)
    text_list[i] = (temp_stats[i, 0], temp_stats[i, 1], temp_stats[i, 2], temp_stats[i, 3], best_word)
    conf_list[i] = best_conf

In [None]:
## Final detected text list
text_list

In [None]:
## Final detected confidence level list
conf_list

In [None]:
text_arrays = np.array(text_list)
text_arrays = text_arrays[text_arrays[:, 4] != '@']
text_arrays

In [None]:
text_infos = pd.DataFrame(text_arrays, columns=['x', 'y', 'width', 'height', 'text'])
text_infos['id'] = text_infos.index
ids = text_infos.pop('id')
text_infos.insert(0,"id",ids)
text_infos

In [None]:
# text_infos.to_csv('./bounded-pred2-texts.csv')

In [None]:
chart_i = chart.Chart(fn='bounded.png', text_from=2)
chart_i

In [None]:
import text_.classifier
cnt = 0
text_clf = text_.classifier.TextClassifier('default')
text_type_preds = text_clf.classify(chart_i)

a = pd.read_csv('./bounded-pred2-texts.csv')
for i in range(len(text_type_preds)):
    print("Pred: ", text_type_preds[i], '; ', 'True: ', a['type'][i])
    if text_type_preds[i] != a['type'][i]:
        cnt += 1
        print("wrong answer")
print("Error rate: ", cnt / len(text_type_preds))