<a href="https://colab.research.google.com/github/Rachnas/vision-RAG/blob/main/5_vectordb_colpali_as_reranker.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Install Libraries

In [None]:
!pip install colpali-engine==0.3.2
!pip install pdf2image

In [None]:
!sudo apt-get install poppler-utils

### Load Visual Language Model (ColPali)

In [None]:
import torch
from colpali_engine.models import ColPali, ColPaliProcessor

model_name = "vidore/colpali-v1.3"

model = ColPali.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="cuda:0",  # or "mps" if on Apple Silicon
).eval()

processor = ColPaliProcessor.from_pretrained(model_name)

### Process PDF
*   Define Colpali class
*   Create image and query embeddings
*   Late interaction
*   Display output





In [None]:
from pdf2image import convert_from_path
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np

In [None]:
class colpali_class():
  def __init__(self,processor,model):
    self.processor = processor
    self.model = model

  def embed_image(self, list_of_images):
    dataset = []
    for img in list_of_images:
      images = convert_from_path(img)
      dataloader = DataLoader(images, batch_size=1, shuffle=False, collate_fn=lambda x:self.processor.process_images(x).to(self.model.device))
      for batch in tqdm(dataloader):
        with torch.no_grad():
          batch = {k: v.to(self.model.device) for k,v in batch.items()}
          embeddings = self.model(**batch)
        dataset.extend(list(torch.unbind(embeddings.to("cpu").to(torch.float32))))
    return dataset

  def embed_query(self, query):
    batch_queries = processor.process_queries(query).to(model.device)
    with torch.no_grad():
      query_embeddings = model(**batch_queries)
      query_embeddings = list(torch.unbind(query_embeddings.to("cpu").to(torch.float32)))
    return query_embeddings

  def score(self, query_embedding, dataset):
    scores = processor.score_multi_vector(query_embedding, dataset)
    scores = np.array(scores)
    matched_pages = scores.flatten().argsort()[::-1]
    return scores, matched_pages

In [None]:
colpali_obj = colpali_class(processor, model)

In [None]:
file_name = "sample_data/AT&T_esg_doc.pdf"

In [None]:
dataset = colpali_obj.embed_image([file_name])

In [None]:
len(dataset), dataset[0].shape

In [None]:
query = ['how much carbon reduction is expected in transportation?']
query_embeddings = colpali_obj.embed_query(query)

In [None]:
len(query_embeddings), query_embeddings[0].shape

In [None]:
scores, matched_pages = colpali_obj.score(query_embeddings, dataset)

In [None]:
scores, matched_pages

In [None]:
images = convert_from_path(file_name)

In [None]:
len(images)

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(16, 12))
ax.imshow(images[matched_pages[0]])
ax.axis("off")
plt.show()

### Vector DB
- Initialize faiss vector DB


In [None]:
!pip install faiss-cpu
!pip install langchain_community
!pip install PyPDF2

In [None]:
import faiss
from langchain_community.docstore.in_memory import InMemoryDocstore
from langchain_community.vectorstores import FAISS

In [None]:
index = faiss.IndexFlatL2(128)

vector_store = FAISS(
    embedding_function=colpali_obj.embed_query,
    index=index,
    docstore=InMemoryDocstore(),
    index_to_docstore_id={},
)

### Add embeddings to FAISS Vector DB
*   Create metadata
*   Add document to DB with metadata

In [None]:
!mkdir sample_data/data

In [None]:
import PyPDF2
from PIL import Image
file_name = file_name
all_images = []
for i, img in enumerate(images):
  page_name = "sample_data/data/"+str(i)+".pdf"
  img = img.resize((800,800),Image.LANCZOS)
  img.save(page_name)
  all_images.append(page_name)

all_embeddings = [l for l in dataset]
list_of_tuple = []
for img, embd in zip(all_images, all_embeddings):
  for i in range(0,1030):
    list_of_tuple.append((img, embd[i]))

