In [1]:
SCITSR_PATH = '/Users/admin/Developer/table-recognition/data/SciTSR-partition'
MODEL_WEIGHT = '/Users/admin/Developer/table-recognition/pret-models/split2.pth'
SAVE_PATH = '/Users/admin/Developer/table-recognition/results/stats_split'

In [2]:
import numpy as np
from tqdm import tqdm
import glob
import cv2 as cv

from data_utils.utils import *
from merge.heuristics import *
from dataset.dataset import ImageDataset
from modules.split_modules import SplitModel

In [3]:
# LOAD MODEL
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = SplitModel(3)
model = torch.nn.DataParallel(model).to(device)

if device == 'cuda':
    model.load_state_dict(torch.load(MODEL_WEIGHT))
else:
    model.load_state_dict(torch.load(MODEL_WEIGHT, map_location=torch.device('cpu')))

device

'cpu'

In [4]:
@torch.no_grad
def get_split_results(model: SplitModel, 
                    set_dir: str, 
                    selected_images=None
):
    '''
    Get split results as a dictionary for further postprocessing or analysis
    Args:
        model -- Split model
        set_dir -- string, path to train, val, or test set
    Returns:
        res -- dictionary where keys are image names
    '''
    img_dir = os.path.join(set_dir, 'img')
    split_json = os.path.join(set_dir, 'label', 'split_label.json')
    
    # Load dataset
    split_labels = load_json(split_json)
    dataset = ImageDataset(img_dir, split_labels, 8, scale=1, min_width=10, returns_image_name=True)
    print(f'- Loaded dataset with {len(dataset)} examples')

    res = {}
    for img, label, name in tqdm(dataset):
        if selected_images is not None and name not in selected_images:
            continue
        # ground truth
        r_gt, c_gt = label
        r_gt, c_gt = r_gt.cpu().numpy(), c_gt.cpu().numpy() 
        row_gt_idxs, col_gt_idxs = borders(r_gt), borders(c_gt)
        num_rows_gt, num_cols_gt = round(len(row_gt_idxs) / 2), round(len(col_gt_idxs) / 2)

        # prediction
        r_pred, c_pred = model(img.unsqueeze(0))
        r_pred, c_pred = process_split_results(r_pred, c_pred)
        r_pred, c_pred = refine_split_results(r_pred), refine_split_results(c_pred)
        row_pred_idxs, col_pred_idxs = borders(r_pred), borders(c_pred)
        num_rows_pred, num_cols_pred = round(len(row_pred_idxs) / 2), round(len(col_pred_idxs) / 2)

        # eval on precision, recall, and f1
        r_metrics = eval3(r_pred, r_gt)
        c_metrics = eval3(c_pred, c_gt)

        # log results
        res[name] = {
            'num_rows_gt': num_rows_gt,
            'num_cols_gt': num_cols_gt,
            'num_rows_pred': num_rows_pred,
            'num_cols_pred': num_cols_pred,
            'row_gt': r_gt.tolist(), 'col_gt': c_gt.tolist(),
            'row_pred': r_pred.tolist(), 'col_pred': c_pred.tolist(), 
            'row_precision': r_metrics['precision'], 
            'row_recall': r_metrics['recall'], 
            'row_f1': r_metrics['f1'],
            'col_precision': c_metrics['precision'], 
            'col_recall': c_metrics['recall'], 
            'col_f1': c_metrics['f1']
        }

    return res

In [5]:
# save result on validation set
val_res = get_split_results(model, os.path.join(SCITSR_PATH, 'val'))

- Loaded dataset with 1971 examples


100%|██████████| 1971/1971 [20:28<00:00,  1.60it/s]  


In [6]:
import csv

def save_csv(results, filepath):
    os.makedirs(os.path.dirname(filepath), exist_ok=True)
    fieldnames = ['image_name', 'num_rows_gt', 'num_cols_gt', 'num_rows_pred', 'num_cols_pred',
                  'row_gt', 'col_gt', 'row_pred', 'col_pred',
                  'row_precision', 'row_recall', 'row_f1',
                  'col_precision', 'col_recall', 'col_f1']

    with open(filepath, 'w', newline='') as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        writer.writeheader()

        for image_name, data in results.items():
            writer.writerow({'image_name': image_name,
                             'num_rows_gt': data['num_rows_gt'],
                             'num_cols_gt': data['num_cols_gt'],
                             'num_rows_pred': data['num_rows_pred'],
                             'num_cols_pred': data['num_cols_pred'],
                             'row_gt': data['row_gt'],
                             'col_gt': data['col_gt'],
                             'row_pred': data['row_pred'],
                             'col_pred': data['col_pred'],
                             'row_precision': data['row_precision'],
                             'row_recall': data['row_recall'],
                             'row_f1': data['row_f1'],
                             'col_precision': data['col_precision'],
                             'col_recall': data['col_recall'],
                             'col_f1': data['col_f1']})

In [7]:
save_csv(val_res, os.path.join(SAVE_PATH, 'val.csv'))

In [8]:
# save result on test set
test_res = get_split_results(model, os.path.join(SCITSR_PATH, 'test'))
save_csv(test_res, os.path.join(SAVE_PATH, 'test.csv'))

- Loaded dataset with 3000 examples


100%|██████████| 3000/3000 [31:45<00:00,  1.57it/s]  


In [9]:
# save result on test COMP set
comp_list = load_txt(os.path.join(SCITSR_PATH, 'SciTSR-COMP.list'))
test_comp_res = get_split_results(model, os.path.join(SCITSR_PATH, 'test'), comp_list)
save_csv(test_comp_res, os.path.join(SAVE_PATH, 'test_comp.csv'))

- Loaded dataset with 3000 examples


100%|██████████| 3000/3000 [09:03<00:00,  5.52it/s]
