<a href="https://colab.research.google.com/github/Shariar076/notebook-snapshots/blob/main/RAG_chat_inteface.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%capture
! pip install gradio
! pip install faiss-gpu
! pip install git+https://github.com/csebuetnlp/normalizer
! pip install -U bitsandbytes

In [None]:
import os
os.kill(os.getpid(), 9)

In [None]:
import huggingface_hub
huggingface_hub.login('')

In [None]:
import re
import torch
from typing import List, Tuple
import time
import pandas as pd
import faiss
from normalizer import normalize
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, BitsAndBytesConfig


In [None]:
class Retriever:
  def __init__(self, doc_path, doc_col):
    df = pd.read_csv(doc_path, sep='\t')
    self.document_df = pd.DataFrame()
    self.document_df['doc'] =  df[doc_col]
    self.document_df = self.document_df[self.document_df['doc'].str.len() > 0]
    self.embedding_model = SentenceTransformer("thenlper/gte-large")
    self.bn_mt_model = AutoModelForSeq2SeqLM.from_pretrained("csebuetnlp/banglat5_nmt_en_bn")
    self.bn_mt_tokenizer = AutoTokenizer.from_pretrained("csebuetnlp/banglat5_nmt_en_bn", use_fast=False)
    self.en_mt_model = AutoModelForSeq2SeqLM.from_pretrained("csebuetnlp/banglat5_nmt_bn_en")
    self.en_mt_tokenizer = AutoTokenizer.from_pretrained("csebuetnlp/banglat5_nmt_bn_en", use_fast=False)


  def update_embeddings(self):
    docs = self.document_df['doc'].to_numpy()
    embeddings = self.embedding_model.encode(docs)
    dimension = embeddings.shape[1]
    self.index = faiss.IndexFlatL2(dimension)
    self.index.add(embeddings)

  def vector_index_search(self, user_query, k):
    query_embedding = self.embedding_model.encode([user_query])
    D, I = self.index.search(query_embedding, k)
    return self.document_df.iloc[I[0]].values


  def get_en2bn_text(self, en_text):
      input_ids = self.bn_mt_tokenizer(normalize(en_text), return_tensors="pt").input_ids
      generated_tokens = self.bn_mt_model.generate(input_ids)
      decoded_tokens = self.bn_mt_tokenizer.batch_decode(generated_tokens)[0].replace("<pad>", "").replace("</s>", "").strip()
      return decoded_tokens

  def get_bn2en_text(self, bn_text):
      input_ids = self.en_mt_tokenizer(normalize(bn_text), return_tensors="pt").input_ids
      generated_tokens = self.en_mt_model.generate(input_ids)
      decoded_tokens = self.en_mt_tokenizer.batch_decode(generated_tokens)[0].replace("<pad>", "").replace("</s>", "").strip()
      return decoded_tokens

In [None]:
class LlmChatBot:
    def __init__(self, model_name: str = "meta-llama/Llama-2-7b-chat-hf"):
        """
        Initialize the Llama 2 chatbot.

        Args:
            model_name (str): Name or path of the Llama 2 model to use
        """
        print("Loading model and tokenizer...")
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)

        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            # load_in_8bit=use_8_bit,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=getattr(torch, "bfloat16"),
            bnb_4bit_use_double_quant=True,
        )

        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            quantization_config=bnb_config,
            device_map="auto",
            torch_dtype=torch.float16,
        )

        # Set default parameters
        self.max_length = 2048
        self.temperature = 0.7
        self.top_p = 0.95
        # retriever
        self.doc_retriever = Retriever("/content/vat_qna_dataset.tsv", "answer_en")
        self.doc_retriever.update_embeddings()
        # System prompt template
        self.system_prompt_full = ("You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe."
                                  "If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct."
                                  "If you don't know the answer to a question, please don't share false information."
                                  "Please use the provided context to answer questions. If you need to refer to the context use proper quotation.")
        self.system_prompt_partial = "Please continue the conversation using the provided contexts. If you need to refer to the context use proper quotation.."
    def build_prompt(self, messages: List[Tuple[str, str]], context: List[str] = None) -> str:
        """
        Build the prompt from conversation history and context documents.

        Args:
            messages (List[Tuple[str, str]]): List of (role, content) tuples
            context (List[str], optional): List of relevant context passages

        Returns:
            str: Formatted prompt for the model
        """

        if len(messages) == 1:
          prompt = f"<s>[INST] <<SYS>>\n{self.system_prompt_full}\n<</SYS>>\n\n"
        else:
          prompt = f"<s>[INST] <<SYS>>\n{self.system_prompt_partial}\n<</SYS>>\n\n"

        # Add context if provided
        if context is not None and len(context) > 0:
            prompt += "Here is some relevant context to help answer the question:\n\n"
            for i, doc in enumerate(context, 1):
                prompt += f"Context {i}:\n{doc[0].strip()}\n\n"
            prompt += "Please use this context to provide an informed answer.\n\n"

        for i, (role, content) in enumerate(messages):
            if role == "user":
                prompt += f"{self.doc_retriever.get_bn2en_text(content) if not self.is_english(content) else content} [/INST] "
            else:  # assistant
                prompt += f"{self.doc_retriever.get_bn2en_text(content) if not self.is_english(content) else content} </s><s>[INST] "

        return prompt

    def is_english(self, text):
        english_pattern = r'^[a-zA-Z0-9\s.,!?\'"()-]+$'
        return bool(re.match(english_pattern, text))


    def generate_response(
        self,
        message: str,
        history: List[Tuple[str, str]],
        # docs: List[str] = None
    ) -> str:
        """
        Generate a response for the given message, conversation history, and context documents.

        Args:
            message (str): Current user message
            history (List[Tuple[str, str]]): Previous conversation history
            docs (List[str], optional): List of relevant context documents

        Returns:
            str: Model's response
        """
        # Convert history to list of (role, content) tuples
        message_en = self.doc_retriever.get_bn2en_text(message) if not self.is_english(message) else message
        messages = []
        for user_msg, assistant_msg in history[-2:]:
            messages.extend([("user", user_msg), ("assistant", assistant_msg)])
        messages.append(("user", message_en))

        docs = self.doc_retriever.vector_index_search(message_en, k=3)
        print("="*100)
        print(docs)
        print("="*100)
        # Build the prompt with context
        prompt = self.build_prompt(messages, context=docs)
        print("="*100)
        print(prompt)
        print("="*100)
        # Tokenize input
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)

        # Generate response
        with torch.inference_mode():
            outputs = self.model.generate(
                inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                max_length=self.max_length,
                temperature=self.temperature,
                top_p=self.top_p,
                do_sample=True,
                pad_token_id=self.tokenizer.eos_token_id
            )

        # Decode and clean response
        response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        response = response.split('[/INST]')[-1].strip()
        response = self.doc_retriever.get_en2bn_text(response) if not self.is_english(message) else response
        return response

