Assemble the RAG chat UI using ipywidgets, handling greeting detection, meta-questions, query rewriting, prompt building, streaming LLM responses, and final rendering.

In [1]:
import sys, os, uuid, re, time, json, torch
import nbimporter

from collections import deque
from html import escape
from ipywidgets import widgets
from IPython.display import display, HTML
from jinja2 import Environment, FileSystemLoader
from typing import List, Dict

In [2]:
# Project root setup
notebooks_dir = os.getcwd()
project_root_candidate = os.path.abspath(os.path.join(notebooks_dir, os.pardir))
if os.path.isdir(os.path.join(project_root_candidate, "modules")):
    if project_root_candidate not in sys.path:
        sys.path.insert(0, project_root_candidate)
else:
    if os.path.isdir(os.path.join(notebooks_dir, "modules")):
        if notebooks_dir not in sys.path:
            sys.path.insert(0, notebooks_dir)

In [3]:
from retriever_logic import getRetriever

from modules.utils import sanitize_input, rewrite_user_query, build_context, token_is_valid, audit_response, log_interaction
from modules.qa import handle_greeting, handle_meta_question
from modules.model_manager import ModelManager



In [4]:
display(HTML("""
<style>
.message { margin: 5px 0; }
.user { color: #2c3e50; }
.assistant { color: #2980b9; }
.typing { color: #2980b9; }
.typing span { opacity: 0; animation: blink 1.5s infinite; color: #2980b9; }
.typing span:nth-child(1) { animation-delay: 0s; }
.typing span:nth-child(2) { animation-delay: 0.5s; }
.typing span:nth-child(3) { animation-delay: 1s; }
@keyframes blink { 0% { opacity: 0; } 50% { opacity: 1; } 100% { opacity: 0; } }
#conversation { font-family: Tahoma, sans-serif; padding: 10px; }
</style>
"""))

In [5]:
with open('../data/models.json', 'r') as f:
    models = json.load(f)

model_manager = ModelManager()
llm = None
prompt_template_key = None
model_dropdown = widgets.Dropdown(options=list(models.keys()), description='انتخاب مدل:', style={'description_width': 'initial'})

def on_model_change(change):
    global llm , prompt_template_key
    llm = None
    prompt_template_key = None
    model_key           = change['new']
    model_info          = models[model_key]
    model_path          = model_info["path"]
    prompt_template_key = model_info["prompt_template_key"];
    params              = model_info.get("params", {})
    model_manager.load_model(model_path, model_key, **params)
    llm = model_manager.get_current_model()


model_dropdown.observe(on_model_change, names='value')
initial_model = list(models.keys())[0]
#on_model_change({'new': initial_model})

In [6]:
history = deque(maxlen=100)
stop_generation = False
enable_thinking = True

In [7]:
def get_live_data():
    import datetime
    now = datetime.datetime.now()
    time_str = now.strftime("%H:%M:%S")
    date_str = now.strftime("%Y-%m-%d")
    weather = "آفتابی"  # Placeholder
    return f"زمان: {time_str}، تاریخ: {date_str}، آب و هوا: {weather}"

retriever = getRetriever()
def get_any_data(query: str):
#    rewritten_query = rewrite_user_query(query, llm)
    retrieved_docs = retriever.get_relevant_documents(query)
    context_chunks, context_html = build_context(retrieved_docs)
    retrieved_context.value = context_html
    
    context = " ".join([doc.page_content for doc in retrieved_docs])
    return context

No sentence-transformers model found with name HooshvareLab/bert-fa-base-uncased. Creating a new one with mean pooling.


FAISS retriever (k=100) ready.


In [8]:
def parse_assistant_response(response: str):
    thinking = ""
    tool_calls = []
    final_answer = ""

    think_match = re.search(r'<think>(.*?)</think>', response, re.DOTALL)
    if think_match:
        thinking = think_match.group(1).strip()

    tool_call_pattern = r'<tool_call>\s*(.*?)\s*</tool_call>'
    tool_call_matches = re.findall(tool_call_pattern, response, re.DOTALL)
    for match in tool_call_matches:
        try:
            tool_call = json.loads(match)
            tool_calls.append(tool_call)
        except json.JSONDecodeError:
            pass

    parts = re.split(r'</think>|</tool_call>', response)
    if parts:
        final_answer = parts[-1].strip()

    return thinking, tool_calls, final_answer

