In [19]:
import os
import gradio as gr
from dotenv import load_dotenv
from langchain_experimental.text_splitter import SemanticChunker
from langchain_cohere import CohereEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams
from dspy.retrieve.qdrant_rm import QdrantRM
import dspy
from langchain_community.document_loaders import UnstructuredMarkdownLoader
from qdrant_client import QdrantClient
import numpy as np
from nltk.tokenize import word_tokenize
from rank_bm25 import BM25Okapi
from sklearn.metrics.pairwise import cosine_similarity
import json
import logging
import re
from fuzzywuzzy import fuzz

load_dotenv()

True

In [2]:
# Initialize Embeddings
embeddings = CohereEmbeddings(
    cohere_api_key=os.environ["COHERE_API_KEY"], model="embed-multilingual-light-v3.0"
)

In [3]:
# Initialize Text Splitter
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=2048, chunk_overlap=128, add_start_index=True
)

In [4]:
# Initialize Semantic Chunker
semantic_splitter = SemanticChunker(
    embeddings=embeddings, breakpoint_threshold_type="interquartile"
)

In [5]:
# Load documents
documents = []

loaded_documents = UnstructuredMarkdownLoader(
    "/Users/hassn-/Desktop/dspy-chatbot-with-citation/data/saudi_vision2030_ar.pdf.md"
).load()

document_text = "\n".join([doc.page_content for doc in loaded_documents])

documents.extend(
    text_splitter.split_documents(semantic_splitter.create_documents([document_text]))
)

chunks = [doc.page_content for doc in documents]

doc_id = list(range(1, len(documents) + 1))

vectors = embeddings.embed_documents(chunks)

In [6]:
# Initialize Qdrant client
client = QdrantClient(":memory:")

In [7]:
# Create Qdrant collection
client.delete_collection(collection_name="data")
client.create_collection(
    collection_name="data",
    vectors_config=VectorParams(size=384, distance=Distance.COSINE),
)

# Upload data to Qdrant
client.upload_collection(collection_name="data", ids=doc_id, vectors=vectors)

In [8]:
# Initialize Retriever
retriever_model = QdrantRM(qdrant_collection_name="data", qdrant_client=client, k=3)

In [9]:
# Initialize LLM
lm = dspy.Cohere(model="command-r-plus", api_key=os.environ["COHERE_API_KEY"])

In [10]:
# Configure dspy module
dspy.settings.configure(lm=lm, rm=retriever_model)

In [11]:
# Define helper functions
def select_relevant_chunks_bm25(query, chunks, top_n):
    tokenized_query = word_tokenize(query.lower())
    tokenized_chunks = [word_tokenize(chunk.lower()) for chunk in chunks]
    bm25 = BM25Okapi(tokenized_chunks)
    scores = bm25.get_scores(tokenized_query)
    top_n_indices = np.argsort(scores)[::-1][:top_n]
    return [chunks[i] for i in top_n_indices]


def select_relevant_chunks_cosine(query, chunks, vectors, top_n):
    query_embedding = embeddings.embed_query(query)
    cosine_similarities = cosine_similarity([query_embedding], vectors)[0]
    top_n_indices = np.argsort(cosine_similarities)[::-1][:top_n]
    return [chunks[i] for i in top_n_indices]


def merge_top_chunks(bm25_chunks, cosine_chunks):
    return list(dict.fromkeys(bm25_chunks + cosine_chunks))


def normalize_text(text):
    return " ".join(re.sub(r"[^\w\s]", "", text.lower()).split())


def verify_citations(response_dict, top_chunks):
    citations = response_dict.get("citations", [])
    if not citations:
        print("Error: No citations found in response")
        return False, []
    matches = []
    normalized_chunks = [normalize_text(chunk) for chunk in top_chunks]

    for citation in citations:
        cited_text = citation.get("snippet", "")
        if not cited_text:
            print(f"Error: Missing snippet in citation: {citation}")
            return False, citations
        normalized_citation = normalize_text(cited_text)

        match_found = False
        for chunk in normalized_chunks:
            if (
                normalized_citation in chunk
                or fuzz.partial_ratio(normalized_citation, chunk) > 90
            ):
                match_found = True
                break
        matches.append(match_found)
    return all(matches), citations

In [12]:
# Define dspy signature
class Generate_answer(dspy.Signature):
    """Answer the questions with citations"""

    context = dspy.InputField(desc="system prompt with context and instructions")
    question = dspy.InputField(desc="question to answer")
    answer = dspy.OutputField(desc="JSON formatted answer with citations")

