In [1]:
# config env
pkgs_path = "/bohr/pkgs-7x29/v1/pkgs"
llava_lib_path = "/bohr/libb-bg5b/v1/llava"
model_path = "/bohr/1111-ggy7/v1/table-llava-v1.5-7b"
# model_path = "/bohr/1111-oxpj/v2/llava-v1.6-vicuna-7b"
cache_path = "/bohr/cache-3bi6/v1/cache"
sym_prompt = "A conversation between a data analyst and an AI assistant who is an expert in processing and interpreting table data. The AI assistant should provide precise and concise answers based on the table contents, demonstrating expertise in understanding and analyzing tabular data. The assistant should only provide necessary information, avoiding any extra explanations or unrelated content."
# pkgs_path = "/personal/pkgs"
# llava_lib_path = "/personal/llava"
# model_path = "/personal/model/llava-v1.6-vicuna-7b"
# cache_path = "/personal/cache"

# !pip install {pkgs_path}/*
# !cp {llava_lib_path} . -r
# !cp {cache_path} . -r
import os

# 提交时可能不能联网，设置成离线模式防止联网失败报错
os.environ['TRANSFORMERS_OFFLINE'] = '1'
os.environ['HF_DATASETS_OFFLINE'] = '1'
os.environ['HF_HUB_OFFLINE'] = '1'
os.environ["HUGGINGFACE_HUB_CACHE"] = "./cache"
os.environ["HF_HOME"] = "./cache"

In [2]:
from llava.constants import (
    IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
)
from llava.conversation import Conversation
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import (
    process_images,
    tokenizer_image_token,
    get_model_name_from_path,
)
from llava.conversation import SeparatorStyle

import torch
import random

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
args = type('Args', (), {
    "model_path": model_path,
    "model_base": None,
    "model_name": get_model_name_from_path(model_path),
    "conv_mode": None,
    "sep": ",",
    "temperature": 0,
    "top_p": 1,
    "num_beams": 1,
    "max_new_tokens": 8
})()
torch.cuda.empty_cache()
disable_torch_init()


