In [None]:
%%capture
!pip install transformers-stream-generator huggingface_hub albumentations \
qwen-vl-utils pyvips-binary sentencepiece opencv-python docling-core \
transformers python-docx torchvision supervision matplotlib \
accelerate pdf2image num2words reportlab html2text markdown \
requests pymupdf loguru hf_xet spaces pyvips pillow gradio \
einops httpx numpy click torch peft fpdf timm av
#Hold tight, this will take around 1-2 minutes.

In [None]:
import os
import sys
import random
import uuid
import json
import time
from threading import Thread
from typing import Iterable

import gradio as gr
import spaces
import torch
from PIL import Image

from transformers import (
    Qwen2_5_VLForConditionalGeneration,
    AutoProcessor,
    TextIteratorStreamer,
)

from transformers.image_utils import load_image

# Custom CSS for styling
css = """
#main-title h1 {
    font-size: 2.3em !important;
}
#output-title h2 {
    font-size: 2.1em !important;
}
"""

# --- Configuration ---
MAX_MAX_NEW_TOKENS = 4096
DEFAULT_MAX_NEW_TOKENS = 2048
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))

# --- Model Loading ---
# Load olmOCR-2-7B-1025
MODEL_ID = "allenai/olmOCR-2-7B-1025"
print(f"Loading model: {MODEL_ID}")
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    MODEL_ID,
    trust_remote_code=True,
    torch_dtype=torch.float16,
    attn_implementation="flash_attention_2" if torch.cuda.is_available() else "eager"
).to(device).eval()
print("Model loaded successfully.")

@spaces.GPU
def generate_response(text: str, image: Image.Image,
                      max_new_tokens: int, temperature: float, top_p: float,
                      top_k: int, repetition_penalty: float):
    """
    Generates responses using the olmOCR model for the given image and text prompt.
    Yields the generated text in a streaming manner.
    """
    if image is None:
        yield "Please upload an image.", "Please upload an image."
        return

    # Prepare the messages for the chat template
    messages = [{
        "role": "user",
        "content": [
            {"type": "image"},
            {"type": "text", "text": text},
        ]
    }]

    prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    inputs = processor(
        text=[prompt_full],
        images=[image],
        return_tensors="pt",
        padding=True
    ).to(device)

    streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)

    generation_kwargs = {
        **inputs,
        "streamer": streamer,
        "max_new_tokens": max_new_tokens,
        "do_sample": True,
        "temperature": temperature,
        "top_p": top_p,
        "top_k": top_k,
        "repetition_penalty": repetition_penalty,
    }

    # Run generation in a separate thread
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    buffer = ""
    for new_text in streamer:
        buffer += new_text
        buffer = buffer.replace("<|im_end|>", "")
        time.sleep(0.01)
        yield buffer, buffer

with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
    gr.Markdown("# **olmOCR-2-7B Demo**", elem_id="main-title")
    gr.Markdown("This interface uses the `allenai/olmOCR-2-7B-1025` model for Optical Character Recognition.")

    with gr.Row():
        with gr.Column(scale=2):
            image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here (e.g., 'Transcribe the text')...")
            image_upload = gr.Image(type="pil", label="Upload Image", height=320)

            image_submit = gr.Button("Submit", variant="primary")

            with gr.Accordion("Advanced Generation Options", open=False):
                max_new_tokens = gr.Slider(label="Max New Tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
                temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.7)
                top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
                top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
                repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.1)

        with gr.Column(scale=3):
            gr.Markdown("## Output", elem_id="output-title")
            output_stream = gr.Textbox(label="Raw Output Stream", interactive=False, lines=15, show_copy_button=True)
            with gr.Accordion("Formatted Markdown Output", open=True):
                markdown_output = gr.Markdown(label="Formatted Result")

    # Connect the submit button to the generation function
    image_submit.click(
        fn=generate_response,
        inputs=[image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
        outputs=[output_stream, markdown_output]
    )

if __name__ == "__main__":
    demo.queue(max_size=50).launch(show_error=True)