`NOTE:` Merge cells function in heuristics is not correct

In [1]:
SCITSR_PATH = '/Users/longhoang/Developer/table-recognition/data/SciTSR-partition'
MODEL_WEIGHT = '/Users/longhoang/Developer/table-recognition/pret-models/split0.pth'

In [2]:
import numpy as np
from tqdm import tqdm

from data_utils.utils import *
from merge.heuristics import *
from dataset.dataset import ImageDataset
from modules.split_modules import SplitModel



In [4]:
# LOAD MODEL
device = 'cuda' if torch.cuda.is_available() else 'cpu'
net = SplitModel(3)
net = torch.nn.DataParallel(net).to(device)

if device == 'cuda':
    net.load_state_dict(torch.load(MODEL_WEIGHT))
else:
    net.load_state_dict(torch.load(MODEL_WEIGHT, map_location=torch.device('cpu')))

device

'cpu'

In [5]:
# Load images
img_dir = os.path.join(SCITSR_PATH, 'train', 'img')
imgs_paths = [os.path.join(img_dir, p) for p in os.listdir(img_dir)]
print(f'- Got {len(imgs_paths)} images')

# Load Merge labels
merge_json = os.path.join(SCITSR_PATH, 'train', 'label', 'merge_label.json')
merge_labels = load_json(merge_json)
print(f"- Loaded {len(merge_labels)} labels for Merge module")

# Load Text posisions
chunk_json = os.path.join(SCITSR_PATH, 'train', 'label', 'chunk_label.json')
chunk_labels = load_json(chunk_json)
print(f"- Loaded texts positions for {len(chunk_labels)} images")

# Load dataset
split_json = os.path.join(SCITSR_PATH, 'train', 'label', 'split_label.json')
split_labels = load_json(split_json)
dataset = ImageDataset(img_dir, split_labels, 8, scale=1, min_width=10, returns_image_name=True)
print(f'- Loaded dataset with {len(dataset)} examples')

- Got 10000 images
- Loaded 10000 labels for Merge module
- Loaded texts positions for 10000 images
- Loaded dataset with 10000 examples


In [7]:
IDX = 20
img, label, img_name = dataset[IDX]

In [8]:
def merge_heur_pred(model: SplitModel, 
                    img_dir: str, 
                    split_json: str, 
                    merge_json: str, 
                    chunk_json: str,
                    iou_th: float = 0.7):
    '''
    Args:
    model -- Split model
    img_dir -- string, path to image folder of on of train, val, or test set
    split_json -- string, path to json ground truth file for Split module
    merge_json -- string, path to json ground truth file for Merge module
    chunk_json -- string, path to json file that contain chunk info (coordinates of texts)
    iou_th -- IoU threshold
    '''
    # Load images
    imgs_paths = [os.path.join(img_dir, p) for p in os.listdir(img_dir)]
    print(f'- Got {len(imgs_paths)} images')

    # Load Merge labels
    merge_labels = load_json(merge_json)
    print(f"- Loaded {len(merge_labels)} labels for Merge module")

    # Load Text posisions
    chunk_labels = load_json(chunk_json)
    print(f"- Loaded texts positions for {len(chunk_labels)} images")

    # Load dataset
    split_labels = load_json(split_json)
    dataset = ImageDataset(img_dir, split_labels, 8, scale=1, min_width=10, returns_image_name=True)
    print(f'- Loaded dataset with {len(dataset)} examples')

    single_col_or_row = []
    shape_mismatch = []
    wrong_label = []
    wrong_coordinates = []
    f1s, recalls, precisions = [], [], []
    model.eval()

    for img, label, img_name in tqdm(dataset):
        texts_pos = chunk_labels[img_name]
        
        # load ground truth
        r_gt, c_gt, R_gt, D_gt = load_merge_gt(merge_labels, img_name)
        if R_gt.ndim != 2 or D_gt.ndim != 2: 
            single_col_or_row.append(img_name)
            continue
        row_gt_idxs, col_gt_idxs = borders(r_gt), borders(c_gt)
        cells_gt = get_cells(row_gt_idxs, col_gt_idxs)
        if len(cells_gt) != R_gt.shape[0] * D_gt.shape[1]:
            shape_mismatch.append(img_name)
            continue
        cells_merged_gt = merge_cells(cells_gt, R_gt, D_gt)
        
        # get predictions
        with torch.no_grad():
            r_pred, c_pred = model(img.unsqueeze(0))
        r_pred, c_pred = process_split_results(r_pred, c_pred)
        r_pred, c_pred = refine_split_results(r_pred), refine_split_results(c_pred)
        row_pred_idxs, col_pred_idxs = borders(r_pred), borders(c_pred)                                      
        cells_pred = get_cells(row_pred_idxs, col_pred_idxs)
        if len(row_pred_idxs) == 0 or len(col_pred_idxs) == 0:
            wrong_label.append(img_name)
            continue
        
        # apply Merge heuristics
        R_pred, D_pred = create_pred_matrices(row_pred_idxs, col_pred_idxs)
        rule1(cells_pred, texts_pos, R_pred, D_pred)
        rule2(cells_pred, texts_pos, R_pred, D_pred)
        cells_merged_pred = merge_cells(cells_pred, R_pred, D_pred)
        
        # get evaluation scores
        f1, rec, prec, name = eval(cells_merged_pred, cells_merged_gt, threshold=iou_th, img_name=img_name)
        if name: 
            wrong_coordinates.append(name)
        f1s.append(f1); recalls.append(rec); precisions.append(prec)

    f1_avg, rec_avg, prec_avg = np.mean(f1s), np.mean(recalls), np.mean(precisions)
    print(f'F1: {f1_avg:.4f} ; Recall: {rec_avg:.4f} ; Precision: {prec_avg:.4f}')
    scores = {
        'f1': f1_avg, 
        'recall': rec_avg, 
        'precision': prec_avg
    }
    errors = {
        'single': single_col_or_row, 
        'shape_mismatch': shape_mismatch, 
        'wrong_label': wrong_label,
        'wrong_coordinates': wrong_coordinates
    }
    return scores, errors

