In [1]:
import chainer
from PIL import Image, ImageDraw, ImageFont, ImageOps
import csv
import cv2
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
import imutils
import easyocr
from manga_ocr import MangaOcr
from scipy.ndimage import rotate
from imutils.contours import sort_contours

import sys
sys.path.append('../../')

from kr.detector.centernet.resnet import Res18UnetCenterNet
from kr.classifier.softmax.mobilenetv3 import MobileNetV3
from kr.datasets import KuzushijiUnicodeMapping

In [2]:
font_path = '/home/ec2-user/code/t-hanya/kuzushiji-recognition/data/font/NotoSansCJKjp-Regular.otf'
font = ImageFont.truetype(font_path, 50, encoding='utf-8')
pred_color = 'rgb(255, 0, 0)'
true_color = 'rgb(0, 0, 255)'

In [30]:
class Predictor:
    
    def __init__(self,
                 detector_model='/home/ec2-user/code/t-hanya/kuzushiji-recognition/results/detector/model_700.npz',
                 classifier_model='/home/ec2-user/code/t-hanya/kuzushiji-recognition/results/classifier/model_1000.npz',
                 full_dataset='/home/ec2-user/code/restor-ai-tion/data/full',
                 save_folder='/home/ec2-user/code/t-hanya/kuzushiji-recognition/data/preds'
                ):
        self.detector_model_file = detector_model
        self.classifier_model_file = classifier_model
        self.full_dataset = full_dataset
        self.save_folder = save_folder
        self.total = 0
        self.correct = 0 
        self.wrong = 0
        self.results_csv = []
        self.results = []

        os.makedirs(save_folder, exist_ok=True)

        # unicode <-> unicode index mapping
        self.mapping = KuzushijiUnicodeMapping()
        
        # load trained detector
        self.detector = Res18UnetCenterNet()
        chainer.serializers.load_npz(self.detector_model_file, self.detector)
        
        # load trained classifier
        self.classifier = MobileNetV3(out_ch=len(self.mapping))
        chainer.serializers.load_npz(self.classifier_model_file, self.classifier)

    def predict(self, image_filename):
        # load image
        image = Image.open(image_filename)
        
        # character detection
        bboxes, bbox_scores = self.detector.detect(image)
        
        # character classification
        unicode_indices, scores = self.classifier.classify(image, bboxes)
        unicodes = [self.mapping.index_to_unicode(idx) for idx in unicode_indices]
        return unicodes, scores, bboxes, bbox_scores

    def bbox_values(self, bbox):
        divisor = 1
        return int(bbox[0]/divisor), int(bbox[1]/divisor), int(bbox[2]/divisor), int(bbox[3]/divisor)
    
    def _do_ocr(self, orig_filename, verbose=False):
        orig_img = Image.open(orig_filename)
        pred_img = Image.open(orig_filename)
        pred_draw = ImageDraw.Draw(pred_img)
    
        unicodes, unicode_scores, bboxes, bbox_scores = self.predict(orig_filename)
        limit = len(unicodes)
    
        for (pred, uscore, bbox, bscore) in zip(unicodes[:limit], unicode_scores[:limit], bboxes[:limit], bbox_scores[:limit]):
            x, y, w, h = self.bbox_values(bbox)
            i = np.argmax(pred)
            label = chr(int(pred[2:], 16))
            if verbose:
                print('!!!', pred, label, uscore, x, y, w, h, bscore)
    
            self.results_csv.append({
                'image': os.path.basename(orig_filename).replace('.jpg', ''),
                'char': label,
                'uc': pred,
                'x': x,
                'y': y,
                'w': abs(x-w),
                'h': abs(y-h)
            })
            pred_draw.text((x - 50, y - 10), label, fill=pred_color, font = font)
            # pred_draw.rectangle([(x, y), (w, h)], outline='blue', width=3)
    
    
        x = os.path.basename(orig_filename).replace('-', '_')
        true_image_available = False
    
        num_correct = 0
        num_wrong = 0
    
        # print(x)
        
        if len(x.split('_')) >= 2:
            true_image_available = True
            bookname = x.split('_')[0]
            if bookname.startswith('umgy'):
                bookname = 'umgy00000'
            elif bookname.startswith('hnsd'):
                bookname = 'hnsd00000'
            elif bookname.startswith('brsk'):
                bookname = 'brsk00000'
    
            image_name = x.split('.')[0]
            csv_file = os.path.join(self.full_dataset, bookname, '{}_coordinate.csv'.format(bookname))
            # print(csv_file)
            true_unicodes = []
            true_bboxes = []
            with open(csv_file) as csvfile:
                reader = csv.DictReader(csvfile)
                for row in reader:
                    if row['Image'] == image_name:
                        uc = row['Unicode']
                        true_unicodes.append(chr(int(uc[2:], 16)))
                        true_bboxes.append([int(row['X']), int(row['Y']), int(row['Width']), int(row['Height'])])
    
            # print(len(true_unicodes), len(unicodes))
            total = len(true_unicodes)
    
            dimg = pred_draw
            for (uc, bbox) in zip(true_unicodes, true_bboxes):
                x, y, w, h = self.bbox_values(bbox)
                w = w + x
                h = h + y
                if verbose:
                    print('>', pred, label, uscore, x, y, w, h, bscore)
        
                
                dimg.text((x + 100, y - 10), uc, fill=true_color, font = font)
                # dimg.rectangle([(x, y), (w, h)], outline='green', width=4)  
                found_char = False
                for (pred, pbbox) in zip(unicodes[:limit], bboxes[:limit]):
                    # if abs(xpred - x) < 200.0 and abs(ypred - y) < 200.0 and cpred == char:
                    xpred, ypred, width, height = self.bbox_values(pbbox)
                    puc = chr(int(pred[2:], 16))
                    if abs(xpred - x) < 200 and abs(ypred - y) < 200 and puc == uc:
                        found_char = True
                        break
                if found_char:
                    num_correct += 1
                else:
                    num_wrong += 1
    
            # print('accuracy', total, (num_correct*1.0/total), (num_wrong*1.0)/total)
    
        plt.figure()
        plt.imshow(pred_img)
        plt.title('Predicted')
        plt.axis('off')

        basename = os.path.basename(orig_filename)
        save_filename = os.path.join(self.save_folder, 'pred_{}'.format(basename))
        pred_img.save(save_filename)
        viz_filenames = [save_filename]

        plt.close() 

        self.total += total
        self.correct += num_correct
        self.wrong += num_wrong
        cp = 0.0
        wp = 0.0
        if total > 0:
            cp = num_correct * 1.0/total
            wp = num_wrong * 1.0/total
        self.results.append([os.path.basename(orig_filename), total, num_correct, cp, num_wrong, wp])
    
    def do_ocr(self, files):
        n = len(files)
        for i, f in enumerate(files):
            self._do_ocr(f)
            if (i+1) % 50 == 0:
                print('processed {} files, {} total chars, {} total correct, {} total wrong'.format(i+1, self.total, self.correct, self.wrong))
        cp = 0.0
        wp = 0.0
        if self.total > 0:
            cp = self.correct * 1.0 / self.total
            wp = self.wrong * 1.0 / self.total
        self.results.append(['TEST RESULTS', self.total, self.correct, cp, self.wrong, wp])
        df = pd.DataFrame(self.results, 
                          columns = ['', 'Total characters', 'Correct Predictions', 'Correct %', 'Wrong Predictions', 'Wrong %'])
        return df

    def save_csv(self, csv_filename='results.csv'):
        csv_absfilename = os.path.join(self.save_folder, csv_filename)
        with open(csv_absfilename, 'w', newline='') as cf:
            fieldnames = ['image', 'char', 'uc', 'x', 'y', 'w', 'h']
            writer = csv.DictWriter(cf, fieldnames=fieldnames)
            writer.writeheader()
            writer.writerows(self.results_csv)

