## 1. Document ingestion

We first download the paper and break it into manageable units. The code below fetches the PDF using `requests` and uses `fitz` to iterate through pages. Each page’s text is split into fixed-size chunks so that it can be embedded. Images are also extracted and stored for optional multimodal retrieval.

> **Note**: If downloading fails due to network restrictions, manually place the PDF in the working directory and set `pdf_path` accordingly.


In [None]:

import os
import io
import requests
import fitz
from PIL import Image

# Download the Transformer paper
pdf_url = "https://arxiv.org/pdf/1706.03762.pdf"
pdf_path = "attention.pdf"

# Fetch the paper only if it isn’t already present
if not os.path.exists(pdf_path):
    response = requests.get(pdf_url, headers={"User-Agent": "Mozilla/5.0"})
    if response.status_code == 200:
        with open(pdf_path, "wb") as f:
            f.write(response.content)
    else:
        raise RuntimeError(f"Failed to download PDF (status {response.status_code}). Please download it manually and place it at {pdf_path}.")

# Parse the PDF
doc = fitz.open(pdf_path)
chunks = []  # list of {page, text}
images = []  # list of {page, image}
chunk_size = 500  # characters per chunk
for page_num in range(doc.page_count):
    page = doc[page_num]
    text = page.get_text().strip()
    # Create chunks of roughly `chunk_size` characters
    for i in range(0, len(text), chunk_size):
        chunk_text = text[i:i + chunk_size]
        chunks.append({"page": page_num, "text": chunk_text})
    # Extract images from the page
    for img_index, img_info in enumerate(page.get_images(full=True)):
        xref = img_info[0]
        base_image = doc.extract_image(xref)
        image_bytes = base_image["image"]
        pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
        images.append({"page": page_num, "image": pil_image})

doc.close()


## 2. Embedding with Jina CLIP V1

