In [32]:
from transformers import AutoTokenizer, AutoProcessor, VisionEncoderDecoderModel
from transformers.image_utils import load_image

In [33]:
from PIL import Image
import io
import re
import cv2
import numpy as np

In [34]:
model_name =  "ByteDance/Dolphin"
processor = AutoProcessor.from_pretrained(model_name,use_fast=True)
model = VisionEncoderDecoderModel.from_pretrained(model_name)

In [35]:
DEVICE = "cuda" if model.device.type == "cuda" else "cpu"

In [36]:
image_path = "../data/table.jpg"
image = load_image(image_path)

In [37]:
batch_inputs = processor(images=image, return_tensors="pt", padding=True)
batch_pixel_values = batch_inputs.pixel_values.to(DEVICE)

In [38]:
prompt = "Parse the reading order of this document."

In [39]:
prompts = [f"<s>{prompt} <Answer/>"]
tokenizer = processor.tokenizer
prompt_inputs = tokenizer(prompts, add_special_tokens=False, return_tensors="pt")

In [40]:
prompt_ids = prompt_inputs.input_ids.to(DEVICE)
attention_mask = prompt_inputs.attention_mask.to(DEVICE)

In [41]:
import torch
with torch.no_grad():
     outputs = model.generate(
                pixel_values=batch_pixel_values,
                decoder_input_ids=prompt_ids,
                decoder_attention_mask=attention_mask,
                min_length=1,
                max_length=4096,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
                use_cache=True,
                bad_words_ids=[[tokenizer.unk_token_id]],
                return_dict_in_generate=True,
                do_sample=False,
                num_beams=1,
                repetition_penalty=1.1,
            )
    


In [42]:
prompt_length = prompt_ids.shape[1]

In [43]:
trimmed_generated_ids = outputs.sequences[:, prompt_length:]

In [44]:
seq = processor.batch_decode(trimmed_generated_ids, skip_special_tokens=True)

In [45]:
seq

[' [0.17,0.90,0.18,0.94] header[0.18,0.05,0.22,0.05] header[0.55,0.05,0.83,0.05] header[0.18,0.06,0.21,0.94] para[0.21,0.07,0.45,0.95] tab[0.52,0.07,0.83,0.10] anno[0.52,0.10,0.83,0.14] list[0.52,0.14,0.83,0.18] list[0.52,0.18,0.83,0.22] list[0.51,0.23,0.83,0.33] para[0.51,0.34,0.83,0.37] para[0.52,0.39,0.83,0.44] list[0.52,0.44,0.83,0.50] list[0.52,0.51,0.83,0.53] list[0.52,0.53,0.83,0.57] list[0.52,0.57,0.83,0.60] list[0.51,0.61,0.83,0.68] para[0.51,0.68,0.83,0.74] para[0.51,0.74,0.83,0.83] para[0.51,0.85,0.62,0.86] sub_sub_sec[0.51,0.86,0.83,0.94] para[0.49,0.96,0.51,0.96] foot']

In [46]:
results = []
for i, sequence in enumerate(seq):
            cleaned = sequence.replace(prompts[i], "").replace("<pad>", "").replace("</s>", "").strip()
            results.append(cleaned)

In [47]:
layout_output = results[0]

In [48]:
from dataclasses import dataclass
@dataclass
class ImageDimensions:
    """Class to store image dimensions"""

    original_w: int
    original_h: int
    padded_w: int
    padded_h: int

In [49]:
def prepare_image(image):
    """Load and prepare image with padding while maintaining aspect ratio

    Args:
        image: PIL image

    Returns:
        tuple: (padded_image, image_dimensions)
    """
    try:
        # Convert PIL image to OpenCV format
        image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
        original_h, original_w = image.shape[:2]

        # Calculate padding to make square image
        max_size = max(original_h, original_w)
        top = (max_size - original_h) // 2
        bottom = max_size - original_h - top
        left = (max_size - original_w) // 2
        right = max_size - original_w - left

        # Apply padding
        padded_image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(0, 0, 0))

        padded_h, padded_w = padded_image.shape[:2]

        dimensions = ImageDimensions(original_w=original_w,
                                     original_h=original_h,
                                     padded_w=padded_w,
                                     padded_h=padded_h)

        return padded_image, dimensions
    except Exception as e:
        print(f"prepare_image error: {str(e)}")
        # Create a minimal valid image and dimensions
        h, w = image.height, image.width
        dimensions = ImageDimensions(original_w=w, original_h=h, padded_w=w, padded_h=h)
        # Return a black image of the same size
        return np.zeros((h, w, 3), dtype=np.uint8), dimensions

