# KuroNet Vs Hanya Comparison

In [1]:
import chainer
from PIL import Image, ImageDraw, ImageFont, ImageOps
import csv
import json
import cv2
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
import imutils
import easyocr
from manga_ocr import MangaOcr
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

from scipy.ndimage import rotate
from imutils.contours import sort_contours

import sys
sys.path.append('../../')

from kr.detector.centernet.resnet import Res18UnetCenterNet
from kr.classifier.softmax.mobilenetv3 import MobileNetV3
from kr.datasets import KuzushijiUnicodeMapping

In [22]:
class KuroVsHanya:
    def __init__(self,
                 detector_model='/home/ec2-user/code/t-hanya/kuzushiji-recognition/results/detector/model_700.npz',
                 classifier_model='/home/ec2-user/code/t-hanya/kuzushiji-recognition/results/classifier/model_1000.npz',
                 full_dataset='/home/ec2-user/code/restor-ai-tion/data/full'):
        self.mapping = KuzushijiUnicodeMapping()
                
        # load trained detector
        self.detector = Res18UnetCenterNet()
        chainer.serializers.load_npz(detector_model, self.detector)
        
        # load trained classifier
        self.classifier = MobileNetV3(out_ch=len(self.mapping))
        chainer.serializers.load_npz(classifier_model, self.classifier)

        self.dataset = full_dataset
        self.test_books = ['200004107', '200005798', '200006665', '200008003', '200008316',
                           '200010454', '200015843', '200017458', '200018243', '200019865',
                           '200020019', '200021063', '200021071', '200021086', '200025191']

        self.results = dict()
        self.page_metrics = dict()
        self.book_metrics = dict()
        self.overall_metrics = dict()


    def predict(self, image_filename):
        # load image
        image = Image.open(image_filename)
        
        # character detection
        bboxes, bbox_scores = self.detector.detect(image)
        
        # character classification
        unicode_indices, scores = self.classifier.classify(image, bboxes)
        unicodes = [self.mapping.index_to_unicode(idx) for idx in unicode_indices]
        return unicodes, scores, bboxes, bbox_scores

    def collect_ground_truth(self, book, page_name):
        csv_file = os.path.join(self.dataset, book, '{}_coordinate.csv'.format(book))
        true_unicodes = []
        true_bboxes = []
        with open(csv_file) as csvfile:
            reader = csv.DictReader(csvfile)
            for row in reader:
                if row['Image'] == page_name:
                    uc = row['Unicode']
                    true_unicodes.append(chr(int(uc[2:], 16)))
                    true_bboxes.append([int(row['X']), int(row['Y']), int(row['Width']), int(row['Height'])])
        page_ground_truth = dict()
        for index, (uc, bbox) in enumerate(zip(true_unicodes, true_bboxes)):
            page_ground_truth[index] = (uc, bbox)
        return page_ground_truth

    def bbox_values(self, bbox):
        divisor = 1
        return int(bbox[0]/divisor), int(bbox[1]/divisor), int(bbox[2]/divisor), int(bbox[3]/divisor)    
  
    def do_ocr(self, book, page):
        unicodes, unicode_scores, bboxes, bbox_scores = self.predict(page)
        
        page_name = os.path.basename(page).replace('.jpg', '')
        page_ground_truth = self.collect_ground_truth(book, page_name)
        
        num_correct, num_wrong = 0, 0
        total = len(page_ground_truth)

        true_labels = []
        pred_labels = []
        
        for index, (uc, bbox) in page_ground_truth.items():
            x, y, w, h = self.bbox_values(bbox)
            w = w + x
            h = h + y

            key = '{}:{}'.format(page_name, index)

            found_char = False
            close_match = False
            for (pred, pbbox) in zip(unicodes, bboxes):
                # if abs(xpred - x) < 200.0 and abs(ypred - y) < 200.0 and cpred == char:
                xpred, ypred, width, height = self.bbox_values(pbbox)
                puc = chr(int(pred[2:], 16))
                #if abs(xpred - x) < 200 and abs(ypred - y) < 200 and puc == uc:
                if abs(xpred - x) < 60.0 and abs(ypred - y) < 80.0:
                    if puc == uc:
                        found_char = True
                        self.results[key] = {
                            'page': page_name,
                            'index': index,
                            'true_bbox': bbox,
                            'pred_bbox': pbbox,
                            'true_char': uc,
                            'pred_char': puc,
                            'match': True
                        }
                        break
                    else:
                        close_match = True
                        self.results[key] = {
                            'page': page_name,
                            'index': index,
                            'true_bbox': bbox,
                            'pred_bbox': pbbox,
                            'true_char': uc,
                            'pred_char': puc,
                            'match': False
                        }                        
            if found_char:
                num_correct += 1
            else:
                num_wrong += 1        
                if not close_match:
                    self.results[key] = {
                        'page': page_name,
                        'index': index,
                        'true_bbox': bbox,
                        'pred_bbox': [],
                        'true_char': uc,
                        'pred_char': '',
                        'match': False
                    }                

            true_labels.append(uc)
            pred_labels.append(self.results[key]['pred_char'])


        # cm = confusion_matrix(true_labels, pred_labels)
        # recall = np.diag(cm) / np.sum(cm, axis = 1)
        # precision = np.diag(cm) / np.sum(cm, axis = 0)

        # page_recall = np.mean(recall)
        if total > 0:
            # page_accuracy = accuracy_score(true_labels, pred_labels)
            # page_precision = precision_score(true_labels, pred_labels, average='micro')
            # page_recall = recall_score(true_labels, pred_labels, average='micro')
            # page_f1 = f1_score(true_labels, pred_labels, average='micro')

            # The number of correctly predicted characters divided by the 
            # total number of predicted characters is defined as precision.
            if len(unicodes) > 0:
                page_precision = (num_correct*1.0) / len(unicodes)
                # We recall as the number of correctly predicted characters 
                # divided by the total number of characters present in the ground truth.
                page_recall = (num_correct*1.0) / total
                # F1-score, which is simply the harmonic mean of the precision and the recall
                page_f1 = 2 * (page_precision * page_recall) / (page_precision + page_recall)
        
                print('page prediction metrics', page_name, page_precision, page_recall, page_f1)
                self.page_metrics[page_name] = {
                    'num_correct': num_correct,
                    'num_wrong': num_wrong,
                    'ground_truth_total_chars': total,
                    'pred_total_chars': len(unicodes),
                    'precision': page_precision,
                    'recall': page_recall,
                    'f1': page_f1
                }
            else:
                num_correct = 0
                num_wrong = total
                print('page prediction metrics', 'no characters predicted in the image')
            
        else:
            print('page prediction metrics', 'no characters in the image')
        return num_correct, num_wrong, total, len(unicodes)
        

    def eval_book(self, book):
        book_folder = os.path.join(self.dataset, book, 'images')
        pages = [os.path.join(book_folder, p) for p in os.listdir(book_folder) if os.path.isfile(os.path.join(book_folder, p))]
        print('book', book, 'pages', len(pages))
        
        c, w, gt, pt = 0, 0, 0, 0
        for page in pages:
            pc, pw, pgt, ppt = self.do_ocr(book, page)
            c += pc
            w += pw
            gt += pgt
            pt += ppt
        book_precision = (c*1.0)/pt
        book_recall = (c*1.0)/gt
        book_f1 = 2 * (book_precision * book_recall) / (book_precision + book_recall)
        self.book_metrics[book] = {
            'num_correct': c,
            'num_wrong': w,
            'ground_truth_total_chars': gt,
            'pred_total_chars': pt,
            'precision': book_precision,
            'recall': book_recall,
            'f1': book_f1
        }
        return c, w, gt, pt

    def eval(self):
        c, w, gt, pt = 0, 0, 0, 0
        for book in self.test_books:
            bc, bw, bgt, bpt = self.eval_book(book)
            c += bc
            w += bw
            gt += bgt
            pt += bpt
            
        overall_precision = (c*1.0)/pt
        overall_recall = (c*1.0)/gt
        overall_f1 = 2 * (overall_precision * overall_recall) / (overall_precision + overall_recall)
        self.overall_metrics = {
            'num_correct': c,
            'num_wrong': w,
            'ground_truth_total_chars': gt,
            'pred_total_chars': pt,
            'precision': overall_precision,
            'recall': overall_recall,
            'f1': overall_f1
        }


