In [1]:
## -----------------------------
## 1. MODEL
## -----------------------------
import torch
from unsloth import FastLanguageModel

max_seq_length = 2048  # Choose any! We auto support RoPE Scaling internally!
dtype = None           # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True    # Use 4bit quantization to reduce memory usage. Can be False.

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "../lora_model",  # YOUR MODEL YOU USED FOR TRAINING
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
)
FastLanguageModel.for_inference(model)

from llama_index.llms.huggingface import HuggingFaceLLM
llm = HuggingFaceLLM(
    context_window=4096,
    max_new_tokens=256,
    generate_kwargs={"temperature": 0.7, "do_sample": False},
    device_map="auto",
    stopping_ids=[50278, 50279, 50277, 1, 0],
    tokenizer_kwargs={"max_length": 4096},
    model_kwargs={"torch_dtype": torch.float16},
    model=model,
    tokenizer=tokenizer,
)


## -----------------------------
## 2. DATASET
## -----------------------------
from datasets import load_dataset
document = load_dataset("xDAN-datasets/medical_meadow_wikidoc_patient_information_6k", split="train")
document.to_csv("./dataset/rag_data.csv")

from llama_index.core import SimpleDirectoryReader
documents = SimpleDirectoryReader("./dataset").load_data()

## -----------------------------
## 3. EMBEDDINGS & SETTINGS
## -----------------------------
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")

from llama_index.core import Settings
Settings.llm = llm
Settings.chunk_size = 1024
Settings.embed_model = embed_model

from llama_index.core.node_parser import SentenceSplitter
Settings.transformations = [SentenceSplitter(chunk_size=1024)]

## -----------------------------
## 4. TẠO INDEX VỚI CHROMADB
## -----------------------------
from llama_index.core import VectorStoreIndex
import chromadb
from llama_index.vector_stores.chroma import ChromaVectorStore

# Tạo ChromaVectorStore, chỉ rõ nơi lưu (persist_directory) và tên collection
chroma_client = chromadb.EphemeralClient()
chroma_collection = chroma_client.create_collection("example_collection")

# Set up the ChromaVectorStore and StorageContext
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)

# Tạo Index, sử dụng ChromaVectorStore
# (Ví dụ này, tất cả documents đều được index chung)
index = VectorStoreIndex.from_documents(
    documents,
    embed_model=embed_model,
    transformations=Settings.transformations,
    vector_store=vector_store
)

## -----------------------------
## 5. TẠO QUERY ENGINE CƠ BẢN
## -----------------------------
from IPython.display import Markdown, display
from llama_index.core.indices.query.query_transform import HyDEQueryTransform
from llama_index.core.query_engine import TransformQueryEngine

query_str = "What causes Alstrom syndrome?"
query_engine = index.as_query_engine()
response = query_engine.query(query_str)
display(Markdown(f"<b>{response}</b>"))

hyde = HyDEQueryTransform(include_original=True)
hyde_query_engine = TransformQueryEngine(query_engine, hyde)
response = hyde_query_engine.query(query_str)
display(Markdown(f"<b>{response}</b>"))

## -----------------------------
## 6. RAG (Retriever + LLM)
## -----------------------------
from llama_index.core import get_response_synthesizer
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.retrievers import VectorIndexRetriever

vector_retriever = VectorIndexRetriever(index=index, similarity_top_k=2)
response_synthesizer = get_response_synthesizer()

vector_query_engine = RetrieverQueryEngine(
    retriever=vector_retriever,
    response_synthesizer=response_synthesizer,
)

query_str = "What causes Alstrom syndrome?"
hyde_query_engine = TransformQueryEngine(vector_query_engine, hyde)
response = hyde_query_engine.query(query_str)
display(Markdown(f"<b>{response}</b>"))

## -----------------------------
## 7. Semantic Splliter
## -----------------------------
# from llama_index.core.node_parser import (
#     SentenceSplitter,
#     SemanticSplitterNodeParser,
# )

