In [None]:
!pip install bitsandbytes
!pip install tiktoken
!pip install blobfile
!pip install torch
!pip install docling_core
!pip install transformers
!pip install --upgrade pip
!pip install setuptools==65.5.0 wheel ninja
!pip install flash-attn --no-build-isolation
!pip install "accelerate>=0.26.0"

In [None]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [None]:
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor, BitsAndBytesConfig

# Set device explicitly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using Device:", device)

# 4-bit quantization configuration
quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True
)

# Model path
model_path = "moonshotai/Kimi-VL-A3B-Instruct"

# Load model (no .to(device) needed; handled by device_map)
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    quantization_config=quant_config,
    trust_remote_code=True,
    device_map="auto"  # required for bitsandbytes to correctly map to GPU
)

# Load processor
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)

from transformers import AutoProcessor as DocProcessor, AutoModelForVision2Seq
from transformers.image_utils import load_image

# Initialize SmolDocling model
doc_model_path = "ds4sd/SmolDocling-256M-preview"
doc_processor = DocProcessor.from_pretrained(doc_model_path)
doc_model = AutoModelForVision2Seq.from_pretrained(
    doc_model_path,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    _attn_implementation="flash_attention_2" if torch.cuda.is_available() else "eager",
)

In [None]:
import gc

MAX_TOKENS = 512
def generate_doctags(image_path):
    """Generate structured document tags from image"""
    image = load_image(image_path)
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "text": "Convert this page to docling."}
            ]
        },
    ]
    prompt = doc_processor.apply_chat_template(messages, add_generation_prompt=True)
    inputs = doc_processor(text=prompt, images=[image], return_tensors="pt").to(doc_model.device)
    generated_ids = doc_model.generate(**inputs, max_new_tokens=8192)
    prompt_length = inputs.input_ids.shape[1]
    trimmed_generated_ids = generated_ids[:, prompt_length:]
    return doc_processor.batch_decode(trimmed_generated_ids, skip_special_tokens=True)[0].strip()
doctags_cache = {}
def structured_qa(image_path, question, caption):
    """Generate answer with structured doctags context"""
    # Generate doctags
    if image_path in doctags_cache:
        doctags = doctags_cache[image_path]
    else:
        doctags = generate_doctags(image_path)
        doctags_cache[image_path] = doctags
    if len(doctags_cache) > 50:
        doctags_cache.pop(next(iter(doctags_cache)))

    # Prepare prompt with strict answer format
    structured_prompt = f"""### Context ###
Image Structure Analysis:
{doctags}

Image Caption:
{caption}

### Question ###
{question}

### Instructions ###
1. Analyze the doctags structure from the image
2. Identify relevant sections
3. Think step by step with the image, question and doctags as context
4. Provide ONLY the final answer in this exact format:
REASONING_STEPS: [your reasoning steps here]
FINAL_ANSWER: [your concise answer here]"""

    # Get VLM response
    image = Image.open(image_path)
    messages = [
        {"role": "user", "content": [
            {"type": "image", "image": image_path},
            {"type": "text", "text": structured_prompt}
        ]}
    ]

    text=processor.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
    # print("Prompt Text is", text)
    inputs = processor(
        images=image,
        text=processor.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt"),
        return_tensors="pt",
        padding=True,
        truncation=True
    ).to(model.device)

    generated_ids = model.generate(**inputs, max_new_tokens=MAX_TOKENS)
    response = processor.batch_decode(
        generated_ids[:, inputs.input_ids.shape[-1]:],
        skip_special_tokens=True,
        clean_up_tokenization_spaces=False
    )[0]

    # Extract final answer
    # print(response)
    final_answer = response.split("FINAL_ANSWER:")[-1].strip()
    return final_answer

In [None]:
# !git clone https://github.com/HJYao00/Mulberry.git
# %cd Mulberry

In [None]:
# from PIL import Image
# import torch
# from transformers import (
#     AutoProcessor as DocProcessor,
#     AutoModelForVision2Seq,
#     LlavaNextProcessor,
#     LlavaNextForConditionalGeneration,
#     BitsAndBytesConfig
# )
# from transformers.image_utils import load_image

