In [None]:
import os

os.chdir('../')
import sys

import numpy as np

sys.path.append("../detr")
from table_datasets import read_pascal_voc
from src.main import get_class_map
from docai_util import bboxes_to_cells, image_replace, binarize, _boundary_range, box_shrink
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
import bisect
from PIL import Image
from glob import glob

In [None]:
class_map = get_class_map('structure')
index_to_class_map = {val: key for key, val in class_map.items()}

color_map = {
    'table': (0, 0, 0),
    'table column': (255, 0, 0),
    'table row': (0, 255, 0),
    'table column header': (0, 0, 255),
    'table projected row header': (255, 255, 0),
    'table spanning cell': (255, 0, 255)
}

## 测试

In [None]:
# xmin, ymin, xmax, ymax
bboxes, labels = read_pascal_voc('/home/suqi/dataset/pubtables-1m/PubTables-1M-Structure/train/PMC1592305_table_0.xml',
                                 class_map)
bboxes = np.array(bboxes).astype(np.int32)
labels = np.array(labels).astype(np.int32)

In [None]:
print(bboxes[labels == 2])

In [None]:
print(len(bboxes))
print(len(labels))

In [None]:
for label in range(6):
    image = cv2.imread('/home/suqi/dataset/pubtables-1m/PubTables-1M-Structure/images/PMC1592305_table_0.jpg')
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    for idx, bbox in enumerate(bboxes):
        if labels[idx] != label:
            continue
        x_min, y_min, x_max, y_max = map(lambda x: int(x), bbox.tolist())
        color = color_map[index_to_class_map[labels[idx]]]
        image = cv2.rectangle(image, (x_min, y_min), (x_max, y_max), color=color,
                              thickness=1)
    plt.imshow(image)
    plt.title(index_to_class_map[label])
    plt.show()