In [13]:
# Define dspy module
class RAG(dspy.Module):
    def __init__(self, num_passages=3):
        super().__init__()
        self.retrieve = dspy.Retrieve(k=num_passages)
        self.generate_answer = dspy.ChainOfThought(Generate_answer)

    def forward(self, question):
        bm25_chunks = select_relevant_chunks_bm25(question, chunks, 2)
        cosine_chunks = select_relevant_chunks_cosine(question, chunks, vectors, 2)
        top_chunks = merge_top_chunks(bm25_chunks, cosine_chunks)

        context = "\n".join(top_chunks)
        system_prompt = f"""You are a research assistant. Use the provided document snippets to
        answer the query. Format your response with citations in structured JSON format:
        <response format>
        {{
        "response":"Your response here.",
        "citations":[
            {{
                "title":"Document Title",
                "snippet":"Exact snippet from the document"
            }}]
        }}
        </response format>

        IMPORTANT CITATION RULES:
        1. Each citation MUST be a complete sentence or phrase from the original text. 
        2. Citations MUST be VERBATIM and EXACT quotes from the provided documents. 
        3. DO NOT use ellipses (...) or any other shortening techniques in citations.
        4. DO NOT paraphrase or modify the original text in any way for citations.
        5. If you need to use multiple sentences in citations, include them in full. 
        6. USE MULTIPLE citations when necessary to fully support your response. 
        7. Ensure that each citation DIRECTLY supports a specific part of your response. 
        8. If you cannot find relevant information in the provided documents, state this clearly in your response. 

        Here are the relevant documents for your query:
        {context}

        Remember:
        1. Citations must be EXACT, COMPLETE sentences or phrases from the provided text.
        2. Do not modify, shorten, or paraphrase the original text in your citations.
        3. Use multiple citations when necessary to fully support your response.
        4. Ensure each citation is directly relevant to the part of your response that it supports.
        5. If you cannot find relevant information in the provided documents, clearly state this in your response.

        Now, please answer the given query using the provided information and following these guidelines.
        """

        prediction = self.generate_answer(context=system_prompt, question=question)

        # Log the raw response for debugging
        logging.info(f"Raw LLM response: {prediction.answer}")

        # Extract the answer and citations from the incomplete response
        extracted_answer, extracted_citations = self.extract_answer_and_citations(
            prediction.answer
        )

        if extracted_answer:
            citation_check, verified_citations = self.verify_citations(
                extracted_citations, top_chunks
            )
            if citation_check:
                return dspy.Prediction(
                    context=context,
                    answer=extracted_answer,
                    citations=verified_citations,
                )
            else:
                logging.warning("Citations could not be verified.")
                return dspy.Prediction(
                    context=context,
                    answer=extracted_answer,
                    citations=verified_citations,
                    warning="Citations could not be verified.",
                )
        else:
            logging.error("Failed to extract answer from LLM response.")
            return dspy.Prediction(
                context=context,
                answer="Failed to generate a valid response.",
                citations=[],
                error="Invalid response from LLM.",
            )

    def extract_answer_and_citations(self, text):
        # Extract everything after '"response": "' as the answer
        answer_match = re.search(r'"response"\s*:\s*"(.+)', text, re.DOTALL)
        if answer_match:
            answer = answer_match.group(1).strip().rstrip('",')
        else:
            answer = None

        # Attempt to extract citations, but this might fail if the response is cut off
        citations = []
        citation_matches = re.findall(r'"snippet"\s*:\s*"([^"]+)"', text)
        for match in citation_matches:
            citations.append({"title": "Document", "snippet": match})

        return answer, citations

    def verify_citations(self, citations, chunks):
        verified_citations = []
        for citation in citations:
            snippet = citation["snippet"]
            if any(snippet in chunk for chunk in chunks):
                verified_citations.append(citation)
            else:
                logging.warning(f"Citation not found in chunks: {snippet}")
        return len(verified_citations) == len(citations), verified_citations

In [14]:
# Initialize the RAG 
rag = RAG()

In [15]:
## chatbot interface
def chatbot_interface(user_input, history):
    response = rag(user_input)
    return response.answer


iface = gr.ChatInterface(
    fn=chatbot_interface,
    title="DSPY Chatbot",
    description="Ask me anything about Saudi arabia vision 2023",
)

iface.launch()

Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.


