# PDF Content Extraction Tool
This notebook extracts **images,surrounding labels and their captions** from PDF files using advanced box detection,expansion and clustering algorithms.

## Instructions
1. Run the installation cell below
2. Run the main code cell
3. Upload your PDF using the file upload widget
4. Optionally enable debug visualization
5. Wait for processing to complete
6. Download extracted images and the compiled PDF

## Step 1: Install Dependencies

In [None]:
!pip install pymupdf numpy Pillow scikit-learn ipywidgets reportlab hdbscan

## Step 2: Import Libraries

In [None]:
import os
import numpy as np
import fitz  # PyMuPDF
import ipywidgets as widgets
from ipywidgets import Accordion, VBox, Image as WImage, HTML as WHTML, Button
from IPython.display import display, HTML
from sklearn.preprocessing import StandardScaler
import base64
import tempfile
import shutil
from reportlab.lib.pagesizes import letter
from reportlab.pdfgen import canvas
from reportlab.lib import colors
from reportlab.lib.units import inch
from PIL import Image
from functools import partial
from sklearn.cluster import DBSCAN

## Step 3: Define Helper Functions

In [None]:
def draw_wrapped_text(c, text, x, y, max_width, line_height,
                      font="Helvetica-Oblique", font_size=11, color=colors.black):
    c.setFont(font, font_size)
    c.setFillColor(color)
    words, lines, line = text.split(), [], ""
    for word in words:
        test_line = f"{line} {word}" if line else word
        if c.stringWidth(test_line, font, font_size) <= max_width:
            line = test_line
        else:
            lines.append(line)
            line = word
    if line:
        lines.append(line)
    for l in lines:
        c.drawCentredString(x, y, l)
        y -= line_height
    return y

def boxes_are_close_or_overlap(box1, box2, gap_threshold=10):
    return (
        (box1.x1 + gap_threshold >= box2.x0 and box1.x0 - gap_threshold <= box2.x1)
        and (box1.y1 + gap_threshold >= box2.y0 and box1.y0 - gap_threshold <= box2.y1)
    )

def merge_boxes_if_close(box_list, gap_threshold=10):
    merged_boxes = []
    while box_list:
        current = box_list.pop(0)
        merged = False
        for i, existing in enumerate(merged_boxes):
            if boxes_are_close_or_overlap(current, existing, gap_threshold):
                merged_boxes[i] = fitz.Rect(
                    min(current.x0, existing.x0),
                    min(current.y0, existing.y0),
                    max(current.x1, existing.x1),
                    max(current.y1, existing.y1),
                )
                merged = True
                break
        if not merged:
            merged_boxes.append(current)
    return merged_boxes

## Step 4: Define PDFBoxExtractor Class

