In [33]:
import fitz
from PIL import Image
import torch
from transformers import pipeline
from pathlib import Path
import re
from typing import List, Dict
from schema import PaperStructure, Figure, Table, Section

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Load models (lazy)
vqa_pipe = None
llava_pipe = None

In [None]:
def load_vqa():
    global vqa_pipe
    if vqa_pipe is None:
        vqa_pipe = pipeline(
            "document-question-answering",
            model="impira/layoutlm-document-qa",
            device=0 if torch.cuda.is_available() else -1
        )
    return vqa_pipe

def load_llava():
    global llava_pipe
    if llava_pipe is None:
        llava_pipe = pipeline(
            "image-to-text",
            model="llava-hf/llava-1.5-7b-hf",
            torch_dtype=torch.float16,
            device=0 if torch.cuda.is_available() else -1
        )
    return llava_pipe

In [35]:
def pdf_to_images(pdf_path: str, dpi: int = 200) -> List[Image.Image]:
    doc = fitz.open(pdf_path)
    return [
        Image.frombytes("RGB", [p.width, p.height], p.get_pixmap(matrix=fitz.Matrix(dpi/72, dpi/72)).samples)
        for p in doc
    ]

In [36]:
def extract_title_abstract(pdf_path: str) -> tuple[str, str]:
    vqa = load_vqa()
    image = pdf_to_images(pdf_path)[0]  # First page
    title = vqa(image, question="What is the title of this paper?")[0]['answer']
    abstract = vqa(image, question="Extract the full abstract.")[0]['answer']
    return title.strip(), abstract.strip()

In [43]:
def extract_figures_tables(pdf_path: str, save_dir: str) -> tuple[List[Figure], List[Table]]:
    doc = fitz.open(pdf_path)
    save_dir = Path(save_dir)
    save_dir.mkdir(exist_ok=True)

    figures, tables = [], []
    fig_counter = table_counter = 1

    llava = load_llava()

    for page_num, page in enumerate(doc):
        # Extract images
        image_list = page.get_images(full=True)
        for img_idx, img in enumerate(image_list):
            xref = img[0]
            base_image = doc.extract_image(xref)
            image_bytes = base_image["image"]
            img_path = save_dir / f"fig_page{page_num+1}_{img_idx}.png"
            img_path.write_bytes(image_bytes)

            # Describe with LLaVA
            image = Image.open(img_path)
            prompt = "USER: <image>\nDescribe this figure in detail: axes, trends, key insights.\nASSISTANT:"
            desc = llava(prompt, images=image, max_new_tokens=150)[0]["generated_text"].split("ASSISTANT:")[-1].strip()

            # Extract caption from text near image
            caption = extract_caption_near_block(page, page.get_image_bbox(img))

            figures.append(Figure(
                fig_num=f"Figure {fig_counter}",
                description=desc,
                caption=caption,
                img_path=str(img_path)
            ))
            fig_counter += 1

        # Extract tables (via text patterns)
        text = page.get_text("text")
        table_blocks = re.finditer(r"Table \d+[\s\S]*?(?=\n\n|\Z)", text, re.DOTALL)
        for match in table_blocks:
            caption = match.group(0).split("\n")[0]
            content = match.group(0)
            newline = '\n'
            html = f"<table><tr><td>{content.replace(newline, '</td></tr><tr><td>')}</td></tr></table>"

            tables.append(Table(
                table_num=f"Table {table_counter}",
                description=f"Table contains data on: {content[:200]}...",
                caption=caption,
                html=html
            ))
            table_counter += 1

    return figures, tables

In [44]:
def extract_caption_near_block(page, bbox) -> str:
    text_blocks = page.get_text("blocks")
    x0, y0, x1, y1 = bbox
    for block in text_blocks:
        bx0, by0, bx1, by1, text = block[:5]
        if abs(by1 - y0) < 50:  # within 50pt below
            return text.strip()
    return ""

In [45]:
def extract_sections(pdf_path: str) -> List[Section]:
    doc = fitz.open(pdf_path)
    sections = []
    current_heading = None
    current_content = []

    for page in doc:
        blocks = page.get_text("dict")["blocks"]
        for b in blocks:
            if "lines" not in b: continue
            for line in b["lines"]:
                for span in line["spans"]:
                    text = span["text"].strip()
                    if span["size"] > 14 and text.isupper():  # Heading
                        if current_heading:
                            sections.append(Section(heading=current_heading, content=" ".join(current_content)))
                        current_heading = text
                        current_content = []
                    else:
                        current_content.append(text)

    if current_heading:
        sections.append(Section(heading=current_heading, content=" ".join(current_content)))
    return sections

In [46]:
def extract_structure(pdf_path: str, arxiv_id: str, img_dir: str) -> PaperStructure:
    print(f"Extracting structure for {arxiv_id}...")
    title, abstract = extract_title_abstract(pdf_path)
    figures, tables = extract_figures_tables(pdf_path, img_dir)
    sections = extract_sections(pdf_path)

    return PaperStructure(
        title=title,
        abstract=abstract,
        sections=sections,
        figures=figures,
        tables=tables,
        arxiv_id=arxiv_id,
        pdf_path=pdf_path
    )

In [47]:
extract_structure("data/raw_papers/2510.23571v1.pdf", "1234.56789", "output_images")

Extracting structure for 1234.56789...


OSError: impira/layoutlm-document-vqa is not a local folder and is not a valid model identifier listed on 'https://huggingface.co/models'
If this is a private repository, make sure to pass a token having permission to this repo either by logging in with `hf auth login` or by passing `token=<your_token>`