In [23]:
kh = KuroVsHanya()
kh.eval()

book 200004107 pages 78
page prediction metrics 200004107_00032_2 0.688034188034188 0.8944444444444445 0.7777777777777777
page prediction metrics 200004107_00041_1 0.9352941176470588 0.9137931034482759 0.9244186046511628
page prediction metrics no characters in the image
page prediction metrics 200004107_00013_2 0.9066666666666666 0.9066666666666666 0.9066666666666666
page prediction metrics 200004107_00026_1 0.9248554913294798 0.903954802259887 0.9142857142857141
page prediction metrics 200004107_00016_1 0.8961038961038961 0.8903225806451613 0.8932038834951457
page prediction metrics no characters in the image
page prediction metrics 200004107_00012_2 0.920863309352518 0.9078014184397163 0.9142857142857144
page prediction metrics no characters in the image
page prediction metrics 200004107_00017_1 0.9281437125748503 0.9226190476190477 0.9253731343283582
page prediction metrics 200004107_00005_1 0.9019607843137255 0.8846153846153846 0.8932038834951457
page prediction metrics no charact

In [24]:
pdf = pd.DataFrame.from_dict(kh.page_metrics, orient='index', columns=['num_correct', 'num_wrong', 'ground_truth_total_chars', 'pred_total_chars', 'precision', 'recall', 'f1'])
pdf

