In [12]:
import os
import sys
import numpy as np

sys.path.append("../detr")
from docai_util import image_replace, box_shrink
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
from scipy import ndimage
import bisect
from glob import glob

In [2]:
def denoise_molecular(image, dilate_degree):
    """
    @param image: np.ndarray, [h, w, ...]
    @param dilate_degree:
    remove the noise point in an image
    """
    # binarize
    threshold = 250
    bin_image = np.mean(image, axis=-1).astype(np.int32)
    mask = bin_image > threshold
    bin_image[mask] = 0
    bin_image[~mask] = 1

    h, w = bin_image.shape[:2]
    kernel = np.ones(shape=(2, 2), dtype=np.int8)
    bin_image = cv2.dilate(bin_image, dilate_degree, iterations=dilate_degree)

    # calculate connected domain
    ndimage.label(bin_image)


def load_cells(k_remain):
    # load cell map
    cell_map = torch.load('./statistics/cell_shapes_pubtable.pth')
    xml_paths, cells, xml_indices = [], [], []

    for idx, (path, cell) in enumerate(cell_map.items()):
        xml_paths.append(path)
        cells.append(cell)
        xml_indices.extend([idx] * len(cell))

    xml_paths = np.array(xml_paths)
    xml_indices = np.array(xml_indices)
    cells = np.concatenate(cells, axis=0)
    widths = cells[:, 2] - cells[:, 0]
    heights = cells[:, 3] - cells[:, 1]

    # sort the cells by their height
    indices_height = np.argsort(heights, axis=0)
    indices_area = np.argsort(heights * widths, axis=0)

    # calculate the score of every cell
    scores = np.zeros_like(indices_height)
    scores[indices_height] += np.arange(len(indices_height))
    scores[indices_area] += np.arange(len(indices_area))

    # sort by score, and choose the highest k-cells
    sorted_indices = scores.argsort()[::-1][:k_remain]

    cells, xml_indices = cells[sorted_indices], xml_indices[sorted_indices]
    del cell_map, widths, heights, indices_height, indices_area, scores, sorted_indices

    return cells, xml_indices, xml_paths


def load_molecular_images():
    image_names = []
    image_shapes = []

    for key, val in torch.load('./statistics/molecular_image_shape.pth').items():
        image_names.append(key)
        image_shapes.append(val)

    image_names = np.array(image_names)
    image_shapes = np.stack(image_shapes, axis=0)

    whr_ratios = image_shapes[:, 0] / image_shapes[:, 1]
    whr_indices = np.argsort(whr_ratios)
    return image_names[whr_indices], image_shapes[whr_indices], whr_ratios[whr_indices]

In [3]:
table_image_root = '/home/suqi/dataset/pubtables-1m/PubTables-1M-Structure/images'
mole_image_root = '/home/suqi/dataset/MolScribe/preprocessed'
save_dir = '/home/suqi/dataset/synthesis_table_test'

os.makedirs(save_dir, exist_ok=True)

cells, xml_indices, xml_paths = load_cells(50000)
mole_names, mole_shapes, mole_ratios = load_molecular_images()

# 处理一张图里有多个选中cell的情况
unique_xml = np.unique(xml_indices)
xml_cell_table = {}
for xml_idx in unique_xml:
    xml_cell_table[xml_paths[xml_idx]] = cells[xml_indices == xml_idx]

In [4]:
test_list = [
    ("PMC2706793_table_0.jpg", 'train'),
    ("PMC2711041_table_0.jpg", 'train'),
    ("PMC2713272_table_0.jpg", 'train'),
    ("PMC2717052_table_0.jpg", 'train'),
    ("PMC2719080_table_2.jpg", 'train'),
    ("PMC2739837_table_0.jpg", 'train'),
    ("PMC2745412_table_0.jpg", 'val'),
    ("PMC2753557_table_1.jpg", 'val'),
    ("PMC2756279_table_1.jpg", 'test'),
    ("PMC2758861_table_0.jpg", 'train'),
    ("PMC2768712_table_0.jpg", 'train'),
    ("PMC2795741_table_1.jpg", 'train'),
    ("PMC2796493_table_0.jpg", 'train'),
    ("PMC2798617_table_0.jpg", 'train')
]
test_list = [f'/home/suqi/dataset/pubtables-1m/PubTables-1M-Structure/{path}/' + name.replace('.jpg', '.xml') for
             name, path in test_list]
