In [1]:
import os
import gradio as gr
from dotenv import load_dotenv
from langchain_experimental.text_splitter import SemanticChunker
from langchain_cohere import CohereEmbeddings
from langchain_community.document_loaders import PyPDFium2Loader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams
from dspy.retrieve.qdrant_rm import QdrantRM
import dspy
from dspy.retrieve.qdrant_rm import QdrantRM
from qdrant_client import QdrantClient

load_dotenv()

  from .autonotebook import tqdm as notebook_tqdm


True

In [2]:
# Initialize Embeddings
embeddings = CohereEmbeddings(
    cohere_api_key=os.environ["COHERE_API_KEY"], model="embed-multilingual-light-v3.0"
)

In [3]:
# Initialize Text Splitter
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=2048, chunk_overlap=128, add_start_index=True
)

In [4]:
# Initialize Semantic Chunker
semantic_splitter = SemanticChunker(
    embeddings=embeddings, breakpoint_threshold_type="interquartile"
)

In [5]:
# Load documents
documents = []
loaded_documents = PyPDFium2Loader("./data/saudi_vision2030_ar.pdf").load()
document_text = "\n".join([doc.page_content for doc in loaded_documents])
documents.extend(
    text_splitter.split_documents(semantic_splitter.create_documents([document_text]))
)
chunks = [doc.page_content for doc in documents]
doc_id = list(range(1, len(documents) + 1))
vectors = embeddings.embed_documents(chunks)




In [6]:
# Initialize Qdrant client
client = QdrantClient(":memory:")

In [7]:
# Create Qdrant collection
client.delete_collection(collection_name="data")
client.create_collection(
    collection_name="data",
    vectors_config=VectorParams(size=384, distance=Distance.COSINE),
)

# Upload data to Qdrant
client.upload_collection(collection_name="data", ids=doc_id, vectors=vectors)

In [8]:
# Initialize Retriever


retriever_model = QdrantRM(qdrant_collection_name="data", qdrant_client=client, k=3)

100%|██████████| 77.7M/77.7M [00:06<00:00, 12.3MiB/s]


In [9]:
# Initialize LLM
lm = dspy.Cohere(model="command-r-plus", api_key=os.environ["COHERE_API_KEY"])

In [10]:
# Configure dspy module
dspy.settings.configure(lm=lm, rm=retriever_model)

In [11]:
# function to retrieve the best matching data chunk stored in the database
def get_context(text):
    query_vector = embeddings.embed_query(text)

    hits = client.search(collection_name="data", query_vector=query_vector, limit=3)

In [12]:
# Initialize Answer Generator
class GenerateAnswerWithContext(dspy.Signature):
    """You are a research assistant. Use the provided document snippets to
        answer the query. Format your response with citations in structured JSON format:
        <response format>
        {{
        "response":"Your response here.",
        "citations":[
            {{
                "title":"Document Title",
                "snippet":"Exact snippet from the document"
            }}]
        }}
        </response format>

    IMPORTANT CITATION RULES:
    1. Each citation MUST be a complete sentence or phrase from the original text.
    2. Citations MUST be VERBATIM and EXACT quotes from the provided documents.
    3. DO NOT use ellipses (...) or any other shortening techniques in citations.
    4. DO NOT paraphrase or modify the original text in any way for citations.
    5. if you need  to use multiple sentences in citations, include them in full.
    6. USE MULTIPLE citations when necessary to fully support your response.
    7. Ensure that each citation DIRECTLY supports a specific part of your response.
    8. if you cannot find relevant information in the provided documents, state this clearly in your response.

    Example of a correct response with multiple, relevant citations:
    {{
    "response": "The mockingbird has unique vocal abilities [1]. It uses these abilities for various purposes, including defending its territory [2].",
    "citations": [
        {{
            "title": "Mockingbird Study",
            "snippet": "The mockingbird is known for its ability to mimic the calls of other birds even mechanical sounds."
        }},
        {{
            "title": "Mockingbird Behavior",
            "snippet": "Mockingbirds use their diverse vocal repertoire to defend their territories from intruders and to attract mates."
        }}
    ]
    }}

    Remember:
    1. Citations must be EXACT, COMPLETE sentences or phrases from the provided text.
    2. Do not modify, shorten, or paraphrase the original text in your citations.
    3. Use multiple citations when necessary to fully support your response.
    4. Ensure each citation is directly relevant to the part of your response that it supports.
    5. if you cannot find relevant information in the provided documents, clearly state this in your response.

    Now, please answer the given query using the provided information and following these guidelines.
    """

    context = dspy.InputField(
        desc="Use the provided document snippets to answer the query"
    )
    question = dspy.InputField()
    answer = dspy.OutputField(
        desc="Format your response with citations in structured JSON format in arabic language"
    )

In [13]:
# Initialize dspy module
class RAG(dspy.Module):
    def __init__(self, num_passages=3):
        super().__init__()
        self.retrieve = dspy.Retrieve(k=num_passages)
        self.generate_answer = dspy.ChainOfThought(GenerateAnswerWithContext)

    def forward(self, question):
        context = get_context(question)
        prediction = self.generate_answer(context=context, question=question)
        return dspy.Prediction(context=context, answer=prediction.answer)

In [14]:
# Initialize RAG
rag = RAG()

In [15]:

## chatbot interface
def chatbot_interface(user_input, history):
    response = rag(user_input)
    return response.answer


iface = gr.ChatInterface(
    fn=chatbot_interface,
    title="DSPY Chatbot",
    description="Ask me about anything about Saudi arabia vision 2023",
)

iface.launch()

Running on local URL:  http://127.0.0.1:7865

To create a public link, set `share=True` in `launch()`.


