#Imports

In [None]:
!pip install -q PyMuPDF>=1.24.0 tabula-py>=2.9.0 faiss-cpu>=1.8.0 numpy>=1.24.0 tqdm>=4.66.0 requests>=2.31.0 python-dotenv>=1.0.0 langchain-text-splitters

In [None]:
import tabula
import faiss
import json
import base64
import fitz # Changed from pymupdf
import requests
import os
import logging
import numpy as np
import warnings
from tqdm import tqdm
from langchain_text_splitters import RecursiveCharacterTextSplitter
from IPython import display

In [None]:
filename = "sample_paper1.pdf"

# Create the data directory if it doesn't exist
os.makedirs("data", exist_ok=True)

filepath = os.path.join("data", filename)
filepath

'data/sample_paper1.pdf'

In [None]:
!pip install pytesseract
import os
import base64
import pytesseract
from PIL import Image
import fitz
import uuid
 # Changed from pymupdf

Collecting pytesseract
  Downloading pytesseract-0.3.13-py3-none-any.whl.metadata (11 kB)
Downloading pytesseract-0.3.13-py3-none-any.whl (14 kB)
Installing collected packages: pytesseract
Successfully installed pytesseract-0.3.13


In [None]:
def _unique_id(prefix="item"):
    return f"{prefix}_{uuid.uuid4().hex[:10]}"

def _block_distance_to_bbox(block_bbox, img_bbox):
    bx0, by0, bx1, by1 = block_bbox
    ix0, iy0, ix1, iy1 = img_bbox
    bcx, bcy = (bx0 + bx1) / 2.0, (by0 + by1) / 2.0
    icx, icy = (ix0 + ix1) / 2.0, (iy0 + iy1) / 2.0
    return ((bcx - icx) ** 2 + (bcy - icy) ** 2) ** 0.5

def _horiz_overlap_fraction(block_bbox, img_bbox):
    """Return fraction of image width overlapped by block (0..1)."""
    bx0, by0, bx1, by1 = block_bbox
    ix0, iy0, ix1, iy1 = img_bbox
    inter_left = max(bx0, ix0)
    inter_right = min(bx1, ix1)
    inter_w = max(0.0, inter_right - inter_left)
    img_w = max(1e-6, ix1 - ix0)
    return inter_w / img_w