In [50]:
## Preparin the image
padded_image, dims = prepare_image(image)

In [51]:
def parse_layout_string(bbox_str):
    """Parse layout string using regular expressions"""
    pattern = r"\[(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+)\]\s*(\w+)"
    matches = re.finditer(pattern, bbox_str)

    parsed_results = []
    for match in matches:
        coords = [float(match.group(i)) for i in range(1, 5)]
        label = match.group(5).strip()
        parsed_results.append((coords, label))

    return parsed_results

In [52]:
import copy
from typing import List
def adjust_box_edges(image, boxes: List[List[float]], max_pixels=15, threshold=0.2):
    """
    Image: cv2.image object, or Path
    Input: boxes: list of boxes [[x1, y1, x2, y2]]. Using absolute coordinates.
    """
    if isinstance(image, str):
        image = cv2.imread(image)
    img_h, img_w = image.shape[:2]
    new_boxes = []
    for box in boxes:
        best_box = copy.deepcopy(box)

        def check_edge(img, current_box, i, is_vertical):
            edge = current_box[i]
            gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
            _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)

            if is_vertical:
                line = binary[current_box[1] : current_box[3] + 1, edge]
            else:
                line = binary[edge, current_box[0] : current_box[2] + 1]

            transitions = np.abs(np.diff(line))
            return np.sum(transitions) / len(transitions)

        # Only widen the box
        edges = [(0, -1, True), (2, 1, True), (1, -1, False), (3, 1, False)]

        current_box = copy.deepcopy(box)
        # make sure the box is within the image
        current_box[0] = min(max(current_box[0], 0), img_w - 1)
        current_box[1] = min(max(current_box[1], 0), img_h - 1)
        current_box[2] = min(max(current_box[2], 0), img_w - 1)
        current_box[3] = min(max(current_box[3], 0), img_h - 1)

        for i, direction, is_vertical in edges:
            best_score = check_edge(image, current_box, i, is_vertical)
            if best_score <= threshold:
                continue
            for step in range(max_pixels):
                current_box[i] += direction
                if i == 0 or i == 2:
                    current_box[i] = min(max(current_box[i], 0), img_w - 1)
                else:
                    current_box[i] = min(max(current_box[i], 0), img_h - 1)
                score = check_edge(image, current_box, i, is_vertical)

                if score < best_score:
                    best_score = score
                    best_box = copy.deepcopy(current_box)

                if score <= threshold:
                    break
        new_boxes.append(best_box)

    return new_boxes


In [53]:
from typing import Tuple
def map_to_original_coordinates(x1, y1, x2, y2, dims: ImageDimensions) -> Tuple[int, int, int, int]:
    """Map coordinates from padded image back to original image

    Args:
        x1, y1, x2, y2: Coordinates in padded image
        dims: Image dimensions object

    Returns:
        tuple: (x1, y1, x2, y2) coordinates in original image
    """
    try:
        # Calculate padding offsets
        top = (dims.padded_h - dims.original_h) // 2
        left = (dims.padded_w - dims.original_w) // 2

        # Map back to original coordinates
        orig_x1 = max(0, x1 - left)
        orig_y1 = max(0, y1 - top)
        orig_x2 = min(dims.original_w, x2 - left)
        orig_y2 = min(dims.original_h, y2 - top)

        # Ensure we have a valid box (width and height > 0)
        if orig_x2 <= orig_x1:
            orig_x2 = min(orig_x1 + 1, dims.original_w)
        if orig_y2 <= orig_y1:
            orig_y2 = min(orig_y1 + 1, dims.original_h)

        return int(orig_x1), int(orig_y1), int(orig_x2), int(orig_y2)
    except Exception as e:
        print(f"map_to_original_coordinates error: {str(e)}")
        # Return safe coordinates
        return 0, 0, min(100, dims.original_w), min(100, dims.original_h)

