## 从pdf文件中提取表格并进行分割

In [1]:
import warnings
warnings.filterwarnings('ignore')
from docai_table.model.layout_model import LayoutModel, LayoutType
from docai_table.util.pdf_helper import get_pdf_page_images
from docai_table.util.visualize_util import visualize_contents
from collections import defaultdict
from pathlib import Path
import os
import pdb
import numpy as np
from tqdm import tqdm
from PIL import Image
import torch
import matplotlib.pyplot as plt
from scipy import ndimage
import cv2
import random
from glob import glob

## 测试Layout模型

In [None]:
pdf_root = '/home/suqi/dataset/pdf_data'
save_root = '/home/suqi/dataset/pdf_images'
os.makedirs(save_root, exist_ok=True)
pdf_paths = sorted(glob(os.path.join(pdf_root, '**', "*.pdf"), recursive=True))

In [None]:
layout_model = LayoutModel()

In [None]:
table_count = 0
# torch.cuda.empty_cache()
for pdf_path in tqdm(pdf_paths[774:]):
    images = get_pdf_page_images(pdf_path)
    pdf_name = pdf_path.split('/')[-1].split('.pdf')[0]
    layout_results = layout_model.predict(images)
    for idx, layout_res in enumerate(layout_results):
        for j, layout in enumerate(layout_res.layouts):
            if layout.type == LayoutType.Table:
                table_count += 1
                table_crop = images[idx].crop(layout.bbox)
                table_crop.save(os.path.join(save_root, f"{pdf_name}_PAGE_{str(idx).rjust(3, '0')}_NO_{str(j).rjust(3, '0')}.jpg"))

In [None]:
page_image = get_pdf_page_images("../data/wu2021.pdf", page_number=7)
layout_model = LayoutModel()
layout_result = layout_model.predict(page_image)
display(visualize_contents(page_image, layout_result.get_visualize_contents()))

table_image = page_image.crop(layout_result.layouts[8].bbox)

## 分子图像处理，将多余的白边去掉

In [None]:
# image = np.array(Image.open("/home/suqi/dataset/MolScribe/Supple/indigo_validation_set_examples/images/700.png"))
# if len(image.shape) == 3:
#     image = image.mean(axis=-1)

# print(image.shape)
# print(image.max())
# print(image.min())

# plt.imshow(image, cmap='gray')

def cut_white_border(image: np.ndarray, threshold=250) -> np.ndarray:
    """
    image: binary image, (H x W)
    """
    x_range, y_range = np.where(image < threshold)
    x_min, y_min = map(lambda x: np.min(x), (x_range, y_range))
    x_max, y_max = map(lambda x: np.max(x), (x_range, y_range))
    return image[x_min: x_max + 1, y_min: y_max + 1]

# image = cut_white_border(image)
# plt.imshow(image, cmap='gray')
# print(image.shape)


In [None]:

data_root = '/home/suqi/dataset/MolScribe/'
target_folder = 'preprocessed'
os.makedirs(os.path.join(data_root, target_folder), exist_ok=True)

patterns = [
    # '.png', '.jpg',
    '.TIF',
    # '.tiff', '.bmp'
]

# png_folders = [
#     'indigo_validation_set_examples/images',
#     'perturb/CLEF_pertubations/*',
#     'perturb/STAKER',
#     'perturb/UOB_pertubations/*',
#     'perturb/USPTO_pertubations',
#     'synthetic/chemdraw',
#     'synthetic/indigo',
#     'uspto_validation_set_examples/images',
#     'valko_testset_results/image_results'
# ]

