In [None]:
## -----------------------------
## 1. MODEL
## -----------------------------
import torch
from unsloth import FastLanguageModel

max_seq_length = 2048  # Choose any! We auto support RoPE Scaling internally!
dtype = None           # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True    # Use 4bit quantization to reduce memory usage. Can be False.

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "../lora_model",  # YOUR MODEL YOU USED FOR TRAINING
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
)
FastLanguageModel.for_inference(model)

from llama_index.llms.huggingface import HuggingFaceLLM
llm = HuggingFaceLLM(
    context_window=4096,
    max_new_tokens=256,
    generate_kwargs={"temperature": 0.7, "do_sample": False},
    device_map="auto",
    stopping_ids=[50278, 50279, 50277, 1, 0],
    tokenizer_kwargs={"max_length": 4096},
    model_kwargs={"torch_dtype": torch.float16},
    model=model,
    tokenizer=tokenizer,
)

## -----------------------------
## 2. DATASET
## -----------------------------
from datasets import load_dataset
document = load_dataset("xDAN-datasets/medical_meadow_wikidoc_patient_information_6k", split="train")
document.to_csv("./dataset/rag_data.csv")

from llama_index.core import SimpleDirectoryReader
documents = SimpleDirectoryReader("./dataset").load_data()

## -----------------------------
## 3. EMBEDDINGS & SETTINGS
## -----------------------------
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")

from llama_index.core import Settings
Settings.llm = llm
Settings.chunk_size = 1024
Settings.embed_model = embed_model

from llama_index.core.node_parser import SentenceSplitter
Settings.transformations = [SentenceSplitter(chunk_size=1024)]

## -----------------------------
## 4. TẠO INDEX VỚI CHROMADB
## -----------------------------
from llama_index.core import VectorStoreIndex
from llama_index.vector_stores import ChromaVectorStore

# Tạo ChromaVectorStore, chỉ rõ nơi lưu (persist_directory) và tên collection
chroma_store = ChromaVectorStore(
    persist_directory="./chroma_db",
    collection_name="medical_collection"
)

# Tạo Index, sử dụng ChromaVectorStore
# (Ví dụ này, tất cả documents đều được index chung)
index = VectorStoreIndex.from_documents(
    documents,
    embed_model=embed_model,
    transformations=Settings.transformations,
    vector_store=chroma_store
)

## -----------------------------
## 5. TẠO QUERY ENGINE CƠ BẢN
## -----------------------------
from IPython.display import Markdown, display
from llama_index.core.indices.query.query_transform import HyDEQueryTransform
from llama_index.core.query_engine import TransformQueryEngine

query_str = "What causes Alstrom syndrome?"
query_engine = index.as_query_engine()
response = query_engine.query(query_str)
display(Markdown(f"<b>{response}</b>"))

hyde = HyDEQueryTransform(include_original=True)
hyde_query_engine = TransformQueryEngine(query_engine, hyde)
response = hyde_query_engine.query(query_str)
display(Markdown(f"<b>{response}</b>"))

## -----------------------------
## 6. RAG (Retriever + LLM)
## -----------------------------
from llama_index.core import get_response_synthesizer
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.retrievers import VectorIndexRetriever

vector_retriever = VectorIndexRetriever(index=index, similarity_top_k=2)
response_synthesizer = get_response_synthesizer()

vector_query_engine = RetrieverQueryEngine(
    retriever=vector_retriever,
    response_synthesizer=response_synthesizer,
)

query_str = "What causes Alstrom syndrome?"
hyde_query_engine = TransformQueryEngine(vector_query_engine, hyde)
response = hyde_query_engine.query(query_str)
display(Markdown(f"<b>{response}</b>"))

## -----------------------------
## 7. THÊM SEMANTIC ROUTER
## -----------------------------
"""
  Mục đích: Tạo một router “semantic” để phân loại câu hỏi vào 2 route:
    1) "product" (câu hỏi về sản phẩm hay domain cụ thể)
    2) "chitchat" (trò chuyện thường)
  
  Thông thường, ta sẽ có 2 index khác nhau (product_index và chitchat_index).
  Trong code demo này, mình dùng chung 'index' để minh họa. 
  Thực tế bạn nên tách tài liệu nào liên quan “product” sang 1 index, 
  còn tài liệu “chitchat” (hay general) vào 1 index khác, rồi router quyết định route.
"""
# Giả sử ta có "product_index" & "chitchat_index". 
# Ở đây demo cùng index => "product_index = index", "chitchat_index = index"

product_index = index  # thay bằng index chứa tài liệu product
chitchat_index = index # thay bằng index chứa tài liệu chitchat

product_query_engine = product_index.as_query_engine()
chitchat_query_engine = chitchat_index.as_query_engine()

##
## 7.1 Tạo "tools" cho 2 route: product_tool, chitchat_tool
##
from llama_index.tools.query_engine import QueryEngineTool
product_tool = QueryEngineTool(
    query_engine=product_query_engine,
    name="product_tool",
    description="Use this tool for answering questions about PRODUCT domain."
)
chitchat_tool = QueryEngineTool(
    query_engine=chitchat_query_engine,
    name="chitchat_tool",
    description="Use this tool for normal chitchat or everyday topics."
)

tools = [product_tool, chitchat_tool]

##
## 7.2 Tạo RouterQueryEngine
##
# RouterQueryEngine sẽ dùng LLM để “route” câu hỏi đến tool thích hợp.
# Nó tự sinh prompt: “Which tool nên được dùng cho query này?” 
#
# Tùy phiên bản, import có thể thay đổi:
from llama_index.query_engine.router_query_engine import RouterQueryEngine

router_query_engine = RouterQueryEngine.from_tools(
    tools=tools,
    llm=llm,  # LLM sử dụng để phân tích và route
    default_tool=chitchat_tool  # route mặc định nếu model không chắc
)

##
## 7.3 Query qua router
##
print("=== Demo: Query về product ===")
product_question = "Give me specifications of the new medical device"
router_response_1 = router_query_engine.query(product_question)
display(Markdown(f"<b>Product Route Q:</b> {product_question}"))
display(Markdown(f"<b>Answer:</b> {router_response_1}"))

print("\n=== Demo: Query chitchat ===")
chitchat_question = "How is the weather today?"
router_response_2 = router_query_engine.query(chitchat_question)
display(Markdown(f"<b>Chitchat Route Q:</b> {chitchat_question}"))
display(Markdown(f"<b>Answer:</b> {router_response_2}"))