model_name = get_model_name_from_path(args.model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(
    args.model_path, args.model_base, model_name,
    local_files_only=True,
    cache_dir="./cache",
    # use_flash_attn=True,
    # load_4bit=True
)

You are using a model of type llava to instantiate a model of type llava_llama. This is not supported for all configurations of models and can yield errors.
  return self.fget.__get__(instance, owner)()
Loading checkpoint shards: 100%|██████████| 2/2 [00:06<00:00,  3.33s/it]
Some weights of the model checkpoint at /bohr/1111-ggy7/v1/table-llava-v1.5-7b were not used when initializing LlavaLlamaForCausalLM: ['model.vision_tower.vision_tower.vision_model.embeddings.class_embedding', 'model.vision_tower.vision_tower.vision_model.embeddings.patch_embedding.weight', 'model.vision_tower.vision_tower.vision_model.embeddings.position_embedding.weight', 'model.vision_tower.vision_tower.vision_model.encoder.layers.0.layer_norm1.bias', 'model.vision_tower.vision_tower.vision_model.encoder.layers.0.layer_norm1.weight', 'model.vision_tower.vision_tower.vision_model.encoder.layers.0.layer_norm2.bias', 'model.vision_tower.vision_tower.vision_model.encoder.layers.0.layer_norm2.weight', 'model.vision_t

In [4]:
import os
import json
from PIL import Image

if os.environ.get('DATA_PATH_B'):  # 提交时会选择隐藏的测试数据集路径（A+B榜），数据集的格式与A榜数据相同，但数目不同（5360张）
    base_dir = os.environ.get('DATA_PATH_B')
else:
    base_dir = '/bohr/form-recognition-train-b6y2/v4'  # 示例，把A榜测试数据集路径作为测试集路径，仅开发时挂载A榜数据用于debug   # 示例，把A榜测试数据集路径作为测试集路径，仅开发时挂载A榜数据用于debug

In [5]:
with open(os.path.join(base_dir, 'dataset.json'), 'r') as f:
    data = json.load(f)

# with open(os.path.join(base_dir, 'sample_submission.json'), 'r') as f:
#     sub = json.load(f)

In [10]:

def one_image(img_path, qs_list):
    image = Image.open(img_path).convert("RGB")
    image_sizes = [image.size]
    images_tensor = process_images([image], image_processor, model.config)[0].unsqueeze(0).half().cuda()

    conv = Conversation(
        system=sym_prompt,
        roles=["Human", "Assistant"],
        messages=[],
        offset=2,
        sep_style=SeparatorStyle.SINGLE,
        sep="###",
    )
    out_list = []
    with torch.inference_mode():
        for qs in qs_list:
            qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
            conv.append_message(conv.roles[0], qs)
            conv.append_message(conv.roles[1], None)
            prompt = conv.get_prompt()

            input_ids =  tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda()

            output_ids = model.generate(
                input_ids,
                images=images_tensor,
                image_sizes=image_sizes,
                do_sample=True if args.temperature > 0 else False,
                temperature=args.temperature,
                top_p=args.top_p,
                num_beams=args.num_beams,
                max_new_tokens=args.max_new_tokens,
                use_cache=True,
            )
            outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
            conv.messages[-1][1] = outputs
            out_list.append(outputs)
            print(outputs)
    return out_list

In [7]:
def clean_out(image_path, out_list):
    response1 = out_list[0]
    response21 = out_list[1]
    response22 = out_list[2]
    response3 = out_list[3]
    sub_list = ('Physics', 'Mathematics', 'ComputerScience', 'QuantitativeBiology', 'QuantitativeFinance', 'Statistics','ElectricalEngineeringandSystemsScience', 'Economics')
    try:
        response1 = int(response1)
        # response1 = response1.replace(" ", "")
        # if response1 not in sub_list:
        #     flag = True
        #     for cat in sub_list:
        #         if cat in response1:
        #             response1 = cat
        #             flag = False
        #             break
        #     if flag:
        #         response1 = random.choice(sub_list)
        # else:
        #     pass
    except:
        response1 = random.randint(0, 7)
    try:
        rows = int(response21)
    except:
        rows = random.randint(6, 16)
    try:
        cols = int(response22)
    except:
        cols = random.randint(4, 14)
    try:
        answer = int(response3[0])
        if 0 <= answer <= 3:
            pass
        else:
            answer = random.randint(0, 3)
    except:
        answer = random.randint(0, 3)
    sub_item = {
        "image_path": image_path,
        "category": sub_list[response1],
        "cols": cols,
        "rows": rows,
        "answer": answer,
    }
    return sub_item

In [8]:
# submission = []

# for item in data:
#     image_path = os.path.join(base_dir, 'test_images', item["image_path"])
#     qs_list = [
#         f'This table caption: "{item["caption"]}". Based on the provided table and description, select the most relevant subject from ([0]Physics, [1]Mathematics, [2]ComputerScience, [3]QuantitativeBiology, [4]QuantitativeFinance, [5]Statistics, [6]ElectricalEngineeringandSystemsScience, [7]Economics), provide the serial number:',
#         'How many rows are in this table? Provide an exact integer:',
#         'How many cols are in this table? Provide an exact integer:',
#         f'Question: "{item["question"]}"\nOptions:\n[0] "{item["options"][0]}"\n[1] "{item["options"][1]}"\n[2] "{item["options"][2]}"\n[3] "{item["options"][3]}"\nSelect the correct option by providing the serial number: (0, 1, 2, or 3)'
#     ]
#     out_list = one_image(image_path, qs_list)
#     sub_item = clean_out(item["image_path"], out_list)
#     submission.append(sub_item)

# with open('submission.json', 'w') as f:
#     json.dump(submission, f)

In [11]:
submission = []

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
idx = 100
for item in data[idx: idx + 10]:
    image_path = os.path.join(base_dir, 'test_images', item["image_path"])
    qs_list = [
        f'This table caption: "{item["caption"]}". Based on the provided table and description, select the most relevant subject from ([0]Physics, [1]Mathematics, [2]ComputerScience, [3]QuantitativeBiology, [4]QuantitativeFinance, [5]Statistics, [6]ElectricalEngineeringandSystemsScience, [7]Economics), provide the serial number:',
        'How many rows are in this table? Provide an exact integer:',
        'How many cols are in this table? Provide an exact integer:',
        f'Question: "{item["question"]}"\nOptions:\n[0] "{item["options"][0]}"\n[1] "{item["options"][1]}"\n[2] "{item["options"][2]}"\n[3] "{item["options"][3]}"\nSelect the correct option by providing the serial number: (0, 1, 2, or 3)'
    ]
    out_list = one_image(image_path, qs_list)
    sub_item = clean_out(item["image_path"], out_list)
    submission.append(sub_item)
    print(qs_list)
    print(out_list)
    print(sub_item)
    img = mpimg.imread(image_path)

    # 显示图片
    plt.imshow(img)
    plt.axis('off')  # 不显示坐标轴
    plt.show()

with open('submission.json', 'w') as f:
    json.dump(submission, f)

Error: 