In [10]:
# Evaluate on train set
def evaluate_set(set_dir, **kwargs):
    img_dir = os.path.join(set_dir, 'img')
    label_dir = os.path.join(set_dir, 'label')
    split_json = os.path.join(label_dir, 'split_label.json')
    merge_json = os.path.join(label_dir, 'merge_label.json')
    chunk_json = os.path.join(label_dir, 'chunk_label.json')
    return merge_heur_pred(net, img_dir, split_json, merge_json, chunk_json, **kwargs)

In [11]:
train_dir = os.path.join(SCITSR_PATH, 'train')
scores, errors = evaluate_set(train_dir, iou_th=0.6)

- Got 10000 images
- Loaded 10000 labels for Merge module
- Loaded texts positions for 10000 images
- Loaded dataset with 10000 examples


 32%|███▏      | 3201/10000 [16:21<34:45,  3.26it/s]  


IndexError: tuple index out of range

In [None]:
len(errors['single']), len(errors['shape_mismatch']), len(errors['wrong_label'], len(errors['wrong_coordinates']))

(8, 1057, 10)

In [None]:
def save_dict(dict, dir_path):
    os.makedirs(dir_path, exist_ok=True)
    for key in errors.keys():
        p = os.path.join(dir_path, key + '.json')
        with open(p, 'w') as f:
            json.dump(dict[key], f)

In [None]:
ERROR_PATH = '/Users/longhoang/Developer/table-reg/code/deep-split-merge-scitsr/merge/error/train'
save_dict(errors, ERROR_PATH)

In [None]:
VAL_DIR = '/Users/longhoang/Developer/table-reg/data/scitsr-split-train/val'
val_scores, val_errors = evaluate_set(VAL_DIR, iou_th=0.6)

Got 1971 images
Loaded 1971 labels for Merge module
Loaded texts positions for 11971 images
Loaded dataset with 1971 examples


100%|██████████| 1971/1971 [09:37<00:00,  3.41it/s]


F1: 0.7573 ; Recall: 0.7973 ; Precision: 0.7319


In [None]:
VAL_ERROR_PATH = '/Users/longhoang/Developer/table-reg/code/deep-split-merge-scitsr/merge/error/val'
save_dict(val_errors, VAL_ERROR_PATH)

In [None]:
len(val_errors['single']), len(val_errors['shape_mismatch']), len(val_errors['wrong_label'], len(val_errors['wrong_coordinates']))

(1, 229, 2)