In [1]:
# config env
# pkgs_path = "/bohr/pkgs-7x29/v5/pkgs"
llava_lib_path = "/bohr/libb-bg5b/v3/llava"
tsr_model_path = "microsoft/table-structure-recognition-v1.1-all"

help_model_path = "lmms-lab/llava-onevision-qwen2-0.5b-si"
main_model_path = "lmms-lab/llava-onevision-qwen2-7b-si"
cache_path = "/bohr/cach-rxl3/v4/cache"

# pkgs_path = "/personal/pkgs"
# llava_lib_path = "/personal/llava"
# model_path = "lmms-lab/llava-onevision-qwen2-0.5b-ov"
# cache_path = "/personal/cache"


# !pip install {pkgs_path}/*
!cp {llava_lib_path} . -r
import os

# # 提交时可能不能联网，设置成离线模式防止联网失败报错
os.environ['TRANSFORMERS_OFFLINE'] = '1'
os.environ['HF_DATASETS_OFFLINE'] = '1'
os.environ['HF_HUB_OFFLINE'] = '1'
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ["HUGGINGFACE_HUB_CACHE"] = cache_path
os.environ["HF_HOME"] = cache_path
device = "cuda"

In [2]:
from llava.conversation import Conversation, SeparatorStyle
from llava.utils import disable_torch_init
import json
from llava.model.builder import load_pretrained_model
from llava.mm_utils import process_images, tokenizer_image_token
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
import torch

from PIL import Image, ImageDraw
from transformers import AutoImageProcessor, TableTransformerForObjectDetection

import warnings
from collections import defaultdict
import re
import asyncio

warnings.filterwarnings("ignore")

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
args = type('Args', (), {
    "conv_mode": None,
    "sep": ",",
    "temperature": 0,
    "top_p": 1,
    "num_beams": 1,
    "max_new_tokens": 8
})()

l2i = defaultdict(lambda: -1)
for i, letter in enumerate('ABCDEFGH'):
    l2i[letter] = i
sub_list = ('Physics', 'Mathematics', 'ComputerScience', 'QuantitativeBiology', 'QuantitativeFinance',
            'Statistics', 'ElectricalEngineeringandSystemsScience', 'Economics', '')

torch.cuda.empty_cache()
disable_torch_init()

Loaded LLaVA model: lmms-lab/llava-onevision-qwen2-0.5b-si
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
You are using a model of type llava to instantiate a model of type llava_qwen. This is not supported for all configurations of models and can yield errors.
Overwriting config with {'image_aspect_ratio': 'anyres_max_9'}
Loading vision tower: google/siglip-so400m-patch14-384
Model Class: LlavaQwenForCausalLM
Loaded LLaVA model: lmms-lab/llava-onevision-qwen2-7b-si
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
You are using a model of type llava to instantiate a model of type llava_qwen. This is not supported for all configurations of models and can yield errors.


In [ ]:
if os.environ.get('DATA_PATH_B'):  # 提交时会选择隐藏的测试数据集路径（A+B榜），数据集的格式与A榜数据相同，但数目不同（5360张）
    base_dir = os.environ.get('DATA_PATH_B')
else:
    base_dir = '/bohr/form-recognition-train-b6y2/v4'  # 示例，把A榜测试数据集路径作为测试集路径，仅开发时挂载A榜数据用于debug   # 示例，把A榜测试数据集路径作为测试集路径，仅开发时挂载A榜数据用于debug

In [ ]:
def clean_out(image_path, out_list):
    matches = re.findall(r"\d+", out_list[0])
    if len(matches) == 2:
        rows, cols = int(matches[0]), int(matches[1])
    elif len(matches) == 1:
        rows = cols = int(matches[0])
    else:
        rows = cols = -1

    sub_item = {
        "image_path": image_path,
        "category": sub_list[l2i[out_list[1]]],
        "cols": cols,
        "rows": rows,
        "answer": l2i[out_list[2]],
    }
    return sub_item

