# Evaluation for grouping task

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from torch.autograd import Variable
from torch.utils.data import DataLoader
from tqdm import tqdm
from typing import List, Dict
import collections
import matplotlib.pyplot as plt
import numpy as np
import os
import torch

from docreader.evaluation.metrics import bbox_evaluation

import simple_grouping_task_helpers as helpers
from pointer_net import PointerNet

In [None]:
%matplotlib inline

## Data

In [None]:
# Load the data
data_test = np.load('data/lines-gaussian-0.3/lines_data_test.npz')['arr_0']
dataset_test = helpers.LineDataset(data_test, random_shuffle=True)
print(len(dataset_test))

## Model

In [None]:
DEVICE = 'cuda:3'

In [None]:
model = PointerNet(n_in=2).cuda(device=DEVICE)
model.load_state_dict(torch.load('/opt/weights/ptr-line-grouping-gaussian-0.3-1.02/115_0.pt'))

## Visualize

In [None]:
# Get the data
datum = dataset_test[221]
points = datum['sequence']
target_pointers = datum['pointers']

# Predict
n_targets = len(target_pointers)  # NOTE: WE'RE USING THIS FOR LENGTH! NEED EOS TOKEN IDEALLY
pred_pointers = predict(points, n_targets)

plt.figure(figsize=(20, 7))
plt.subplot(1, 2, 1)
plt.title('Target')
plt.yticks(np.arange(13) / 10 + 0.015, np.arange(13) / 10)
plt.grid()
helpers.plot_points_and_lines(points, target_pointers)
plt.subplot(1, 2, 2)
plt.title('Predictions')
helpers.plot_points_and_lines(points, pred_pointers)
plt.yticks(np.arange(13) / 10 + 0.015, np.arange(13) / 10)
plt.grid()

## Metrics

The overall task has two components:
1. "Grouping": Can the model correctly group all the words in one text-line together?
2. "Ordering": Can the model predict the correct order of the words in the text-line?


Accordingly the metrics are defined as follows:
1. "Grouping correctness": consider each line as a "set" of words and calculate IoU
    1. Post-processing
        1. Order does NOT matter
        2. Duplicates removed (1, 2, 2, 3 -> 1, 2, 3)
    2. Evaluation: calculate by controlling two parameters:
        1. IoU of the text-line bounding boxes
        2. Edit-distance between the sorted labels (target: {11, 12, 13}, pred: {11, 13} -> edit distance = 1)
