In [1]:
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(sys.path[0]))))

In [None]:
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'True'
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image,ImageDraw
import easyocr
import numpy as np
import json
import re
import torch
from tqdm import tqdm
from scripts.train.utils import smart_tokenizer_and_embedding_resize,Html2BboxTree,move_to_device,BboxTree2Html,add_special_tokens
from vars import *
from my_dataset import UICoderDataset,UICoderCollater
from transformers import AutoProcessor, Pix2StructForConditionalGeneration,AddedToken
from utils import Html2BboxTree, BboxTree2StyleList, BboxTree2Html
from datasets import Dataset, load_dataset


torch.manual_seed(SEED)

device = 'cuda:0'
# bbox_model_path = "/data02/users/lz/code/UICoder/checkpoints/stage1/l2048_p1024_vu_3m*1/checkpoint-200000"
bbox_model_path = "/data02/users/lz/code/UICoder/checkpoints/stage1/l2048_p1024_vu2048_3m*1/checkpoint-50000"
style_model_path = "/data02/users/lz/code/UICoder/checkpoints/stage2/l256_p512_vu_100k*1/checkpoint-50000"

data_path = '/data02/bbox_v2/data/00002.parquet'
# data_path = '/data02/users/lz/code/UICoder/datasets/WebSight-format-parquet/arrow'
output_dir = '/data02/users/lz/code/UICoder/test_result'

In [None]:
processor = AutoProcessor.from_pretrained(processor_name_or_path)
model_bbox = Pix2StructForConditionalGeneration.from_pretrained(bbox_model_path,is_encoder_decoder=True,device_map=device,torch_dtype=torch.float16)
model_style = Pix2StructForConditionalGeneration.from_pretrained(style_model_path,is_encoder_decoder=True,device_map=device,torch_dtype=torch.float16)
add_special_tokens(model_bbox,processor.tokenizer)
add_special_tokens(model_style,processor.tokenizer)

ds = load_dataset('parquet', data_files={'train':data_path})['train']
# ds = UICoderDataset(path=data_path,processor=processor,max_length=1024,max_patches=1024,max_num=100,drop_longer=True,stage=1,preprocess=True, make_patches_while_training=True, workers=1)

In [8]:

def drawBboxOnImage(draw: ImageDraw,bbox_node):
    bbox = bbox_node['bbox']
    if bbox[2] > 0 and bbox[3] > 0:
        draw.rectangle((bbox[0],bbox[1],bbox[0]+bbox[2],bbox[1]+bbox[3]),outline="red",width=2)
    for node in bbox_node['children']:
        drawBboxOnImage(draw, node)
        
def stickImages(images):
    sizes = list(map(lambda x: x.size,images))

    max_width = max(list(map(lambda x: x[0],sizes)))
    max_height = max(list(map(lambda x: x[1],sizes)))

    new_image = Image.new('RGB', ((max_width+10)*len(images), max_height))
    
    for idx,image in enumerate(images):
        new_image.paste(image, ((max_width+10)*idx, 0))

    return new_image

def remove_bbox(html_content):
    bbox_pattern = r' bbox=\[[^\]]*\]'
    cleaned_html = re.sub(bbox_pattern, '', html_content)
    return cleaned_html

def infer_bbox(image):
    model_bbox.eval()
    with torch.no_grad():
        input = f'<body bbox=['
        decoder_input_ids = processor.tokenizer.encode(input,return_tensors='pt',add_special_tokens=True)[...,:-1]
        encoding = processor(images=[image],text=[""],max_patches=1024,return_tensors='pt')
        item = {
            'decoder_input_ids': decoder_input_ids,
            'flattened_patches': encoding['flattened_patches'].half(),
            'attention_mask': encoding['attention_mask']
        }
        item = move_to_device(item,device)
    
        outputs = model_bbox.generate(**item,max_new_tokens=2560,eos_token_id=processor.tokenizer.eos_token_id,do_sample=True)
                
        prediction_html = processor.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]

    return prediction_html


In [None]:
idx = 10
# image
image = ds[idx]['image']

# infer
prediction_html = infer_bbox(image)

