In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import time
class Timer(object):
    def __init__(self):
        self.total_time = 0.
        self.calls = 0
        self.start_time = 0.
        self.diff = 0.
        self.average_time = 0.

    def tic(self):
        self.start_time = time.time()

    def toc(self, average=True):
        self.diff = time.time() - self.start_time
        self.total_time += self.diff
        self.calls += 1
        self.average_time = self.total_time / self.calls
        if average:
            return self.average_time
        else:
            return self.diff
        
        
 

In [None]:
from keras.layers import Input,Conv2D,MaxPooling2D,ZeroPadding2D
from keras.layers import Flatten,BatchNormalization,Permute,TimeDistributed,Dense,Bidirectional,GRU
from keras.models import Model

import shutil
import numpy as np
from PIL import Image
import keras.backend  as K

from imp import reload 
import densenet
reload(densenet)

import os
from keras.layers import Lambda
from keras.optimizers import SGD

import tensorflow as tf  
import keras.backend.tensorflow_backend as K  
from matplotlib import pyplot as plt
import pandas as pd

import glob

class PredictCheck:
    """
    images_dir: 批量验证图片的路径
    image_label: 批量验证图片的标记文件
    dicfile: 字符字典
    unih: 图片高度。
    modelPath:模型路径。
    """
    def __init__(self, images_dir, image_label, dicfile, unih, modelPath, gpu_fraction=0):
        self.dicfile = dicfile
        self.image_label = image_label
        self.images_dir = images_dir
        self.modelPath = modelPath
        self.unih = unih
        self.t = Timer()
        with open(self.image_label, encoding='utf-8') as f:
            self.labelFile = [str(line).strip() for line in f.readlines()]
        # 得到一个labelList
        self.labelList = []
        for i in self.labelFile:
            fname = i.split(" ")[0]
            findex = len(fname)
            aclabel = i[findex:]
            aclabel = aclabel.strip()
            self.labelList.append([fname, aclabel])
        
        tmp = self.__get_session__(gpu_fraction)
        K.set_session(tmp)
        
        self.char = ""
        with open(self.dicfile, encoding='utf-8') as f:
            for ch in f.readlines():
                ch = ch.strip('\r\n')
                self.char = self.char + ch
        print('xxnclass:',len(self.char))
        print(self.char)
        self.char =self.char[1:]+'卍'
        self.nclass = len(self.char) + 1
        print('nclass:',self.nclass)
        self.id_to_char = {i:j for i,j in enumerate(self.char)}
        input = Input(shape=(self.unih,None,1),name='the_input')
        y_pred= densenet.dense_cnn(input,self.nclass)
        self.basemodel = Model(inputs=input,outputs=y_pred)
        self.basemodel.load_weights(self.modelPath)        
        
    def __get_session__(self, gpu_fraction=0.8):  
        '''''Assume that you have 6GB of GPU memory and want to allocate ~2GB'''  
        num_threads = os.environ.get('OMP_NUM_THREADS')  
        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_fraction)  

        if num_threads:  
            return tf.Session(config=tf.ConfigProto(gpu_options=gpu_options, intra_op_parallelism_threads=num_threads))  
        else:  
            return tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))  

    def __predict__(self, img_path):

        img = Image.open(img_path)
        im = img.convert('L')
        scale = im.size[1]*1.0 / self.unih
        w = im.size[0] / scale
        w = int(w)
        print('w:',w)

        im = im.resize((w,self.unih),Image.ANTIALIAS)
        img = np.array(im).astype(np.float32)/255.0-0.5
        X  = img.reshape((self.unih,w,1))
        X = np.array([X])

        self.t.tic()
        y_pred = self.basemodel.predict(X)
        self.t.toc()
        print("times,",self.t.diff)
        argmax = np.argmax(y_pred, axis=2)[0]

        y_pred = y_pred[:,:,:]
        out = K.get_value(K.ctc_decode(y_pred, input_length=np.ones(y_pred.shape[0])*y_pred.shape[1], )[0][0])[:, :]
        out = u''.join([self.id_to_char[x] for x in out[0]])
        return out,im
        
    def loopPredict(self):
        self.result = []
        for img in glob.glob(f"{self.images_dir}/*.jpeg"):
            fname = os.path.basename(img)
            b, ximg = self.__predict__(img)
            self.result.append([fname, b, img])
        ## 就在这里通过dataframe进行join操作。
        aaa = pd.DataFrame(np.array(self.labelList), columns=['fileName', 'labelContent'])
        bbb = pd.DataFrame(np.array(self.result), columns=['fileName', 'predictContent', 'filePath'])
        self.ccc = pd.merge(aaa, bbb, on='fileName')
        ## 二者做一个join操作。。
        
    def compare(self, same="same", diff="diff"):
        # 分割成两个文件夹，分别放置全对的和非全对的。。
        # 遍历预测的东西。看看是否在，label中，如果在的话，
        self.rightList = []
        self.wrongList = []
        if os.path.exists(same):
            shutil.rmtree(same)
        if os.path.exists(diff):
            shutil.rmtree(diff)
        os.mkdir(same)
        os.mkdir(diff)
        with open(os.path.join(same,same+".txt"), "w" , encoding='utf-8') as rf, open(os.path.join(diff, diff+".txt"), "w" , encoding='utf-8') as wf:
            for i in self.result:
                tmp  = i[0] + " " + i[1]
                if tmp in self.labelFile:
                    # 识别正确。。
                    shutil.copy2(os.path.join(self.images_dir, i[0]), same)
                    self.rightList.append(i)
                    rf.write(tmp+"\n")
                else:
                    shutil.copy2(os.path.join(self.images_dir, i[0]), diff)
                    self.wrongList.append(i)
                    wf.write(tmp+"\n")
    def displot(self, Right=True, n_size=10):
        # 把图片显示出来。。
        # n_size = -1表示使用对应的长度。。
#         plt.figure()
        if Right:
            if n_size > len(self.rightList) or n_size < 0:
                n_size = len(self.rightList)
            # 然后显示指定数量的图片
            
            for j,i in zip(range(n_size), self.rightList):
                flg.add_subplot(n_size,2,(j % 2)+1)
                print(i[0] + "===" + i[1])
                plt.title(i[0] + "===" + i[1])
                plt.imshow(Image.open(i[2]))
        else:
            if n_size > len(self.wrongList) or n_size < 0:
                n_size = len(self.wrongList)
            for j,i in zip(range(n_size), self.wrongList):
                plt.subplot(n_size,2,(j % 2)+1)
                print(i[0] + "===" + i[1])
                plt.title(i[0] + "===" + i[1])          
                plt.imshow(Image.open(i[2]))           
        plt.show()       

# 当前目录是：/workspace/densent_ocr是docker镜像中的目录

In [None]:
x = PredictCheck(images_dir="./ybt_test_images/pad_only_chinese_test_set_debug", 
                 image_label="./ybt_test_images/only_chinese1.txt",
                dicfile="./only_chinese_chn.txt",
                 unih=28,
                 modelPath="./ybt_model_weights/weights_densenet-09-9.79.h5",
                 gpu_fraction=0
                )

In [None]:
x.loopPredict()

In [None]:
x.compare()

In [None]:
x.displot(Right=False)