# Dataset Preparation

In [None]:
from glob import glob

image_files = glob("data/social-posts/*.jpg")

text_files = glob("data/social-posts/*.txt")

print(len(image_files), len(text_files))

In [None]:
print(image_files[1], text_files[1])

In [6]:
documents = []

for i in range(1, len(image_files)+1):
    text_file  = f'data/social-posts/{i}.txt'
    image_file = f'data/social-posts/{i}.jpg'
    
    text = open(text_file).read()
    doc = {"text": text, "image":image_file}
    documents.append(doc)

In [None]:
import random
from PIL import Image

num = random.randint(0, len(documents)-1)

# print text
print(documents[num]["text"])

# display image
display(Image.open(documents[num]["image"]))

# Embed Dataset

In [11]:
# !pip install fastembed

from fastembed import TextEmbedding, ImageEmbedding

class EmbedData:
    def __init__(self,
                 documents,
                 text_model_name="Qdrant/clip-ViT-B-32-text",
                 image_model_name="Qdrant/clip-ViT-B-32-vision"):

        # Initialize text embedding model
        self.documents = documents
        self.text_model = TextEmbedding(model_name=text_model_name)
        self.text_embed_dim = self.text_model._get_model_description(text_model_name)["dim"]
        
        # Initialize image embedding model
        self.image_model = ImageEmbedding(model_name=image_model_name)
        self.image_embed_dim = self.image_model._get_model_description(image_model_name)["dim"]
    
    def embed_texts(self, texts):
        text_embeddings = list(self.text_model.embed(texts))
        return text_embeddings
    
    def embed_images(self, images):
        image_embeddings = list(self.image_model.embed(images))
        return image_embeddings

In [None]:
embeddata = EmbedData(documents)

embeddata.text_embeds  = embeddata.embed_texts([doc["text"] for doc in documents])

embeddata.image_embeds = embeddata.embed_images([doc["image"] for doc in documents])

In [None]:
import numpy as np

print(np.array(embeddata.text_embeds))

# Define vector store

In [14]:
from qdrant_client import QdrantClient, models

class QdrantVDB:
    def __init__(self,
                 collection_name,
                 image_dim,
                 text_dim,
                 url="http://localhost:6333"):

        self.image_dim = image_dim
        self.text_dim = text_dim
        self.collection_name = collection_name
        self.client = QdrantClient(url=url, prefer_grpc=True)
    

    def create_collection(self):

        if not self.client.collection_exists(self.collection_name):
        
            print(f"Creating collection '{self.collection_name}'...")
        
            self.client.create_collection(
                collection_name=self.collection_name,
                
                vectors_config={
                    "image": models.VectorParams(size=self.image_dim,
                                                 distance=models.Distance.COSINE),
                    "text": models.VectorParams(size=self.text_dim,
                                                distance=models.Distance.COSINE),
                }
            )
        
            print(f"Collection '{self.collection_name}' created successfully.")
        
        else:
            print(f"Collection '{self.collection_name}' already exists.")
        
    def upload_embeddings(self, embeddata):

        print(f"Uploading points to collection '{self.collection_name}'...")
        
        points = []
        
        for idx, doc in enumerate(embeddata.documents):
            point = models.PointStruct(id=idx,  # Unique ID for each point
                                       vector={
                                           "text": embeddata.text_embeds[idx], 
                                           "image": embeddata.image_embeds[idx]
                                           },
                                       payload=doc  # Original image and its caption
                                       )
        
            points.append(point)

        self.client.upload_points(collection_name=self.collection_name, points=points)
        
        print(f"Uploaded {len(points)} points to collection '{self.collection_name}'.")

In [None]:
vector_db = QdrantVDB("linkedin-posts",
                      embeddata.image_embed_dim,
                      embeddata.text_embed_dim)

vector_db.create_collection()

vector_db.upload_embeddings(embeddata)

# Define Retriever class

In [19]:
class Retriever:

    def __init__(self, vector_db, embeddata):
        
        self.vector_db = vector_db
        self.embeddata = embeddata

    def search(self, query, limit=3):
        query_embedding = list(self.embeddata.embed_texts(query))[0]

        result = self.vector_db.client.search(
                collection_name=self.vector_db.collection_name,
                query_vector=("image", query_embedding),
                with_payload=["image", "text"], 
                limit=limit
            )

        return result

In [None]:
query = "What are some examples of Graph-based clustering algorithms?"

result = Retriever(vector_db, embeddata).search(query, limit=1)

for i in result:
    print(i.payload["text"])
    display(Image.open(i.payload["image"]))

# RAG Class

In [27]:
import ollama

class RAG:

    def __init__(self,
                 retriever,
                 llm_name="llama3.2-vision"):
        
        self.llm_name = llm_name
        self.retriever = retriever
        self.qa_prompt_tmpl_str = """Context information is below.
                                     ---------------------
                                     {context}
                                     ---------------------

                                     Some images may also be available to you
                                     for answering the question better. You have
                                     to undersatnd those images thoroughly and 
                                     extra all relevant information that might 
                                     help you answer the query better.

                                     ---------------------
                                     
                                     Given the context information above I want you
                                     to think step by step to answer the query in a
                                     crisp manner, incase case you don't know the
                                     answer say 'I don't know!'
                                     
                                     ---------------------
                                     
                                     Query: {query}
                                     
                                     ---------------------
                                     Answer: """
    
    def generate_context(self, query):
    
        result = self.retriever.search(query)
        context = [dict(data) for data in result]
        combined_prompt = []

        for entry in context:
            context = entry["payload"]["text"]

            combined_prompt.append(context)

        return "\n\n---\n\n".join(combined_prompt), result
    
    def query(self, query):
        context, result = self.generate_context(query=query)
        
        prompt = self.qa_prompt_tmpl_str.format(context=context,
                                                query=query)

        messages = [
                {
                    "role": "user",
                    "content": prompt,
                    "images": [result[0].payload['image']]
                },
            ]
        
        response = ollama.chat(model=self.llm_name, messages=messages)
    
        return response

# Using RAG

In [24]:
retriever = Retriever(vector_db, embeddata)

rag = RAG(retriever)

In [None]:
query = """What are some examples of
           Graph-based clustering algorithms?"""

response = rag.query(query)

In [None]:
query = """Are there any ways to
           speed up native Python code?"""

response = rag.query(query)

In [None]:
query = """What is the mathematics behind the
           kernel trick? Show me a step-by-step
           explanation with the polynomial
           kernel and two 2D vectors."""

response = rag.query(query)