In [None]:
# Clone the repository
!git clone https://github.com/rednote-hilab/dots.ocr.git

%cd dots.ocr

In [None]:
# Delete flash-attn from requirements
!sed -i '/flash-attn/d' requirements.txt

In [None]:
# Install dependencies
!pip install torch torchvision torchaudio
!pip install -e .

In [None]:
# Download dots.ocr model
!python3 tools/download_model.py

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoProcessor
from qwen_vl_utils import process_vision_info

# --- Load the model with CPU-specific configurations ---
model_path = "./weights/DotsOCR"
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    attn_implementation="sdpa",  # Change 1: Use "sdpa" instead of "flash_attention_2"
    torch_dtype=torch.bfloat16,
    device_map="cpu",            # Change 2: Explicitly set the device to "cpu"
    trust_remote_code=True
)
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)

# --- Prepare your image and prompt ---
image_path = "/content/example.jpg" # Or any other image path
prompt = """Extract texts from image""" # Your prompt here

messages = [
    {
        "role": "user",
        "content": [
            { "type": "image", "image": image_path },
            {"type": "text", "text": prompt}
        ]
    }
]

# --- Standard processing steps ---
text = processor.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
    text=[text],
    images=image_inputs,
    videos=video_inputs,
    padding=True,
    return_tensors="pt",
)

# --- Inference ---
generated_ids = model.generate(**inputs, max_new_tokens=128)
generated_ids_trimmed = [
    out_ids[len(in_ids):]
    for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
    generated_ids_trimmed,
    skip_special_tokens=True,
    clean_up_tokenization_spaces=False
)
print(output_text)