def process_images(doc, page, page_num, base_dir, items,
                   caption_margin=24, nearby_margin=60, max_nearby_blocks=3,
                   enable_ocr=True, min_width_px=80, min_height_px=40,
                   min_horizontal_overlap_ratio=0.2):
    """
    Improved image + metadata extraction:
      - doc, page, page_num, base_dir, items, filepath required
      - caption search: only *below* image and requires horizontal overlap
      - nearby_text: only text inside expanded bbox AND horizontally overlapping
      - filters out small images (logos) by pixel size
    """
    os.makedirs(os.path.join(base_dir, "images"), exist_ok=True)

    page_dict = page.get_text("dict")
    blocks = page_dict.get("blocks", [])

    # Build text_blocks list (bbox, text)
    text_blocks = []
    for b in blocks:
        if b.get("type", 0) == 0:  # text block
            bbox = tuple(b.get("bbox", (0, 0, 0, 0)))
            lines = []
            for line in b.get("lines", []):
                for span in line.get("spans", []):
                    lines.append(span.get("text", ""))
            text = " ".join([ln for ln in (l.strip() for l in lines) if ln])
            text_blocks.append((bbox, text))

    # Process image blocks
    for b in blocks:
        if b.get("type", 0) != 1:
            continue
        img_bbox = tuple(b.get("bbox", (0.0, 0.0, page.rect.width, page.rect.height)))
        image_info = b.get("image", None)

        # Try to get xref if present
        xref = None
        if isinstance(image_info, dict):
            xref = image_info.get("xref") or image_info.get("index") or image_info.get("number")

        # Try to get pixmap via xref first
        pix = None
        if xref:
            try:
                pix = fitz.Pixmap(doc, int(xref))
            except Exception:
                pix = None

        # If pix failed, try clip-based rendering using img_bbox
        if pix is None:
            try:
                # Convert bbox to fitz.Rect if needed
                clip_rect = fitz.Rect(img_bbox)
                pix = page.get_pixmap(clip=clip_rect, matrix=fitz.Matrix(2, 2))  # render higher res
            except Exception:
                pix = None

        if pix is None:
            # Could not render this image – skip safely
            continue

        # Filter out very small images (likely logos / icons)
        try:
            w_px, h_px = pix.width, pix.height
        except Exception:
            w_px, h_px = 0, 0
        if w_px < min_width_px or h_px < min_height_px:
            # skip small decorative images
            continue

        # Save image file (use filepath for base name)
        image_name = f"{base_dir}/images/{os.path.basename(filepath)}_image_{page_num}_{uuid.uuid4().hex[:5]}.png"
        pix.save(image_name)

        # base64 encode
        with open(image_name, "rb") as fh:
            encoded_image = base64.b64encode(fh.read()).decode("utf8")

        # ---------- caption extraction (ONLY below the image, require horizontal overlap) ----------
        caption = ""
        caption_candidates = []
        ix0, iy0, ix1, iy1 = img_bbox

        for (tbbox, ttext) in text_blocks:
            bx0, by0, bx1, by1 = tbbox
            # Candidate must be *below* the image (caption below)
            if by0 >= iy1 and (by0 - iy1) <= caption_margin:
                # require some horizontal overlap to avoid captions for other images
                overlap_frac = _horiz_overlap_fraction(tbbox, img_bbox)
                if overlap_frac >= min_horizontal_overlap_ratio:
                    # distance measure: vertical distance (smaller is better)
                    dist = by0 - iy1
                    caption_candidates.append((dist, ttext.strip()))

        # Prefer explicit "Figure"/"Fig." tokens among candidates
        caption_candidates_sorted = sorted(caption_candidates, key=lambda x: x[0])
        for _, cand_text in caption_candidates_sorted:
            st = cand_text.strip().lower()
            if st.startswith(("figure", "fig.", "fig ")):
                caption = cand_text.strip()
                break

        # If none explicit, take nearest candidate (optional)
        if not caption and caption_candidates_sorted:
            caption = caption_candidates_sorted[0][1]

        # # ---------- nearby_text extraction (tightened to expanded bbox & horizontal overlap) ----------
        # nearby_text = ""
        # nearby_candidates = []
        # ex0, ey0, ex1, ey1 = (ix0 - nearby_margin, iy0 - nearby_margin, ix1 + nearby_margin, iy1 + nearby_margin)
        # for (tbbox, ttext) in text_blocks:
        #     bx0, by0, bx1, by1 = tbbox
        #     # must intersect expanded bbox
        #     if (bx1 >= ex0 and bx0 <= ex1 and by1 >= ey0 and by0 <= ey1):
        #         # require horizontal overlap too
        #         if _horiz_overlap_fraction(tbbox, img_bbox) >= min_horizontal_overlap_ratio:
        #             dist = _block_distance_to_bbox(tbbox, img_bbox)
        #             nearby_candidates.append((dist, ttext.strip()))
        # nearby_candidates_sorted = sorted(nearby_candidates, key=lambda x: x[0])
        # if nearby_candidates_sorted:
        #     nearby_text = " ".join([t for _, t in nearby_candidates_sorted[:max_nearby_blocks]])

        # ---------- OCR ----------
        ocr_text = ""
        if enable_ocr:
            try:
                pil_img = Image.open(image_name)
                ocr_text = pytesseract.image_to_string(pil_img)
            except Exception:
                ocr_text = ""

        # plot_csv placeholder
        plot_csv_summary = ""

        metadata = {
            "item_id": _unique_id("img"),
            "caption": caption,
            "ocr_text": ocr_text,
            "img_bbox": tuple(img_bbox),
            "pixel_size": (w_px, h_px)
        }

        items.append({
            "page": page_num,
            "type": "image",
            "path": image_name,
            "image": encoded_image,
            "metadata": metadata
        })


In [None]:
# Create the directories
def create_directories(base_dir):
    directories = ["images", "text", "tables"]
    for dir in directories:
        os.makedirs(os.path.join(base_dir, dir), exist_ok=True)

# Process tables
def process_tables(doc, page_num, base_dir, items):
    try:
        tables = tabula.read_pdf(filepath, pages=page_num + 1, multiple_tables=True)
        if not tables:
            return
        for table_idx, table in enumerate(tables):
            table_text = "\n".join([" | ".join(map(str, row)) for row in table.values])
            table_file_name = f"{base_dir}/tables/{os.path.basename(filepath)}_table_{page_num}_{table_idx}.txt"
            with open(table_file_name, 'w') as f:
                f.write(table_text)
            items.append({"page": page_num, "type": "table", "text": table_text, "path": table_file_name})
    except Exception as e:
        print(f"Error extracting tables from page {page_num}: {str(e)}")