In [ ]:
class Worker:
    def __init__(self):
        with open(os.path.join(base_dir, 'dataset.json'), 'r') as f:
            self.data = json.load(f)
        self.tsr_result = []
        self.help_result = []
        self.main_input = asyncio.Queue()
        self.help_tokenizer, self.help_model, self.help_image_processor, _ = load_pretrained_model(
            help_model_path, None, "llava_qwen", device_map="auto",
            attn_implementation='sdpa',
            # load_8bit=True,
            # load_4bit=False,
            **{
                "multimodal": True,
                "overwrite_config": {
                    "image_aspect_ratio": "anyres_max_9"
                }
            }
        )

        # self.tokenizer, self.model, self.image_processor, _ = load_pretrained_model(
        #     main_model_path, None, "llava_qwen", device_map="auto",
        #     attn_implementation='sdpa',
        #     # load_8bit=True,
        #     # load_4bit=False,
        #     **{
        #         "multimodal": True,
        #         "overwrite_config": {
        #             "image_aspect_ratio": "anyres_max_9"
        #         }
        #     }
        # )

        self.tsr_img_processor = AutoImageProcessor.from_pretrained(tsr_model_path)
        self.tsr_img_processor.size['shortest_edge'] = self.help_image_processor.size[0]
        self.tsr_model = TableTransformerForObjectDetection.from_pretrained(tsr_model_path)
        label2id = self.tsr_model.config.label2id
        self.label_row = label2id['table row']
        self.label_col = label2id['table column']

    async def run(self):
        tasks = [
            asyncio.create_task(self.help_process()),
            asyncio.create_task(self.main_process()),
            asyncio.create_task(self.tsr_process())
        ]
        await asyncio.gather(*tasks)

    async def tsr_process(self):
        for item in self.data:
            path = os.path.join(base_dir, 'test_images', item["image_path"])
            image = Image.open(path).convert("RGB")
            inputs = self.tsr_img_processor(images=image, return_tensors="pt")
            outputs = self.tsr_model(**inputs)

            target_sizes = torch.tensor([image.size[::-1]])  # (height, width) of each image in the batch
            results = \
                self.tsr_img_processor.post_process_object_detection(outputs, threshold=0.6, target_sizes=target_sizes)[
                    0]
            draw = ImageDraw.Draw(image)
            rows = 0
            cols = 0
            for label, box in zip(results["labels"], results["boxes"]):
                label, box = label.item(), box.tolist()
                draw.rectangle(box, outline="red", width=1)
                if label == self.label_row:
                    rows += 1
                elif label == self.label_col:
                    cols += 1
            if self.help_result:
                await self.main_input.put((self.help_result.pop(0), (image, rows, cols)))
            else:
                self.tsr_result.append((image, rows, cols))

        if not self.help_result:
            await self.main_input.put(None)

    async def help_process(self):
        for item in self.data:
            path = os.path.join(base_dir, 'test_images', item["image_path"])
            caption = item["caption"]
            image = Image.open(path).convert("RGB")
            image_sizes = [image.size]
            images = [image]
            image_tensors = [
                process_images(images, self.help_image_processor, self.help_model.config)[0].to(dtype=torch.float16,
                                                                                                device=device)]
            conv = Conversation(
                system="""<|im_start|>system
                                You are a helpful assistant. Provide only an option's letter or an integer for each question, without any additional explanation.""",
                roles=["<|im_start|>user", "<|im_start|>assistant"],
                version="qwen",
                messages=[
                    ["<|im_start|>user",
                     f'{DEFAULT_IMAGE_TOKEN}\n This is a table image. The caption of the table is "{caption}".'],
                    ["<|im_start|>assistant",
                     "I have a general understanding of the information in this table."]
                ],
                offset=0,
                sep_style=SeparatorStyle.CHATML,
                sep="<|im_end|>",
            )
            qs_list = [
                f'Based on the provided table, what is its shape? Answer with two positive integers for rows and columns, separated by a comma:',
                f"""Based on the provided table and caption, select the most relevant subject from (A. Physics, B. Mathematics, C. ComputerScience, D. QuantitativeBiology, E. QuantitativeFinance, F. Statistics, G. ElectricalEngineeringandSystemsScience, H. Economics). Answer with the option's letter from the given choices directly.""",
                f"""Based on the provided table and caption, for the question: "{item["question"]}", select the most correct option from (A. {item["options"][0]}, B. {item["options"][1]}, C. {item["options"][2]}, D. {item["options"][3]}). Answer with the option's letter from the given choices directly."""
            ]
            out_list = self.one_image(self.help_model, self.help_tokenizer, image_tensors, image_sizes, conv, qs_list)
            print("HELP:", out_list)
            if self.tsr_result:
                await self.main_input.put(((caption, qs_list, out_list), self.tsr_result.pop(0)))
            else:
                self.help_result.append((caption, qs_list, out_list))
        if not self.tsr_result:
            await self.main_input.put(None)

    # async def main_process(self):
    #     submission = []
    #     while True:
    #         item = await self.main_input.get()
    #         if item is None:
    #             break
    # 
    #         (caption, qs_list, out_list), (image, rows, cols) = item
    #         image_sizes = [image.size]
    #         images = [image]
    #         image_tensors = [
    #             process_images(images, self.image_processor, self.model.config)[0].to(dtype=torch.float16,
    #                                                                                   device=device)]
    #         conv = Conversation(
    #             system="""<|im_start|>system
    #                     You are a helpful assistant. Provide only an option's letter or an integer for each question, without any additional explanation.""",
    #             roles=["<|im_start|>user", "<|im_start|>assistant"],
    #             version="qwen",
    #             messages=[
    #                 ["<|im_start|>user",
    #                  f'{DEFAULT_IMAGE_TOKEN}\n This is a table image with red borders. The table shape might be ({rows}, {cols}) but could vary. The caption of the table is "{caption}". Besides that, for the following questions, the answer from the llava-0.5 model is {out_list}, which you can use as a reference.'],
    #                 ["<|im_start|>assistant", "I have a general understanding of the information in this table."]
    #             ],
    #             offset=0,
    #             sep_style=SeparatorStyle.CHATML,
    #             sep="<|im_end|>",
    #         )
    #         out_list = self.one_image(self.model, self.tokenizer, image_tensors, image_sizes, conv, qs_list)
    #         sub_item = clean_out(item["image_path"], out_list)
    #         print("MAIN:", out_list)
    #         submission.append(sub_item)
    #     with open('submission.json', 'w') as f:
    #         json.dump(submission, f)

    def one_image(self, model, tokenizer, image_tensors, image_sizes, conv, qs_list):
        out_list = []
        with torch.inference_mode():
            for qs in qs_list:
                conv.append_message(conv.roles[0], qs)
                conv.append_message(conv.roles[1], None)
                prompt = conv.get_prompt()
                input_ids = tokenizer_image_token(
                    prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda()

                output_ids = model.generate(
                    input_ids,
                    images=image_tensors,
                    image_sizes=image_sizes,
                    do_sample=True if args.temperature > 0 else False,
                    temperature=args.temperature,
                    top_p=args.top_p,
                    num_beams=args.num_beams,
                    max_new_tokens=args.max_new_tokens,
                    use_cache=True,
                )
                outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
                conv.messages[-1][-1] = outputs
                out_list.append(outputs)
        return out_list

In [ ]:
worker = Worker()
await worker.run()