In [54]:
def process_coordinates(coords, padded_image, dims: ImageDimensions, previous_box=None):
    """Process and adjust coordinates

    Args:
        coords: Normalized coordinates [x1, y1, x2, y2]
        padded_image: Padded image
        dims: Image dimensions object
        previous_box: Previous box coordinates for overlap adjustment

    Returns:
        tuple: (x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, new_previous_box)
    """
    try:
        # Convert normalized coordinates to absolute coordinates
        x1, y1 = int(coords[0] * dims.padded_w), int(coords[1] * dims.padded_h)
        x2, y2 = int(coords[2] * dims.padded_w), int(coords[3] * dims.padded_h)

        # Ensure coordinates are within image bounds before adjustment
        x1 = max(0, min(x1, dims.padded_w - 1))
        y1 = max(0, min(y1, dims.padded_h - 1))
        x2 = max(0, min(x2, dims.padded_w))
        y2 = max(0, min(y2, dims.padded_h))

        # Ensure width and height are at least 1 pixel
        if x2 <= x1:
            x2 = min(x1 + 1, dims.padded_w)
        if y2 <= y1:
            y2 = min(y1 + 1, dims.padded_h)

        # Extend box boundaries
        new_boxes = adjust_box_edges(padded_image, [[x1, y1, x2, y2]])
        x1, y1, x2, y2 = new_boxes[0]

        # Ensure coordinates are still within image bounds after adjustment
        x1 = max(0, min(x1, dims.padded_w - 1))
        y1 = max(0, min(y1, dims.padded_h - 1))
        x2 = max(0, min(x2, dims.padded_w))
        y2 = max(0, min(y2, dims.padded_h))

        # Ensure width and height are at least 1 pixel after adjustment
        if x2 <= x1:
            x2 = min(x1 + 1, dims.padded_w)
        if y2 <= y1:
            y2 = min(y1 + 1, dims.padded_h)

        # Check for overlap with previous box and adjust
        if previous_box is not None:
            prev_x1, prev_y1, prev_x2, prev_y2 = previous_box
            if (x1 < prev_x2 and x2 > prev_x1) and (y1 < prev_y2 and y2 > prev_y1):
                y1 = prev_y2
                # Ensure y1 is still valid
                y1 = min(y1, dims.padded_h - 1)
                # Make sure y2 is still greater than y1
                if y2 <= y1:
                    y2 = min(y1 + 1, dims.padded_h)

        # Update previous box
        new_previous_box = [x1, y1, x2, y2]

        # Map to original coordinates
        orig_x1, orig_y1, orig_x2, orig_y2 = map_to_original_coordinates(x1, y1, x2, y2, dims)

        return x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, new_previous_box
    except Exception as e:
        print(f"process_coordinates error: {str(e)}")
        # Return safe values
        orig_x1, orig_y1, orig_x2, orig_y2 = 0, 0, min(100, dims.original_w), min(100, dims.original_h)
        return 0, 0, 100, 100, orig_x1, orig_y1, orig_x2, orig_y2, [0, 0, 100, 100]


In [55]:
import os
def save_figure_to_local(pil_crop, save_dir, image_name, reading_order):
    """Save cropped figure to local file system

    Args:
        pil_crop: PIL Image object of the cropped figure
        save_dir: Base directory to save results
        image_name: Name of the source image/document
        reading_order: Reading order of the figure in the document

    Returns:
        str: Filename of the saved figure
    """
    try:
        # Create figures directory if it doesn't exist
        figures_dir = os.path.join(save_dir, "markdown", "figures")
        # os.makedirs(figures_dir, exist_ok=True)

        # Generate figure filename
        figure_filename = f"{image_name}_figure_{reading_order:03d}.png"
        figure_path = os.path.join(figures_dir, figure_filename)

        # Save the figure
        pil_crop.save(figure_path, format="PNG", quality=95)

        # print(f"Saved figure: {figure_filename}")
        return figure_filename

    except Exception as e:
        print(f"Error saving figure: {str(e)}")
        # Return a fallback filename
        return f"{image_name}_figure_{reading_order:03d}_error.png"