Unnamed: 0,num_correct,num_wrong,ground_truth_total_chars,pred_total_chars,precision,recall,f1
200004107_00032_2,161,19,180,234,0.688034,0.894444,0.777778
200004107_00041_1,159,15,174,170,0.935294,0.913793,0.924419
200004107_00013_2,136,14,150,150,0.906667,0.906667,0.906667
200004107_00026_1,160,17,177,173,0.924855,0.903955,0.914286
200004107_00016_1,138,17,155,154,0.896104,0.890323,0.893204
...,...,...,...,...,...,...,...
200025191_00052_1,261,19,280,275,0.949091,0.932143,0.940541
200025191_00012_1,132,24,156,149,0.885906,0.846154,0.865574
200025191_00068_1,227,48,275,257,0.883268,0.825455,0.853383
200025191_00044_1,220,21,241,239,0.920502,0.912863,0.916667


In [25]:
bdf = pd.DataFrame.from_dict(kh.book_metrics, orient='index', columns=['num_correct', 'num_wrong', 'ground_truth_total_chars', 'pred_total_chars', 'precision', 'recall', 'f1'])
bdf

Unnamed: 0,num_correct,num_wrong,ground_truth_total_chars,pred_total_chars,precision,recall,f1
200004107,7855,781,8636,9205,0.853341,0.909565,0.880556
200005798,32957,4930,37887,36993,0.890898,0.869876,0.880262
200006665,15427,1456,16883,17013,0.906777,0.913759,0.910255
200008003,12139,652,12791,13109,0.926005,0.949027,0.937375
200008316,32730,4977,37707,36791,0.88962,0.868009,0.878681
200010454,9314,2252,11566,10760,0.865613,0.805291,0.834364
200015843,45821,4266,50087,49859,0.919012,0.914828,0.916915
200017458,27752,1832,29584,29258,0.948527,0.938075,0.943272
200018243,12567,1056,13623,13878,0.905534,0.922484,0.91393
200019865,34348,4835,39183,38708,0.887362,0.876605,0.88195


In [26]:
odf = pd.DataFrame.from_dict({'overall': kh.overall_metrics}, orient='index', columns=['num_correct', 'num_wrong', 'ground_truth_total_chars', 'pred_total_chars', 'precision', 'recall', 'f1'])
odf

Unnamed: 0,num_correct,num_wrong,ground_truth_total_chars,pred_total_chars,precision,recall,f1
overall,350303,40743,391046,384875,0.910173,0.89581,0.902935


In [29]:
kuro_book = {
    '200004107': [0.778708, 0.884214, 0.828114],
    '200005798': [0.775391, 0.675723, 0.722134],
    '200006665': [0.807131, 0.815499, 0.811294],
    '200008003': [0.931166, 0.935542, 0.933349],
    '200008316': [0.839333, 0.870747, 0.854751],
    '200010454': [0.852287, 0.858848, 0.855555],
    '200015843': [0.742832, 0.746247, 0.744535],
    '200017458': [0.806243, 0.669653, 0.731627],
    '200018243': [0.885414, 0.890162, 0.887782],
    '200019865': [0.789542, 0.77634 , 0.782885],
    '200020019': [0.793021, 0.703133, 0.745377],
    '200021063': [0.549717, 0.259251, 0.352337],
    '200021071': [0.841072, 0.835898, 0.838477],
    '200021086': [0.796496, 0.714183, 0.753097],
    '200025191': [0.714985, 0.713568, 0.714276]
}
kuro_reg_book = {
    '200004107': [0.872809, 0.894761, 0.883649],
    '200005798': [0.878944, 0.889179, 0.884032],
    '200006665': [0.883257, 0.915808, 0.899238],
    '200008003': [0.917325, 0.927240, 0.922256],
    '200008316': [0.842049, 0.863097, 0.852443],
    '200010454': [0.843034, 0.846826, 0.844926],
    '200015843': [0.901105, 0.91964 , 0.910278],
    '200017458': [0.946381, 0.952779, 0.949569],
    '200018243': [0.87668 , 0.924987, 0.900186],
    '200019865': [0.887304, 0.903500, 0.895329],
    '200020019': [0.892866, 0.909477, 0.901095],
    '200021063': [0.895414, 0.865742, 0.880328],
    '200021071': [0.902501, 0.909207, 0.905842],
    '200021086': [0.881725, 0.892401, 0.887031],
    '200025191': [0.87776 , 0.890543, 0.884105]
}
page_precision_comparison = []
page_recall_comparison = []
page_f1_comparison = []
for book in kh.test_books:
    page_precision_comparison.append([kuro_book[book][0], kuro_reg_book[book][0], kh.book_metrics[book]['precision']])
    page_recall_comparison.append([kuro_book[book][1], kuro_reg_book[book][1], kh.book_metrics[book]['recall']])
    page_f1_comparison.append([kuro_book[book][2], kuro_reg_book[book][2], kh.book_metrics[book]['f1']])

