In [1]:
# config env
pkgs_path = "/bohr/pkgs-7x29/v5/pkgs"
llava_lib_path = "/bohr/libb-bg5b/v3/llava"
tsr_model_path = "microsoft/table-structure-recognition-v1.1-all"
model_path = "lmms-lab/llava-onevision-qwen2-7b-si"
cache_path = "/bohr/cach-rxl3/v3/cache"

# pkgs_path = "/personal/pkgs"
# llava_lib_path = "/personal/llava"
# model_path = "lmms-lab/llava-onevision-qwen2-0.5b-ov"
# cache_path = "/personal/cache"

# !pip install {pkgs_path}/*
!cp {llava_lib_path} . -r
import os

# # 提交时可能不能联网，设置成离线模式防止联网失败报错
os.environ['TRANSFORMERS_OFFLINE'] = '1'
os.environ['HF_DATASETS_OFFLINE'] = '1'
os.environ['HF_HUB_OFFLINE'] = '1'
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ["HUGGINGFACE_HUB_CACHE"] = cache_path
os.environ["HF_HOME"] = cache_path
device = "cuda"

In [2]:
from llava.conversation import Conversation, SeparatorStyle
from llava.utils import disable_torch_init
import json
from llava.model.builder import load_pretrained_model
from llava.mm_utils import process_images, tokenizer_image_token
from llava.constants import IMAGE_TOKEN_INDEX
import torch

from PIL import Image, ImageDraw
from transformers import AutoImageProcessor, TableTransformerForObjectDetection
from llava.constants import DEFAULT_IMAGE_TOKEN
import multiprocessing
import warnings
from collections import defaultdict

import re

warnings.filterwarnings("ignore")

  from .autonotebook import tqdm as notebook_tqdm


In [3]:

args = type('Args', (), {
    "conv_mode": None,
    "sep": ",",
    "temperature": 0,
    "top_p": 1,
    "num_beams": 1,
    "max_new_tokens": 4096
})()

l2i = defaultdict(lambda: -1)
for i, letter in enumerate('ABCDEFGH'):
    l2i[letter] = i
sub_list = ('Physics', 'Mathematics', 'ComputerScience', 'QuantitativeBiology', 'QuantitativeFinance',
            'Statistics', 'ElectricalEngineeringandSystemsScience', 'Economics', '')
torch.cuda.empty_cache()
disable_torch_init()

if os.environ.get('DATA_PATH_B'):  # 提交时会选择隐藏的测试数据集路径（A+B榜），数据集的格式与A榜数据相同，但数目不同（5360张）
    base_dir = os.environ.get('DATA_PATH_B')
else:
    base_dir = '/bohr/form-recognition-train-b6y2/v4'  # 示例，把A榜测试数据集路径作为测试集路径，仅开发时挂载A榜数据用于debug   # 示例，把A榜测试数据集路径作为测试集路径，仅开发时挂载A榜数据用于debug

In [4]:
def count_rows_cols(latex_code):
    try:
        # 查找列数：根据表格行的定义找到表格列标识符，如 |l|c|c|c|c|
        columns = re.search(r'\\begin\{tabular\}\{([^\}]+)\}', latex_code)
        if columns:
            num_cols = len([c for c in columns.group(1) if c.isalpha()])
        else:
            num_cols = 0

        # 查找行数：根据 \hline 分隔符统计表格的行数
        rows = latex_code.split(r'\hline')
        num_rows = sum(1 for row in rows if '&' in row or '\\rule' in row)

        return num_rows, num_cols
    except:
        return -1, -1
    

def clean_out(image_path, outputs):
    pattern = r'{.*}'

    # Find the JSON string using the pattern
    match = re.search(pattern, outputs, re.DOTALL)
    sub_item = {
        "image_path": image_path,
        "category": "",
        "cols": -1,
        "rows": -1,
        "answer": -1,
    }
    if match:
        json_str = match.group(0)
        try:
            # Parse the JSON string into a Python dictionary
            data = json.loads(json_str)
            rows, cols = count_rows_cols(data["LaTex"])
            sub_item = {
                "image_path": image_path,
                "category": sub_list[l2i[data["subject"][0]]],
                "cols": cols,
                "rows": rows,
                "answer": l2i[data["option"][0]],
            }
        except:
            return sub_item
    else:
        return sub_item