In [9]:
def flatten_history(history: List[Dict[str, str]]) -> List[Dict[str, str]]:
    flat_history: List[Dict[str, str]] = []
    for turn in history:
        role = turn.get("role", "")
        content = turn.get("content", "")
        if role == "tool":
            wrapped = f"<tool_response>\n{content}\n</tool_response>"
            flat_history.append({"role": "user", "content": wrapped})
        
        elif role in ("system", "user", "assistant"):
            flat_history.append({"role": role, "content": content})
        
        else:
            flat_history.append({"role": role, "content": content})
    
    return flat_history

In [10]:
name = "Persian Rag Assistant"

system_message = f"""
Your name is {name}, a Retrieval-Augmented Generation (RAG) assistant. Follow these rules precisely:

1. All user inputs will be in Persian. For each user message:
   a. If it is a simple greeting (e.g., 'سلام', 'خوبی؟'), respond with a brief Persian greeting (e.g., 'سلام! چطور می‌توانم کمک کنم؟') and do NOT call any tool.
   b. Otherwise, choose exactly one of the tools defined below and invoke it. Do NOT generate any other content.
   c. If you cannot find a suitable tool or there is not enough information, reply exactly: 'نمی‌دونم.'
2. OUTPUT RULE: All your final replies must be written fully in Persian (Persian script). Do NOT include any English words, Latin letters, or transliterations.
"""

history.append({"role": "system", "content": system_message})

tools = [
    {
        "name": "get_live_data",
        "description": """
        Retrieve live data such as:
          • Current time and date
          • Weather in a specified city
        Examples (Persian → tool):
          - 'الان ساعت چنده؟' → get_live_data
          - 'هوا امروز توی تهران چه‌جوریه؟' → get_live_data
        Non-examples:
          - 'در فایل PDF بخش سوم را پیدا کن' → NOT get_live_data
        """,
        "parameters": {
            "query": {
                "type": "string",
                "description": "Persian question asking for live data."
            }
        }
    },
    {
        "name": "get_any_data",
        "description": """
        Retrieve data from uploaded documents or a knowledge base. Use this tool when:
          • The user asks for content that must be looked up in a Persian document or index.
        Examples (Persian → tool):
          - 'در فایلِ PDF فصل سوم را پیدا کن و خلاصه‌اش را بگو.' → get_any_data
          - 'لیست ایمیل شرکت را از فایل اکسل استخراج کن.' → get_any_data
        Non-examples:
          - 'الان تاریخ چنده؟' → NOT get_any_data
        Parameters (if supporting multiple indices):
          • index: name of the document index or database to search (optional)
        """,
        "parameters": {
            "query": {
                "type": "string",
                "description": "Persian question for retrieving from documents."
            },
            "index": {
                "type": "string",
                "description": "Optional: name of the document index or customer-specific database."
            }
        }
    }
]

In [16]:
def build_prompt_templates(history: str) -> str:
    template_dir = "'../templates"
    env = Environment(loader=FileSystemLoader(template_dir), autoescape=False)
    env.filters["tojson"] = lambda value: json.dumps(value, sort_keys=False, ensure_ascii=False)
    template = env.get_template(prompt_template_key)
    context = {"tools": tools,"messages": flatten_history(history),"add_generation_prompt": True , "enable_thinking":enable_thinking}
    rendered_output = template.render(context)
    return rendered_output

In [12]:
def format_conversation(history, is_typing=False, current_response=""):
    html = "<div id='conversation' dir='rtl' lang='fa'>"
    for turn in history:
        if turn['role'] == 'user':
            html += f"<div class='message user'><b>یوزر:</b> {escape(turn['content']).replace(chr(10), '<br>')}</div>"
        elif turn['role'] == 'assistant':
            if enable_thinking and 'thinking' in turn:
                html += f"<div class='message assistant'><b>فکر کردن:</b> {escape(turn['thinking']).replace(chr(10), '<br>')}</div>"
            html += f"<div class='message assistant'><b>اسیستنت:</b> {escape(turn['content']).replace(chr(10), '<br>')}</div>"
    if current_response:
        html += f"<div class='message assistant'><b>اسیستنت:</b> {escape(current_response).replace(chr(10), '<br>')}</div>"
    if is_typing and not current_response:
        html += "<div class='typing'>اسیستنت در حال فکر کردن<span>.</span><span>.</span><span>.</span></div>"
    html += "</div>"
    return html

