In [16]:
# !pip install -U byaldi pdf2image qwen-vl-utils transformers
# run the above in the terminal

In [None]:
# brew install poppler
!pdftoppm -v
import subprocess
result = subprocess.run(['pdftoppm', '-v'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
print(result.stdout.decode())
print(result.stderr.decode())


In [None]:
import os

# map each product name to its local PDF path
pdfs = {
    "MALM":   os.path.join("data", "MALM.pdf"),
    "BILLY":  os.path.join("data", "BILLY.pdf"),
    "BOAXEL": os.path.join("data", "BOAXEL.pdf"),
    "ADILS":  os.path.join("data", "ADILS.pdf"),
    "MICKE":  os.path.join("data", "MICKE.pdf"),
}

# check they all exist
for name, path in pdfs.items():
    if not os.path.isfile(path):
        raise FileNotFoundError(f"{name} PDF not found at {path}")
    print(f"Found {name} at {path}")

# now you can feed those paths directly into pdf2image, byaldi, etc.
# e.g.:
from pdf2image import convert_from_path

all_images = {}
for doc_id, (name, pdf_path) in enumerate(pdfs.items()):
    pages = convert_from_path(pdf_path)
    all_images[doc_id] = pages
    print(f"{name}: {len(pages)} pages converted")

In [None]:
import os
from pdf2image import convert_from_path


def convert_pdfs_to_images(pdf_folder):
    pdf_files = [f for f in os.listdir(pdf_folder) if f.endswith(".pdf")]
    all_images = {}

    for doc_id, pdf_file in enumerate(pdf_files):
        pdf_path = os.path.join(pdf_folder, pdf_file)
        images = convert_from_path(pdf_path)
        all_images[doc_id] = images

    return all_images


all_images = convert_pdfs_to_images("data/")

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 8, figsize=(15, 10))

for i, ax in enumerate(axes.flat):
    img = all_images[0][i]
    ax.imshow(img)
    ax.axis("off")

plt.tight_layout()
plt.show()

In [None]:
import torch
device = (
    torch.device("cuda")
    if torch.cuda.is_available() else
    torch.device("mps")
    if torch.backends.mps.is_available() else
    torch.device("cpu")
)
print("Running on", device)


In [None]:
# from byaldi import RAGMultiModalModel

# docs_retrieval_model = RAGMultiModalModel.from_pretrained("vidore/colpali-v1.2")

import torch
from byaldi import RAGMultiModalModel

device = "mps" if torch.backends.mps.is_available() else "cpu"
print("Loading on", device)

docs_retrieval_model = RAGMultiModalModel.from_pretrained(
    "vidore/colpali-v1.2",
    device=device
)

In [None]:
docs_retrieval_model.index(
    input_path="data/", index_name="image_index", store_collection_with_index=False, overwrite=True
)

In [None]:
text_query = "How many people are needed to assemble the Malm?"

results = docs_retrieval_model.search(text_query, k=3)
results

In [9]:
def get_grouped_images(results, all_images):
    grouped_images = []

    for result in results:
        doc_id = result["doc_id"]
        page_num = result["page_num"]
        grouped_images.append(
            all_images[doc_id][page_num - 1]
        )  # page_num are 1-indexed, while doc_ids are 0-indexed. Source https://github.com/AnswerDotAI/byaldi?tab=readme-ov-file#searching

    return grouped_images


grouped_images = get_grouped_images(results, all_images)

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 3, figsize=(15, 10))

for i, ax in enumerate(axes.flat):
    img = grouped_images[i]
    ax.imshow(img)
    ax.axis("off")

plt.tight_layout()
plt.show()

In [None]:
from transformers import Qwen2VLForConditionalGeneration, Qwen2VLProcessor
from qwen_vl_utils import process_vision_info
import torch

vl_model = Qwen2VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2-VL-7B-Instruct",
    torch_dtype=torch.bfloat16,
)
# vl_model.cuda().eval()

if torch.backends.mps.is_available():                      # Apple Silicon (M1/M2) Metal support
    device = torch.device("mps")
