# Chatbot

<img src="https://raw.githubusercontent.com/Vishesh8/databricks-tests/refs/heads/main/training-images/rag.jpeg" width="1368">

In [0]:
%pip install -qU openai databricks-langchain langchain-chroma pypdf docarray gradio
dbutils.library.restartPython()

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
botocore 1.34.39 requires urllib3<2.1,>=1.25.4; python_version >= "3.10", but you have urllib3 2.3.0 which is incompatible.
google-api-core 2.18.0 requires protobuf!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0.dev0,>=3.19.5, but you have protobuf 5.29.4 which is incompatible.[0m[31m
[0m[43mNote: you may need to restart the kernel using %restart_python or dbutils.library.restartPython() to use updated packages.[0m


In [0]:
from openai import OpenAI
import os

DATABRICKS_TOKEN = dbutils.secrets.get(scope = "db-field-eng", key = "va-pat-token")

client = OpenAI(
  api_key=DATABRICKS_TOKEN,
  base_url="https://e2-demo-field-eng.cloud.databricks.com/serving-endpoints"
)

In [0]:
from databricks_langchain import ChatDatabricks
from databricks_langchain import DatabricksEmbeddings
from langchain_chroma import Chroma

In [0]:
# Set Temperature = 0 for generation model in our Q&A application for low variability and factual answers
llm = ChatDatabricks(endpoint="databricks-claude-3-7-sonnet", temperature=0)
embedding = DatabricksEmbeddings(endpoint="databricks-gte-large-en")

persist_directory = './data/docs/chroma/'
vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding)

print(vectordb._collection.count())

208


In [0]:
# Build prompt
from langchain.prompts import PromptTemplate

template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. Use three sentences maximum. Keep the answer as concise as possible. Always say "thanks for asking!" at the end of the answer. 
{context}
Question: {question}
Helpful Answer:"""

QA_CHAIN_PROMPT = PromptTemplate.from_template(template)

In [0]:
# Run chain
from langchain.chains import RetrievalQA

question = "Is probability a class topic?"

qa_chain = RetrievalQA.from_chain_type(
  llm=llm,
  retriever=vectordb.as_retriever(),
  return_source_documents=True,
  chain_type_kwargs={"prompt": QA_CHAIN_PROMPT}
  )

result = qa_chain.invoke(question)
result["result"]

'Probability is not a main class topic but is a prerequisite knowledge that students are expected to have. The instructor assumes familiarity with basic probability and statistics, including concepts like random variables, expectation, and variance, but will offer refreshers in discussion sections for those who need it. Thanks for asking!'

## Memory
`ConversationBufferMemory` keeps chat messages in the history and passes that to the chatbot along with the question everytime. We'll also specify a `memory_key` that will line up with one of the `input_variable` in the prompt. `return_messages` is set to `True` to return chat history as list of messages as opposed to a single string

In [0]:
from langchain.memory import ConversationBufferMemory

memory = ConversationBufferMemory(
  memory_key="chat_history",
  return_messages=True
)

## ConversationalRetrieval Chain

We'll introduce a new type of chain called `ConversationalRetrievalChain` that is similar to QARetrieval chain but also takes in a `LangChain 
memory` argument that allows us to take the chat history in the context for answering follow-up questions. Note that different `retrieval` strategies like `self-query`, `compression`, etc. as well as different `generation` approaches with various `chain_type` e.g., stuff, refine, etc. are supported by this chain

`ConversationalRetrievalChain` not only adds `memory` on `QARetrieverChain` but it also adds a step that takes the history, along with the new question, and condenses it to a standalone question to do the `retrieval`

In [0]:
import mlflow
mlflow.langchain.autolog()

In [0]:
from langchain.chains import ConversationalRetrievalChain

qa = ConversationalRetrievalChain.from_llm(
  llm=llm,
  retriever=vectordb.as_retriever(search_type="mmr"),
  memory=memory
)

In [0]:
question = "Is probability a class topic?"

result = qa.invoke(question)
result['answer']

'Yes, probability is a prerequisite for the class rather than a main topic. The instructor mentions that they assume familiarity with basic probability and statistics, including concepts like random variables, expectation, and variance. For students who need a refresher on probability, they mention that some discussion sections will go over these prerequisites. The class itself (CS229, which appears to be a machine learning course) builds on this probabilistic foundation rather than teaching probability as a primary topic.'

Trace(request_id=tr-b0f246a0b9d64ad792f27ad90d94245e)

In [0]:
question = "why are those prerequesites needed?"

result = qa.invoke(question)
result['answer']

"Based on the context provided, probability and statistics are prerequisites for the CS229 machine learning course because:\n\n1. The course builds on concepts like random variables, expectation, and variance\n2. These statistical concepts are fundamental to understanding machine learning algorithms and models\n3. The instructor mentions that an undergraduate statistics class like Stat 116 at Stanford would be sufficient preparation\n4. While the course offers some refresher material in discussion sections for students who haven't used statistics recently, the core lectures assume this background knowledge\n\nThe course appears to use these statistical foundations when developing concepts like the probabilistic interpretation of linear regression, which is mentioned as leading into logistic regression and other classification algorithms."

Trace(request_id=tr-6d5c4e8a149446f1b0e1bdf61b71b4ab)

In the above trace we see that before even retriever step, we also have a `chat_history` that is coming from the memory (`memory_key="chat_history"`). First LLM call rephrases the follow up question based on the chat history which also has essence of the first question asked.  This standalone question then goes through the usual RAG process to generate response

## Chatbot App

In [0]:
mlflow.langchain.autolog(disable=True)

In [0]:
from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter
from langchain.vectorstores import DocArrayInMemorySearch
from langchain.chains import RetrievalQA,  ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain.document_loaders import PyPDFLoader

In [0]:
def load_db(file, chain_type, k):

  """This function loads a pdf file into an in-memory vector store and returns a ConversationalRetrievalChain with teh defined chain type and number of chunks to retrieve.
  file: PDF file path to load on which to do Q&A
  chain_type: type of chain to use for generation
  k: number of chunks to retrieve"""

  # load documents
  loader = PyPDFLoader(file)
  documents = loader.load()
  
  # split documents
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
  docs = text_splitter.split_documents(documents)
  
  # define embedding
  embeddings = DatabricksEmbeddings(endpoint="databricks-gte-large-en")
  
  # create vector database from data
  db = DocArrayInMemorySearch.from_documents(docs, embeddings)
  
  # define retriever
  retriever = db.as_retriever(search_type="mmr", search_kwargs={"k": k})
  
  # create a chatbot chain. Memory is managed externally.
  qa = ConversationalRetrievalChain.from_llm(
    llm=ChatDatabricks(endpoint="databricks-claude-3-7-sonnet", temperature=0), 
    chain_type=chain_type, 
    retriever=retriever, 
    return_source_documents=True,
    return_generated_question=True,
  )
  
  return qa 

Note that in the above code, we are not passing the memory. We'll manage it externally for the convenience of GUI implementation. This means chat history will have to be managed outside the chain

In [0]:
import gradio as gr

class ChatSession:
    def __init__(self):
        self.loaded_file = "./data/docs/cs229_lectures/MachineLearning-Lecture01.pdf"
        self.chat_history = []
        self.db_query = ""
        self.db_response = []
        self.qa = load_db(self.loaded_file, "stuff", 4)

    def load_file(self, file_obj):
        if file_obj is None:
            return f"Loaded File: {self.loaded_file}"
        else:
            with open("temp.pdf", "wb") as f:
                f.write(file_obj.read())
            self.loaded_file = file_obj.name
            self.qa = load_db("temp.pdf", "stuff", 4)
            self.clear_history()
            return f"Loaded File: {self.loaded_file}"

    def process_query(self, query):
        if not query:
            return self.chat_history
        result = self.qa({"question": query, "chat_history": self.chat_history})
        # Append new exchange as a tuple (user, bot)
        self.chat_history.append((query, result["answer"]))
        self.db_query = result["generated_question"]
        self.db_response = result["source_documents"]
        return self.chat_history

    def get_db_query(self):
        if not self.db_query:
            return "No DB accesses so far."
        return f"**DB Query:** {self.db_query}"

    def get_db_sources(self):
        if not self.db_response:
            return "No DB response yet."
        sources_md = "**Result of DB lookup:**\n\n"
        for doc in self.db_response:
            sources_md += f"- {doc}\n"
        return sources_md

    def get_chat_history_text(self):
        if not self.chat_history:
            return "No chat history yet."
        history_md = ""
        for query, answer in self.chat_history:
            history_md += f"**User:** {query}\n\n**ChatBot:** {answer}\n\n"
        return history_md

    def clear_history(self):
        self.chat_history = []

# Global session instance
session = ChatSession()

# Wrapper functions for Gradio callbacks
def gr_load_file(file_obj):
    return session.load_file(file_obj)

def gr_process_query(query):
    conv = session.process_query(query)
    # The gr.Chatbot component accepts a list of (user, bot) pairs.
    return conv

def gr_get_db_query():
    return session.get_db_query()

def gr_get_db_sources():
    return session.get_db_sources()

def gr_get_chat_history():
    return session.get_chat_history_text()

def gr_clear_history():
    session.clear_history()
    return "Chat history cleared."

In [0]:
# Custom CSS for enhanced styling
custom_css = """
body {
    background-color: #f7f7f7;
    font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
}
h1 {
    color: #333;
    text-align: center;
}
.gradio-container {
    border-radius: 12px;
    box-shadow: 0 4px 12px rgba(0,0,0,0.1);
}
.gr-button {
    background-color: #4CAF50 !important;
    color: white !important;
    border-radius: 8px;
}
#chatbot {
    background-color: #ffffff;
    border-radius: 8px;
    padding: 10px;
}
"""

In [0]:
with gr.Blocks(css=custom_css, title="ChatWithYourData Bot") as demo:
    gr.Markdown("<h1>ChatWithYourData Bot</h1>")
    with gr.Tabs():
        # --- Conversation Tab ---
        with gr.Tab("Conversation"):
            with gr.Row():
                chatbot = gr.Chatbot(label="Conversation", elem_id="chatbot")
            with gr.Row():
                query_input = gr.Textbox(placeholder="Enter your question here...", label="Your Query")
                submit_btn = gr.Button("Submit", variant="primary")
            submit_btn.click(fn=gr_process_query, inputs=query_input, outputs=chatbot)
            query_input.submit(fn=gr_process_query, inputs=query_input, outputs=chatbot)
        
        # --- Database Tab ---
        with gr.Tab("Database"):
            with gr.Column():
                db_query_box = gr.Markdown(label="DB Query")
                db_sources_box = gr.Markdown(label="DB Sources")
                refresh_db_btn = gr.Button("Refresh Database Info", variant="primary")
            refresh_db_btn.click(fn=lambda: (gr_get_db_query(), gr_get_db_sources()),
                                 inputs=[], outputs=[db_query_box, db_sources_box])
        
        # --- Chat History Tab ---
        with gr.Tab("Chat History"):
            with gr.Column():
                chat_history_box = gr.Markdown(label="Chat History")
                refresh_history_btn = gr.Button("Refresh Chat History", variant="primary")
            refresh_history_btn.click(fn=gr_get_chat_history, inputs=[], outputs=chat_history_box)
        
        # --- Configure Tab ---
        with gr.Tab("Configure"):
            with gr.Row():
                file_input = gr.File(label="Upload PDF", file_types=[".pdf"])
                load_btn = gr.Button("Load DB", variant="primary")
            load_status = gr.Markdown(label="Load Status")
            load_btn.click(fn=gr_load_file, inputs=file_input, outputs=load_status)
            with gr.Row():
                clear_history_btn = gr.Button("Clear History", variant="secondary")
                clear_status = gr.Markdown(label="Clear History Status")
            clear_history_btn.click(fn=gr_clear_history, inputs=[], outputs=clear_status)
            gr.Markdown("Clears chat history. Use to start a new topic.")
            gr.Image("./data/img/convchain.jpg", label="Conversation Chain", show_label=True, elem_id="convchain_img", height=300)

demo.launch(share=True, debug=True)