##### 1. Model

In [None]:
import torch
from unsloth import FastLanguageModel

max_seq_length = 2048  # Choose any! We auto support RoPE Scaling internally!
dtype = None           # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True    # Use 4bit quantization to reduce memory usage. Can be False.

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "../lora_model",  # YOUR MODEL YOU USED FOR TRAINING
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
)
FastLanguageModel.for_inference(model)

input_ids = tokenizer("What causes Alstrom syndrome?", return_tensors="pt").input_ids
outputs = model.generate(input_ids, max_new_tokens=50)
print(tokenizer.decode(outputs[0]))

from llama_index.llms.huggingface import HuggingFaceLLM
llm = HuggingFaceLLM(
    context_window=4096,
    max_new_tokens=256,
    generate_kwargs={"temperature": 0.7, "do_sample": False},
    device_map="auto",
    stopping_ids=[50278, 50279, 50277, 1, 0],
    tokenizer_kwargs={"max_length": 4096},
    model_kwargs={"torch_dtype": torch.float16},
    model=model,
    tokenizer=tokenizer,
)

##### 2. Dataset

In [None]:
from datasets import load_dataset
document = load_dataset("xDAN-datasets/medical_meadow_wikidoc_patient_information_6k", split="train")
document.to_csv("./dataset/rag_data.csv")

from llama_index.core import SimpleDirectoryReader
documents = SimpleDirectoryReader("./dataset").load_data()

##### 3. Embedding & Settings

In [3]:
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")

from llama_index.core import Settings
Settings.llm = llm
Settings.chunk_size = 1024
Settings.embed_model = embed_model

from llama_index.core.node_parser import SentenceSplitter
Settings.transformations = [SentenceSplitter(chunk_size=1024)]

##### 4. Index with ChromaDB

In [None]:
from llama_index.core import VectorStoreIndex, StorageContext
import chromadb
from llama_index.vector_stores.chroma import ChromaVectorStore

# Tạo ChromaVectorStore, chỉ rõ nơi lưu (persist_directory) và tên collection
chroma_client = chromadb.PersistentClient(path='./chroma_db')
product_collection = chroma_client.get_or_create_collection("product_store")
chitchat_collection = chroma_client.get_or_create_collection("chitchat_store")

# Set up the ChromaVectorStore and StorageContext
product_store = ChromaVectorStore(
    persist_dir="./chromadb/chroma_db_product",
    collection_name="product_store",
    chroma_collection=product_collection,
)

chitchat_store = ChromaVectorStore(
    persist_dir="./chromadb/chroma_db_chitchat",
    collection_name="chitchat_store",
    chroma_collection=chitchat_collection,
)

product_storage_context = StorageContext.from_defaults(vector_store=product_store)
chitchat_storage_context = StorageContext.from_defaults(vector_store=chitchat_store)

# Tạo Index, sử dụng ChromaVectorStore
# (Ví dụ này, tất cả documents đều được index chung)
product_index = VectorStoreIndex.from_documents(
    documents,
    embed_model=embed_model,
    transformations=Settings.transformations,
    vector_store=product_store,
    storage_context=product_storage_context
)

from llama_index.core import Document
chitchat_docs = [
    Document(content="Hello, how are you today?"),
    Document(content="What do you think about the weather?"),
    Document(content="Hey, have you watched any good movies lately?"),
    Document(content="What's your favorite hobby?"),
    Document(content="How's your day going?"),
    # Thêm nhiều tài liệu chitchat...
]
chitchat_index = VectorStoreIndex.from_documents(
    chitchat_docs,
    embed_model=embed_model,
    transformations=Settings.transformations,
    vector_store=chitchat_store,
    storage_context=chitchat_storage_context
)

##### 7. Semantic Router Selector

In [None]:
from llama_index.core.tools import ToolMetadata
from llama_index.core.selectors import LLMSingleSelector
from llama_index.core.query_engine.router_query_engine import RouterQueryEngine
from llama_index.core.selectors import LLMSingleSelector
from llama_index.core.tools import QueryEngineTool
from llama_index.core.indices.query.query_transform import HyDEQueryTransform
from llama_index.core.query_engine import TransformQueryEngine
from IPython.display import display, Markdown

# choices as a list of tool metadata
choices = [
    ToolMetadata(description="Return medical patient information related to the query. ", name="product"),
    ToolMetadata(description="Chitchat", name="chitchat"),
]