In [13]:
def simulate_streaming(response, delay=0.1):
    global stop_generation
    words = response.split()
    current_text = ""
    for word in words:
        if stop_generation:
            current_text += " [توقف شد]"
            break
        current_text += word + " "
        conversation_widget.value = format_conversation(history, current_response=current_text.strip())
        time.sleep(delay)
    return current_text.strip()

In [14]:
def on_submit(button):
    global stop_generation
    user_input = text_input.value.strip()
    if not user_input:
        return

    try:
        user_question = sanitize_input(user_input)
    except ValueError as e:
        conversation_widget.value = format_conversation(history, current_response=f"خطا: {e}")
        return

    text_input.value = ""
    history.append({"role": "user", "content": user_question})
    conversation_widget.value = format_conversation(history, is_typing=True)

    # greeting = handle_greeting(user_question)
    # if greeting:
    #     state, reply = greeting
    #     final_response = simulate_streaming(reply)
    #     history.append({'role': 'assistant', 'content': final_response})
    #     conversation_widget.value = format_conversation(history)
    #     return

    # meta_question = handle_meta_question(user_question)
    # if meta_question:
    #     state, reply = meta_question
    #     final_response = simulate_streaming(reply)
    #     history.append({'role': 'assistant', 'content': final_response})
    #     conversation_widget.value = format_conversation(history)
    #     return

    submit_button.disabled = True
    stop_button.layout.display = ''
    stop_generation = False

    prompt = build_prompt_templates(history)

    all_thinking = []
    tool_iteration = 0
    max_tool_iterations = 5

    while tool_iteration < max_tool_iterations:
        response = ""
        try:
            prompt_output.value = "<b>پرامپت نهایی:</b><br>" + escape(prompt).replace(chr(10), "<br>")
            tool_open_tag  = "<tool_call>"
            tool_close_tag = "</tool_call>"
            visible_response = ""
            tool_buffer      = ""
            in_tool_call     = False
            stream_buffer    = ""
            for completion in llm.create_completion(
                prompt=prompt,
                max_tokens=512,
                temperature=0.8,
                top_p=0.95,
                top_k=40,
                repeat_penalty=1.1,
                stream=True,
                min_p=0,
#                stop=["<|im_end|>", "\n"],  # Stop at end token or newline
                ):
                if stop_generation:
                    response += " [توقف شد]"
                    break
                token = completion["choices"][0]["text"]
                stream_buffer += token
                response += token
                if not in_tool_call:
                    idx = stream_buffer.find(tool_open_tag)
                    if idx == -1:
                        visible_response += stream_buffer
                        stream_buffer = ""
                        conversation_widget.value = format_conversation(history, current_response=visible_response)
                    else:
                        before_tag = stream_buffer[:idx]
                        after_tag  = stream_buffer[idx:]
                        visible_response += before_tag
                        conversation_widget.value = format_conversation(history, current_response=visible_response)
                        in_tool_call = True
                        tool_buffer  = after_tag
                        stream_buffer = ""
                else:
                    tool_buffer += stream_buffer
                    stream_buffer = ""
                    idx_close = tool_buffer.find(tool_close_tag)
                    if idx_close != -1:
                        full_block  = tool_buffer[: idx_close + len(tool_close_tag)]
                        remainder   = tool_buffer[idx_close + len(tool_close_tag) :]
                        json_text = re.sub(r"^<tool_call>\s*|\s*</tool_call>$", "", full_block)
                        visible_response += remainder
                        in_tool_call  = False
                        tool_buffer   = ""
                        conversation_widget.value = format_conversation(history, current_response=visible_response)
                    else:
                        conversation_widget.value = format_conversation(history, is_typing=True)
                        pass
            thinking, tool_calls, final_answer = parse_assistant_response(response)
            if thinking:
                all_thinking.append(thinking)
            if tool_calls:
                for tool_call in tool_calls:
                    print(tool_call)
                    tool_name = tool_call.get("name")
                    arguments = tool_call.get("arguments", {})
                    try:
                        if tool_name == "get_live_data":
                            tool_response = get_live_data()
                        elif tool_name == "get_any_data":
                            query = arguments.get("query", "")
                            tool_response = get_any_data(query)
                        else:
                            tool_response = "the tool was not correctly called"
                        history.append({'role': 'tool', 'content': f"{tool_response}"})
                    except Exception as e:
                        history.append({"role": "tool", "content": f"خطا در اجرای ابزار: {e}"})
                prompt = build_prompt_templates(history)
                tool_iteration += 1
            else:
                history.append({'role': 'assistant', 'content': final_answer, 'thinking': "\n".join(all_thinking)})
                conversation_widget.value = format_conversation(history)
                break
        except Exception as e:
            final_answer = f"خطا: {e}"
            history.append({'role': 'assistant', 'content': final_answer, 'thinking': "\n".join(all_thinking)})
            break
    else:
        final_answer = "حداکثر تعداد فراخوانی ابزار رسیده است."
        history.append({'role': 'assistant', 'content': final_answer, 'thinking': "\n".join(all_thinking)})
        conversation_widget.value = format_conversation(history)

    history[-1]["thinking"] = "\n".join(all_thinking) if all_thinking else ""
    history[-1]["assistant"] = final_answer
    conversation_widget.value = format_conversation(history)

 #   if not final_answer.startswith("خطا"):
 #       is_clean = audit_response(final_answer, context, llm)
 #   else:
 #       is_clean = False

