<a href="https://colab.research.google.com/github/Hung369/LLM-Engineering/blob/main/RAG_Reranking_LLAMA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## **Llama 3.1 70B  Stark chatbot**
### Author: Nguyễn Mạnh Hùng.
### Date: 16/02/2025

#Installation

Install ***Ollama*** server and serve **Llama-3.1-70B-Instruct** on your local server via *Linux Terminal CMD* if you have *Colab Pro/Pro+* version

If you don't have Pro/Pro+ version please run this cell bellow

In [None]:
#https://pypi.org/project/colab-xterm/ - install terminal extension for collab
!pip install colab-xterm
%load_ext colabxterm

In [None]:
%xterm
 # Run these 3 cmd in terminal
 # curl https://ollama.ai/install.sh | sh
 # ollama serve & ollama pull llama3.1:70b
 # ollama pull llama3.1:8b

Install langchain packages

In [1]:
!pip install langchain-community
!pip install -U langchain-ollama
!pip install langchain-chroma
!pip install rich
!pip install transformers -U

Collecting langchain-community
  Downloading langchain_community-0.3.17-py3-none-any.whl.metadata (2.4 kB)
Collecting dataclasses-json<0.7,>=0.5.7 (from langchain-community)
  Downloading dataclasses_json-0.6.7-py3-none-any.whl.metadata (25 kB)
Collecting pydantic-settings<3.0.0,>=2.4.0 (from langchain-community)
  Downloading pydantic_settings-2.7.1-py3-none-any.whl.metadata (3.5 kB)
Collecting httpx-sse<1.0.0,>=0.4.0 (from langchain-community)
  Downloading httpx_sse-0.4.0-py3-none-any.whl.metadata (9.0 kB)
Collecting marshmallow<4.0.0,>=3.18.0 (from dataclasses-json<0.7,>=0.5.7->langchain-community)
  Downloading marshmallow-3.26.1-py3-none-any.whl.metadata (7.3 kB)
Collecting typing-inspect<1,>=0.4.0 (from dataclasses-json<0.7,>=0.5.7->langchain-community)
  Downloading typing_inspect-0.9.0-py3-none-any.whl.metadata (1.5 kB)