else:
    device = torch.device("cpu")

vl_model.to(device).eval()

In [None]:
min_pixels = 224 * 224
max_pixels = 1024 * 1024
vl_model_processor = Qwen2VLProcessor.from_pretrained(
    "Qwen/Qwen2-VL-7B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels
)

In [13]:
chat_template = [
    {
        "role": "user",
        "content": [
            {
                "type": "image",
                "image": grouped_images[0],
            },
            {
                "type": "image",
                "image": grouped_images[1],
            },
            {
                "type": "image",
                "image": grouped_images[2],
            },
            {"type": "text", "text": text_query},
        ],
    }
]

In [14]:
text = vl_model_processor.apply_chat_template(chat_template, tokenize=False, add_generation_prompt=True)

In [None]:
import torch

# pick the right device
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():   # Apple M1/M2 Metal support
    device = torch.device("mps")
else:
    device = torch.device("cpu")

from qwen_vl_utils import process_vision_info

# 1) build your chat_template …
image_inputs, _ = process_vision_info(chat_template)

# …later, after you’ve created your `inputs` dict:
inputs = vl_model_processor(
    text=[text],
    images=image_inputs,
    padding=True,
    return_tensors="pt",
)

# move everything to the chosen device
# BatchEncoding supports .to(), but if yours is a plain dict you can do:
inputs = {k: v.to(device) for k, v in inputs.items()}

# and make sure your model is on the same device:
vl_model.to(device).eval()

In [23]:
generated_ids = vl_model.generate(**inputs, max_new_tokens=500)

In [24]:
# Strip off the prompt tokens from each sequence
generated_ids_trimmed = [
    out_ids[len(in_ids):]
    for in_ids, out_ids in zip(inputs["input_ids"], generated_ids)
]

# Decode the model’s answer
output_text = vl_model_processor.batch_decode(
    generated_ids_trimmed,
    skip_special_tokens=True,
    clean_up_tokenization_spaces=False
)

In [None]:
print(output_text[0])

In [26]:
import torch
from qwen_vl_utils import process_vision_info

def answer_with_multimodal_rag(
    vl_model,
    docs_retrieval_model,
    vl_model_processor,
    all_images,
    text_query,
    top_k=3,
    max_new_tokens=200,
):
    print(f"[INFO] running on {device}")

    # 2) ensure the VL model is on that device
    vl_model.to(device).eval()

    # 3) retrieve the top-K pages
    results = docs_retrieval_model.search(text_query, k=top_k)
    grouped_images = get_grouped_images(results, all_images)

    # 4) build the chat template
    chat_template = [
        {
            "role": "user",
            "content": (
                [{"type": "image", "image": img} for img in grouped_images]
                + [{"type": "text",  "text": text_query}]
            )
        }
    ]

    # 5) tokenize text + process images
    text = vl_model_processor.apply_chat_template(
        chat_template,
        tokenize=False,
        add_generation_prompt=True
    )
    image_inputs, _ = process_vision_info(chat_template)
    inputs = vl_model_processor(
        text=[text],
        images=image_inputs,
        padding=True,
        return_tensors="pt"
    )

    # 6) move all tensor inputs to the chosen device
    inputs = {k: v.to(device) for k, v in inputs.items()}

    # 7) generate and decode
    generated_ids = vl_model.generate(**inputs, max_new_tokens=max_new_tokens)
    # strip off the prompt tokens
    trimmed = [
        out_ids[len(in_ids):]
        for in_ids, out_ids in zip(inputs["input_ids"], generated_ids)
    ]
    output = vl_model_processor.batch_decode(
        trimmed,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=False
    )

    return output

In [None]:
output_text = answer_with_multimodal_rag(
    vl_model=vl_model,
    docs_retrieval_model=docs_retrieval_model,
    vl_model_processor=vl_model_processor,
    all_images=all_images,                       # <— pass the full images dict
    text_query="How do I assemble the Micke desk?",
    top_k=3,
    max_new_tokens=500,
)
print(output_text[0])