[Jina CLIP V1](https://huggingface.co/jinaai/jina-clip-v1) is a multimodal embedding model capable of mapping both text and images into the same vector space. Loading the model with the Hugging Face `transformers` library is straightforward:



The model’s remote code adds `encode_text` and `encode_image` methods which return tensors. We wrap these calls in helper functions so the rest of the pipeline remains agnostic to the underlying library.【208295142115717†L92-L121】


In [None]:
import torch
from transformers import AutoModel
import numpy as np
from collections import defaultdict

# ---- Load Jina CLIP ----
clip_model = AutoModel.from_pretrained("jinaai/jina-clip-v1", trust_remote_code=True)
clip_model.eval()

USE_CUDA = torch.cuda.is_available()
USE_FP16 = False  # flip to True only if GPU hits memory limits

if USE_CUDA:
    clip_model = clip_model.to("cuda")
    if USE_FP16:
        clip_model = clip_model.half()

# Build inputs from Cell 1 outputs
texts = [c["text"] for c in chunks]  # from Cell 1
image_list = [im["image"] for im in images] if images else []

# ---- Build page→images map for later use (Vision model) ----
page_to_images = defaultdict(list)
for im in images:
    page_to_images[im["page"]].append(im["image"])

# ---- Simple batching helper ----
def batched(seq, n):
    for i in range(0, len(seq), n):
        yield seq[i:i+n]

# ---- Encode helpers (return numpy arrays) ----
@torch.no_grad()
def encode_text(text_list):
    # Jina-CLIP remote code returns numpy arrays
    return clip_model.encode_text(text_list)

@torch.no_grad()
def encode_images(pil_images):
    return clip_model.encode_image(pil_images)

# ---- Run encoders in batches ----
TEXT_BATCH = 64
IMG_BATCH = 16

text_embeds = []
for b in batched(texts, TEXT_BATCH):
    if b:  # skip empties
        text_embeds.append(encode_text(b))
text_embeddings = np.vstack(text_embeds) if text_embeds else np.zeros((0, 768), dtype=np.float32)

image_embeds = []
for b in batched(image_list, IMG_BATCH):
    if b:
        image_embeds.append(encode_images(b))
image_embeddings = np.vstack(image_embeds) if image_embeds else np.zeros((0, 768), dtype=np.float32)

# ---- L2-normalize for cosine similarity ----
def l2_normalize(a, eps=1e-12):
    if a.size == 0:
        return a
    norms = np.linalg.norm(a, axis=1, keepdims=True)
    return a / np.maximum(norms, eps)

text_embeddings = l2_normalize(text_embeddings)
image_embeddings = l2_normalize(image_embeddings)



## 3. Storing embeddings in Chroma

The [Chroma](https://docs.trychroma.com/) vector database is used to persist our embeddings. It performs approximate nearest-neighbour search to quickly retrieve the most relevant chunks. After creating a collection, we insert the text and image embeddings along with metadata. Each entry uses a unique ID so we can trace results back to the original page or chunk.


In [None]:
import chromadb
from math import ceil


db_client = chromadb.Client()

collection = db_client.get_or_create_collection(name="transformer_paper")

# --- Prepare inputs ---
texts = [c["text"] for c in chunks]
text_ids = [f"text_{i}" for i in range(len(texts))]
text_metadatas = [{"page": c["page"], "type": "text", "idx": i} for i, c in enumerate(chunks)]


try:
    _ = text_embeddings
except NameError:
    text_embeddings = encode_text(texts)

# Convert to list-of-lists for Chroma
text_vectors = [row.tolist() for row in text_embeddings]

# --- Add in batches to avoid payload issues ---
BATCH = 512
for b in range(ceil(len(texts) / BATCH)):
    s = b * BATCH
    e = min((b + 1) * BATCH, len(texts))
    if s < e:
        collection.add(
            ids=text_ids[s:e],
            embeddings=text_vectors[s:e],
            metadatas=text_metadatas[s:e],
            documents=texts[s:e],
        )

if images:
    pil_images = [im["image"] for im in images]
    image_ids = [f"img_{i}" for i in range(len(images))]
    image_metadatas = [{"page": im["page"], "type": "image", "idx": i} for i, im in enumerate(images)]


    try:
        _ = image_embeddings
    except NameError:
        image_embeddings = encode_images(pil_images)

    image_vectors = [row.tolist() for row in image_embeddings]

    for b in range(ceil(len(pil_images) / BATCH)):
        s = b * BATCH
        e = min((b + 1) * BATCH, len(pil_images))
        if s < e:
            collection.add(
                ids=image_ids[s:e],
                embeddings=image_vectors[s:e],
                metadatas=image_metadatas[s:e],
                documents=[""] * (e - s),  # no text docs for images
            )



## 4. Retrieval

Retrieval is handled by querying the Chroma collection with a new embedding. Given a query string, we compute its embedding using Jina CLIP and request the top-`k` nearest neighbours. The function below returns the matching documents, metadata and distances.


In [None]:
import numpy as np

def _l2_normalize(v, eps=1e-12):
    if v.ndim == 1:
        n = np.linalg.norm(v) or eps
        return v / max(n, eps)
    n = np.linalg.norm(v, axis=1, keepdims=True)
    return v / np.maximum(n, eps)

def retrieve(query, top_k=3, where=None):
    """
    Retrieve most relevant text chunks for a query.
    - Normalizes query embedding to match index normalization.
    - Optional `where` filter; defaults to text-only.
    """
    if where is None:
        where = {"type": "text"}

    # Encode + normalize query
    q_emb = encode_text([query])[0]  # (768,)
    q_emb = _l2_normalize(q_emb)

    # Bound top_k by collection size
    count = collection.count()
    k = min(top_k, max(count, 0))

    if k == 0:
        return {"documents": [], "metadatas": [], "distances": []}

    res = collection.query(
        query_embeddings=[q_emb.tolist()],
        n_results=k,
        include=["documents", "metadatas", "distances"],
        where=where
    )
    return res

def show_results(res, max_chars=160):
    """Pretty-print results for quick sanity checks."""
    if not res or not res.get("documents"):
        return
    docs = res["documents"][0]
    metas = res["metadatas"][0]
    dists = res["distances"][0]
    for i, (doc, meta, dist) in enumerate(zip(docs, metas, dists), 1):
        snippet = (doc[:max_chars] + "…") if doc and len(doc) > max_chars else (doc or "")
        page = meta.get("page", "?")


## 5. Generation with Phi-3 Vision

The [Phi-3 Vision](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct) model is a multimodal generative model capable of processing a single image plus a long text prompt. It expects input in a special chat format:

```
<|user|>
<|image_1|>
{prompt}<|end|>
<|assistant|>

```

An optional multi-turn format can also be used, but for our retrieval augmented generation we only need a single turn. When loading the model with `transformers`, be sure to pass `trust_remote_code=True` and disable flash attention on CPU by setting `_attn_implementation="eager"`. A small blank image is supplied if no relevant image was retrieved.


In [None]:
# === Phi-3 Vision: robust hybrid offload + real-image selection + generation (memory-hygiene) ===
import os, gc, torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoProcessor
from PIL import Image

# ---------------------------
# 0) Clean up prior models if cell re-run
# ---------------------------
if 'phi3_model' in globals():
    try:
        phi3_model.to('cpu')
    except Exception:
        pass
    try:
        del phi3_model
    except Exception:
        pass
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# ---------------------------
# 1) Load Phi-3 Vision (robust hybrid offload)
# ---------------------------
phi3_id   = "microsoft/Phi-3-vision-128k-instruct"
cuda_cap  = "12GiB"
off_dir   = "offload_phi3"
os.makedirs(off_dir, exist_ok=True)

def _clear_cuda():
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

# Try multiple key styles for max_memory to satisfy different HF/Accelerate versions
def _try_load_phi3_with_keys():
    gpu_idx = 0
    try:
        if torch.cuda.is_available():
            gpu_idx = torch.cuda.current_device()
    except Exception:
        pass

    candidate_keys = [
        gpu_idx,                   # e.g., 0 (int)
        f"cuda:{gpu_idx}",         # e.g., "cuda:0"
        f"{gpu_idx}",              # e.g., "0"
        "cuda",                    # generic
    ]

    last_err = None
    for key in candidate_keys:
        try:
            model = AutoModelForCausalLM.from_pretrained(
                phi3_id,
                trust_remote_code=True,
                torch_dtype=(torch.float16 if torch.cuda.is_available() else "auto"),
                device_map="auto",
                max_memory={key: cuda_cap, "cpu": "64GiB"},
                offload_folder=off_dir,
                low_cpu_mem_usage=True,
                _attn_implementation="eager",
            )
            return model
        except Exception as e:
            last_err = e
            _clear_cuda()  # free any partial allocations before next attempt

    # Fallback: no max_memory hint
    try:
        model = AutoModelForCausalLM.from_pretrained(
            phi3_id,
            trust_remote_code=True,
            torch_dtype=(torch.float16 if torch.cuda.is_available() else "auto"),
            device_map="auto",
            offload_folder=off_dir,
            low_cpu_mem_usage=True,
            _attn_implementation="eager",
        )
        return model
    except Exception as e:
        raise last_err or e

phi3_model = _try_load_phi3_with_keys()
phi3_model.eval()

phi3_processor = AutoProcessor.from_pretrained(phi3_id, trust_remote_code=True)
if getattr(phi3_processor.tokenizer, "pad_token_id", None) is None:
    phi3_processor.tokenizer.pad_token = phi3_processor.tokenizer.eos_token

blank_image = Image.new("RGB", (224, 224), color=(255, 255, 255))

# ---------------------------
# 2) Utilities
# ---------------------------
def _truncate_context(context: str, max_chars: int = 6000) -> str:
    return context[:max_chars]

def _l2_normalize(v, eps=1e-12):
    v = np.asarray(v, dtype=np.float32)
    n = np.linalg.norm(v)
    return v / max(n, eps)

# Ensure page->images map exists
try:
    page_to_images
except NameError:
    from collections import defaultdict
    page_to_images = defaultdict(list)
    try:
        for im in images:
            page_to_images[im["page"]].append(im["image"])
    except Exception:
        pass

# ---------------------------
# 3) Image selection
# ---------------------------
def select_images_for_query(query: str, top_k_images: int = 3, prefer_text_pages: bool = True):
    selected = []

    # Prefer images from pages of top text hits
    if prefer_text_pages:
        try:
            res = retrieve(query, top_k=5, where={"type": "text"})
            metas = (res or {}).get("metadatas") or []
            if metas and isinstance(metas[0], list):
                metas = metas[0]
            pages = [m.get("page") for m in metas if isinstance(m, dict) and "page" in m]
            for p in pages:
                for img in page_to_images.get(p, []):
                    if len(selected) < top_k_images:
                        selected.append(img)
        except Exception:
            pass

    # Fallback: cross-modal query into Chroma's image vectors
    if len(selected) < 1:
        try:
            q = encode_text([query])[0]
            q = _l2_normalize(q)
            res_img = collection.query(
                query_embeddings=[q.tolist()],
                n_results=top_k_images,
                include=["metadatas"],
                where={"type": "image"}
            )
            metas = (res_img or {}).get("metadatas") or []
            if metas and isinstance(metas[0], list):
                metas = metas[0]
            for m in metas:
                p = m.get("page")
                if p in page_to_images:
                    for img in page_to_images[p]:
                        if len(selected) < top_k_images:
                            selected.append(img)
        except Exception:
            pass

    return selected

# ---------------------------
# 4) Core: generate_answer (Vision-only) with self-cleanup
# ---------------------------
def generate_answer(
    query: str,
    top_k: int = 3,
    max_tokens: int = 200,
):
    """
    Vision-only Phi-3 answer generation with real PDF figures.
    Frees CUDA memory at the end of the call.
    """
    import gc as _gc
    inputs = None
    output_ids = None
    imgs = None
    try:
        # Retrieve supporting chunks
        hits = retrieve(query, top_k=top_k) or {}
        docs = hits.get("documents") or []

        # Flatten list-of-lists -> list[str]
        context_chunks = [d for lst in docs for d in (lst if isinstance(lst, list) else [lst]) if isinstance(d, str)]
        context = "\n\n".join(context_chunks) if context_chunks else "No relevant context."
        context = _truncate_context(context)

        # Select up to 3 relevant images
        imgs = select_images_for_query(query, top_k_images=3, prefer_text_pages=True)
        if not imgs:
            imgs = [blank_image]

        # Build chat with matching <|image_i|> tags
        image_tags = " ".join(f"<|image_{i+1}|>" for i in range(len(imgs)))
        messages = [{
            "role": "user",
            "content": f"{image_tags}\nContext:\n{context}\n\nQuestion: {query}\nAnswer based solely on the context."
        }]
        prompt = phi3_processor.tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )

        # Tokenize + move to model device
        inputs = phi3_processor(text=prompt, images=imgs, return_tensors="pt")
        device = next(phi3_model.parameters()).device
        inputs = {k: v.to(device) for k, v in inputs.items()}

        # Generate (use_cache=False avoids DynamicCache issue)
        with torch.no_grad():
            output_ids = phi3_model.generate(
                **inputs,
                max_new_tokens=max_tokens,
                do_sample=False,
                temperature=0.0,
                eos_token_id=phi3_processor.tokenizer.eos_token_id,
                pad_token_id=phi3_processor.tokenizer.pad_token_id,
                use_cache=False,
            )

        # Strip prompt and decode
        output_ids = output_ids[:, inputs["input_ids"].shape[1]:]
        response = phi3_processor.batch_decode(
            output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )[0].strip()
        return response

    finally:
        # drop large tensors and clear cache each call
        try:
            del output_ids
        except Exception:
            pass
        try:
            if isinstance(inputs, dict):
                for k in list(inputs.keys()):
                    try: del inputs[k]
                    except Exception: pass
            del inputs
        except Exception:
            pass
        try:
            del imgs
        except Exception:
            pass
        _gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

# (Optional) quick sanity info
#try:
#    from pprint import pprint
#except Exception:
#    pass


In [None]:
# Minimal smoke test
answer = generate_answer(
    "What is the main idea of the Transformer paper?",
    top_k=1,
    max_tokens=60
)
print(answer)


In [None]:
answer = generate_answer(
    "What is the purpose of positional encoding in the Transformer model?",
    top_k=2,
    max_tokens=80
)
print(answer)


In [None]:
# This prompt requires a GPU with higher VRAM (an A100 works well).
answer = generate_answer(
    "Describe the sinusoidal positional encoding figure and what the color bands mean.",
    top_k=2,
    max_tokens=120
)
print(answer)


## 6. Conclusion

This notebook demonstrates how to build a multimodal retrieval-augmented generation (RAG) pipeline from the ground up without relying on high-level frameworks. It parses a PDF, embeds text and figures using Jina CLIP, stores and retrieves vectors through ChromaDB, and finally uses Phi-3 Vision to generate grounded answers.
The notebook employs hybrid offloading and memory-efficient loading to handle large models on limited hardware. It runs smoothly on a T4 GPU for smaller prompts, but for complex or larger inputs, an A100 GPU is recommended. If you encounter resource or download issues, try pre-installing dependencies and ensuring adequate GPU or CPU memory.