image = cv2.imread('/home/suqi/dataset/pubtables-1m/PubTables-1M-Structure/images/PMC1592305_table_0.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
cells = bboxes_to_cells(bboxes, labels)
print(cells.shape)

for cell in cells:
    x_min, y_min, x_max, y_max = map(lambda x: int(x), cell.tolist())
    image = cv2.rectangle(image, (x_min, y_min), (x_max, y_max), color=(255, 255, 0), thickness=1)

plt.imshow(image)
plt.title("cell")
plt.show()

## 统计表格中grid的大小

In [None]:
anno_root = [
    '/home/suqi/dataset/FinTabNet.c/FinTabNet.c-Structure/train/',
    '/home/suqi/dataset/FinTabNet.c/FinTabNet.c-Structure/test/',
    '/home/suqi/dataset/FinTabNet.c/FinTabNet.c-Structure/val/'
]

# length 97475
xml_paths = []
for root in anno_root:
    xml_paths.extend([os.path.join(root, path) for path in sorted(os.listdir(root)) if path.endswith('.xml')])
print(len(xml_paths))

In [None]:
cell_map = {}

for xml_path in tqdm(xml_paths):
    bboxes, labels = read_pascal_voc(xml_path, class_map)
    bboxes, labels = map(lambda x: np.array(x).astype(np.int32), (bboxes, labels))
    cell_map[xml_path] = bboxes_to_cells(bboxes, labels)

In [None]:
torch.save(cell_map, './statistics/cell_shapes_fintable.pth')
# cell_map = torch.load('./statistics/cell_shapes_pubtable.pth')

In [None]:
pathes, cells, indices = [], [], []

for idx, (path, cell) in enumerate(cell_map.items()):
    pathes.append(path)
    cells.append(cell)
    indices.append([idx] * len(cell))

cells = np.concatenate(cells, axis=0)
widths = cells[:, 2] - cells[:, 0]
heights = cells[:, 3] - cells[:, 1]
indices = np.concatenate(indices, axis=0)

In [None]:
print(cells.shape)
print(indices.shape)

In [None]:
# sort the cells by their height
indices_height = np.argsort(heights, axis=0)
indices_area = np.argsort(heights * widths, axis=0)

In [None]:
scores = np.zeros_like(indices_height)
scores[indices_height] += np.arange(len(indices_height))
scores[indices_area] += np.arange(len(indices_area))

In [None]:
sorted_indices = scores.argsort()[::-1]

In [None]:
width = cells[:, 2] - cells[:, 0]
height = cells[:, 3] - cells[:, 1]
print(np.sum((width > 100) & (height > 100)))


## 插图测试

In [None]:
image = Image.open('/home/suqi/dataset/FinTabNet.c/FinTabNet.c-Structure/images/BIIB_2015_page_115_table_0.jpg').convert(
    'RGB')
image = np.array(image)
bboxes, labels = read_pascal_voc('/home/suqi/dataset/FinTabNet.c/FinTabNet.c-Structure/train/BIIB_2015_page_115_table_0.xml',
                                 class_map)
cells = bboxes_to_cells(np.array(bboxes), np.array(labels))

## 载入分子数据集的路径，供随机选择

In [None]:
image_names = []
image_shapes = []

for key, val in torch.load('./statistics/molecular_image_shape.pth').items():
    image_names.append(key)
    image_shapes.append(val)

image_names = np.array(image_names)
image_shapes = np.stack(image_shapes, axis=0)

width_height_ratio = image_shapes[:, 0] / image_shapes[:, 1]
whr_indices = np.argsort(width_height_ratio)
image_names, image_shapes, width_height_ratio = image_names[whr_indices], image_shapes[whr_indices], width_height_ratio[
    whr_indices]


In [None]:
for bbox in cells:
    box_width, box_height = bbox[2] - bbox[0], bbox[3] - bbox[1]
    whr = box_width / box_height
    idx = bisect.bisect(width_height_ratio, whr)
    cand_radius = 20
    cand_indices = np.arange(max(idx - cand_radius, 0),
                             min(idx + cand_radius, len(width_height_ratio)))

    # 计算这些候选分子图填到cell里需要进行的缩放的大小
    scales = (np.array([box_width, box_height]) / image_shapes[cand_indices]).min(axis=-1)
    scales[scales < 1] = 1 / scales[scales < 1]

    # 缩放越接近1的分子有更大概率被抽到
    prob_no_norm = np.exp(-scales)
    prob = prob_no_norm / np.sum(prob_no_norm)
    idx = np.random.choice(cand_indices, size=1, replace=False, p=prob)[0]

    target = Image.open('/home/suqi/dataset/MolScribe/preprocessed/' + image_names[idx]).convert('RGB')
    target = np.array(target)
    bbox = box_shrink(image, bbox)
    merged_image = image_replace(image, bbox, target)
    plt.imshow(merged_image)
    plt.show()
    # plt.imshow(image)
    # plt.show()
    # plt.imshow(box_image)
    # plt.show()
    # plt.imshow(bin_image, cmap='gray')
    # plt.show()

 ## 改进的边界线检测算法

In [None]:
image_path = '/home/suqi/dataset/pubtables-1m/PubTables-1M-Structure/images/PMC4741940_table_0.jpg'
image = Image.open(image_path)
image = np.array(image)

bboxes, labels = read_pascal_voc('/home/suqi/dataset/pubtables-1m/PubTables-1M-Structure/test/PMC4741940_table_0.xml',
                                 class_map)
cells = bboxes_to_cells(np.array(bboxes), np.array(labels))

In [None]:
for idx, bbox in enumerate(cells):
    bbox = list(map(lambda x: int(x), bbox))
    x_min, y_min, x_max, y_max = bbox

    box_image = cv2.rectangle(image.copy(), (x_min, y_min), (x_max, y_max), color=(255, 0, 0), thickness=1)
    plt.imshow(box_image)
    plt.title(idx)
    plt.show()

    x_min, y_min, x_max, y_max = box_shrink(image, bbox)
    box_image = cv2.rectangle(image.copy(), (x_min, y_min), (x_max, y_max), color=(255, 0, 0), thickness=1)
    plt.imshow(box_image)
    plt.title(idx)
    plt.show()

## Badcase

In [None]:
image_names = [
    "PMC2698926_table_0_COL_01_cells.jpg",
    "PMC3398646_table_1_COL_01_cells.jpg",
    "PMC4508365_table_0_COL_01_cells.jpg",
    "PMC4629724_table_0_COL_08_cells.jpg",
    "PMC4707276_table_1_COL_01_cells.jpg",
    "PMC4869378_table_0_COL_00_cells.jpg",
    "PMC5039233_table_0_COL_00_cells.jpg",
    "PMC5129660_table_2_COL_03_cells.jpg",
    "PMC5124561_table_1_COL_03_cells.jpg",
    "PMC5854118_table_0_COL_02_cells.jpg",
    "PMC5855521_table_0_COL_01_cells.jpg",
    "PMC5937855_table_0_COL_01_cells.jpg",
    "PMC6166094_table_0_COL_00_cells.jpg",
    "PMC5051389_table_0_COL_01_cells.jpg",
    "PMC2672029_table_0_COL_03_cells.jpg",
]
for name in image_names:
    base_name = name.split('_COL_')[0]
    source_image = Image.open(
    f'/home/suqi/dataset/pubtables-1m/PubTables-1M-Structure/images/{base_name}.jpg').convert('RGB')
    source_image = np.array(source_image)
    
    label_file = glob(os.path.join("/home/suqi/dataset/pubtables-1m/PubTables-1M-Structure/", "**", f"{base_name}.xml"))
    assert len(label_file) == 1
    bboxes, labels = read_pascal_voc(label_file[0], class_map)
    cells = bboxes_to_cells(np.array(bboxes), np.array(labels))
    
    # show source image
    plt.imshow(source_image)
    plt.title('Source Image')
    plt.show()
    
    cell_image = source_image.copy()
    for idx, bbox in enumerate(cells):
        bbox = list(map(lambda x: int(x), bbox))
        x_min, y_min, x_max, y_max = bbox
    
        cell_image = cv2.rectangle(cell_image, (x_min, y_min), (x_max, y_max), color=(255, 0, 0), thickness=1)
    plt.imshow(cell_image)
    plt.title('Cell Image')
    plt.show()
