In [1]:
# config env
pkgs_path = "/bohr/pkgs-7x29/v5/pkgs"
llava_lib_path = "/bohr/libb-bg5b/v3/llava"
tsr_model_path = "microsoft/table-structure-recognition-v1.1-all"
model_path = "lmms-lab/llava-onevision-qwen2-7b-si"
cache_path = "/bohr/cach-rxl3/v3/cache"

# pkgs_path = "/personal/pkgs"
# llava_lib_path = "/personal/llava"
# model_path = "lmms-lab/llava-onevision-qwen2-0.5b-ov"
# cache_path = "/personal/cache"

# !pip install {pkgs_path}/*
!cp {llava_lib_path} . -r
import os

# # 提交时可能不能联网，设置成离线模式防止联网失败报错
os.environ['TRANSFORMERS_OFFLINE'] = '1'
os.environ['HF_DATASETS_OFFLINE'] = '1'
os.environ['HF_HUB_OFFLINE'] = '1'
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ["HUGGINGFACE_HUB_CACHE"] = cache_path
os.environ["HF_HOME"] = cache_path
device = "cuda"

In [2]:
from llava.utils import disable_torch_init
import json
import torch
import warnings
from collections import defaultdict
import re
import sglang as sgl
from sglang.lang.chat_template import get_chat_template

warnings.filterwarnings("ignore")

In [3]:
args = type('Args', (), {
    "conv_mode": None,
    "sep": ",",
    "temperature": 0,
    "top_p": 1,
    "num_beams": 1,
    "max_new_tokens": 4096
})()

l2i = defaultdict(lambda: -1)
for i, letter in enumerate('ABCDEFGH'):
    l2i[letter] = i
sub_list = ('Physics', 'Mathematics', 'ComputerScience', 'QuantitativeBiology', 'QuantitativeFinance',
            'Statistics', 'ElectricalEngineeringandSystemsScience', 'Economics', '')
torch.cuda.empty_cache()
disable_torch_init()

if os.environ.get('DATA_PATH_B'):  # 提交时会选择隐藏的测试数据集路径（A+B榜），数据集的格式与A榜数据相同，但数目不同（5360张）
    base_dir = os.environ.get('DATA_PATH_B')
else:
    base_dir = '/bohr/form-recognition-train-b6y2/v4'  # 示例，把A榜测试数据集路径作为测试集路径，仅开发时挂载A榜数据用于debug   # 示例，把A榜测试数据集路径作为测试集路径，仅开发时挂载A榜数据用于debug

In [4]:
def count_rows_cols(latex_code):
    try:
        # 查找列数：根据表格行的定义找到表格列标识符，如 |l|c|c|c|c|
        columns = re.search(r'\\begin\{tabular\}\{([^\}]+)\}', latex_code)
        if columns:
            num_cols = len([c for c in columns.group(1) if c.isalpha()])
        else:
            num_cols = 0

        # 查找行数：根据 \hline 分隔符统计表格的行数
        rows = latex_code.split(r'\hline')
        num_rows = sum(1 for row in rows if '&' in row or '\\rule' in row)

        return num_rows, num_cols
    except:
        return -1, -1

In [5]:
@sgl.function
def one_image(self, s, img_path, caption, q3):
    # output_ids = self.model.generate(
    #     input_ids.cuda(),
    #     images=[img.cuda() for img in image_tensors],
    #     image_sizes=image_sizes,
    #     do_sample=True if args.temperature > 0 else False,
    #     temperature=args.temperature,
    #     top_p=args.top_p,
    #     num_beams=args.num_beams,
    #     max_new_tokens=args.max_new_tokens,
    #     use_cache=True,
    # )
    img_path = os.path.join(base_dir, 'test_images', img_path)
    s += sgl.user(
        sgl.image(img_path) +
        f'This is a table image. This is a table image with red borders. The caption of the table is "{caption}".')
    s += sgl.assistant("I have a general understanding of the information in this table.")
    s += sgl.user("Convert this table to LaTex.")
    s += sgl.assistant(sgl.gen("LaTex", temperature=0.0, top_p=1))
    s += sgl.user(
        "Based on the provided table, caption and LaTex, select the most relevant subject to the table from (A. Physics, B. Mathematics, C. ComputerScience, D. QuantitativeBiology, E. QuantitativeFinance, F. Statistics, G. ElectricalEngineeringandSystemsScience, H. Economics). Answer with the option's letter from the given choices directly.")
    s += sgl.assistant(
        sgl.gen("subject", choices=["A", "B", "C", "D", "E", "F", "G", "H"], max_tokens=2, temperature=0.0,
                top_p=1))
    s += sgl.user(q3)
    s += sgl.assistant(sgl.gen("option", choices=["A", "B", "C", "D"], max_tokens=2, temperature=0.0, top_p=1))


class Worker:

    def __init__(self):
        with open(os.path.join(base_dir, 'dataset.json'), 'r') as f:
            self.data = json.load(f)
        self.batch_size = 1
        self.submission = []

    def run(self):
        # self.tokenizer, self.model, self.image_processor, _ = load_pretrained_model(
        #     model_path, None, "llava_qwen", device_map="auto",
        #     attn_implementation='sdpa',
        #     # load_8bit=True,
        #     # load_4bit=False,
        #     **{
        #         "multimodal": True,
        #         "overwrite_config": {
        #             "image_aspect_ratio": "anyres_max_9"
        #         }
        #     }
        # )
        model_overide_args = {
            "attn_implementation": "eager",
            "multimodal": True,
            "overwrite_config": {
                "image_aspect_ratio": "anyres_max_9"
            }
        }
        runtime = sgl.Runtime(model_path=model_path, model_overide_args=model_overide_args)
        runtime.endpoint.chat_template = get_chat_template("qwen")
        sgl.set_default_backend(runtime)
        self.process()
        runtime.shutdown()

    def process(self):
        batch_images = []
        for item in self.data:
            path = os.path.join(base_dir, 'test_images', item["image_path"])
            caption = item["caption"]
            q3 = f"""Based on the provided table, caption and LaTex, for the question: "{item["question"]}", select the most correct option from (A. {item["options"][0]}, B. {item["options"][1]}, C. {item["options"][2]}, D. {item["options"][3]}). Answer with the option\'s letter from the given choices directly."""
            batch_images.append((path, caption, q3))
            if len(batch_images) == self.batch_size:
                self.batch(batch_images)
                batch_images = []
        with open('submission.json', 'w') as f:
            json.dump(self.submission, f)

    def batch(self, batch_images):
        states = one_image.run_batch(batch_images)
        for i, s in enumerate(states):
            self.clean_out(batch_images[i][0], s)

    def clean_out(self, img_path, s):
        latex = s["LaTex"]
        rows, cols = count_rows_cols(latex)
        try:
            sub_item = {
                "image_path": img_path,
                "category": sub_list[l2i[s["subject"][0]]],
                "cols": cols,
                "rows": rows,
                "answer": l2i[s["option"][0]],
            }
        except:
            sub_item = {
                "image_path": img_path,
                "category": "",
                "cols": -1,
                "rows": -1,
                "answer": -1,
            }
        print(sub_item)
        self.submission.append(sub_item)

In [6]:
worker = Worker()
worker.run()