2. Absolute order correctness (:
    

In [None]:
from tqdm import tqdm
import collections
import numpy as np
import pandas as pd

def _calculate_iou_lists(list1, list2) -> float:
    """
    NOTE:
        1. Ignores duplicate elements
        2. Order insensitive
    """
    set1 = set(list1)
    set2 = set(list2)
    
    return len(set1.intersection(set2)) / len(set1.union(set2))

@np.vectorize
def calculate_iou_lists(list1, list2) -> float:
    return _calculate_iou_lists(list1, list2)


LineMetrics = collections.namedtuple('LineMetrics', ['tp', 'fn', 'fp', 'n_targets', 'n_preds',])


def get_line_grouping_metrics(points, target_pointers, pred_pointers, iou_thresh=.75):
    # Split into lines
    _, target_line_pointers = helpers.get_lines_from_pointers(points, target_pointers, return_labels=True)
    _, pred_line_pointers = helpers.get_lines_from_pointers(points, pred_pointers, return_labels=True)

    pred_line_pointers = np.array(pred_line_pointers)
    target_line_pointers = np.array(target_line_pointers)
    n_targets = len(target_line_pointers)
    n_preds = len(pred_line_pointers)
    
    if len(target_line_pointers.shape) != 1 or len(pred_line_pointers.shape) != 1:
        return LineMetrics(tp=0, fp=0, fn=0, n_targets=0, n_preds=0)
    
    if n_preds == 1 and n_targets == 1:
        target = list(target_line_pointers[0])
        pred = list(pred_line_pointers[0])
        ious = np.array([[_calculate_iou_lists(target, pred)]])
    elif n_preds == 1:
        pred = list(pred_line_pointers[0])
        ious = np.array([[_calculate_iou_lists(target, pred) for target in target_line_pointers]])
    elif n_targets == 1:
        target = list(target_line_pointers[0])
        ious = np.array([[_calculate_iou_lists(target, pred) for pred in pred_line_pointers]]).T
    else:
        # Broadcast
        pred_line_pointers = np.array(pred_line_pointers)[:, np.newaxis]  # n_preds * 1
        target_line_pointers = np.array(target_line_pointers)[np.newaxis, :]  # 1 * n_targets

        # Calculate n_preds * n_targets
        preds = np.repeat(pred_line_pointers, n_targets, axis=1)
        targets = np.repeat(target_line_pointers, n_preds, axis=0)

        assert np.all(preds.shape == np.array([n_preds, n_targets]))
        assert np.all(targets.shape == np.array([n_preds, n_targets]))

        ious = calculate_iou_lists(preds, targets)

    assert np.all(ious.shape == np.array([n_preds, n_targets]))

    # Match
    rows_org, cols_org = np.unravel_index(np.argsort(-ious.ravel()), ious.shape)
    idx_th = np.where(ious[rows_org, cols_org] >= iou_thresh)
    rows = rows_org[idx_th]
    cols = cols_org[idx_th]

    # While obtaining the unique rows and columns, make sure that the order is maintained (highest IoU first)
    # FIXME: I've used pandas right now just because it was easy, we should do it the numpy way eventually
    df_inds = pd.DataFrame([rows, cols]).T
    df_inds = df_inds.drop_duplicates(subset=[0], keep='first').drop_duplicates(subset=[1], keep='first')
    matched_pred = df_inds[0].tolist()
    matched_gt = df_inds[1].tolist()

    tp = len(matched_gt)  # Number of textlines found with IoU >= iou_thresh
    fn = n_targets - tp  # Misses
    fp = n_preds - tp  # Extra predictions
    
    return LineMetrics(
        tp=tp,
        fp=fp,
        fn=fn,
        n_targets=n_targets,
        n_preds=n_preds,
    )

In [None]:
def get_all_gts_preds(dataset, n_iters):
    all_targets = []
    all_preds = []
    all_points = []
    for ix in tqdm(range(n_iters)):
        # Get the data
        datum = dataset[ix]
        points = datum['sequence']
        target_pointers = datum['pointers']

        # Predict
        n_targets = len(target_pointers)  # NOTE: WE'RE USING THIS FOR LENGTH! NEED EOS TOKEN IDEALLY
        pred_pointers = predict(points, n_targets)
    
        all_preds.append(pred_pointers)
        all_targets.append(target_pointers)
        all_points.append(points)
    return all_targets, all_preds, all_points

def get_textline_accuracy_by_sets(all_points, all_targets, all_preds, iou_thresh=0.75):
    assert len(all_targets) == len(all_preds)

    tp_total, n_targets_total = 0, 0
    for points, target_pointers, pred_pointers in zip(all_points, all_targets, all_preds):
        metrics = get_line_grouping_metrics(points, target_pointers=target_pointers, pred_pointers=pred_pointers, iou_thresh=iou_thresh)
        tp = metrics.tp
        n_targets = metrics.n_targets
        assert tp <= n_targets
        tp_total += tp
        n_targets_total += n_targets

    return tp_total/ n_targets_total, (tp_total, n_targets_total)

In [None]:
%%time
all_targets, all_preds, all_points = get_all_gts_preds(dataset_test, n_iters=1000)

In [None]:
%%time
ious = [0.2, 0.4, 0.6, 0.7, 0.75, 0.8, 0.85, 0.9, 1]
accs = []
for iou_thresh in ious:
    acc, (tp, n_targets) = get_textline_accuracy_by_sets(all_points, all_targets, all_preds, iou_thresh=iou_thresh)
    accs.append(acc)

plt.plot(ious, accs, 'x-')
plt.title('Text-line accuracy (as sets) vs IoU')
plt.xlabel('IoU')
plt.ylabel('Accuracy')
plt.grid()

In [None]:
accs

## Qualitative results on document data

This **will perform poorly** since the training data had y-coordinates discretized to 0.1 intervals which is not the case with the real data.

In [None]:
from docschema.semantic import Document
import cv2
import pathlib

In [None]:
import doc_data, doc_visualize

In [None]:
preprocessor = doc_data.Preprocessor(TextLine, crop_h=500, crop_w=500, random_shuffle=False, only_midpoints=True)

In [None]:
path_to_json = pathlib.Path('/opt/data/document-datasets/acord/acord-test-files/125-2007-10/125 2007-10 - (17.12)-copy(2).json')
path_to_image = path_to_json.with_suffix('.png')

In [None]:
doc = Document.load(str(path_to_json))
image = cv2.imread(str(path_to_image))
doc.rendered_image = image

In [None]:
datum = preprocessor(doc)

In [None]:
points=datum['bboxes']
pointers=datum['pointers']
image=datum['image']
scale=datum['scale']

In [None]:
plt.figure(figsize=(15, 20))
doc_visualize.plot_points_and_lines(points=points, pointers=pointers, image=image, scale=scale, fontsize=0)

In [None]:
sequence = torch.from_numpy(points.astype(np.float32)[np.newaxis, ...]).cuda(DEVICE)
seq_lens = [sequence.shape[1]]
n_outputs = len(pointers)
pred_pointer_probs = model(sequence, seq_lens, max_output_len=n_outputs)

In [None]:
pred_pointers = pred_pointer_probs.argmax(dim=-1).data.cpu().numpy().squeeze()

In [None]:
plt.figure(figsize=(15, 20))
doc_visualize.plot_points_and_lines(points=points, pointers=pred_pointers, image=image, scale=scale, fontsize=0)

In [None]:
points