# Process text chunks
def process_text_chunks(text, text_splitter, page_num, base_dir, items):
    chunks = text_splitter.split_text(text)
    for i, chunk in enumerate(chunks):
        text_file_name = f"{base_dir}/text/{os.path.basename(filepath)}_text_{page_num}_{i}.txt"
        with open(text_file_name, 'w') as f:
            f.write(chunk)
        items.append({"page": page_num, "type": "text", "text": chunk, "path": text_file_name})


In [None]:
import shutil
import os

image_dir = os.path.join("data", "tables")

if os.path.exists(image_dir):
    shutil.rmtree(image_dir)
    print(f"Successfully deleted the directory: {image_dir}")
else:
    print(f"Directory does not exist: {image_dir}")


Successfully deleted the directory: data/tables


In [None]:
doc = fitz.open(filepath) # Changed from pymupdf.open
num_pages = len(doc)
base_dir = "data"

# Creating the directories
create_directories(base_dir)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=250, chunk_overlap=50, length_function=len)
items = []

# Process each page of the PDF
for page_num in tqdm(range(num_pages), desc="Processing PDF pages"):
    page = doc[page_num]
    text = page.get_text()
    process_tables(doc, page_num, base_dir, items)
    process_text_chunks(text, text_splitter, page_num, base_dir, items)
    process_images(doc, page, page_num, base_dir, items) # Pass doc explicitly

Nov 24, 2025 3:34:15 PM org.apache.pdfbox.pdmodel.font.PDType1Font <init>

Nov 24, 2025 3:34:18 PM org.apache.pdfbox.pdmodel.font.PDType1Font <init>
Nov 24, 2025 3:34:18 PM org.apache.pdfbox.pdmodel.font.PDType1Font <init>

Nov 24, 2025 3:34:23 PM org.apache.pdfbox.pdmodel.font.PDType1Font <init>

Nov 24, 2025 3:34:28 PM org.apache.pdfbox.pdmodel.font.PDSimpleFont toUnicode
Nov 24, 2025 3:34:28 PM org.apache.pdfbox.pdmodel.font.PDType1Font <init>
Nov 24, 2025 3:34:28 PM org.apache.pdfbox.pdmodel.font.PDType1Font <init>

Nov 24, 2025 3:34:32 PM org.apache.pdfbox.pdmodel.font.PDType1Font <init>

Nov 24, 2025 3:34:39 PM org.apache.pdfbox.pdmodel.font.PDType1Font <init>
Nov 24, 2025 3:34:39 PM org.apache.pdfbox.pdmodel.font.PDType1Font <init>

Nov 24, 2025 3:34:44 PM org.apache.pdfbox.pdmodel.font.PDType1Font <init>

Nov 24, 2025 3:34:50 PM org.apache.pdfbox.pdmodel.font.PDType1Font <init>

Nov 24, 2025 3:34:55 PM org.apache.pdfbox.pdmodel.font.PDType1Font <init>
Nov 24, 2025 3:34:56 PM or

In [None]:
import json

for item in items:
    if item["type"] == "image":
        print(f"Image Path: {item['path']}")
        print("Metadata:")
        print(json.dumps(item['metadata'], indent=2))
        # Decode and display the image
        display.display(display.Image(base64.b64decode(item['image'])))
        print("---\n")

#Generating Multimodal Embeddings

In [None]:
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import torch
# from langchain.chat_models import init_chat_model
# from langchain.prompts import PromptTemplate
# from langchain.schema.messages import HumanMessage
from sklearn.metrics.pairwise import cosine_similarity
import os
import base64
import io

