In [2]:
import os
import sys

sys.path.append('../../detr')
import numpy as np
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
from scipy import ndimage
import bisect
from src.eval import infer_visualize
from src.main import get_model
import json
import argparse
import torch.nn as nn
from PIL import Image, ImageOps
import warnings
import torchvision.models as tmd
from torchvision.transforms import transforms
from torchvision.models._utils import IntermediateLayerGetter
from thop import profile


warnings.filterwarnings('ignore')

In [3]:
def get_args():
    parser = argparse.ArgumentParser()

    parser.add_argument('--data_root_dir',
                        # required=True,
                        help="Root data directory for images and labels")
    parser.add_argument('--config_file',
                        # required=True,
                        help="Filepath to the config containing the args")
    parser.add_argument('--backbone',
                        default='resnet18',
                        help="Backbone for the model")
    parser.add_argument(
        '--data_type',
        choices=['detection', 'structure'],
        default='structure',
        help="toggle between structure recognition and table detection")
    parser.add_argument('--model_load_path', help="The path to trained model")
    parser.add_argument('--load_weights_only', action='store_true')
    parser.add_argument('--model_save_dir', help="The output directory for saving model params and checkpoints")
    parser.add_argument('--metrics_save_filepath',
                        help='Filepath to save grits outputs',
                        default='')
    parser.add_argument('--debug_save_dir',
                        help='Filepath to save visualizations',
                        default='debug')
    parser.add_argument('--table_words_dir',
                        help="Folder containg the bboxes of table words")
    parser.add_argument('--mode',
                        choices=['train', 'eval'],
                        default='train',
                        help="Modes: training (train) and evaluation (eval)")
    parser.add_argument('--debug', action='store_true')
    parser.add_argument('--device')
    parser.add_argument('--lr', type=float)
    parser.add_argument('--lr_drop', type=int)
    parser.add_argument('--lr_gamma', type=float)
    parser.add_argument('--epochs', type=int)
    parser.add_argument('--checkpoint_freq', default=1, type=int)
    parser.add_argument('--batch_size', type=int)
    parser.add_argument('--num_workers', type=int)
    parser.add_argument('--train_max_size', type=int)
    parser.add_argument('--val_max_size', type=int)
    parser.add_argument('--test_max_size', type=int)
    parser.add_argument('--eval_pool_size', type=int, default=1)
    parser.add_argument('--eval_step', type=int, default=1)
    
    parser.add_argument('--overlap', action='store_true')
    parser.add_argument("--overlap_loss_coef", type=int, default=1)
    args, unknown = parser.parse_known_args()
    return args

In [4]:
image_folder = '/home/suqi/dataset/screenshot'

# 随机选择部分样本进行测试
k = 500
np.random.seed(0)
# test_list = sorted([name for name in os.listdir(image_folder) if 'COL' in name])
test_list = sorted([name for name in os.listdir(image_folder)])
# np.random.shuffle(test_list)
# test_list = test_list[:k]
img_paths = [os.path.join(image_folder, name) for name in test_list]

# img_names = ['_'.join(name.split('_')[:-1]) + '.jpg' for name in os.listdir('/home/suqi/dataset/temp/visualize')]
# img_paths = [os.path.join(image_folder, name) for name in list(set(img_names))]
print(len(img_paths))

31


In [5]:
# load images
samples = []
img_paths_filter = []
errors = 0
for path in img_paths:
    try:
        image = Image.open(path).convert('RGB')
        w, h = image.size
        padding_size = int(min(w, h) * 0.2)
        image = ImageOps.expand(image, border=padding_size, fill='white')
        samples.append(image)
        img_paths_filter.append(path)
    except Exception as e:
        errors += 1
        print(f"Exception when load image {path}")
img_paths = img_paths_filter
print(errors)

0


In [6]:
args = get_args()
args.config_file = "../structure_config.json"
args.data_type = None
args.model_load_path = "/home/suqi/model/TATR/finetune/train_finetune_resnet34/model_best.pth"
args.backbone = "resnet34"
args.debug_save_dir = "/home/suqi/dataset/temp/visualize_html_image_resnet34/"
os.makedirs(args.debug_save_dir, exist_ok=True)
args.device = "cuda:3"

cmd_args = args.__dict__
config_args = json.load(open(cmd_args['config_file'], 'rb'))
for key, value in cmd_args.items():
    if not key in config_args or not value is None:
        config_args[key] = value
# config_args.update(cmd_args)
args = type('Args', (object,), config_args)

# define and load model
device = torch.device(args.device)
model, _, _ = get_model(args, device)

loading model from checkpoint
load model parameters successfully!


In [7]:
# 约3.00s/it
infer_visualize(model, samples, device, img_paths, args.debug_save_dir, True)