In [56]:
def chat(prompt, image):
    """Process image(s) with the given prompt(s)
    
    Args:
        prompt: String or list of prompts
        image: PIL Image or list of PIL Images
        
    Returns:
        List of generated text results
    """
    try:
        # Handle single image/prompt case
        is_single = not isinstance(image, list)
        if is_single:
            images = [image]
            prompts = [prompt]
        else:
            images = image
            prompts = prompt if isinstance(prompt, list) else [prompt] * len(images)
            
        # Process images
        inputs = processor(
            images=images,
            text=prompts,
            return_tensors="pt",
            padding=True
        )
        
        # Move inputs to device and handle precision
        pixel_values = inputs.pixel_values.to(DEVICE)
        if hasattr(model, 'dtype') and model.dtype == torch.float16:
            pixel_values = pixel_values.half()
            
        # Prepare prompts
        formatted_prompts = [f"<s>{p} <Answer/>" for p in prompts]
        prompt_ids = processor.tokenizer(
            formatted_prompts,
            return_tensors="pt",
            padding=True,
            add_special_tokens=False
        ).input_ids.to(DEVICE)
        
        # Generate text
        with torch.no_grad():
            outputs = model.generate(
                pixel_values=pixel_values,
                decoder_input_ids=prompt_ids,
                min_length=1,
                max_length=4096,
                pad_token_id=processor.tokenizer.pad_token_id,
                eos_token_id=processor.tokenizer.eos_token_id,
                use_cache=True,
                bad_words_ids=[[processor.tokenizer.unk_token_id]],
                return_dict_in_generate=True,
                do_sample=False,
                num_beams=1,
                repetition_penalty=1.1,
                temperature=1.0
            )
        
        # Decode outputs
        sequences = processor.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)
        
        # Clean up sequences
        results = []
        for sequence, input_prompt in zip(sequences, formatted_prompts):
            cleaned = sequence.replace(input_prompt, "").replace("<pad>", "").replace("</s>", "").strip()
            results.append(cleaned)
            
        return results[0] if is_single else results
        
    except Exception as e:
        print(f"Error in chat function: {str(e)}")
        if is_single:
            return ""
        return [""] * len(images)

In [28]:
def process_element_batch(elements, model, prompt, max_batch_size):
    """Process elements of the same type in batches"""
    results = []
    
    # Use default batch size if None
    batch_size = max_batch_size if max_batch_size else 4
    
    # Process in batches
    for i in range(0, len(elements), batch_size):
        batch_elements = elements[i:i+batch_size]
        crops_list = [elem["crop"] for elem in batch_elements]
        
        # Process batch
        batch_results = chat(prompt, crops_list)
        
        # Add results
        for j, result in enumerate(batch_results):
            elem = batch_elements[j]
            results.append({
                "label": elem["label"],
                "bbox": elem["bbox"],
                "text": result.strip(),
                "reading_order": elem["reading_order"],
            })
    
    return results