# # choices as a list of strings
# choices = [
#     "choice 1 - description for choice 1",
#     "choice 2: description for choice 2",
# ]
product_query_engine = product_index.as_query_engine()
chitchat_query_engine = chitchat_index.as_query_engine()
product_tool = QueryEngineTool.from_defaults(
    query_engine=product_query_engine,
    description=(
        "Return medical patient information related to the query. "
    ),
)

chitchat_tool = QueryEngineTool.from_defaults(
    query_engine=chitchat_query_engine,
    description=(
        "Return chitchat responses related to the query. "
    ),
)

query_engine_tools = [product_tool, chitchat_tool]

selector = LLMSingleSelector.from_defaults()
_metadatas = [x.metadata for x in query_engine_tools]
selector_result = selector.select(
    _metadatas, query="What causes Alstrom syndrome?"
)
print(selector_result.selections)

##### 6. Semantic Router Query Engine

In [None]:
from llama_index.core.query_engine.router_query_engine import RouterQueryEngine
from llama_index.core.selectors import LLMSingleSelector
from llama_index.core.tools import QueryEngineTool
from llama_index.core.indices.query.query_transform import HyDEQueryTransform
from llama_index.core.query_engine import TransformQueryEngine
from IPython.display import display, Markdown


product_query_engine = product_index.as_query_engine()
chitchat_query_engine = chitchat_index.as_query_engine()

product_tool = QueryEngineTool.from_defaults(
    query_engine=product_query_engine,
    description=(
        "Return medical patient information related to the query. "
    ),
)

chitchat_tool = QueryEngineTool.from_defaults(
    query_engine=chitchat_query_engine,
    description=(
        "Return chitchat responses related to the query. "
    ),
)

query_engine = RouterQueryEngine(
    selector=LLMSingleSelector.from_defaults(),
    query_engine_tools=[
        product_tool,
        chitchat_tool,
    ],
)

# response = query_engine.query("Hello, how are you today?")

hyde = HyDEQueryTransform(include_original=True)
product_hyde_query = TransformQueryEngine(query_engine, hyde)
response = product_hyde_query.query("What causes Alstrom syndrome?")
display(Markdown(f"<b>{response}</b>"))


# list_tool = QueryEngineTool.from_defaults(
#     query_engine=product_index,
#     description="Useful for summarization questions related to the data source",
# )
# vector_tool = QueryEngineTool.from_defaults(
#     query_engine=chitchat_index,
#     description="Useful for retrieving specific context related to the data source",
# )

# # initialize router query engine (single selection, pydantic)
# query_engine = RouterQueryEngine(
#     selector=LLMSingleSelector.from_defaults(),
#     query_engine_tools=[
#         list_tool,
#         vector_tool,
#     ],
# )

# hyde_query_engine = TransformQueryEngine(
#     query_engine=query_engine,
#     query_transform=hyde,
# )
# hyde_query_engine.query("What causes Alstrom syndrome?")

##### 6. RAG (Retriever + LLM)

In [None]:
from llama_index.core import get_response_synthesizer
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.retrievers import VectorIndexRetriever

vector_retriever = VectorIndexRetriever(index=product_index, similarity_top_k=2)
response_synthesizer = get_response_synthesizer()

vector_query_engine = RetrieverQueryEngine(
    retriever=vector_retriever,
    response_synthesizer=response_synthesizer,
)

query_str = "What causes Alstrom syndrome?"
hyde_query_engine = TransformQueryEngine(vector_query_engine, hyde)
response = hyde_query_engine.query(query_str)
display(Markdown(f"<b>{response}</b>"))

##### 7. Semantic Splliter

In [None]:
# from llama_index.core.node_parser import (
#     SentenceSplitter,
#     SemanticSplitterNodeParser,
# )

# splitter = SemanticSplitterNodeParser(
#     buffer_size=1, breakpoint_percentile_threshold=95, embed_model=embed_model
# )

# # also baseline splitter
# base_splitter = SentenceSplitter(chunk_size=512)
# nodes = splitter.get_nodes_from_documents(documents)
# print(f"Number of nodes: {len(nodes)}")
# print(nodes[1].get_content())

# base_nodes = base_splitter.get_nodes_from_documents(documents)

# from llama_index.core import VectorStoreIndex
# from llama_index.core.response.notebook_utils import display_source_node

# vector_index = VectorStoreIndex(nodes)
# query_engine = vector_index.as_query_engine()

# base_vector_index = VectorStoreIndex(base_nodes)
# base_query_engine = base_vector_index.as_query_engine()