# # Quantization configuration for efficient inference
# quant_config = BitsAndBytesConfig(
#     load_in_4bit=True,
#     bnb_4bit_compute_dtype=torch.float16,
#     bnb_4bit_quant_type="nf4",
#     bnb_4bit_use_double_quant=True
# )

# # Initialize SmolDocling model for document structure analysis
# doc_model_path = "ds4sd/SmolDocling-256M-preview"
# doc_processor = DocProcessor.from_pretrained(doc_model_path)
# doc_model = AutoModelForVision2Seq.from_pretrained(
#     doc_model_path,
#     torch_dtype=torch.bfloat16,
#     device_map="auto",
#     # _attn_implementation="flash_attention_2" if torch.cuda.is_available() else "eager",
# )

# # Initialize Mulberry model
# mulberry_model_path = 'HuanjinYao/Mulberry_llava_8b'
# mulberry_processor = LlavaNextProcessor.from_pretrained(mulberry_model_path)
# # mulberry_processor.num_additional_image_tokens = 1

# mulberry_model = LlavaNextForConditionalGeneration.from_pretrained(
#     mulberry_model_path,
#     quantization_config=quant_config,
#     device_map='auto'
# )

# mulberry_processor.num_additional_image_tokens = 1

In [None]:
# def generate_doctags(image_path):
#     """Generate structured document tags from image"""
#     image = load_image(image_path)
#     messages = [
#         {
#             "role": "user",
#             "content": [
#                 {"type": "image"},
#                 {"type": "text", "text": "Convert this page to docling."}
#             ]
#         },
#     ]
#     prompt = doc_processor.apply_chat_template(messages, add_generation_prompt=True)
#     inputs = doc_processor(text=prompt, images=[image], return_tensors="pt").to(doc_model.device)
#     generated_ids = doc_model.generate(**inputs, max_new_tokens=8192)
#     prompt_length = inputs.input_ids.shape[1]
#     trimmed_generated_ids = generated_ids[:, prompt_length:]
#     return doc_processor.batch_decode(trimmed_generated_ids, skip_special_tokens=True)[0].strip()
# doctags_cache = {}
# def output_process(answer):
#     """Clean up model output to extract relevant parts"""
#     if "<s>" in answer:
#         answer = answer.replace("<s>", "").strip()
#     if "[/INST]" in answer:
#         answer = answer.split("[/INST]")[1].strip()
#     elif "ASSISTANT:" in answer:
#         answer = answer.split("ASSISTANT:")[1].strip()
#     elif "assistant\n" in answer:
#         answer = answer.split("assistant\n")[1].strip()
#     elif "<|end_header_id|>\n\n" in answer:
#         answer = answer.split("<|end_header_id|>\n\n")[2].strip()

#     if "</s>" in answer:
#         answer = answer.split("</s>")[0].strip()
#     elif "<|im_end|>" in answer:
#         answer = answer.split("<|im_end|>")[0].strip()
#     elif "<|eot_id|>" in answer:
#         answer = answer.split("<|eot_id|>")[0].strip()
#     return answer

# def structured_qa_with_mulberry(image_path, question, only_output_final_answer=False):
#     """Generate answer with structured doctags context using Mulberry model"""
#     # Generate doctags

#     if image_path in doctags_cache:
#         doctags = doctags_cache[image_path]
#     else:
#         doctags = generate_doctags(image_path)
#         doctags_cache[image_path] = doctags
#     if len(doctags_cache) > 50:
#         doctags_cache.pop(next(iter(doctags_cache)))

#     # doctags = generate_doctags(image_path)

#     # Prepare prompt with doctags and structured format
#     structured_prompt = f"""Generate an image description based on the question.
# Then, provide a rationale to analyze the question.
# Next, generate a step-by-step reasoning process to solve the problem. Ensure the steps are logical and concise.
# Finally, provide a concise summary of the final answer in the following format: 'The final answer is: xxx'. If the question is multiple-choice, provide the options along with their content. If it is free-form, directly present the final result. Do not provide any explanation.

