In [123]:
import cv2
import matplotlib.pyplot as plt
import os
import paddle as paddle
import paddle.fluid as fluid
import PIL.Image as Image
import random
from htgCrap import htgCrap
import numpy as np
import time




#预测模型
class inferModel:
    def __init__(self,model_name):
        self.math_map = ['0','1','2','3','4','5','6','7','8','9','(',')','+','-','*','.','sqrt','[',']','E','T']
        model_net_path = 'models/infer_model/'+str(model_name)+'/'
        place = fluid.CPUPlace()
        exe = fluid.Executor(place)
        [infer_program, feeded_var_names, target_var] = fluid.io.load_inference_model(dirname=model_net_path, executor=exe)
        self.args = [exe,infer_program,feeded_var_names,target_var]
    
    def infer_char(self,image):
        exe, infer_program , feeded_var_names ,target_var = self.args
        data_shape = np.array(image).shape[0]
        image = np.array(image).reshape(-1,1,data_shape,data_shape).astype('float32')
        image = (image) /255.0
        result = exe.run(program = infer_program,
                    feed = {feeded_var_names[0]:image},
                     fetch_list = target_var
                    )
        result_char = self.math_map[np.argmax(result)]
        return result_char


#连通域图形处理
class ConnDomain:
    def __init__(self):
        pass
        
    def cropCharContours(self,formula):#提取图片中的轮廓，解决父子边框嵌入问题，父子合为一体
        try:
            contours,hierarchy= cv2.findContours(formula, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
        except Exception:
            _,contours,hierarchy= cv2.findContours(formula, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
        x_lenth = formula.shape[1]
        y_lenth = formula.shape[0]
        temp_image = np.zeros((y_lenth,x_lenth))
        #父轮廓与子轮廓结合
        temp_contours = [[] * 1 for i in range(len(hierarchy[0]))]
        total_len = np.sum(np.array([len(contour) for contour in contours]))
        total_avg = total_len/len(hierarchy[0])
        total_avg /= 8
        for index in range(len(hierarchy[0])):
            father = hierarchy[0][index][3]
            if len(contours[index]) < total_avg:#过滤过小的噪点
                hierarchy[0][index][3] = 0
            for cont1 in contours[index]:#对Contours数据组进行整合处理
                temp_contours[index].append([cont1[0][0],cont1[0][1]])
            if father!=-1:#将子轮廓并入父轮廓数组中
                for cont1 in contours[index]:
                    temp_contours[father].append([cont1[0][0],cont1[0][1]])
                temp_contours[index] = []
        deal_contours = []
        for index in range(len(hierarchy[0])):#将所有的父轮廓提取出来
            father = hierarchy[0][index][3]
            if father == -1:
                deal_contours.append(temp_contours[index])
        return deal_contours
    
    def getCharLocation(self,formulas):#根据Char的位置进行组排序
        charLocationList = []
        for formula in formulas:
            formula = cv2.bitwise_not(formula)
            mg = cv2.blur(formula,(2,2))#羽化
            ret, formula = cv2.threshold(mg, 10, 255,cv2.THRESH_BINARY)
            x_lenth = formula.shape[1]
            y_lenth = formula.shape[0]
            contours = self.cropCharContours(formula)
            charLocation = []
            for contour in contours:
                x_min = x_lenth
                y_min = y_lenth
                x_max = 0
                y_max = 0
                for x,y in contour:
                    if x < x_min:
                        x_min = x
                    if y < y_min:
                        y_min = y
                    if x > x_max:
                        x_max = x
                    if y > y_max:
                        y_max = y
                after_contour = np.array(contour)-[x_min,y_min]
                charLocation.append([[x_min,y_min,x_max-x_min,y_max-y_min],after_contour])
            def pxCulate(com):
                x_min = com[0][0]
                y_min = com[0][1]
                return x_min
            charLocation.sort(key=pxCulate)
            charLocationList.append(charLocation)
        return charLocationList
    
    def cropChar(self,formulas):#提取字符的起始xy轴坐标以及宽高以及对字符图片进行微处理（填充边框）
        formulas = self.getCharLocation(formulas)
        result_image_list = []
        for locas in formulas:
            result_image = []
            for loca in locas:
                after_contour = loca[1]
                x_min,y_min,x_lenght,y_lenght = loca[0]
                lable_image = np.array(Image.new('L',(x_lenght+10,y_lenght+10)))#生成一个空图
                for index in after_contour:#将边框嵌入空图，并且缩放左右5px，上下5px
                    x = index[0]+5
                    y = index[1]+5
                    lable_image[y][x]=255
                mask = np.zeros([lable_image.shape[0]+2, lable_image.shape[1]+2], np.uint8)
                copy_image = lable_image
                cv2.floodFill(copy_image, mask,(0,0), (99 , 0, 0), cv2.FLOODFILL_MASK_ONLY)#将背景进行灰化处理
                def find_point(read_image):#找到字符内部的像素作为洪水填充的种子结点
                    y = int(read_image.shape[0]/2)
                    light_flage = False
                    frist_falge = False
                    black_flage = False
                    find_x = -1
                    find_y = -1
                    for i in range(int(read_image.shape[1])):
                        if light_flage == False and read_image[y][i] == 255:
                            light_flage = True
                        elif light_flage == True and frist_falge == False and read_image[y][i] == 0:
                            black_flage = True
                            frist_falge = True
                            find_x = i
                            find_y = y
                        elif light_flage == True and black_flage == True and read_image[y][i] == 255:
                            return find_x,find_y
                    return 0,0
                cv2.floodFill(copy_image, mask,find_point(copy_image), (255 , 255, 255), cv2.FLOODFILL_MASK_ONLY)#洪水填充
                _,copy_image = cv2.threshold(copy_image,127,255,cv2.THRESH_BINARY)#临时背景替换成二值化图
                h,w = copy_image.shape
                #以下为将字符缩放成64*64画幅
                data_shape = 50
                printBack = Image.new('L',(64,64))
                if w>h:
                    h = int(data_shape*(h/w))
                    w = data_shape
                    copy_image = Image.fromarray(copy_image).resize((w,h))
                else:
                    w = int(data_shape*(w/h))
                    h = data_shape
                    copy_image = Image.fromarray(copy_image).resize((w,h))
                printBack.paste(copy_image,(int((64-w)/2),int((64-h)/2)))
                #缩放完之后将字符进行加粗
                mg = cv2.blur(np.array(printBack),(3,3))#羽化
                ret, printBack = cv2.threshold(mg, 10, 255,cv2.THRESH_BINARY)
                result_image.append([loca[0],printBack])
            result_image_list.append(result_image)
        return result_image_list
    
    def getChar(self,image):
        htgCrapAPI = htgCrap(image)
        formulas = htgCrapAPI.crapFormula()#提取图片中的式子
        char_list = self.cropChar(formulas)
        return char_list


#容错排序算法
class charFix:
    def __init__(self,char_list,infer_list):
        self.char_list = char_list
        self.infer_list = infer_list
    
    def char_list_sort(self):#根据坐标以及预测的字符进行排序
        char_formula_index = -1
        sort_formula = []
        for char_infor in self.char_list:
            char_formula_index += 1
            infer_char = self.infer_list[char_formula_index]
            char_index = 0
            father_char_position = [0,0,0,0]
            sort_char_list = []
            char_index = 0
            father_char_index = -1
            while(char_index<len(char_infor)):
                char = char_infor[char_index][0]
                plus_number = 1
                father_x_start = father_char_position[0]
                father_x_end = father_char_position[0]+father_char_position[2]
                char_x_start = char[0]
                char_x_end = char[0]+char[2]
                father_size = father_char_position[2]*father_char_position[3]
                char_size = char[2]*char[3]
                father_y_start = father_char_position[1]
                father_y_end = father_char_position[1]+father_char_position[3]
                char_y_start = char[1]
                char_y_end = char[1]+char[3]
                is_in = not((char_x_start > father_x_start) and (char_x_end < father_x_end))
                is_up = not(char_size < father_size and char_y_end < father_y_start+father_char_position[3]/3)
                is_down = not(char_size < father_size/6 and char_y_start > father_y_start+father_char_position[3]-father_char_position[3]/3)
                is_include = (char_y_start < father_y_start and char_y_end > father_y_end)
                def text_del(fChar):
                    plus_number=1
                    #sin值
                    if fChar == 's':
                        if char_index+3<len(char_infor) and infer_char[char_index+3]=='n' and (infer_char[char_index+1]=='1' or infer_char[char_index+2]=='1'):
                            fChar = 'sin'
                            plus_number=4
                        elif char_index+2<len(char_infor) and infer_char[char_index+2]=='n' and infer_char[char_index+1]=='1':
                            fChar = 'sin'
                            plus_number=3
                        else:
                            fChar = '5'
                    #cos值
                    if fChar == '(':
                        if char_index+2<len(char_infor) and infer_char[char_index+1]=='0' and (infer_char[char_index+2]=='s' or infer_char[char_index+2]=='5'):
                            fChar = 'cos'
                            plus_number=3
                    #tan值
                    if fChar == 't':
                        if char_index+2<len(char_infor) and infer_char[char_index+1]=='a' and infer_char[char_index+2]=='n':
                            fChar = 'tan'
                            plus_number=3
                    return fChar,plus_number
                if (is_in and is_up and is_down) or is_include:
                    fChar = infer_char[char_index]
                    fChar,plus_number = text_del(fChar)
                    #######
                    sort_char_list.append([fChar,[],[],[]])
                    father_char_position = char
                    father_char_index += 1
                else:
                    fChar = infer_char[char_index]
                    if fChar == 's':
                        fChar = '5'
                    if char_y_start<father_y_start+father_char_position[3]*0.1:#在上面
                        sort_char_list[father_char_index][1].append(fChar)
                    elif char_y_start>father_y_start+father_char_position[3]*0.5:
                        sort_char_list[father_char_index][3].append(fChar)
                    else:
                        sort_char_list[father_char_index][2].append(fChar)
                char_index+=plus_number
            sort_formula.append(sort_char_list)
        return sort_formula
    
    def formula_infer(self):
        sort_formula = self.char_list_sort()
        text_formula = []
        for sort_char_list in sort_formula:
            char_index = 0
            text = ""
            for father in sort_char_list:
                math_map = ['0','1','2','3','4','5','6','7','8','9','(',')','+','-','*','/','=','a','t','s','n','sqrt','.','PI','y']
                this_char = father[0]
                plus_number = 1
                index = -1
                if (this_char in math_map):
                    index = math_map.index(this_char)
                #123456789（）
                if index < 12:
                    text+=this_char
                    if len(father[1])>0:
                        text+='^'
                        for i in father[1]:
                            text+=i
                    if len(father[3])>0 and (father[3][0]=='.' or father[3][0]=='0' or father[3][0]=='8' or father[3][0]=='4' or father[3][0]=='6' or father[3][0]=='9'):
                        text+='.'
                #减号
                elif this_char == '-':
                    if len(father[1])!=0 and len(father[3])!=0:
                        if len(father[1])==1 and len(father[3])==1:
                            dianchar = ['.','1','0']
                            if (father[1][0] in dianchar) and (father[1][0] in dianchar):
                                text+='/'
                                continue
                        fenshu = ''
                        up = ''
                        down = ''
                        for i in father[1]:
                            up+=i
                        for i in father[3]:
                            down+=i
                        fenshu='(('+up+')/('+down+'))'
                        text+=fenshu
                    elif len(father[1])==0 and len(father[3])==1 and father[3][0]=='-':
                        text+='='
                    elif len(father[1])==1 and len(father[3])==0 and father[1][0]=='-':
                        text+='='
                    else:
                        text+='-'
                elif this_char == 'sqrt':
                    up=''
                    if len(father[1])==0:
                        text+='sqrt2'
                    else:
                        text+='sqrt'+str(father[1][0])
                    for i in father[2]:
                        up+=str(i)
                    text+='('+up+')'
                else:
                    text+=this_char
                char_index+=plus_number
            text_formula.append(text)
        return text_formula
    
    #四则汇总API
class four_operations:
    def __init__(self,model_name):
        #初始化工具类
        self.infer_model = inferModel(model_name)
        self.dnn = ConnDomain()
    
    def infer_image(self,image):
        image = np.array(Image.open(image))
        char_image_list = self.dnn.getChar(image)
        formulas_image_list = []
        for formula in char_image_list:
            formula_image_list = []
            for char in formula:
                result = self.infer_model.infer_char(char[1])
                formula_image_list.append(result)
            formulas_image_list.append(formula_image_list)
        charProgram = charFix(char_image_list,formulas_image_list)
        result = charProgram.formula_infer()
        return result

In [124]:
test = four_operations(3)

In [126]:
test.infer_image('test_image/54.jpg')

['(3426+)428.3432)']