# splitter = SemanticSplitterNodeParser(
#     buffer_size=1, breakpoint_percentile_threshold=95, embed_model=embed_model
# )

# # also baseline splitter
# base_splitter = SentenceSplitter(chunk_size=512)
# nodes = splitter.get_nodes_from_documents(documents)
# print(f"Number of nodes: {len(nodes)}")
# print(nodes[1].get_content())

# base_nodes = base_splitter.get_nodes_from_documents(documents)

# from llama_index.core import VectorStoreIndex
# from llama_index.core.response.notebook_utils import display_source_node

# vector_index = VectorStoreIndex(nodes)
# query_engine = vector_index.as_query_engine()

# base_vector_index = VectorStoreIndex(base_nodes)
# base_query_engine = base_vector_index.as_query_engine()

## -----------------------------
## 8. Semantic Router
## -----------------------------
from llama_index.core.query_engine.router_query_engine import RouterQueryEngine
from llama_index.core import Document

product_docs = [
    Document(content="Our product A has a battery life of 12 hours."),
    Document(content="Product B is made of aluminum, very lightweight."),
    # ...
]

chitchat_docs = [
    Document(content="Hello, how are you today?"),
    Document(content="What do you think about the weather?"),
    # ...
]

from llama_index.core import VectorStoreIndex
# Giả sử bạn đã có embed_model

product_index = VectorStoreIndex.from_documents(product_docs, embed_model=embed_model)
chitchat_index = VectorStoreIndex.from_documents(chitchat_docs, embed_model=embed_model)
product_query_engine = product_index.as_query_engine()
chitchat_query_engine = chitchat_index.as_query_engine()

product_tool = product_index.as_query_engine()
response = product_tool.query("What is the battery life of product A?")
print(response)
chitchat_tool = chitchat_index.as_query_engine()
response = chitchat_tool.query("How are you today?")
print(response)

  from .autonotebook import tqdm as notebook_tqdm


🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
==((====))==  Unsloth: Fast Llama patching release 2024.5
   \\   /|    GPU: NVIDIA GeForce RTX 3070 Ti. Max memory: 7.779 GB. Platform = Linux.
O^O/ \_/ \    Pytorch: 2.3.0. CUDA = 8.6. CUDA Toolkit = 12.1.
\        /    Bfloat16 = TRUE. Xformers = 0.0.26.post1. FA = True.
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth


Unsloth 2024.5 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.



Using the latest cached version of the dataset since xDAN-datasets/medical_meadow_wikidoc_patient_information_6k couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'default' at /home/golde/.cache/huggingface/datasets/xDAN-datasets___medical_meadow_wikidoc_patient_information_6k/default/0.0.0/e5fb4f4032e8d812a3d14d6dd886f530eb42a766 (last modified on Fri Sep 20 13:44:26 2024).
Creating CSV from Arrow format: 100%|██████████| 6/6 [00:00<00:00, 37.15ba/s]


<b>
The cause of Alstrom syndrome is unknown. It is thought to be caused by a mutation in the ALMS1 gene. This gene provides instructions for making a protein that is involved in the development and function of the nervous system. Mutations in this gene can lead to the development of Alstrom syndrome.</b>

<b>
The cause of Alstrom syndrome is unknown. It is thought to be caused by a mutation in the ALMS1 gene. This gene provides instructions for making a protein that is involved in the development and function of the nervous system. Mutations in this gene can lead to the development of Alstrom syndrome.</b>

<b>
The cause of Alstrom syndrome is unknown. It is thought to be caused by a mutation in the ALMS1 gene. This gene provides instructions for making a protein that is involved in the development and function of the nervous system. Mutations in this gene can lead to the development of Alstrom syndrome.</b>

Some nodes are missing content, skipping them...
Some nodes are missing content, skipping them...
Empty Response
Empty Response