In [30]:
page_precision_comparison_df = pd.DataFrame(page_precision_comparison, columns=['KuroNet', 'KuroNet+Reg', 'HanyaOCR'], index=kh.test_books)
page_precision_comparison_df

Unnamed: 0,KuroNet,KuroNet+Reg,HanyaOCR
200004107,0.778708,0.872809,0.853341
200005798,0.775391,0.878944,0.890898
200006665,0.807131,0.883257,0.906777
200008003,0.931166,0.917325,0.926005
200008316,0.839333,0.842049,0.88962
200010454,0.852287,0.843034,0.865613
200015843,0.742832,0.901105,0.919012
200017458,0.806243,0.946381,0.948527
200018243,0.885414,0.87668,0.905534
200019865,0.789542,0.887304,0.887362


In [31]:
page_recall_comparison_df = pd.DataFrame(page_recall_comparison, columns=['KuroNet', 'KuroNet+Reg', 'HanyaOCR'], index=kh.test_books)
page_recall_comparison_df

Unnamed: 0,KuroNet,KuroNet+Reg,HanyaOCR
200004107,0.884214,0.894761,0.909565
200005798,0.675723,0.889179,0.869876
200006665,0.815499,0.915808,0.913759
200008003,0.935542,0.92724,0.949027
200008316,0.870747,0.863097,0.868009
200010454,0.858848,0.846826,0.805291
200015843,0.746247,0.91964,0.914828
200017458,0.669653,0.952779,0.938075
200018243,0.890162,0.924987,0.922484
200019865,0.77634,0.9035,0.876605


In [32]:
page_f1_comparison_df = pd.DataFrame(page_f1_comparison, columns=['KuroNet', 'KuroNet+Reg', 'HanyaOCR'], index=kh.test_books)
page_f1_comparison_df

Unnamed: 0,KuroNet,KuroNet+Reg,HanyaOCR
200004107,0.828114,0.883649,0.880556
200005798,0.722134,0.884032,0.880262
200006665,0.811294,0.899238,0.910255
200008003,0.933349,0.922256,0.937375
200008316,0.854751,0.852443,0.878681
200010454,0.855555,0.844926,0.834364
200015843,0.744535,0.910278,0.916915
200017458,0.731627,0.949569,0.943272
200018243,0.887782,0.900186,0.91393
200019865,0.782885,0.895329,0.88195


In [33]:
kuro_overall = {
    'precision': 0.7964,
    'recall': 0.7509,
    'f1': 0.7730
}
kuro_reg_overall = {
    'precision': 0.8889,
    'recall': 0.9025,
    'f1': 0.8957
}

In [34]:
overall_comp = {
    'precision': {
        'KuroNet': kuro_overall['precision'],
        'KuroNet+Reg': kuro_reg_overall['precision'],
        'Hanya OCR': kh.overall_metrics['precision']
    },
    'recall': {
        'KuroNet': kuro_overall['recall'],
        'KuroNet+Reg': kuro_reg_overall['recall'],
        'Hanya OCR': kh.overall_metrics['recall']
    },
    'f1': {
        'KuroNet': kuro_overall['f1'],
        'KuroNet+Reg': kuro_reg_overall['f1'],
        'Hanya OCR': kh.overall_metrics['f1']
    }    
}
overall_comp_df = pd.DataFrame.from_dict(overall_comp, orient='index', columns=['KuroNet', 'KuroNet+Reg', 'Hanya OCR'])
overall_comp_df