dataset_counter = defaultdict(lambda: 0)
data_dict = {}
for pattern in patterns:
    image_list = sorted(Path(data_root).rglob('*' + pattern))
    print(len(image_list))
    for image_path in tqdm(image_list, desc=f'Now: {pattern}'):
        try:
            dataset = str(image_path).strip().split('/')[5]
            data_idx = dataset_counter[dataset]
            dataset_counter[dataset] = data_idx + 1

            save_name = dataset + '_' + str(data_idx).rjust(7, '0') + '.png'
            data_dict[save_name] = image_path
            if os.path.exists(os.path.join(data_root, target_folder, save_name)):
                continue

            image = np.array(Image.open(os.path.join(data_root, image_path)))
            if len(image.shape) == 3:
                image = image.mean(axis=-1)

            if pattern == '.tiff' or pattern == '.bmp':
                width = image.shape[1]
                image = image[:, width * 3 // 4:]

            if image.max() <= 1.1:
                threshold = 250 / 255
                image = cut_white_border(image.astype(np.float16), threshold)
                image = (image * 255).astype(np.uint8)
            else:
                threshold = 250
                image = cut_white_border(image.astype(np.uint8), threshold)

            image = np.tile(image[..., None], 3)
            Image.fromarray(image).save(os.path.join(data_root, target_folder, save_name))
        except Exception as e:
            print(f'Exception at: {image_path}')
            pdb.set_trace()


## 随机选择图形进行查看

In [None]:
data_root = '/home/suqi/dataset/MolScribe/'
target_folder = 'preprocessed'

img_prefix = [img_name[:5] for img_name in os.listdir(os.path.join(data_root, target_folder))]

In [None]:
img_prefix = set(img_prefix)
print(len(img_prefix))
print(img_prefix)

In [None]:
for prefix in img_prefix:
    paths = list(sorted(Path(os.path.join(data_root, target_folder)).glob(f"{prefix}*.png")))
    np.random.shuffle(paths)
    paths = paths[:20]
    for path in paths:
        os.system(f'cp {path} /home/suqi/dataset/MolScribe/temp')

## 统计图像的平均大小

In [None]:
image_root = os.path.join(data_root, target_folder)
image_shapes = {}

for image_name in tqdm(sorted(os.listdir(image_root))):
    image_shapes[image_name] = list(Image.open(os.path.join(image_root, image_name)).size)

In [None]:
torch.save(image_shapes, './statistics/molecular_image_shape.pth')

# image_names = []
# image_shapes = []
# 
# for key, val in torch.load('./statistics/molecular_image_shape.pth').items():
#     image_names.append(key)
#     image_shapes.append(val)
# 
# image_shapes = np.stack(image_shapes, axis=0)


In [None]:
width_height_ratio = image_shapes[:, 0] / image_shapes[:, 1]
plt.hist(width_height_ratio, np.linspace(0, 5, 500))
plt.show()

## 处理分子最大连通域

In [None]:
def denoise_molecular(image, dilate_degree=15):
    """
    @param image: np.ndarray, [h, w, ...]
    @param dilate_degree:
    remove the noise point in an image
    """
    fill_value = image.max()
    assert image.dtype == np.uint8 or fill_value > 2

    # binarize
    threshold = 250
    bin_image = np.mean(image, axis=-1).astype(np.uint8)
    mask = bin_image > threshold
    bin_image[mask] = 0
    bin_image[~mask] = 1

    kernel = np.ones(shape=(2, 2), dtype=np.int8)
    dilated_bin_image = cv2.dilate(bin_image, kernel, dilate_degree, iterations=dilate_degree).astype(np.uint8)
    # calculate connected domain, remove too small domain
    size_threshold = 0.5

    label, n_dom = ndimage.label(dilated_bin_image)
    dom_size = np.bincount(label.reshape(-1))
    max_dom = np.max(dom_size[1:])
    erase_dom = np.arange(0, n_dom + 1)[dom_size < int(max_dom * size_threshold)]
    dilated_bin_image[np.isin(label, erase_dom)] = 0

    # mask and cut
    x_range, y_range = np.where(dilated_bin_image == 0)
    image[x_range, y_range, :] = fill_value
    bin_image[x_range, y_range] = 0
    del x_range, y_range
    x_range, y_range = np.where(bin_image == 1)
    x_min, y_min = map(lambda x: x.min(), (x_range, y_range))
    x_max, y_max = map(lambda x: x.max(), (x_range, y_range))

    return image[x_min: x_max + 1, y_min: y_max + 1]

In [None]:
mole_root = '/home/suqi/dataset/MolScribe/preprocessed'
# test_sample = [
#     'uspto_mol_0000000.png',
#     'uspto_mol_0000003.png',
#     'uspto_mol_0000004.png',
#     'uspto_mol_0000007.png',
#     'uspto_mol_0040757.png',
#     'uspto_mol_0099345.png',
#     'uspto_mol_0106860.png',
#     'valko_testset_results_0000004.png',
#     'valko_testset_results_0000005.png',
#     'valko_testset_results_0000008.png',
#     'valko_testset_results_0000009.png',
#     'valko_testset_results_0000393.png'
# ]
# 
# for i in range(len(test_sample)):
#     mole_image = cv2.cvtColor(cv2.imread(os.path.join(mole_root, test_sample[i])), cv2.COLOR_BGR2RGB)
# 
#     plt.imshow(mole_image)
#     plt.show()
# 
#     mole_image = denoise_molecular(mole_image, 15)
#     plt.imshow(mole_image, cmap='gray')
#     plt.show()

In [None]:
new_mole_root = '/home/suqi/dataset/MolScribe/preprocessed'
os.makedirs(new_mole_root, exist_ok=True)

for name in tqdm(sorted(os.listdir(mole_root))):
    mole_image = cv2.cvtColor(cv2.imread(os.path.join(mole_root, name)), cv2.COLOR_RGB2BGR)
    mole_image = denoise_molecular(mole_image, 15)
    cv2.imwrite(os.path.join(new_mole_root, name), mole_image)


## 数据集结构化

In [3]:
pub_root = '/home/suqi/dataset/pubtables-1m/PubTables-1M-Structure'
syn_root = '/home/suqi/dataset/synthesis_table'
target_root = '/home/suqi/dataset/Pub_Syn_Union'

random.seed(1327)
np.random.seed(1327)

syn_list = list(sorted(os.listdir(os.path.join(syn_root, 'images'))))
np.random.shuffle(syn_list)
syn_train_num = int(8 * len(syn_list) / 9)
syn_train_list, syn_val_list = syn_list[: syn_train_num], syn_list[syn_train_num:]
syn_train_list, syn_val_list = map(
    lambda x: [name.split('.')[0] for name in x],
    (syn_train_list, syn_val_list)
)

407259 362008 45251


In [4]:
pub_train_list = list(sorted(os.listdir(os.path.join(pub_root, 'train'))))
pub_val_list = list(sorted(os.listdir(os.path.join(pub_root, 'val'))))

np.random.shuffle(pub_train_list)
np.random.shuffle(pub_val_list)

pub_train_list = pub_train_list[:len(syn_train_list)]
pub_val_list = pub_val_list[:len(syn_val_list)]

pub_train_list, pub_val_list = map(
    lambda x: [name.split('.')[0] for name in x],
    (pub_train_list, pub_val_list)
)

362008 45251


In [5]:
os.makedirs(os.path.join(target_root, 'images'), exist_ok=True)
os.makedirs(os.path.join(target_root, 'train'), exist_ok=True)
os.makedirs(os.path.join(target_root, 'val'), exist_ok=True)

In [6]:
# link images
for name in syn_train_list + syn_val_list:
    os.link(os.path.join(syn_root, 'images', name + '.jpg'), os.path.join(target_root, 'images', name + '.jpg'))

for name in pub_train_list + pub_val_list:
    os.link(os.path.join(pub_root, 'images', name + '.jpg'), os.path.join(target_root, 'images', name + '.jpg'))

In [7]:
# link xml files
for name, tp in zip(syn_train_list + syn_val_list, (['train'] * len(syn_train_list)) + (['val'] * len(syn_val_list))):
    prefix = '_'.join(name.split('_')[:2])
    base_name = '_'.join(name.split('_')[2:])
    xml_path = glob(os.path.join(pub_root, 'train', base_name + '.xml')) + \
               glob(os.path.join(pub_root, 'test', base_name + '.xml')) + \
               glob(os.path.join(pub_root, 'val', base_name + '.xml'))
    assert len(xml_path) == 1
    os.link(os.path.join(xml_path[0]), os.path.join(target_root, tp, name + '.xml'))

for name, tp in zip(pub_train_list + pub_val_list, (['train'] * len(pub_train_list)) + (['val'] * len(pub_val_list))):
    xml_path = glob(os.path.join(pub_root, tp, name + '.xml'))
    assert len(xml_path) == 1
    os.link(os.path.join(xml_path[0]), os.path.join(target_root, tp, name + '.xml'))

In [8]:
# generate txt file
with open(os.path.join(target_root, 'images_filelist.txt'), 'w') as f:
    for image in sorted(os.listdir(os.path.join(target_root, 'images'))):
        f.write('images/' + image + '\n')

with open(os.path.join(target_root, 'train_filelist.txt'), 'w') as f:
    for xml in sorted(os.listdir(os.path.join(target_root, 'train'))):
        f.write('train/' + xml + '\n')
        
with open(os.path.join(target_root, 'val_filelist.txt'), 'w') as f:
    for xml in sorted(os.listdir(os.path.join(target_root, 'val'))):
        f.write('val/' + xml + '\n')

## 试加载模型

In [None]:
import sys

def print(*args, **kwargs):
    sep = kwargs.get('sep', ' ')
    end = kwargs.get('end', '\n')
    file = kwargs.get('file', sys.stdout)
    text = sep.join(args) + end
    file.write(text)

In [None]:
d = {1: 2, 3: 4}
print(d)