In [None]:
%load_ext autoreload
%autoreload 2

import random
import sys
import torch
import numpy as np
from PIL import Image

sys.path.extend(["../src/table_detr/detr", "../src/table_detr/src"])
from app.inference import (
    detect_text, 
    load_artifacts,
    objects_to_cells, 
    predict, 
    select_structure_predictions, 
    select_table_predictions,     
    structure_class_map,
    structure_class_names,
    structure_class_thresholds,
)
from app.visualize import visualize_bbox, visualize_postprocessed_cell

In [None]:
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

detection_preprocessor, detection_model = load_artifacts(
    data_type="detection",
    config_file="../src/table_detr/src/detection_config.json",
    model_load_path="../models/pubtables1m_detection_detr_r18.pth",
)

structure_preprocessor, structure_model = load_artifacts(
    data_type="structure",
    config_file="../src/table_detr/src/structure_config.json",
    model_load_path="../models/pubtables1m_structure_detr_r18.pth",
)

In [None]:
paths = [
    "../data/1.jpg",
]

for path in paths:
    img = Image.open(path).convert("RGB")
    
    pred_labels, pred_scores, pred_bboxes = predict(detection_preprocessor, detection_model, img)
    pred_tables = select_table_predictions(pred_labels, pred_scores, pred_bboxes)
    
    visualize_bbox(img, pred_tables, "detection")
    
    for pred in pred_tables:
        bbox = pred["bbox"]
        #table_img = img.crop([bbox[0] - 50, bbox[1] - 50, bbox[2] + 50, bbox[3] + 50])
        table_img = img.crop(bbox)

        tokens = detect_text(table_img)

        pred_labels, pred_scores, pred_bboxes = predict(structure_preprocessor, structure_model, table_img)
        pred_structures = select_structure_predictions(pred_labels, pred_scores, pred_bboxes)
        
        visualize_bbox(table_img, pred_structures, "structure")
        
        _, pred_cells, _ = objects_to_cells(
            pred_bboxes,
            pred_labels,
            pred_scores,
            tokens,
            structure_class_names,
            structure_class_thresholds,
            structure_class_map,
        )
        
        visualize_postprocessed_cell(table_img, pred_cells)