Collecting python-dotenv>=0.21.0 (from pydantic-settings<3.0.0,>=2.4.0->langchain-community)
  Downloading python_dotenv-1.0.1-py3-none-any.whl.metadata (23 kB



In [23]:
!pip install rank-bm25


Collecting rank-bm25
  Downloading rank_bm25-0.2.2-py3-none-any.whl.metadata (3.2 kB)
Downloading rank_bm25-0.2.2-py3-none-any.whl (8.6 kB)
Installing collected packages: rank-bm25
Successfully installed rank-bm25-0.2.2


# Inference
## After installation, let's start with the inference code

Initialize the system prompt

In [2]:
prompt = """You are Stark, an AI assistant, responding only in English with concise replies.
Reject unrelated languages politely.

{context}"""

Write a inference class for chat bot with inference and chat memory

In [3]:
import os
import glob
from langchain.document_loaders import DirectoryLoader, TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.schema import Document
from langchain_ollama import OllamaEmbeddings
from langchain_chroma import Chroma

from langchain.retrievers.contextual_compression import ContextualCompressionRetriever

from rich.console import Console
from rich.markdown import Markdown

In [4]:
class Document:
    def __init__(self, page_content, metadata=None):
        self.page_content = page_content
        self.metadata = metadata or {}

In [5]:
def load_documents(directory):
    documents = []
    for file_path in glob.glob(os.path.join(directory, "*.txt")):
        with open(file_path, "r", encoding="utf-8") as f:
            content = f.read()
        metadata = {"doc_type": os.path.splitext(os.path.basename(file_path))[0]}
        documents.append(Document(page_content=content, metadata=metadata))
    return documents

In [6]:
def print_markdown_console(markdown_string):
    console = Console()
    markdown = Markdown(markdown_string)
    console.print(markdown)

In [7]:
documents = load_documents(directory="data")
print(f"Total number of documents: {len(documents)}")

# Split documents into chunks
text_splitter = RecursiveCharacterTextSplitter(chunk_size=175, chunk_overlap=20)
chunks = text_splitter.split_documents(documents)

print(f"Total number of chunks: {len(chunks)}")
print(f"Document types found: {set(doc.metadata['doc_type'] for doc in documents)}")


Total number of documents: 5
Total number of chunks: 96
Document types found: {'Toraripi', 'Loss_Cut', 'Micro Strategy', 'Config_Concept', 'Semi_Auto'}


In [8]:
embeddings = OllamaEmbeddings(model="llama3.1:8b")
vector_store = Chroma.from_documents(documents=chunks, embedding=embeddings, persist_directory="./vector_db")

In [9]:
print(f"Vectorstore created with {vector_store._collection.count()} documents")
retriever = vector_store.as_retriever()

Vectorstore created with 96 documents


In [33]:
from typing import List
from pydantic import BaseModel, Field
from langchain.schema import BaseChatMessageHistory
from langchain.prompts.chat import (
    ChatPromptTemplate,
    SystemMessagePromptTemplate,
    HumanMessagePromptTemplate,
)
from langchain.chains import LLMChain
from langchain_ollama import ChatOllama
from typing import List
from langchain.docstore.document import Document
from rank_bm25 import BM25Okapi

# -----------------------------------------------------------------------------------
# 1. Define in-memory chat history storage
# -----------------------------------------------------------------------------------
class InMemoryHistory(BaseChatMessageHistory, BaseModel):
    """Stores the conversation messages in memory."""
    messages: List[str] = Field(default_factory=list)

    def add_messages(self, messages: List[str]) -> None:
        self.messages.extend(messages)

    def clear(self) -> None:
        self.messages = []

# -----------------------------------------------------------------------------------
# 2. Define the ChatBot class that orchestrates the RAG re-ranking flow
# -----------------------------------------------------------------------------------
class ChatBot:
    """
    A chatbot that:
      - modifies the user query for better retrieval,
      - retrieves candidate documents,
      - re-ranks them,
      - and finally answers using the top-ranked docs.
    """
    def __init__(
        self,
        model_name: str = "llama3.1:8b",
        system_prompt: str = "You are a helpful assistant that answers questions based on the provided context.",
        retriever = None,  # We'll assign the actual retriever from outside
        **kwargs
    ):
        # Maintain session-based histories in a dictionary
        self.store = {}

        # Create an Ollama-based chat model
        self.chat_model = ChatOllama(model=model_name, **kwargs)

        # Store the retriever
        self.retriever = retriever

        # ------------------ Step 1: "modify" chain (refine user query) ------------------
        self.modify_prompt = ChatPromptTemplate.from_messages([
            SystemMessagePromptTemplate.from_template(
                "You are a system that refines user queries for better search results."
            ),
            HumanMessagePromptTemplate.from_template("{user_question}")
        ])
        self.modify_chain = LLMChain(llm=self.chat_model, prompt=self.modify_prompt)

        # ------------------ Step 3: "ranking" chain (re-rank retrieved docs) -----------
        self.ranking_prompt = ChatPromptTemplate.from_messages([
            SystemMessagePromptTemplate.from_template(
                "You are a system that re-ranks documents by relevance to the user query."
            ),
            HumanMessagePromptTemplate.from_template(
                "User query:\n{user_question}\n\nDocuments:\n{documents}\n\n"
                "Please return a sorted list of document indices in order of relevance."
            )
        ])
        self.ranking_chain = LLMChain(llm=self.chat_model, prompt=self.ranking_prompt)

        # ------------------ Step 4: "chat" chain (final answer) -------------------------
        self.qa_prompt = ChatPromptTemplate.from_messages([
            SystemMessagePromptTemplate.from_template(system_prompt),
            HumanMessagePromptTemplate.from_template(
                "Question: {user_question}\n\n"
                "Relevant context:\n{context}\n\n"
                "Answer concisely and helpfully."
            ),
        ])
        self.qa_chain = LLMChain(llm=self.chat_model, prompt=self.qa_prompt)

    def get_session_history(self, session_id: str) -> InMemoryHistory:
        """Returns the session history for a given session_id."""
        if session_id not in self.store:
            self.store[session_id] = InMemoryHistory()
        return self.store[session_id]

    # -----------------------------------------------------------------------------------
    # 3. The main chat method that executes the entire RAG flow
    # -----------------------------------------------------------------------------------
    def chat(self, user_message: str, session_id: str) -> str:
        """
        Executes the 4-step pipeline:
          1) modify user query
          2) retrieve docs
          3) re-rank docs
          4) produce final answer using top-ranked docs
        """
        # ------------------ Step 1: Modify user query -----------------------------------
        modified_question = self.modify_chain.invoke({"user_question": user_message})

        # ------------------ Step 2: Retrieve documents (Asked Search) -------------------
        # Get the modified question as a string
        modified_question_str = modified_question['text']

        documents = self.retriever.get_relevant_documents(modified_question_str) if self.retriever else [] # Pass the string to the retriever
        # Convert documents to text for the ranking chain
        doc_text_list = [
            f"Doc {i}: {doc.page_content}" for i, doc in enumerate(documents)
        ]
        docs_text = "\n".join(doc_text_list)

        # ------------------ Step 3: Re-rank documents -----------------------------------
        if documents:
            ranked_indices = self._bm25_rerank(modified_question, documents)
            ranked_docs = [documents[i] for i in ranked_indices]
        else:
            ranked_docs = []

        # Select the top N documents (for example, top 3)
        top_docs = ranked_docs[:3]
        context = "\n\n".join([doc.page_content for doc in top_docs])

        # Take top docs (e.g. top 3)
        top_docs = ranked_docs[:3]
        context = "\n\n".join([doc.page_content for doc in top_docs])

        # ------------------ Step 4: Generate final answer (chat) ------------------------
        final_answer = self.qa_chain.invoke({
            "user_question": user_message,
            "context": context
        })

        # ------------------ Store the exchange in session history -----------------------
        session_history = self.get_session_history(session_id)
        session_history.add_messages([
            f"User: {user_message}",
            f"Assistant: {final_answer}"
        ])

        return final_answer

    def _bm25_rerank(self, query: str, documents: List[Document]) -> List[int]:
        """
        Re-ranks the candidate documents using BM25 scoring.
        Returns a list of document indices sorted in descending order of relevance.
        """
        corpus = [doc.page_content.split() for doc in documents]
        bm25 = BM25Okapi(corpus)

        # Access the 'text' element of the query dictionary to get the string
        query_text = query['text']

        # Tokenize the query_text
        query_tokens = query_text.split()
        scores = bm25.get_scores(query_tokens)

        # Sort indices based on scores in descending order
        ranked_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)
        return ranked_indices

    def show_chat_history(self, session_id: str):
        """
        Prints out the conversation history for debugging or display purposes.
        """
        if session_id in self.store:
            for msg in self.store[session_id].messages:
                print(msg)
        else:
            print("No chat history found for this session.")


In [37]:
# Example usage
bot = ChatBot(
    model_name="llama3.1:70b",
    system_prompt=prompt,
    retriever=retriever,
    base_url="http://localhost:11434"  # Hosted URL
)

Use chat method with message to start the conversation

In [38]:
response = bot.chat("hello, who are you?", session_id="unique_session_id_123")
print_markdown_console(response['text'])

In [42]:
response = bot.chat("What is Loss Cut?", session_id="unique_session_id_123")
print_markdown_console(response['text'])

In [40]:
response = bot.chat("In what scenarios is the use of Micro Strategy considered appropriate or beneficial?", session_id="unique_session_id_123")
print_markdown_console(response['text'])

In [44]:
response = bot.chat("What is Stop Hunting?", session_id="unique_session_id_123")
print_markdown_console(response['text'])