## Sample Inference

Inference run of our model on 20 sample images from the `PubTables-1M` test dataset located in the `./sample_pubtables1m` directory.
The folder contains subfolders:
- `images` with images used for the inference
- `ocr` with ocr results from our custom internal ocr model
- `ocr_gt` with ocr results provided in the PubTables-1M dataset
- `test` with ground truth json files containing bounding boxes coordinates for tables, columns, rows, headers and extra_cells in the case of spanning cells

The below code:
- loads the model
- runs the inference on 20 sample images
- visualizes the output and saves it in the `./demo_results` directory
    - images with `.table.jpg` suffix contain the prediction of our model
    - other images contain the visualized ground truth and prediction for classes where some mistake was made
- calculates average precision and average recall metrics


In [1]:
import os
import sys
import json
import torch
import numpy as np
import glob
from torch.utils.data import DataLoader

from model_functions.transformer_tf_copy import TransformerEncoderTable
from train_data_preparation.coco_tables_dataset import CocoValidDataset
from train.table_extraction import collate_fn_pad, run_all_validations



if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

model_recognition = TransformerEncoderTable(num_clustering_heads=4, is_use_4_points=True,
                                            is_use_image_patches=True, use_content_emb=True)

if torch.cuda.is_available():
    model_recognition.cuda(0)
    # weights are in float16 precision
    model_recognition.half()

checkpoint_path_recognition = 'model_weights/table_recognition.pth'
checkpoint_recognition = torch.load(checkpoint_path_recognition, map_location=device)
model_recognition.load_state_dict(checkpoint_recognition, strict=False)
model_recognition = model_recognition.float()
model_recognition.eval()


num_clustering_heads = 4
class_map = {
    'table': 0,
    'table column': 1,
    'column': 1,
    'table row': 2,
    'row': 2,
    'table column header': 3,
    'table header': 3,
    'header': 3,
    'no object': 6
}

# path to our validation datasets
dataset_paths_validation = {
    './sample_pubtables1m': 1,
}
eval_set = 'test'

valid_dataset = CocoValidDataset(dataset_paths_validation,
                                 eval_set,
                                 class_map=class_map,
                                 num_clustering_heads=num_clustering_heads,
                                 ocr_labels_folder='ocr_gt',
                                 is_use_4_points=True,
                                 is_use_image_patches=True,
                                 is_one_model='both',
                                 use_cell_pointers=True,
                                 is_augment_in_eval=False)

valid_loader = DataLoader(valid_dataset,
                          batch_size=1,
                          shuffle=False,
                          pin_memory=True,
                          num_workers=8,
                          drop_last=False,
                          collate_fn=collate_fn_pad)

output_dir = './demo_results'
os.makedirs(output_dir, exist_ok=True)

best_score = run_all_validations(model_recognition,
                                 valid_loader,
                                 output_dir,
                                 is_use_4_points=True,
                                 is_run_4_5_classes=False,
                                 is_debug_plot=True,
                                 is_augment_in_eval=False)

print('best_score', best_score)
print('END')


d_model 256
is_sum_embeddings True
use_content_emb True
num_coords 8




Looking for jsons


100%|██████████| 1/1 [00:00<00:00, 762.60it/s]


Looking for images


100%|██████████| 1/1 [00:00<00:00, 889.38it/s]


Looking for OCR


100%|██████████| 20/20 [00:00<00:00, 1031.89it/s]
100%|██████████| 20/20 [00:00<00:00, 97.33it/s]


creating index...
index created!


                                                           

current accuracy 0.9959655036832306
For all (6) classes
Accumulating evaluation results...
DONE (t=0.03s).
IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.901
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.946
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.903
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.962
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.865
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.883
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.463
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.804
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.917
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.983
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.

### Sample Model Training

To run the sample training/finetuning with the same 20 images, please find the below code.
Assumes that the previous cells in this jupyter notebook were already run. 

In [6]:
from train_data_preparation.tables_dataset import TrainTablesDataset
from tqdm import trange, tqdm
from train.table_extraction import compute_cost


if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

model = model_recognition
print('Number of model parameters', sum(p.numel() for p in model.parameters() if p.requires_grad))

optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
model.train()
if torch.cuda.is_available():
    model.cuda(0)

# for the demo purposes, we use the same 20 sample images for finetuning
dataset_paths = {
    './sample_pubtables1m': 1,
}

train_dataset = TrainTablesDataset(dataset_paths,
                                   class_map=class_map,
                                   num_clustering_heads=num_clustering_heads,
                                   ocr_labels_folder='ocr_gt',
                                   is_use_4_points=True,
                                   is_use_image_patches=True,
                                   is_one_model='both',
                                   use_cell_pointers=True)

train_loader = DataLoader(train_dataset,
                          batch_size=8,
                          shuffle=True,
                          pin_memory=True,
                          num_workers=8,
                          drop_last=True,
                          collate_fn=collate_fn_pad)


train_iterator = iter(train_loader)

EPOCHS = 3
STEPS_PER_EPOCH = 10
best_score = 0.0

for epoch in trange(EPOCHS):
    model.train()
    progress_bar = trange(STEPS_PER_EPOCH, leave=False, desc='Train')

    for _ in progress_bar:
        try:
            word_boxes, contents_idx, _, adjacency_matrices, _, mask, _, img_patches, _, shadow_mask, _, header_mask = next(
                train_iterator)
        except StopIteration:
            train_iterator = iter(train_loader)
            word_boxes, contents_idx, _, adjacency_matrices, _, mask, _, img_patches, _, shadow_mask, _, header_mask = next(
                train_iterator)

        word_boxes = word_boxes.to(device)
        img_patches = img_patches.to(device)
        contents_idx = contents_idx.to(device)
        mask = mask.to(device).float()
        shadow_mask = shadow_mask.to(device).float()
        header_mask = header_mask.to(device).float()
        adjacency_matrices = adjacency_matrices.to(device)

        preds_clustering = model(word_boxes, contents_idx, mask, img_patches)

        cost = compute_cost(preds_clustering, adjacency_matrices, mask, shadow_mask, header_mask)

        progress_bar.set_description(f'Train loss {cost:.6f}')
        optimizer.zero_grad()
        cost.backward()
        optimizer.step()
        model.zero_grad()


Number of model parameters 15117192
Looking for jsons


100%|██████████| 1/1 [00:00<00:00, 1110.19it/s]


Looking for images


100%|██████████| 1/1 [00:00<00:00, 1260.69it/s]


Looking for OCR


100%|██████████| 20/20 [00:00<00:00, 20887.97it/s]
100%|██████████| 3/3 [00:17<00:00,  5.81s/it]