Unnamed: 0,KuroNet,KuroNet+Reg,Hanya OCR
precision,0.7964,0.8889,0.910173
recall,0.7509,0.9025,0.89581
f1,0.773,0.8957,0.902935


In [35]:
print(page_precision_comparison_df.to_markdown())

|           |   KuroNet |   KuroNet+Reg |   HanyaOCR |
|----------:|----------:|--------------:|-----------:|
| 200004107 |  0.778708 |      0.872809 |   0.853341 |
| 200005798 |  0.775391 |      0.878944 |   0.890898 |
| 200006665 |  0.807131 |      0.883257 |   0.906777 |
| 200008003 |  0.931166 |      0.917325 |   0.926005 |
| 200008316 |  0.839333 |      0.842049 |   0.88962  |
| 200010454 |  0.852287 |      0.843034 |   0.865613 |
| 200015843 |  0.742832 |      0.901105 |   0.919012 |
| 200017458 |  0.806243 |      0.946381 |   0.948527 |
| 200018243 |  0.885414 |      0.87668  |   0.905534 |
| 200019865 |  0.789542 |      0.887304 |   0.887362 |
| 200020019 |  0.793021 |      0.892866 |   0.922658 |
| 200021063 |  0.549717 |      0.895414 |   0.884664 |
| 200021071 |  0.841072 |      0.902501 |   0.942779 |
| 200021086 |  0.796496 |      0.881725 |   0.916482 |
| 200025191 |  0.714985 |      0.87776  |   0.916883 |


In [36]:
print(page_recall_comparison_df.to_markdown())

|           |   KuroNet |   KuroNet+Reg |   HanyaOCR |
|----------:|----------:|--------------:|-----------:|
| 200004107 |  0.884214 |      0.894761 |   0.909565 |
| 200005798 |  0.675723 |      0.889179 |   0.869876 |
| 200006665 |  0.815499 |      0.915808 |   0.913759 |
| 200008003 |  0.935542 |      0.92724  |   0.949027 |
| 200008316 |  0.870747 |      0.863097 |   0.868009 |
| 200010454 |  0.858848 |      0.846826 |   0.805291 |
| 200015843 |  0.746247 |      0.91964  |   0.914828 |
| 200017458 |  0.669653 |      0.952779 |   0.938075 |
| 200018243 |  0.890162 |      0.924987 |   0.922484 |
| 200019865 |  0.77634  |      0.9035   |   0.876605 |
| 200020019 |  0.703133 |      0.909477 |   0.889244 |
| 200021063 |  0.259251 |      0.865742 |   0.801341 |
| 200021071 |  0.835898 |      0.909207 |   0.942259 |
| 200021086 |  0.714183 |      0.892401 |   0.894596 |
| 200025191 |  0.713568 |      0.890543 |   0.893645 |


In [37]:
print(page_f1_comparison_df.to_markdown())

|           |   KuroNet |   KuroNet+Reg |   HanyaOCR |
|----------:|----------:|--------------:|-----------:|
| 200004107 |  0.828114 |      0.883649 |   0.880556 |
| 200005798 |  0.722134 |      0.884032 |   0.880262 |
| 200006665 |  0.811294 |      0.899238 |   0.910255 |
| 200008003 |  0.933349 |      0.922256 |   0.937375 |
| 200008316 |  0.854751 |      0.852443 |   0.878681 |
| 200010454 |  0.855555 |      0.844926 |   0.834364 |
| 200015843 |  0.744535 |      0.910278 |   0.916915 |
| 200017458 |  0.731627 |      0.949569 |   0.943272 |
| 200018243 |  0.887782 |      0.900186 |   0.91393  |
| 200019865 |  0.782885 |      0.895329 |   0.88195  |
| 200020019 |  0.745377 |      0.901095 |   0.905643 |
| 200021063 |  0.352337 |      0.880328 |   0.840944 |
| 200021071 |  0.838477 |      0.905842 |   0.942519 |
| 200021086 |  0.753097 |      0.887031 |   0.905407 |
| 200025191 |  0.714276 |      0.884105 |   0.905115 |


In [38]:
print(overall_comp_df.to_markdown())

|           |   KuroNet |   KuroNet+Reg |   Hanya OCR |
|:----------|----------:|--------------:|------------:|
| precision |    0.7964 |        0.8889 |    0.910173 |
| recall    |    0.7509 |        0.9025 |    0.89581  |
| f1        |    0.773  |        0.8957 |    0.902935 |