In [25]:
p = Predictor()
p.do_ocr(['./data/kuzushiji-recognition/test_images/100241706_00011_2.jpg'])

Unnamed: 0,Unnamed: 1,Total characters,Correct Predictions,Correct %,Wrong Predictions,Wrong %
0,100241706_00011_2.jpg,139,136,0.978417,3,0.021583
1,TEST RESULTS,139,136,0.978417,3,0.021583


In [31]:
test_data_folder = './data/kuzushiji-recognition/test_images'
# test_files = [os.path.join(test_data_folder, f) for f in os.listdir(test_data_folder) if os.path.isfile(os.path.join(test_data_folder, f))]
test_files = [os.path.join('/home/ec2-user/code/restor-ai-tion/data/full/200014685/images/', p) for p in os.listdir('/home/ec2-user/code/restor-ai-tion/data/full/200014685/images/')]
# test_files = ['/home/ec2-user/code/restor-ai-tion/data/full/200014685/images/200014685_00002_2.jpg', '/home/ec2-user/code/restor-ai-tion/data/full/200021763/images/200021763_00016_1.jpg']
p = Predictor()
df = p.do_ocr(test_files[:1])
p.save_csv(csv_filename='200014685.csv')

In [None]:
df.rename(columns={'': 'image'}, inplace=True)

In [None]:
df