In [None]:
### initialize the Clip Model for unified embeddings
clip_model=CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor=CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
clip_model.eval()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/605M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/605M [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


preprocessor_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/592 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/389 [00:00<?, ?B/s]

CLIPModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 512)
      (position_embedding): Embedding(77, 512)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=512, out_features=2048, bias=True)
            (fc2): Linear(in_features=2048, out_features=512, bias=True)
          )
          (layer_norm2): LayerNorm((512,), eps=1e-05,

In [None]:
# helper
from io import BytesIO

def _l2_normalize_np(x: np.ndarray, eps=1e-12):
    norm = np.linalg.norm(x)
    if norm < eps:
        return x
    return x / norm

def generate_multimodal_embeddings_clip(prompt=None, image=None, output_embedding_length=512):
    """
    Generate multimodal embeddings using the CLIP model.
    This function processes either text, image, or both.

    Args:
        prompt (str): The text prompt.
        image (str): A base64-encoded image string.
        output_embedding_length (int): The embedding dimension.
                                       CLIP's default is 512 for this model.

    Returns:
        list: The combined text and/or image embedding.
    """
    if not prompt and not image:
        raise ValueError("Please provide either a text prompt, base64 image, or both as input.")

    # Convert base64 image to a PIL Image object
    pil_image = None
    if image:
        image_data = base64.b64decode(image)
        pil_image = Image.open(BytesIO(image_data))

    inputs = clip_processor(text=prompt,
                            images=pil_image,
                            return_tensors="pt",
                            padding=True,
                            truncation=True )

    with torch.no_grad():
        if prompt and pil_image:
            outputs = clip_model(**inputs)
            # You can decide how to combine the embeddings.
            # A common approach is to get both and use a similarity metric later.
            text_embedding = outputs.text_embeds.squeeze().tolist()
            image_embedding = outputs.image_embeds.squeeze().tolist()
            # --- Convert to numpy arrays ---
            t = np.array(text_embedding, dtype=np.float32)
            i = np.array(image_embedding, dtype=np.float32)

            # --- Normalize each vector ---
            t = _l2_normalize_np(t)
            i = _l2_normalize_np(i)

            # --- Weighted average combine (hard-coded weights) ---
            image_weight = 0.1
            text_weight  = 0.9
            s = image_weight + text_weight
            wi = image_weight / s
            wt = text_weight / s

            combined = wi * i + wt * t

            # --- Normalize final combined embedding ---
            combined = _l2_normalize_np(combined)

            # --- Return combined embedding as Python list ---
            return combined.tolist()

        elif prompt:
            # Only generate text embeddings
            outputs = clip_model.get_text_features(**inputs)
            return outputs.squeeze().tolist()

        elif pil_image:
            # Only generate image embeddings
            outputs = clip_model.get_image_features(**inputs)
            return outputs.squeeze().tolist()

# Note: The output_embedding_length parameter is a placeholder to match
# the original function signature, but the CLIP model has a fixed
# embedding dimension (512 for `clip-vit-base-patch32`).

In [None]:
import re
import unicodedata

def combine_image_metadata(metadata, max_length=250):
    """
    Combine and clean image metadata into a single text string suitable for embedding.

    Args:
        metadata (dict): expected keys (optional): "caption", "ocr_text".
                         Other keys are ignored.
        max_length (int): maximum number of characters to return (truncates safely).

    Returns:
        str: cleaned, combined text (caption first, then OCR text), or "" if nothing useful.
    """
    if not metadata or not isinstance(metadata, dict):
        return ""

    # Pull fields safely
    caption = metadata.get("caption") or ""
    ocr_text = metadata.get("ocr_text") or ""

    # Preferred order: caption then OCR (caption gives a concise human label)
    parts = []
    if isinstance(caption, str) and caption.strip():
        parts.append(caption.strip())
    if isinstance(ocr_text, str) and ocr_text.strip():
        parts.append(ocr_text.strip())

    if not parts:
        return ""

    # Join with a clear separator
    combined = " . ".join(parts)

    # -------------------------
    # Basic normalization / cleaning
    # -------------------------
    # 1. Normalize unicode (NFKC) to fix ligatures and odd characters
    combined = unicodedata.normalize("NFKC", combined)

    # 2. Replace common OCR artifacts (optional but useful)
    #    e.g., replace multiple hyphens/em-dashes with single dash, remove repeated punctuation
    combined = re.sub(r"[–—]{1,}", "-", combined)
    combined = re.sub(r"\.{2,}", ".", combined)
    combined = re.sub(r"[_]{2,}", "_", combined)
    combined = re.sub(r"[,]{2,}", ",", combined)

    # 3. Remove control characters and odd non-printables
    combined = re.sub(r"[\x00-\x1f\x7f-\x9f]", " ", combined)

    # 4. Replace multiple whitespace/newlines/tabs with a single space
    combined = re.sub(r"\s+", " ", combined)

    # 5. Remove stray leading/trailing punctuation/spaces
    combined = combined.strip(" \t\n\r\f\v-–—:;,.")
    combined = combined.strip()

    # 6. Collapse repeated short tokens (e.g., "Fig Fig" -> "Fig")
    combined = re.sub(r"\b(\w+)(?:\s+\1\b){1,}", r"\1", combined, flags=re.IGNORECASE)

    # 7. Truncate to max_length without cutting words (prefer safe truncation)
    if len(combined) > max_length:
        # try to cut at last space before max_length
        cut = combined.rfind(" ", 0, max_length)
        if cut == -1:
            combined = combined[:max_length]
        else:
            combined = combined[:cut]

    return combined

In [None]:
# Count the number of each type of item
item_counts = {
    'text': sum(1 for item in items if item['type'] == 'text'),
    'table': sum(1 for item in items if item['type'] == 'table'),
    'image': sum(1 for item in items if item['type'] == 'image')
}

# Initialize counters
counters = dict.fromkeys(item_counts.keys(), 0)

# Generate embeddings for all items
with tqdm(
    total=len(items),
    desc="Generating embeddings",
    bar_format=(
        "{l_bar}{bar}| {n_fmt}/{total_fmt} "
        "[{elapsed}<{remaining}, {rate_fmt}{postfix}]"
    )
) as pbar:

    for item in items:
        item_type = item['type']
        counters[item_type] += 1

        if item_type in ['text', 'table']:
            # For text or table, use the formatted text representation
            if item_type == 'table' and not item['text']:
                # Skip generating embedding if table text is empty
                pbar.set_postfix_str(f"Text: {counters['text']}/{item_counts['text']}, Table: {counters['table']}/{item_counts['table']} (skipped empty), Image: {counters['image']}/{item_counts['image']}")
                pbar.update(1)
                continue
            item['embedding'] = generate_multimodal_embeddings_clip(prompt=item['text'])
        else:
            # For images, use the base64-encoded image data
            metadata_str = combine_image_metadata(item['metadata'])
            item['embedding'] = generate_multimodal_embeddings_clip(prompt = metadata_str, image = item['image'])

        # Update the progress bar
        pbar.set_postfix_str(f"Text: {counters['text']}/{item_counts['text']}, Table: {counters['table']}/{item_counts['table']}, Image: {counters['image']}/{item_counts['image']}")
        pbar.update(1)

Generating embeddings: 100%|██████████| 277/277 [00:40<00:00,  6.87it/s, Text: 249/249, Table: 10/10, Image: 18/18]


In [None]:
embedding_vector_dimension = 512
# Filter items to include only those with embeddings
items_with_embeddings = [item for item in items if 'embedding' in item]

# All the embeddings
all_embeddings = np.array([item['embedding'] for item in items_with_embeddings])

# Create FAISS Index
index = faiss.IndexFlatL2(embedding_vector_dimension)

# Clear any pre-existing index
index.reset()

# Add embeddings to the index
index.add(np.array(all_embeddings, dtype=np.float32))

In [None]:
# --- Save the Embeddings and Index to Disk ---
np.save('document_embeddings.npy', all_embeddings)
faiss.write_index(index, 'document_index.faiss')

print("Embeddings and FAISS index have been saved locally.")

Embeddings and FAISS index have been saved locally.


In [None]:
## Loading saved embeddings

# # --- Load the Embeddings and Index from Disk ---
# loaded_embeddings = np.load('document_embeddings.npy')
# loaded_index = faiss.read_index('document_index.faiss')

# print("Embeddings and FAISS index have been loaded.")
# print(f"Loaded embeddings shape: {loaded_embeddings.shape}")

In [None]:
# Count the number of embeddings for each type
embedding_counts = {
    'text': sum(1 for item in items_with_embeddings if item['type'] == 'text'),
    'table': sum(1 for item in items_with_embeddings if item['type'] == 'table'),
    'image': sum(1 for item in items_with_embeddings if item['type'] == 'image')
}

print("Total embeddings by type:")
print(f"Text embeddings: {embedding_counts['text']}")
print(f"Table embeddings: {embedding_counts['table']}")
print(f"Image embeddings: {embedding_counts['image']}")

Total embeddings by type:
Text embeddings: 249
Table embeddings: 10
Image embeddings: 18


In [None]:
def get_query_embedding_single(prompt=None, image_b64=None):
    """Return a single 1D numpy.float32 vector for the query."""
    out = generate_multimodal_embeddings_clip(prompt=prompt, image=image_b64)
    v = np.array(out, dtype=np.float32).reshape(-1)
    return v

# ----------------------------
# 2) Content preview loaders
# ----------------------------
def load_text_snippet(item, max_chars=400):
    """
    Return a cleaned preview string (up to max_chars visible characters).
    Cleans ligatures, control chars, collapses whitespace, and truncates at word boundary.
    """
    # get raw text from inline or file
    raw = ""
    if item.get("type") == "text" and item.get("text"):
        raw = item["text"]
    else:
        p = item.get("path")
        if p and os.path.exists(p):
            try:
                with open(p, "r", encoding="utf8", errors="ignore") as fh:
                    raw = fh.read()
            except Exception:
                raw = item.get("text", "") or ""

    if not raw:
        return ""

    # Normalize unicode (fix ligatures, odd unicode)
    txt = unicodedata.normalize("NFKC", raw)

    # Replace common OCR/encoding artifacts (optional rules)
    txt = txt.replace("\u2013", "-").replace("\u2014", "-")  # ndashes/emdashes
    txt = txt.replace("\ufb01", "fi").replace("\ufb02", "ff")  # ligatures
    # Remove control characters
    txt = re.sub(r"[\x00-\x1f\x7f-\x9f]", " ", txt)

    # Collapse repeated whitespace/newlines to single space
    txt = re.sub(r"\s+", " ", txt).strip()

    # Truncate without cutting a word
    if len(txt) <= max_chars:
        return txt
    cut = txt.rfind(" ", 0, max_chars)
    if cut == -1:
        return txt[:max_chars]
    return txt[:cut]

def load_table_preview(item, max_rows=5):
    csvp = item.get("csv_path") or item.get("path")
    if csvp and os.path.exists(csvp):
        try:
            import pandas as pd
            df = pd.read_csv(csvp, nrows=max_rows, dtype=str)
            return df.fillna("").to_csv(index=False)
        except Exception:
            try:
                with open(csvp, "r", encoding="utf8", errors="ignore") as fh:
                    lines = []
                    for _, line in zip(range(max_rows), fh):
                        lines.append(line)
                    return "".join(lines)
            except Exception:
                return ""
    return item.get("text_preview", "")

def load_image_info(item):
    md = item.get("metadata", {}) or {}
    return {
        "image_path": item.get("path"),
        "caption": md.get("caption", ""),
        "ocr_text": md.get("ocr_text", ""),

    }

# ----------------------------
# 3) Normalize & safety helpers
# ----------------------------
def _ensure_modality(item):
    """Return one of 'text','image','table' based on explicit type or content fallback."""
    t = item.get("type", "")
    if isinstance(t, str) and t.lower() in ("text","image","table"):
        return t.lower()
    # fallback heuristics
    if item.get("text") or item.get("text_preview"):
        return "text"
    if item.get("csv_path"):
        return "table"
    if item.get("path") and str(item.get("path")).lower().endswith((".png", ".jpg", ".jpeg", ".tiff", ".bmp")):
        return "image"
    # default fallback
    return "text"

# ----------------------------
# 4) Analyze one query (L2 index, single-vector embed)
# ----------------------------
def analyze_one_query(query, index, items_with_embeddings, k=10, image_b64=None):
    """
    Returns:
      {
        'query': str,
        'k': int,
        'counts': {'text':n, 'image':n, 'table':n, 'missing':n},
        'hits': [ {'rank':int,'idx':int,'dist':float,'type':str,'item':dict,'content':...}, ... ]
      }
    """
    qvec = get_query_embedding_single(prompt=query, image_b64=image_b64)
    qvec = qvec.astype(np.float32).reshape(1, -1)

    distances, idxs = index.search(qvec, k)
    distances = distances[0].tolist()
    idxs = idxs[0].tolist()

    counts = {"text":0, "image":0, "table":0, "missing":0}
    hits = []

    for rank, (idx, dist) in enumerate(zip(idxs, distances), start=1):
        if int(idx) < 0:
            continue
        try:
            item = items_with_embeddings[int(idx)]
        except Exception:
            counts["missing"] += 1
            hits.append({"rank": rank, "idx": idx, "dist": float(dist), "type": None, "item": None, "content": None})
            continue

        typ = _ensure_modality(item)
        counts[typ] = counts.get(typ, 0) + 1

        if typ == "text":
            content = load_text_snippet(item)
        elif typ == "table":
            content = load_table_preview(item)
        else:  # image
            content = load_image_info(item)

        hits.append({
            "rank": rank,
            "idx": int(idx),
            "dist": float(dist),
            "type": typ,
            "item": item,
            "content": content
        })

    return {"query": query, "k": k, "counts": counts, "hits": hits}


In [None]:
query = "In the 3D surfaces showing how conversion and C₂ yield change with temperature and space velocity, at what combination of conditions does the system tend to achieve the highest methane conversion?"

analysis = analyze_one_query(query, index, items_with_embeddings, k=10)

# Print modality counts (text/image/table)
print("\n=== Retrieval Summary ===")
print("Query:", analysis["query"])
print("Top-k:", analysis["k"])
print("Counts:", analysis["counts"])

# Print each retrieved result
print("\n=== Top-k Retrieved Chunks ===\n")
for h in analysis["hits"]:
    print(f"Rank {h['rank']} | idx={h['idx']} | type={h['type']} | dist={h['dist']:.2f}")

    if h['type'] == "image":
        info = load_image_info(h['item'])

        print("Preview (image):")
        print("Caption:", info["caption"])
        print("OCR:", info["ocr_text"][:300])

        # display actual image from base64
        from IPython import display
        display.display(display.Image(data=base64.b64decode(h['item']['image'])))

    elif h['type'] == "text":
        print("Preview (text):")
        print(load_text_snippet(h['item']))   # cleaner

    elif h['type'] == "table":
        print("Preview (table):")
        print(load_table_preview(h['item']))

    print("-" * 70)



=== Retrieval Summary ===
Query: In the 3D surfaces showing how conversion and C₂ yield change with temperature and space velocity, at what combination of conditions does the system tend to achieve the highest methane conversion?
Top-k: 10
Counts: {'text': 10, 'image': 0, 'table': 0, 'missing': 0}

=== Top-k Retrieved Chunks ===

Rank 1 | idx=89 | type=text | dist=26.26
Preview (text):
response surfaces based on the data shown in Fig. 4. We see that the highest methane conversions are found at the lowest space velocities combined with the highest temperature. Less obvious
----------------------------------------------------------------------
Rank 2 | idx=165 | type=text | dist=30.16
Preview (text):
descriptors for heterogeneous catalysts is sparse. Several groups published successful examples of the application of DFT simulations to describe reaction networks. Application of the d-band center as a descriptor is also referred to in many occa-
-------------------------------------------

In [None]:
import os, base64, numpy as np
from io import BytesIO
from PIL import Image
import torch

# ---------- helpers ----------
def _to_np(tensor_or_list):
    a = np.array(tensor_or_list, dtype=np.float32)
    return a.reshape(-1)

def _l2_norm_np(v):
    v = np.array(v, dtype=np.float32).reshape(-1)
    n = np.linalg.norm(v)
    if n < 1e-12:
        return v
    return v / (n + 1e-12)

def _squared_l2(a, b):
    a = np.array(a, dtype=np.float32).reshape(-1)
    b = np.array(b, dtype=np.float32).reshape(-1)
    diff = a - b
    return float(np.dot(diff, diff))

# ---------- compute image embedding (on-the-fly) ----------
def compute_image_embedding_for_item_onfly(item, device='cpu', force=False, normalize=True):
    """
    Compute CLIP image embedding for item in-place (stores item['image_embedding']).
    Accepts item['image'] (base64) or item['path'] (path to saved PNG).
    Returns numpy vector (float32).
    """
    # reuse cached
    if (not force) and item.get('image_embedding') is not None:
        return _to_np(item['image_embedding'])

    pil_img = None
    if item.get('image'):  # base64 stored
        try:
            b = base64.b64decode(item['image'])
            pil_img = Image.open(BytesIO(b)).convert("RGB")
        except Exception as e:
            print("base64->PIL failed:", e)
            pil_img = None

    if pil_img is None and item.get('path') and os.path.exists(item['path']):
        try:
            pil_img = Image.open(item['path']).convert("RGB")
        except Exception:
            pil_img = None

    if pil_img is None:
        # nothing to do
        return None

    # prepare inputs (move to device if model on GPU)
    clip_inputs = clip_processor(images=pil_img, return_tensors="pt")
    if torch.cuda.is_available() and device.startswith('cuda'):
        clip_inputs = {k:v.to(device) for k,v in clip_inputs.items()}
        clip_model.to(device)

    with torch.no_grad():
        img_feats = clip_model.get_image_features(**clip_inputs).squeeze().cpu().numpy()

    if normalize:
        img_feats = _l2_norm_np(img_feats)
    item['image_embedding'] = img_feats.tolist()
    return img_feats

# ---------- compute text/caption embedding (on-the-fly) ----------
def compute_text_embedding_for_item_onfly(item, device='cpu', force=False, normalize=True):
    """
    Compute CLIP text embedding for item. For image items uses caption+ocr;
    for text/table items uses item['text'].
    Stores item['text_embedding'].
    """
    if (not force) and item.get('text_embedding') is not None:
        return _to_np(item['text_embedding'])

    if item.get('type') == 'text' and item.get('text'):
        txt = item['text']
    elif item.get('type') == 'table' and item.get('text'):
        txt = item['text']
    else:
        md = item.get('metadata', {}) or {}
        txt = (md.get('caption','') + " " + md.get('ocr_text','')).strip()

    if not txt:
        return None

    clip_inputs = clip_processor(text=txt, return_tensors="pt", padding=True, truncation=True)
    if torch.cuda.is_available() and device.startswith('cuda'):
        clip_inputs = {k:v.to(device) for k,v in clip_inputs.items()}
        clip_model.to(device)

    with torch.no_grad():
        txt_feats = clip_model.get_text_features(**clip_inputs).squeeze().cpu().numpy()

    if normalize:
        txt_feats = _l2_norm_np(txt_feats)
    item['text_embedding'] = txt_feats.tolist()
    return txt_feats

# ---------- compute fused embedding (if you used averaging) ----------
def compute_fused_embedding_for_item_onfly(item, image_weight=0.6, text_weight=0.4, normalize=True):
    """
    Create fused embedding (weighted avg of normalized image and text embeddings).
    Stores item['embedding'] as the fused vector (if desired).
    """
    img = compute_image_embedding_for_item_onfly(item, normalize=True)
    txt = compute_text_embedding_for_item_onfly(item, normalize=True)
    if img is None and txt is None:
        return None
    if img is None:
        fused = txt
    elif txt is None:
        fused = img
    else:
        s = image_weight + text_weight
        wi = image_weight / s
        wt = text_weight / s
        fused = wi * img + wt * txt
        if normalize:
            fused = _l2_norm_np(fused)
    item['embedding'] = fused.tolist()
    return fused

# ---------- driver: compute and print squared L2 distances for given image indices ----------
def debug_image_item_distances(query, items, image_item_indices, use_normalize=True, squared=True):
    """
    For a text query (string), compute query_emb and distances to image, caption, fused embeddings.
    Print results sorted by fused distance (or whichever metric).
    """
    # query embedding (text only)
    q_emb = _to_np(generate_multimodal_embeddings_clip(prompt=query, image=None))
    if use_normalize:
        q_emb = _l2_norm_np(q_emb)

    results = []
    for idx in image_item_indices:
        it = items[idx]
        # compute on the fly if missing
        img_emb = compute_image_embedding_for_item_onfly(it, normalize=use_normalize)
        txt_emb = compute_text_embedding_for_item_onfly(it, normalize=use_normalize)
        fused_emb = None
        if it.get('embedding') is not None:
            fused_emb = _to_np(it['embedding'])
            if use_normalize:
                fused_emb = _l2_norm_np(fused_emb)
        else:
            fused_emb = compute_fused_embedding_for_item_onfly(it, normalize=use_normalize)

        # compute squared L2 distances (FAISS IndexFlatL2 uses squared L2)
        d_img = _squared_l2(q_emb, img_emb) if img_emb is not None else float('inf')
        d_txt = _squared_l2(q_emb, txt_emb) if txt_emb is not None else float('inf')
        d_fused = _squared_l2(q_emb, fused_emb) if fused_emb is not None else float('inf')

        results.append({
            'idx': idx,
            'path': it.get('path'),
            'caption_preview': ((it.get('metadata') or {}).get('caption','') or '')[:140],
            'd_img': d_img,
            'd_txt': d_txt,
            'd_fused': d_fused
        })

    # sort by fused distance ascending (closer first)
    results = sorted(results, key=lambda x: x['d_fused'])
    for r in results:
        print(f"Item idx={r['idx']} path={r['path']}")
        print(f"  squared_L2: image={r['d_img']:.6f}  caption={r['d_txt']:.6f}  fused={r['d_fused']:.6f}")
        print("  caption preview:", r['caption_preview'])
        print("-"*60)

    return results


In [None]:
# choose a text query that should be answered from a figure
query = "The workflow diagram in the paper shows multiple coloured blocks arranged in a loop. Which key model-building tasks are highlighted as the “empirical” components in that workflow"

# pick image item indices you want to inspect (e.g., top 8 images in items)
image_item_indices = [i for i,it in enumerate(items) if it.get('type')=='image'][:5]

# run debug
debug_res = debug_image_item_distances(query, items, image_item_indices, use_normalize=True)


Item idx=51 path=data/images/sample_paper1.pdf_image_1_b02f1.png
  squared_L2: image=1.429907  caption=0.508454  fused=0.847324
  caption preview: Fig. 1 Flow diagram representing a typical work ﬂ ow for the devel- opment of (heterogeneous) catalysts.
------------------------------------------------------------
Item idx=76 path=data/images/sample_paper1.pdf_image_2_eeca5.png
  squared_L2: image=1.416431  caption=0.555809  fused=0.877407
  caption preview: Fig. 3 Modeling methods used in catalysis research. The empirical methods are highlighted in green on the right hand side of the diagram. Th
------------------------------------------------------------
Item idx=120 path=data/images/sample_paper1.pdf_image_4_3fcfe.png
  squared_L2: image=1.508568  caption=0.920199  fused=1.060468
  caption preview: Fig. 4 Raw performance data for the oxidative coupling of methane over a Mn-promoted Na 2 WO 4 /SiO 2 catalyst. The vertical axis denotes th
-------------------------------------------------