In [None]:
import gradio as gr

def create_chat_interface(model_name) -> gr.Interface:
    """
    Create and configure the Gradio chat interface.

    Args:
        model_name (str): Name or path of the Llama 2 model to use

    Returns:
        gr.Interface: Configured Gradio interface
    """
    # Initialize chatbot
    chatbot = LlmChatBot(model_name)

    # Custom CSS for better appearance
    custom_css = """
    .message.user{
        background-color: #2563eb !important;
        color: white !important;
        border-radius: 15px !important;
        padding: 15px !important;
        margin: 10px 0 !important;
    }

    """

    # Create the interface
    interface = gr.ChatInterface(
        fn=chatbot.generate_response,
        title="Llama 2 Chat",
        description="Chat with Llama 2 - A powerful language model by Meta",
        theme=gr.themes.Soft(),
        examples=[
            ["কে ভ্যাট দেয়?"],
            ["বাংলাদেশে মূসক কবে থেকে চালু হয়েছে?"],
            ["আমি কেন ভ্যাট দেব?"]
        ],
        # retry_btn=None,
        # undo_btn="🔄 Undo",
        # clear_btn="🗑️ Clear",
    )

    # Add custom CSS
    interface.css = custom_css

    return interface


def main():
    """Main function to run the chat server."""
    # Create and launch the interface
    model_name: str = "meta-llama/Llama-2-7b-chat-hf"
    # model_name: str = "google/gemma-2-2b-it"
    interface = create_chat_interface(model_name)

    # Launch with share=True to get a public URL
    interface.launch(
        share=True,
        server_name="0.0.0.0",
        server_port=7860,
        debug=True
    ).launch(share=True)


if __name__ == "__main__":
    main()

Loading model and tokenizer...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://1f12fafbf9e7f0eed8.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


[['The Value Added Tax was introduced in Bangladesh on 1 July 1991. The new law on VAT will come into force with effect from 1 July 2017.']
 ['The rates of Tax applicable in Bangladesh are as follows:\n\nThe standard rate of VAT in Bangladesh is 15%.\nExport 0%.\nTurnover Tax applicable to turnover tax payer up to Taka 8 million is 3% (VAT is not applicable to them).\nSupplementary Duty at different rates on luxury goods and various services.\nNo other lower rates.']
 ['There is no geographical limitation for registration under the VAT law. Under the new procedure, no Trade Licence or any such document shall also be required for obtaining registration. Registration will be granted on the basis of Postal Codes. Therefore, registration can be obtained from any place in Bangladesh.']]
<s>[INST] <<SYS>>
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.If a question does not make any sense, or is not factually coherent, explain why

AttributeError: 'TupleNoPrint' object has no attribute 'launch'

In [None]:
def is_english(text):
    english_pattern = r'^[a-zA-Z0-9\s.,!?\'"()-]+$'
    return bool(re.match(english_pattern, text))

is_english("who?")

True