In [None]:
class PDFBoxExtractor:
    def __init__(self, pdf_path, debug_visualization=False):
        self.pdf_path = pdf_path
        self.doc = fitz.open(pdf_path)
        self.output_folder = tempfile.mkdtemp()
        self.output_pdf = "images_with_captions.pdf"
        self.debug_visualization = debug_visualization

        if debug_visualization:
            self.debug_pdf = "debug_visualization.pdf"
            self.debug_doc = fitz.open()

        self.extracted_images = []
        self.debug_pages = []

        self.w_margin_factor = 0.10
        self.h_margin_factor = 0.05
        self.expansion_step = 1.175
        self.max_expansion_factor = 2.0
        self.num_iterations = 7

    @staticmethod
    def is_near(exp_box, cap_box, factor=2.0):
        vertical_distance = cap_box.y0 - exp_box.y1
        cap_height = cap_box.y1 - cap_box.y0
        threshold = cap_height * factor
        horizontal_overlap = max(0, min(exp_box.x1, cap_box.x1) - max(exp_box.x0, cap_box.x0))
        min_overlap = min(cap_box.width * 0.5, exp_box.width * 0.5)
        return (vertical_distance >= 0 and vertical_distance < threshold and horizontal_overlap > min_overlap)

    def save_image(self, page_num, img_xref):
        img_dict = self.doc.extract_image(img_xref)
        img_data = img_dict["image"]
        img_ext = img_dict["ext"]
        img_filename = f"page_{page_num}_img_{img_xref}.{img_ext}"
        img_path = os.path.join(self.output_folder, img_filename)
        with open(img_path, "wb") as img_file:
            img_file.write(img_data)
        return img_filename

    def draw_boxes(self, page, boxes, color):
        for box in boxes:
            annot = page.add_rect_annot(box)
            annot.set_colors(stroke=color)
            annot.set_border(width=1.0)
            annot.update()

    def expand_image_boxes(self, image_rects, text_blocks, protected_captions=None, debug_page=None):
        if protected_captions is None:
            protected_captions = []

        protected_boxes = [cap_box for _, cap_box in protected_captions]
        expanded_boxes = []

        for image_rect in image_rects:
            x0, y0, x1, y1 = image_rect
            detected_text_boxes = []

            w_margin = (x1 - x0) * self.w_margin_factor
            h_margin = (y1 - y0) * self.h_margin_factor

            for _ in range(self.num_iterations):
                strips = [
                    fitz.Rect(x0, y0 - h_margin, x1, y0),
                    fitz.Rect(x0, y1, x1, y1 + h_margin),
                    fitz.Rect(x0 - w_margin, y0, x0, y1),
                    fitz.Rect(x1, y0, x1 + w_margin, y1),
                ]

                if debug_page:
                    for strip in strips:
                        annot = debug_page.add_rect_annot(strip)
                        annot.set_colors(stroke=(1, 1, 0))
                        annot.set_border(width=0.5)
                        annot.update()

                new_text_detected = False
                for block in text_blocks:
                    text_rect = fitz.Rect(block[:4])
                    if (image_rect.contains(text_rect) or any(text_rect.intersects(p) for p in protected_boxes)):
                        continue
                    if any(strip.intersects(text_rect) for strip in strips):
                        detected_text_boxes.append(text_rect)
                        new_text_detected = True

                if not new_text_detected:
                    break

                w_margin *= self.expansion_step
                h_margin *= self.expansion_step

                if w_margin > (x1 - x0) * self.max_expansion_factor:
                    break

            expanded_rect = image_rect
            for text_rect in detected_text_boxes:
                expanded_rect = fitz.Rect(
                    min(expanded_rect.x0, text_rect.x0),
                    min(expanded_rect.y0, text_rect.y0),
                    max(expanded_rect.x1, text_rect.x1),
                    max(expanded_rect.y1, text_rect.y1),
                )
            expanded_boxes.append(expanded_rect)

        return expanded_boxes

    def cluster_boxes(self, boxes):
        if not boxes:
            return []

        features = []
        for box in boxes:
            cx = (box.x0 + box.x1) / 2
            cy = (box.y0 + box.y1) / 2
            w = box.x1 - box.x0
            h = box.y1 - box.y0
            features.append([cx, cy, w, h])

        features = np.array(features)
        if features.shape[0] == 1:
            return boxes

        scaled = StandardScaler().fit_transform(features)
        db = DBSCAN(eps=0.8, min_samples=1).fit(scaled)
        labels = db.labels_

        clustered_boxes = {}
        for box, label in zip(boxes, labels):
            clustered_boxes.setdefault(label, []).append(box)

        merged_boxes = []
        for group in clustered_boxes.values():
            x0 = min(b.x0 for b in group)
            y0 = min(b.y0 for b in group)
            x1 = max(b.x1 for b in group)
            y1 = max(b.y1 for b in group)
            merged_boxes.append(fitz.Rect(x0, y0, x1, y1))

        return merged_boxes

    def merge_boxes(self, expanded_boxes, captions, caption_info):
        merged_expanded = merge_boxes_if_close(expanded_boxes, gap_threshold=10)
        boxes_with_caption = {}
        boxes_without_caption = []
        box_to_caption_map = {}

        for i, exp_box in enumerate(merged_expanded):
            best_caption = None
            min_distance = float("inf")
            for cap_text, cap_box in captions:
                if self.is_near(exp_box, cap_box, factor=2.0):
                    distance = cap_box.y0 - exp_box.y1
                    if distance < min_distance:
                        min_distance = distance
                        best_caption = cap_text
            if best_caption:
                box_to_caption_map[i] = best_caption

        for i, exp_box in enumerate(merged_expanded):
            if i in box_to_caption_map:
                cap = box_to_caption_map[i]
                boxes_with_caption.setdefault(cap, []).append(exp_box)
            else:
                boxes_without_caption.append(exp_box)

        clusters = []
        for cap, group in boxes_with_caption.items():
            combined_x0 = min(box.x0 for box in group)
            combined_y0 = min(box.y0 for box in group)
            combined_x1 = max(box.x1 for box in group)
            combined_y1_raw = max(box.y1 for box in group)
            caption_box = caption_info.get(cap)
            if not caption_box:
                continue
            combined_y1 = min(combined_y1_raw, caption_box.y0)
            merged_box = fitz.Rect(combined_x0, combined_y0, combined_x1, combined_y1)
            clusters.append({"box": merged_box, "caption": cap})

        merged_no_caption = self.cluster_boxes(boxes_without_caption)
        for box in merged_no_caption:
            clusters.append({"box": box, "caption": None})

        return clusters

    def process_pdf(self):
        total_pages = len(self.doc)
        progress = widgets.FloatProgress(value=0, min=0, max=total_pages, description="Page:")
        display(progress)

        c = canvas.Canvas(self.output_pdf, pagesize=letter)
        page_width, page_height = letter
        max_text_width = page_width - 100
        y_position = page_height - 70
        high_res_matrix = fitz.Matrix(2, 2)

        for page_num, page in enumerate(self.doc, start=1):
            has_images = False
            debug_page = None

            if self.debug_visualization:
                debug_page = self.debug_doc.new_page(width=page.rect.width, height=page.rect.height)
                debug_page.show_pdf_page(debug_page.rect, self.doc, page_num - 1)

            images = []
            for img in page.get_images(full=True):
                xref = img[0]
                bbox_list = page.get_image_rects(xref)
                for rect in bbox_list:
                    img_filename = self.save_image(page_num, xref)
                    images.append({"filename": img_filename, "bbox": rect})
                    has_images = True

            if not has_images:
                progress.value = page_num
                continue

            text_blocks = page.get_text("blocks")
            captions, caption_info = [], {}
            caption_patterns = ["figure", "fig."]
            for block in text_blocks:
                text = block[4].strip().lower()
                if any(text.startswith(patt) for patt in caption_patterns):
                    original_text = block[4].strip()
                    caption_box = fitz.Rect(block[:4])
                    if original_text not in caption_info:
                        captions.append((original_text, caption_box))
                        caption_info[original_text] = caption_box

            image_rects = [img["bbox"] for img in images]
            expanded_boxes = self.expand_image_boxes(
                image_rects, text_blocks, captions, debug_page
            )
            clusters = self.merge_boxes(expanded_boxes, captions, caption_info)

            if self.debug_visualization:
                self.draw_boxes(debug_page, image_rects, (1, 0, 0))
                self.draw_boxes(debug_page, expanded_boxes, (0, 1, 0))
                self.draw_boxes(debug_page, [cap_box for _, cap_box in captions], (0, 0, 1))
                self.draw_boxes(debug_page, [c["box"] for c in clusters], (0.5, 0, 0.5))

            for cluster in clusters:
                box = cluster["box"]
                pix = page.get_pixmap(matrix=high_res_matrix, clip=box)
                img_filename = f"page_{page_num}_box_{int(box.x0)}_{int(box.y0)}.png"
                img_path = os.path.join(self.output_folder, img_filename)
                pix.save(img_path)
                self.extracted_images.append((img_path, cluster["caption"]))

                try:
                    img = Image.open(img_path)
                    img_width, img_height = img.size
                    display_width = 4 * inch
                    display_height = display_width * (img_height / img_width)

                    if y_position - display_height - 80 < 50:
                        c.showPage()
                        y_position = page_height - 70

                    x_pos = (page_width - display_width) / 2
                    c.drawInlineImage(img_path, x_pos, y_position - display_height,
                                      width=display_width, height=display_height)

                    y_position -= (display_height + 15)
                    caption_text = cluster["caption"] if cluster["caption"] else "(No caption detected)"
                    caption_color = colors.black if cluster["caption"] else colors.darkgray
                    y_position = draw_wrapped_text(
                        c, caption_text, page_width / 2, y_position, max_text_width,
                        14, font="Helvetica-Oblique", font_size=11, color=caption_color
                    )
                    y_position -= 40
                except Exception as e:
                    print(f"Error inserting image {img_path}: {e}")

            progress.value = page_num

            if self.debug_visualization and has_images:
                pix_dbg = debug_page.get_pixmap(matrix=fitz.Matrix(1.5, 1.5))
                self.debug_pages.append(pix_dbg.tobytes("png"))

        self.doc.close()
        c.save()
        print(f"✅ Images + captions PDF saved as: {self.output_pdf}")

        if self.debug_visualization:
            self.debug_doc.save(self.debug_pdf)
            self.debug_doc.close()
            print(f"✅ Debug visualization saved as: {self.debug_pdf}")
            self._display_debug_results()
        else:
            self._display_extraction_results()

        shutil.rmtree(self.output_folder)

    def _display_extraction_results(self):
        items = []
        for i, (img_path, caption) in enumerate(self.extracted_images, start=1):
            with open(img_path, "rb") as f:
                img_bytes = f.read()
            wimg = WImage(value=img_bytes, format="png")
            cap_html = WHTML(f"<b>{caption if caption else '(No caption detected)'}</b>")
            btn = Button(description="⬇️ Download",
                         layout=widgets.Layout(width="120px"), button_style="info")

            def download_image(path, b):
                with open(path, "rb") as f:
                    img_data = f.read()
                b64 = base64.b64encode(img_data).decode()
                link = HTML(
                    f'<a href="data:image/png;base64,{b64}" download="{os.path.basename(path)}">'  
                    f"Download started...</a>"
                )
                display(link)

            btn.on_click(partial(download_image, img_path))
            items.append(VBox([wimg, cap_html, btn]))

        accordion = Accordion(children=items)
        for i in range(len(items)):
            accordion.set_title(i, f"Image {i} of {len(items)}")
        display(accordion)

    def _display_debug_results(self):
        items = []
        for i, dbg_img in enumerate(self.debug_pages, start=1):
            wimg = WImage(value=dbg_img, format="png")
            btn = Button(description="⬇️ Download",
                         layout=widgets.Layout(width="120px"), button_style="info")

            def download_debug(img_bytes, page_num, b):
                img_name = f"debug_page_{page_num}.png"
                b64 = base64.b64encode(img_bytes).decode()
                link = HTML(
                    f'<a href="data:image/png;base64,{b64}" download="{img_name}">'  
                    f"Download started...</a>"
                )
                display(link)

            btn.on_click(partial(download_debug, dbg_img, i))
            items.append(VBox([wimg, btn]))

        accordion = Accordion(children=items)
        for i in range(len(items)):
            accordion.set_title(i, f"Debug Page {i}")
        display(accordion)

## Step 5: Create UI and Process PDF

In [None]:
upload_btn = widgets.FileUpload(accept=".pdf", multiple=False)
debug_checkbox = widgets.Checkbox(value=False, description="Enable debug visualization")
display(widgets.VBox([upload_btn, debug_checkbox]))

def process_uploaded_file(change):
    uploaded_file = list(upload_btn.value.values())[0]
    file_path = "uploaded_file.pdf"
    with open(file_path, "wb") as f:
        f.write(uploaded_file["content"])

    pdf_extractor = PDFBoxExtractor(file_path, debug_visualization=debug_checkbox.value)
    pdf_extractor.process_pdf()

    with open(pdf_extractor.output_pdf, "rb") as f:
        pdf_data = f.read()
    b64_pdf = base64.b64encode(pdf_data).decode("utf-8")
    download_link = HTML(
        f'<a href="data:application/pdf;base64,{b64_pdf}" '
        f'download="{pdf_extractor.output_pdf}" '
        f'style="padding: 0.5em 1em; background: #007bff; color: white; '
        f'border-radius: 3px; text-decoration: none;">' 
        f'Download Images + Captions PDF</a>'
    )
    display(download_link)

upload_btn.observe(process_uploaded_file, names="value")