# ### Context ###
# Image Structure Analysis:
# {doctags}

# Format your response with the following sections, separated by ###:
# ### Image Description:
# ### Rationales:
# ### Let's think step by step.
# ### Step 1:
# ### Step 2:
# ...
# ### The final answer is:

# {question}"""

#     # Get Mulberry VLM response
#     images = [Image.open(image_path).convert("RGB")]
#     content = [
#         {"type": 'text', "text": structured_prompt},
#         {"type": "image"}
#     ]

#     conversation = [
#         {
#             "role": "user",
#             "content": content,
#         }
#     ]

#     # Generate response
#     prompt = mulberry_processor.apply_chat_template(
#         conversation, add_generation_prompt=True
#     )
#     inputs = mulberry_processor(prompt, images, return_tensors="pt").to(
#         mulberry_model.device
#     )

#     kwargs = dict(
#         do_sample=False,
#         temperature=0.2,
#         max_new_tokens=512,
#         top_p=None,
#         num_beams=1,
#         repetition_penalty=1.0
#     )

#     output = mulberry_model.generate(**inputs, **kwargs)
#     answer = mulberry_processor.decode(output[0], skip_special_tokens=True)
#     answer = output_process(answer)

#     # Extract only final answer if requested
#     if only_output_final_answer:
#         if len(answer.split('### The final answer is:')) == 2:
#             answer = answer.split('### The final answer is:')[-1].strip()

#     return answer


In [None]:
import zipfile

zip_path = "./SPIQA_testA_Images.zip"
extract_path = ""

with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_path)

In [None]:
import numpy as np
import pandas as pd
import json
import os

file_path = "./SPIQA_testA.json"
with open(file_path, 'r') as file:
    text = json.load(file)

data = []
cols = ['paper', 'question', 'answer', 'reference_figure', 'reference_figure_caption']
for paper in text.keys():
    for question in text[paper]['qa']:
        data.append([paper, question['question'], question['answer'], question['reference'], text[paper]['all_figures'][question['reference']]['caption']])


test_df = pd.DataFrame(data, columns=cols)
test_df['generated_answer'] = np.nan

# only use this if you have an existing pdf with some answers
existing_df = pd.read_csv("./spiqa-caption-512.csv")
print(len(existing_df))
cols.append('generated_answer')
existing_df.columns = cols
existing_questions = set(existing_df["question"])
test_df = test_df[~test_df["question"].isin(existing_questions)].reset_index(drop=True)

test_df["image_path"] = "./SPIQA_testA_Images/" + test_df["paper"] + "/" + test_df["reference_figure"]
# test_df['exists'] = test_df['image_path'].apply(lambda x: os.path.exists(x))
# test_df=test_df[test_df['exists']]
print(len(test_df))

In [None]:
from tqdm import tqdm
import csv
import gc
from IPython.display import clear_output

output_csv = './spiqa-caption-512.csv'

for i, row in tqdm(test_df.iterrows(), total=len(test_df)):
    try:
        image_path = row['image_path']
        paper = row['paper']
        question = row['question']
        answer = row['answer']
        reference_figure = row['reference_figure']
        reference_figure_caption = row['reference_figure_caption']

        with torch.no_grad():
            generated_answer = structured_qa(image_path, question, reference_figure_caption)
            # or structured_qa_with_mulberry(image_path, question, only_output_final_answer=True)

        with open(output_csv, 'a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow([paper, question, answer, reference_figure, reference_figure_caption, generated_answer])

        del generated_answer
        del image_path, paper, question, answer, reference_figure, reference_figure_caption
        torch.cuda.empty_cache()
        gc.collect()
        clear_output(wait=True)

    except Exception as e:
        print(f"Error on row {i} (image_path={row['image_path']}, question={row['question']}): {e}")
        torch.cuda.empty_cache()
        gc.collect()