In [5]:
class Worker:
    def __init__(self):
        with open(os.path.join(base_dir, 'dataset.json'), 'r') as f:
            self.data = json.load(f)
            # self.data = list(json.load(f))[:2]
        self.main_input = multiprocessing.Queue()

    def run(self):
        tsr_process = multiprocessing.Process(target=self.tsr_process)
        tsr_process.start()
        self.main_process()

    def tsr_process(self):
        tsr_img_processor = AutoImageProcessor.from_pretrained(tsr_model_path)
        tsr_img_processor.size = {'height': 384, 'width': 384}
        tsr_model = TableTransformerForObjectDetection.from_pretrained(tsr_model_path)
        label2id = tsr_model.config.label2id
        label_row = label2id['table row']
        label_col = label2id['table column']
        for item in self.data:
            path = os.path.join(base_dir, 'test_images', item["image_path"])
            image = Image.open(path).convert("RGB")
            inputs = tsr_img_processor(images=image, return_tensors="pt")
            outputs = tsr_model(**inputs)

            target_sizes = torch.tensor([image.size[::-1]])  # (height, width) of each image in the batch
            results = \
                tsr_img_processor.post_process_object_detection(outputs, threshold=0.6, target_sizes=target_sizes)[0]
            draw = ImageDraw.Draw(image)
            rows = 0
            cols = 0
            for label, box in zip(results["labels"], results["boxes"]):
                label, box = label.item(), box.tolist()
                draw.rectangle(box, outline="red", width=1)
                if label == label_row:
                    rows += 1
                elif label == label_col:
                    cols += 1
            self.main_input.put((image, rows, cols, item))

        self.main_input.put(None)

    def main_process(self):
        tokenizer, model, image_processor, _ = load_pretrained_model(
            model_path, None, "llava_qwen", device_map="auto",
            attn_implementation='sdpa',
            # load_8bit=True,
            # load_4bit=False,
            **{
                "multimodal": True,
                "overwrite_config": {
                    "image_aspect_ratio": "anyres_max_9"
                }
            }
        )
        submission = []
        while True:
            item = self.main_input.get()
            # print("MAIN ITEM", item)
            if item is None:
                break

            image, rows, cols, item = item
            image_sizes = [image.size]
            images = [image]
            image_tensors = [
                process_images(images, image_processor, model.config)[0].to(dtype=torch.float16, device=device)]
            latex = r'\begin{tabular}{|c|c|c|} \hline 1 & 2 & 3 \\ \hline 4 & 5 & 6 \\ \hline \end{tabular}'
            qs = f"""{DEFAULT_IMAGE_TOKEN}\n This is a table image with red borders. The caption of the table is "{item["caption"]}". Following are three tasks:
1. Convert this table to LaTex.
2. Based on the provided table, caption and LaTex, select the most relevant subject to the table from (A. Physics, B. Mathematics, C. ComputerScience, D. QuantitativeBiology, E. QuantitativeFinance, F. Statistics, G. ElectricalEngineeringandSystemsScience, H. Economics).
3. Based on the provided table, caption and LaTex, for the question: "{item["question"]}", select the most correct option from (A. {item["options"][0]}, B. {item["options"][1]}, C. {item["options"][2]}, D. {item["options"][3]}).
- Answer in the Json format. Example: `{
            "LaTex": "{latex}",
  "subject": "B",
  "option": "C"
}`"""
            conv = Conversation(
                system="""<|im_start|>system
                    You are a helpful assistant.""",
                roles=["<|im_start|>user", "<|im_start|>assistant"],
                version="qwen",
                messages=[],
                offset=0,
                sep_style=SeparatorStyle.CHATML,
                sep="<|im_end|>",
            )
            conv.append_message(conv.roles[0], qs)
            conv.append_message(conv.roles[1], None)
            prompt = conv.get_prompt()
            input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(
                0).cuda()
            print(prompt)
            output_ids = model.generate(
                input_ids,
                images=image_tensors,
                image_sizes=image_sizes,
                do_sample=True if args.temperature > 0 else False,
                temperature=args.temperature,
                top_p=args.top_p,
                num_beams=args.num_beams,
                max_new_tokens=args.max_new_tokens,
                use_cache=True,
            )
            outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
            sub_item = clean_out(item["image_path"], outputs)
            # print("MAIN", rows, cols)
            # print("MAIN:", out_list)
            print(outputs)
            print("MAIN", sub_item)
            image.show()
            submission.append(sub_item)
        with open('submission.json', 'w') as f:
            json.dump(submission, f)

In [6]:
worker = Worker()
worker.run()

Loaded LLaVA model: lmms-lab/llava-onevision-qwen2-7b-si
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
You are using a model of type llava to instantiate a model of type llava_qwen. This is not supported for all configurations of models and can yield errors.
Overwriting config with {'image_aspect_ratio': 'anyres_max_9'}
Loading vision tower: google/siglip-so400m-patch14-384
Loading checkpoint shards: 100%|██████████| 4/4 [00:09<00:00,  2.40s/it]
Model Class: LlavaQwenForCausalLM
<|im_start|>system
                    You are a helpful assistant.<|im_end|>
<|im_start|>user
<image>
 This is a table image with red borders. The caption of the table is "Finding the MD4-39 preimages for 500 randomly generated 128-bit Boolean vectors. Instances with preimages are satisfiable, while those with no preimages are unsatisfiable.". Following are three tasks:
            1. Convert this table to LaTex.
            2. Based on th

Error: 