# predicted html
print(len(processor.tokenizer.encode(prediction_html)))
prediction_html

In [None]:
# aBbox = Html2BboxTree(answer_html)
pBbox = Html2BboxTree(prediction_html, size=image.size)
aImage = image.copy()
pImage = image.copy()

drawBboxOnImage(ImageDraw.Draw(aImage),json.loads(ds[idx]['bbox']))
drawBboxOnImage(ImageDraw.Draw(pImage),pBbox)

pair = stickImages([aImage,pImage])
pair

# For Style Test

In [58]:
def predictStyle(image,styleItem):
    cnode_type_bbox_list = list(map(lambda x: f'{x["type"]}({x["bbox"][0]-styleItem["bbox"][0]},{x["bbox"][1]-styleItem["bbox"][1]},{x["bbox"][2]},{x["bbox"]})', styleItem['children']))
    input = f"{styleItem['type']}<{','.join(cnode_type_bbox_list)}> %{styleItem['style']}%<%"
    decoder_input_ids = processor.tokenizer.encode(input,return_tensors='pt')[...,:-1]
    encoding = processor(images=[image],text=[""],max_patches=512,return_tensors='pt')
    item = {
        'decoder_input_ids': decoder_input_ids,
        'flattened_patches': encoding['flattened_patches'].half(),
        'attention_mask': encoding['attention_mask']
    }
    item = move_to_device(item,device)
    with torch.no_grad():
        outputs = model_style.generate(**item,max_new_tokens=256,eos_token_id=processor.tokenizer.eos_token_id)
        predictions = processor.tokenizer.batch_decode(outputs, skip_special_tokens=True)
    css = '<'.join(predictions[0].split('<')[2:]).strip()
    if css and css[0] == '%':
        css = css[1:]
    if css and css[-1] == '>':
        css = css[:-1]
    if css and css[-1] == '%':
        css = css[:-1]
    css = css.split('%,%')
    return css

In [None]:
bboxTree = Html2BboxTree(prediction_html, size=image.size)
indexList = BboxTree2StyleList(bboxTree, skip_leaf=False)

def locateByIndex(bboxTree,index):
    target = bboxTree
    for i in list(filter(lambda x: x,index.split('-'))):
        target = target['children'][int(i)]
    return target

for item in tqdm(indexList):
    bbox = item['bbox']
    index = item['index']

    if not item['children'] or bbox[2] <= bbox_padding*2 or bbox[3] <= bbox_padding*2:
        continue
    image_crop = image.crop((bbox[0],bbox[1],bbox[0]+bbox[2],bbox[1]+bbox[3]))
    predicted_css = predictStyle(image_crop,item)

    for idx, css_item in enumerate(predicted_css):
        index_tmp = f"{index}{'-' if index else ''}{idx}"
        target = locateByIndex(bboxTree, index_tmp)
        target['style'] = css_item

# For Text and Image Apply

In [None]:
reader = easyocr.Reader(['ch_sim','en'])

img_idx = 0
for item in tqdm(indexList):
    if not len(item['children']):
        bbox = item['bbox']
        index = item['index']
        if bbox[2] <= 0 or bbox[3] <= 0:
            continue
        image_crop = image.crop((bbox[0],bbox[1],bbox[0]+bbox[2],bbox[1]+bbox[3]))
        image_crop_text = image.crop((bbox[0]-bbox_padding*5,bbox[1]-bbox_padding*5,bbox[0]+bbox[2]+bbox_padding*10,bbox[1]+bbox[3]+bbox_padding*10)).convert('L')
        target = locateByIndex(bboxTree, index)
        if item['type'] == 'img':
            image_crop.save(os.path.join(output_dir,f'{img_idx}.png'))
            target['children'] = [f'{img_idx}.png']
            img_idx += 1
        else:
            result = reader.readtext(np.array(image_crop_text))
            text = '\n'.join(list(map(lambda x: x[1], result)))
            target['children'] = text

In [None]:
predicted_html_with_style = BboxTree2Html(bboxTree,style=True)

os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir,'index.html'),'w') as f:
    f.write(predicted_html_with_style)

In [None]:
indexList