In [57]:
def process_elements(layout_results, padded_image, dims, model, max_batch_size, save_dir=None, image_name=None):
    """Parse all document elements with parallel decoding"""
    layout_results = parse_layout_string(layout_results)

    # Store text and table elements separately
    text_elements = []  # Text elements
    table_elements = []  # Table elements
    figure_results = []  # Image elements (no processing needed)
    previous_box = None
    reading_order = 0

    # Collect elements to process and group by type
    for bbox, label in layout_results:
        try:
            # Adjust coordinates
            x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, previous_box = process_coordinates(
                bbox, padded_image, dims, previous_box
            )

            # Crop and parse element
            cropped = padded_image[y1:y2, x1:x2]
            if cropped.size > 0 and cropped.shape[0] > 3 and cropped.shape[1] > 3:
                if label == "fig":
                    pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
                    
                    figure_filename = save_figure_to_local(pil_crop, save_dir, image_name, reading_order)
                    
                    # For figure regions, store relative path instead of base64
                    figure_results.append(
                        {
                            "label": label,
                            "text": f"![Figure](figures/{figure_filename})",
                            "figure_path": f"figures/{figure_filename}",
                            "bbox": [orig_x1, orig_y1, orig_x2, orig_y2],
                            "reading_order": reading_order,
                        }
                    )
                else:
                    # Prepare element for parsing
                    pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
                    element_info = {
                        "crop": pil_crop,
                        "label": label,
                        "bbox": [orig_x1, orig_y1, orig_x2, orig_y2],
                        "reading_order": reading_order,
                    }
                    
                    # Group by type
                    if label == "tab":
                        table_elements.append(element_info)
                    else:  # Text elements
                        text_elements.append(element_info)

            reading_order += 1

        except Exception as e:
            print(f"Error processing bbox with label {label}: {str(e)}")
            continue

    # Initialize results list
    recognition_results = figure_results.copy()
    
    # Process text elements (in batches)
    if text_elements:
        text_results = process_element_batch(text_elements, model, "Read text in the image.", max_batch_size)
        recognition_results.extend(text_results)
    
    # Process table elements (in batches)
    if table_elements:
        table_results = process_element_batch(table_elements, model, "Parse the table in the image.", max_batch_size)
        recognition_results.extend(table_results)

    # Sort elements by reading order
    recognition_results.sort(key=lambda x: x.get("reading_order", 0))

    return recognition_results

In [58]:
recognition_results = process_elements(layout_output, padded_image, dims, model, None, "data", "table")


In [61]:
def generate_readme_from_results(recognition_results):
    """Generate a structured README from OCR recognition results
    
    Args:
        recognition_results: List of dictionaries containing OCR results with labels and text
        
    Returns:
        str: Formatted README content
    """
    # Initialize sections
    title = ""
    headers = []
    paragraphs = []
    lists = []
    tables = []
    current_section = ""
    
    # Process results in reading order
    for item in recognition_results:
        label = item['label']
        text = item['text'].strip()
        
        # Skip empty text
        if not text:
            continue
            
        # Process by label type
        if label == 'header':
            headers.append(text)
            if not title and len(text) > 10:  # Assume longer header is title
                title = text
                
        elif label == 'para':
            paragraphs.append(text)
            
        elif label == 'list':
            lists.append(text)
            
        elif label == 'tab':
            tables.append(text)
    
    # Build README content
    readme = []
    
    # Add title
    if title:
        readme.append(f"# {title}\n")
    
    # Add remaining headers as sections
    for header in headers:
        if header != title:
            readme.append(f"## {header}\n")
    
    # Add paragraphs
    if paragraphs:
        readme.append("## Description\n")
        for para in paragraphs:
            readme.append(f"{para}\n\n")
    
    # Add lists
    if lists:
        readme.append("## Key Points\n")
        for item in lists:
            readme.append(f"- {item}\n")
        readme.append("\n")
    
    # Add tables 
    if tables:
        readme.append("## Tables\n")
        for table in tables:
            readme.append(f"{table}\n\n")
    
    return "".join(readme)


In [63]:
print(generate_readme_from_results(recognition_results))

# Crépet et a
## Table 1
## International Journal of Hygiene and Environmental Health 222 (2019) 291–306
## Description
Table 1
Description of consumption and concentration data for nine different European countries, n total – number of individuals in the overall consumption survey, n – number of individuals included in this study
(adults 18-64 years old, children 11-15 years old), N – number of substances in steatosis CAG after matching with contamination date; the number in brackets indicates the number of substances with

Substances were coded using the ParamCodes from the harmonised
European Standard Sample Description 1 format SDD1 ( EFSA, 2010 ) .
Substances were removed if no ParamCode coding for pesticides, no
NOAL or no LOAEL (copper compounds) were available. Some sub-
stances sharing the same residue definition (benalaxyl-M and benalaxyl,
cypermethrin and alpha-cypermethrin, metam and dazomet, metaloxyl-
M and metalaxyl, triadimefon and triadimenol) were presented together
i