metadata = []
uids = []
count=0
pdfReader = PyPDF2.PdfReader(file_name)
total_pages = len(pdfReader.pages)
for i in range(0,total_pages):
  for j in range(0,1030):
    file_name = file_name
    page_name = file_name.split("/")[-1].split(".")[0]+"_"+str(i)+".pdf"
    patch_num = j
    uid = count
    metadata.append({"file_name":file_name, "page_name": page_name, "patch_num":patch_num,"uid":uid})
    uids.append(uid)
    count += 1

In [None]:
ids = vector_store.add_embeddings(text_embeddings= list_of_tuple, metadatas=metadata, ids = uids)

In [None]:
#index_to_docstore_id = vector_store.index_to_docstore_id
#uid_to_del = []
#for i in range(0,len(index_to_docstore_id)):
#  uid_to_del.append(vector_store.docstore._dict[index_to_docstore_id[i]].metadata['uid'])
#vector_store.delete(ids=uid_to_del)

### Do similarity search


*   Match each word of query with patch in DB
*   Get all the page names and sort based on frequency of occurance



In [None]:
page_name_list = []
for i in range(0,query_embeddings[0].shape[0]):
    vec = query_embeddings[0][i].tolist()
    results = vector_store.similarity_search_by_vector(vec, k=3)
    for doc in results:
       page_name_list.append(doc.metadata['page_name'])

In [None]:
unique_list_items = set(page_name_list)
page_dict={}

for list_item in unique_list_items:
    page_dict[list_item]=page_name_list.count(list_item)
sorted_page_dict = dict(sorted(page_dict.items(),key=lambda item: item[1], reverse=True))
top_pages = list(sorted_page_dict.keys())
top_pages

### Rerank vector DB output using colpali late interaction
*  Get page embeddings from DB
*  Call score function of colpali



In [None]:
def get_doc_embeddings(page_names_list):
  doc_vectors = []
  for page_name in page_names_list:

      page_num = page_name.split("_")[-1].split(".")[0]
      page_num = int(page_num)
      doc_vectors.append(vector_store.index.reconstruct_n(page_num*1030,1030))

  doc_embd = torch.from_numpy(np.stack(doc_vectors))
  return doc_embd

In [None]:
doc_embd = get_doc_embeddings(top_pages)
print(doc_embd.shape)

In [None]:
scores, matched_pages = colpali_obj.score(query_embeddings, doc_embd)

In [None]:
scores, matched_pages

In [None]:
final_page_name = top_pages[matched_pages[0]]
final_page_name

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(16, 12))
ax.imshow(images[2])
ax.axis("off")
plt.show()

### Generation model
*   Model setup
*   Call model with reranker output



In [None]:
!pip install qwen-vl-utils==0.0.08

In [None]:
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info

# default: Load the model on the available device(s)
gen_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2.5-VL-3B-Instruct", torch_dtype="auto", device_map="auto"
)
# default processer
gen_processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")

In [None]:
all_images, final_page_name

In [None]:
def get_page_numbers(all_images, page_name):
 page_num = page_name.split("_")[-1].split(".")[0]
 local_page_name = "sample_data/data/"+str(page_num)+".pdf"
 return int(page_num), local_page_name

In [None]:
page_num,local_page_name = get_page_numbers(all_images, final_page_name)
print(page_num)

In [None]:
messages = [
    {"role": "user",
     "content": [
         {"type": "image",
          "image": images[page_num],
          "resized_height": 800,
          "resized_width": 800,
         },
        {"type": "text", "text": query}]},
]

In [None]:
# Preparation for inference
text = gen_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

image_inputs, video_inputs = process_vision_info(messages)
inputs = gen_processor(
    text=[text],
    images=image_inputs,
    videos=video_inputs,
    padding=True,
    return_tensors="pt",
)
inputs = inputs.to("cuda")

In [None]:

# Inference: Generation of the output
import torch
with torch.no_grad():
  generated_ids = gen_model.generate(**inputs, max_new_tokens=64)

generated_ids_trimmed = [
    out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = gen_processor.batch_decode(
    generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text)