Inferring: : 31it [02:26,  4.73s/it]


In [8]:
del model
torch.cuda.empty_cache()

## 分开测试

In [None]:
data_root = '/home/suqi/dataset/Pub_Fin_Syn_Union_Clean'
curr_root = '/home/suqi/dataset/temp/Pub_Fin_Syn_split'

val_list = [name.split('.xml')[0] for name in sorted(os.listdir(os.path.join(data_root, 'val')))]
pub_list = [name for name in val_list if ('CELL' not in name) and ('COL' not in name)]
syn_list = list(set(val_list) - set(pub_list))

In [None]:
os.makedirs(os.path.join(curr_root, 'pubset', 'images'), exist_ok=True)
os.makedirs(os.path.join(curr_root, 'pubset', 'val'), exist_ok=True)
os.makedirs(os.path.join(curr_root, 'synset', 'images'), exist_ok=True)
os.makedirs(os.path.join(curr_root, 'synset', 'val'), exist_ok=True)

for name in pub_list:
    os.link(os.path.join(data_root, 'images', name + '.jpg'),
            os.path.join(curr_root, 'pubset', 'images', name + '.jpg'))
    os.link(os.path.join(data_root, 'val', name + '.xml'), os.path.join(curr_root, 'pubset', 'val', name + '.xml'))
    with open(os.path.join(curr_root, 'pubset', 'val_filelist.txt'), 'a') as f:
        f.write(f'val/{name}.xml\n')

for name in syn_list:
    os.link(os.path.join(data_root, 'images', name + '.jpg'),
            os.path.join(curr_root, 'synset', 'images', name + '.jpg'))
    os.link(os.path.join(data_root, 'val', name + '.xml'), os.path.join(curr_root, 'synset', 'val', name + '.xml'))
    with open(os.path.join(curr_root, 'synset', 'val_filelist.txt'), 'a') as f:
        f.write(f'val/{name}.xml\n')

# 40057 43155
# 45251 45251

## 测试模型计算量

In [None]:
args = get_args()
args.config_file = "../structure_config_res34_large.json"
args.data_type = None
args.backbone = "resnet34"
# args.model_load_path = "/home/suqi/model/TATR/finetune/20231016093942/model_1.pth"
args.device = "cuda"

cmd_args = args.__dict__
config_args = json.load(open(cmd_args['config_file'], 'rb'))
for key, value in cmd_args.items():
    if not key in config_args or not value is None:
        config_args[key] = value
# config_args.update(cmd_args)
args = type('Args', (object,), config_args)

os.makedirs(args.debug_save_dir, exist_ok=True)

# define and load model
device = torch.device(args.device)
model, _, _ = get_model(args, device)
model.eval()

In [None]:
from thop import profile
from detr.util.misc import nested_tensor_from_tensor_list

x = torch.randn(size=(1, 3, 1024, 1024), device=args.device)

# backbone
if isinstance(x, (list, torch.Tensor)):
    x = nested_tensor_from_tensor_list(x)

# transformer
features, pos = model.backbone(x)

src, mask = features[-1].decompose()
assert mask is not None
src = model.input_proj(src)

# encoder, decoder
query_embed = model.query_embed.weight
pos_embed = pos[-1]
bs, c, h, w = src.shape
src = src.flatten(2).permute(2, 0, 1)
pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
mask = mask.flatten(1)

tgt = torch.zeros_like(query_embed)
memory = model.transformer.encoder(src, src_key_padding_mask=mask, pos=pos_embed)

with torch.no_grad():
    # flops, params = profile(model, (x,))
    # flops, params = profile(model.backbone, (x,))
    # flops, params = profile(model.transformer, (src, mask, model.query_embed.weight, pos[-1]))
    flops, params = profile(model.transformer.encoder, (src, None, mask, pos_embed))
    # flops, params = profile(model.transformer.decoder, (tgt, memory, None, None, None, mask, pos_embed, query_embed))
    print('FLOPs = ' + str(flops / 1000 ** 3) + 'G')
    print('Params = ' + str(params / 1000 ** 2) + 'M')

In [None]:
del model

## 测试Swin Transformer Backbone

In [None]:
model = tmd.resnet50(pretrained=True).cuda().eval()
x = torch.randn(size=(1, 3, 1024, 1024)).cuda()
# for name, parameter in model.named_parameters():
#     print(name)
    
model = IntermediateLayerGetter(model, return_layers={'layer4': '0'})
print(model(x)['0'].shape)

# with torch.no_grad():
#     flops, params = profile(model, (x,))
#     print('FLOPs = ' + str(flops/1000**3) + 'G')
#     print('Params = ' + str(params/1000**2) + 'M')