In [None]:
!sudo apt update && sudo apt install -y poppler-utils
!uv pip install pdf2image timm

In [None]:
import PIL
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw
from pdf2image import convert_from_path
from transformers import AutoImageProcessor
from transformers.models.detr import DetrForSegmentation

img_proc = AutoImageProcessor.from_pretrained("cmarkea/detr-layout-detection")
model = DetrForSegmentation.from_pretrained("cmarkea/detr-layout-detection")

In [None]:
pages = convert_from_path("/resources/data/sample/document-01.pdf", 500)
print(pages)

img: PIL.Image = pages[0]


In [None]:
# Load the input image
# img = Image.open("path_to_your_image.jpg").convert("RGB")  # Provide the correct image path

# Run inference
with torch.no_grad():
    input_ids = img_proc(img, return_tensors="pt")
    output = model(**input_ids)

# Set thresholds
threshold = 0.4

# Post-process segmentation mask
segmentation_mask = img_proc.post_process_segmentation(
    output, threshold=threshold, target_sizes=[img.size[::-1]]
)

# Post-process object detection (bounding boxes)
bbox_pred = img_proc.post_process_object_detection(
    output, threshold=threshold, target_sizes=[img.size[::-1]]
)


In [None]:
import matplotlib.pyplot as plt

# Extract the mask
mask = segmentation_mask[0]["masks"][0].cpu().numpy()

# Plot the results
fig, ax = plt.subplots(1, 2, figsize=(15, 10))

# Display the original image
ax[0].imshow(img)
ax[0].set_title("Original Image")
ax[0].axis("off")

# Display the segmentation mask overlaid on the image
ax[1].imshow(img)
ax[1].imshow(
    mask, alpha=0.5, cmap="jet"
)  # Overlay the segmentation mask with transparency
ax[1].set_title("Segmentation Mask")
ax[1].axis("off")

# Draw bounding boxes on the image
draw = ImageDraw.Draw(img)
for box, score, label in zip(
    bbox_pred[0]["boxes"], bbox_pred[0]["scores"], bbox_pred[0]["labels"]
):
    if score > threshold:
        x_min, y_min, x_max, y_max = box.tolist()
        draw.rectangle([x_min, y_min, x_max, y_max], outline="red", width=3)

# Save or display the image with bounding boxes
img.show()  # Open the image with bounding boxes
plt.show()  # Show the segmentation mask