xml_cell_table_part = {key: xml_cell_table[key] for key in test_list}

In [6]:
for xml_path, cells in tqdm(xml_cell_table_part.items(), desc='Generating: '):
    # load table image
    table_base_name = xml_path.split('/')[-1].split('.xml')[0]
    table_image = cv2.imread(os.path.join(table_image_root, table_base_name + '.jpg'))
    table_image_correct = table_image.copy()
    
    for bbox in cells:
        whr = (bbox[2] - bbox[0]) / (bbox[3] - bbox[1])
        # 找出与当前cell长宽比最接近的分子图片，在其附近随机选择
        idx = bisect.bisect(mole_ratios, whr)
        idx = np.random.randint(max(idx - 20, 0), min(idx + 20, len(mole_ratios)))
        mole_image = cv2.imread(os.path.join(mole_image_root, str(mole_names[idx])))
        mole_image = cv2.cvtColor(mole_image, cv2.COLOR_BGR2RGB)
        
        bbox_correct = box_shrink(table_image_correct, bbox, 0.95)
        table_image = image_replace(table_image, bbox, mole_image, scale_shrink=0.8)
        table_image_correct = image_replace(table_image_correct, bbox_correct, mole_image, scale_shrink=0.8)

    plt.imsave('./data_debug/line/' + f'{table_base_name}_original.png', table_image)
    plt.imsave('./data_debug/line/' + f'{table_base_name}_corrected.png', table_image_correct)
    

Generating: 100%|██████████| 14/14 [00:03<00:00,  4.44it/s]


In [16]:
score = np.array([1000, 5, 1, 1, 1])
score = score / score.sum()
np.random.choice([1, 2, 3, 4, 5], 3, replace=False, p=score)

array([1, 3, 2])

In [3]:
table_root = '/home/suqi/dataset/synthesis_table_new'
sort_by_mean = sorted(os.listdir(table_root), key=lambda x: x.split('_')[0])
sort_by_max = sorted(os.listdir(table_root), key=lambda x: x.split('_')[1])

mean_save_dir = '/home/suqi/dataset/temp/syn_table/mean'
max_save_dir = '/home/suqi/dataset/temp/syn_table/max'
os.makedirs(mean_save_dir, exist_ok=True)
os.makedirs(max_save_dir, exist_ok=True)

print(len(sort_by_max))

73229


In [4]:
save_name = ['high', 'mid', 'low']
ranges = [(73229 - 100, 73229), (30000, 30100), (0, 100)]
for i in range(3):
    for name in sort_by_mean[ranges[i][0]: ranges[i][1]]:
        os.makedirs(os.path.join(mean_save_dir, save_name[i]), exist_ok=True)
        os.link(os.path.join(table_root, name), os.path.join(mean_save_dir, save_name[i], name))

for i in range(3):
    for name in sort_by_max[ranges[i][0]: ranges[i][1]]:
        os.makedirs(os.path.join(max_save_dir, save_name[i]), exist_ok=True)
        os.link(os.path.join(table_root, name), os.path.join(max_save_dir, save_name[i], name))

In [13]:
source_root = '/home/suqi/dataset/Pub_Syn_Union'
target_root = '/home/suqi/dataset/Pub_Fin_Syn_Union'

fintable = '/home/suqi/dataset/synthesis_fintable'
xml_root = '/home/suqi/dataset/FinTabNet.c/FinTabNet.c-Structure'

# link image and write filelist
with open(os.path.join(target_root, 'train_filelist.txt'), 'a') as f:
    for name in sorted(os.listdir(fintable)):
        if '_COL_' in name:
            base_name = name.split('_COL_')[0]
        else:
            base_name = name.split('_CELL_')[0]
        link_name = name.split('.jpg')[0] + '.xml'
        
        xml_source = glob(os.path.join(xml_root, '**', base_name + '.xml'))
        assert len(xml_source) == 1
        os.link(xml_source[0], os.path.join(target_root, 'train', link_name))
        f.write(f'train/{link_name}\n')


In [7]:
import torch
import os

model_root = '/home/suqi/model/TATR/TATR-v1.1-All-msft.pth'
target_root = "/home/suqi/model/TATR/finetune/20231027181701"

checkpoint = torch.load(model_root, map_location='cpu')
checkpoint = {'model_state_dict': checkpoint}
torch.save(checkpoint, os.path.join(target_root, 'model.pth'))

True
