In [1]:
import os
from zipfile import *
import shutil
import json
import pandas as pd
import numpy as np

In [2]:
models = os.listdir('output')
factors = np.array(["Blur", "Condition", "Contrast", "Curved", "FontSize", "FontStyle", "Noise", "Resolution", "Ripple", "Rotated", "SkewLevel"])
df = pd.DataFrame(columns = factors)
for model in models:
    path = os.path.join('output/',model,'results.zip')
    source_results = ZipFile(path, 'r')
    target_dir = os.path.join('tmpresults/', model)
    source_results.extractall(target_dir)
    print("getting results for " + model)
    results_filelist = os.listdir(target_dir)
    results_filelist.remove('method.json')
    factor_counts = pd.Series(index=factors, dtype="int8")
    factor_sums = pd.Series(index=factors, dtype="float64").fillna(0)
    for file in results_filelist:
        factor = file.split("_")[0]
        factor_counts[factor] += 1
        file_path = os.path.join(target_dir,file)
        f = open(file_path, 'r')
        result = json.loads(f.read())
        f.close()
        factor_sums[factor] += result["hmean"]
    averages = factor_sums / factor_counts
    averages = averages.rename(model)
    averages = averages.to_frame()
    averages = averages.transpose()
    df = df.append(averages)


getting results for FOTS
getting results for CRAFT
getting results for Pixel_link
getting results for DB_resnet50
getting results for Charnet
getting results for textfusenet
getting results for EAST
getting results for DB_resnet18


In [4]:
import plotly.express as px
import plotly.graph_objects as go
from ipywidgets import Output, VBox
from IPython.display import clear_output
import cv2
from matplotlib import pyplot as plt
import matplotlib.gridspec as gridspec
from ipywidgets.widgets.interaction import show_inline_matplotlib_plots
%matplotlib widget
        

def show_factor(x,y):
    print("rendering images... This might take a while")
    model = df.index[x]
    factor = df.columns[y]
    image_path = os.path.join(os.pardir,"IVtext")
    image_files = os.listdir(image_path)
    image_files.sort()
    clear_output()
    for image in [x for x in image_files if x.startswith(factor)]:
        img1 = cv2.imread(os.path.join(image_path,image))
        img2 = cv2.imread(os.path.join(image_path,image))
        json_file = open(os.path.join("tmpresults",model,os.path.splitext(image)[0]) +".txt.json", "r")
        result = json.loads(json_file.read())
        gt_points = result["gtPolPoints"]
        det_points = result["detPolPoints"]
        for box in gt_points:
            box = list(map(np.int32,box))
            x1 = min(box[0::2])
            y1 = min(box[1::2])
            x2 = max(box[0::2])
            y2 = max(box[1::2])
            cv2.rectangle(img1, (x1,y1), (x2,y2), (0,255,0), 4)
        for box in det_points:
            box = list(map(np.int32,box))
            x1 = min(box[0::2])
            y1 = min(box[1::2])
            x2 = max(box[0::2])
            y2 = max(box[1::2])
            cv2.rectangle(img2, (x1,y1), (x2,y2), (0,0,255), 4)
        f = plt.figure(figsize=(10,4))
        f.suptitle(model + ": " + image, fontsize=16)
        #gs = gridspec.GridSpec(1, 2)
        #gs.update(wspace=0.0025, hspace=0.05)
        ax = f.add_subplot(121)
        ax2 = f.add_subplot(122)
        ax.imshow(img1)
        ax.axis("off")
        ax.set_title("ground truth")
        ax2.imshow(img2)
        ax2.axis("off")
        ax2.set_title("Prediction")
        ax2.text(0.5,-0.1, "F-score: " + str(result["hmean"])[0:5], size=12, ha="center", transform=ax2.transAxes)
        plt.tight_layout(pad=0.001, w_pad=0.001, h_pad=1.0)
        show_inline_matplotlib_plots()
        

out = Output()
@out.capture(clear_output=True)
def handle_click(trace, points, state):
    x = points.point_inds[0][0]
    y = points.point_inds[0][1]
    show_factor(x,y)

fig = go.FigureWidget()
fig.add_heatmap(
    x = df.columns, y=df.index, z=df,colorscale="thermal")
heatmapp = fig.data[0]

fig.update_layout(height = 600, width = 900)



heatmapp.on_click(handle_click)

VBox([fig, out])


VBox(children=(FigureWidget({
    'data': [{'colorscale': [[0.0, 'rgb(3, 35, 51)'], [0.09090909090909091,
    …