import json
import torch
from langchain.docstore.document import Document
from langchain.vectorstores import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from langchain_huggingface import HuggingFacePipeline
from langchain.chains import RetrievalQA

# Step 1: Load Your JSON File
with open("rag_format.json", "r") as file:
    data = json.load(file)

# Step 2: Convert JSON Data to LangChain Document Objects
documents = [
    Document(
        page_content=entry["text"],
        metadata={
            "id": entry["id"],
            "article": entry["article"],
            "clause": entry["clause"],
            "title": entry["title"]
        }
    )
    for entry in data
]

# Step 3: Initialize HuggingFace Embeddings
embeddings_model = HuggingFaceEmbeddings()

# Step 4: Create Chroma Vector Store
vectorstore = Chroma.from_documents(
    documents=documents,
    embedding=embeddings_model,
    persist_directory="./chroma.db"
)

# Step 5: Load Qwen Model
model_id = "Qwen/Qwen2.5-0.5B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)

# Step 6: Set Device for GPU/CPU
if torch.cuda.is_available():
    gpu_count = torch.cuda.device_count()
    if gpu_count > 1:
        device = 0  # Or handle multi-GPU with Accelerate/DataParallel
    else:
        device = 0  # Use the first GPU
else:
    device = -1  # Use CPU

# Step 7: Create a Text-Generation Pipeline with GPU/CPU and Handle Tokenization Warning
pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=100,
    device=device,
    clean_up_tokenization_spaces=False
)

# Wrap the pipeline for LangChain
hf = HuggingFacePipeline(pipeline=pipe)

# Step 8: Create a Retrieval-Based QA System
qa_chain = RetrievalQA.from_chain_type(
    llm=hf,
    retriever=vectorstore.as_retriever(),
    return_source_documents=True
)

# Test query

In [None]:
# Step 9: Test the RAG System
query = "Dược chất là gì theo luật dược Việt Nam?"
result = qa_chain({"query": query})

# Print the Result
print("Answer:", result["result"])
print("Source Documents:")
for doc in result["source_documents"]:
    print(f"Metadata: {doc.metadata}")
    print(f"Content: {doc.page_content}\n")