#    log_interaction(user_question, context, final_answer, is_clean)
    submit_button.disabled = False
    stop_button.layout.display = 'none'
#    responce_output.value = "<b>پاسخ نهایی:</b><br>" + escape(final_answer).replace(chr(10), "<br>") if is_clean else "خطا: پاسخ ممکن است نادرست باشد."

In [15]:
def on_clear(button):
    first_element = history[0]
    history.clear()
    history.append(first_element)
    conversation_widget.value = format_conversation(history)
    retrieved_context.value = ""
    prompt_output.value = ""
    responce_output.value = ""

def on_stop(button):
    global stop_generation
    stop_generation = True

def on_unload_model(button):
    global llm , prompt_template_key
    llm = None
    prompt_template_key = None
    model_manager.unload_model()

conversation_widget = widgets.HTML(value="", layout=widgets.Layout(width='99%', height='300px', overflow='auto', padding='20px'))
retrieved_context = widgets.HTML(value="", layout=widgets.Layout(width='99%', height='200px', overflow='auto', border='2px solid #ccc', padding='20px'))
prompt_output = widgets.HTML(value="", layout=widgets.Layout(width='99%', height='200px', overflow='auto', border='2px solid #ccc', padding='20px'))
responce_output = widgets.HTML(value="", layout=widgets.Layout(width='99%', height='100px', overflow='auto', border='2px solid #ccc', padding='20px'))

text_input = widgets.Text(value='', placeholder='سوال خود را اینجا تایپ کنید', layout=widgets.Layout(width='70%', direction='rtl'))
submit_button = widgets.Button(description="ارسال", button_style='success')
clear_button = widgets.Button(description="پاک کردن تاریخچه", button_style='warning')
stop_button = widgets.Button(description="توقف", button_style='danger', layout=widgets.Layout(display='none'))
unload_button = widgets.Button(description="X", button_style='danger', layout=widgets.Layout(width='40px'))
thinking_toggle = widgets.ToggleButton(value=True, description='نمایش فکر کردن', button_style='info')

submit_button.on_click(on_submit)
clear_button.on_click(on_clear)
stop_button.on_click(on_stop)
unload_button.on_click(on_unload_model)

def on_thinking_toggle(change):
    global enable_thinking
    enable_thinking = change['new']
    conversation_widget.value = format_conversation(history)

thinking_toggle.observe(on_thinking_toggle, names='value')

interface = widgets.VBox([
    conversation_widget,
    widgets.HBox([unload_button, model_dropdown, text_input, submit_button, clear_button, stop_button, thinking_toggle], layout=widgets.Layout(justify_content='flex-start', direction='rtl', width='99%')),
    retrieved_context,
    prompt_output,
    responce_output
])
display(interface)

if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print(f"حافظه GPU پاک شد: {torch.cuda.memory_allocated(0)/1e6:.2f} MB")

VBox(children=(HTML(value='', layout=Layout(height='300px', overflow='auto', padding='20px', width='99%')), HB…

حافظه GPU پاک شد: 0.00 MB
