In [None]:
import os
import torch
import faiss
import logging
from transformers import AutoTokenizer, AutoModel
from sentence_transformers import SentenceTransformer
from RAGLibrary import Widgets
from RAGLibrary import CheckConstruct, CreateSchema, FaissConvert, Embedding, Search, Rerank, Respond

In [None]:
widgets_list = Widgets.create_name_form()

In [None]:
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
force_download = True

In [None]:
""" DEFINE """

data   = widgets_list[0] #HBox 1
keys   = widgets_list[1] #HBox 2
choose = widgets_list[2] #HBox 3

embedd_model = widgets_list[3]
search_egine = widgets_list[4]
rerank_model = widgets_list[5]
respon_model = widgets_list[6]
API_drop     = widgets_list[7]
button_box   = widgets_list[8]

# HBox 1
file_name = data.children[0]
file_type = data.children[1]

# HBox 2
data_key = keys.children[0]
embe_key = keys.children[1]

# HBox 3
switch_model = choose.children[0]
merge_otp    = choose.children[1]
path_end_val = choose.children[1]

# Get value
data_folder   = file_name.value
file_type_val = file_type.value

data_key_val  = data_key.value
embe_key_val  = embe_key.value

API_key_val = API_drop.value
switch      = switch_model.value
merge       = merge_otp.value
path_end    = path_end_val.value

embedding_model = embedd_model.value
searching_egine = search_egine.value
reranking_model = rerank_model.value
responing_model = respon_model.value


# Define
base_path = f"../Data/{data_folder}/{file_type_val}_{data_folder}"

json_file_path = f"{base_path}_Database.json"
schema_ex_path = f"{base_path}_Schema.json"
embedding_path = f"{base_path}_Embeds_{merge}"

torch_path  = f"{embedding_path}.pt"
faiss_path  = f"{embedding_path}.faiss"
mapping_path = f"{embedding_path}_mapping.json"
mapping_data = f"{embedding_path}_map_data.json"

FILE_TYPE    = file_type_val
DATA_KEY     = data_key_val
EMBE_KEY     = embe_key_val
SWITCH       = switch
EMBEDD_MODEL = embedding_model
SEARCH_EGINE = searching_egine
RERANK_MODEL = reranking_model
RESPON_MODEL = responing_model

if FILE_TYPE == "Data":
    MERGE = merge
else: 
    MERGE = "no_Merge"

API_KEY = API_key_val

SEARCH_ENGINE = faiss.IndexFlatIP

print("\n")
print(f"Embedder: {EMBEDD_MODEL}")
print(f"Searcher: {SEARCH_EGINE}")
print(f"Reranker: {RERANK_MODEL}")
print(f"Responer: {RESPON_MODEL}")
print(f"Data Key: {DATA_KEY}")
print(f"Embe Key: {EMBE_KEY}")
print(f"Database: {json_file_path}")
print(f"Torch   : {torch_path}")
print(f"Faiss   : {faiss_path}")
print(f"Mapping : {mapping_path}")
print(f"Map Data: {mapping_data}")
print(f"Schema  : {schema_ex_path}")
print(f"Model   : {SWITCH}")
print(f"Merge   : {MERGE}")
print(f"API Key : {API_KEY}")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if (SWITCH == "Auto Model"):
    try:
        tokenizer = AutoTokenizer.from_pretrained(EMBEDD_MODEL, force_download=force_download)
        model = AutoModel.from_pretrained(EMBEDD_MODEL, force_download=force_download)
        model = model.to(device)
        print("Model and tokenizer loaded successfully")
    except Exception as e:
        raise
elif (SWITCH == "Sentence Transformer"):
    try:
        # model = SentenceTransformer(EMBEDD_MODEL).to(device)
        model = SentenceTransformer("../../cached_model")
        print("SentenceTransformer loaded successfully")
    except Exception as e:
        raise

print(f"Using: {device}")

In [None]:
if os.path.exists(json_file_path):
    if not os.path.exists(schema_ex_path):
        CreateSchema.create_schema(json_file_path, schema_ex_path)
    else:
        print(f"{schema_ex_path} alredy existed")
else:
    print(f"{json_file_path} does not exist")

In [None]:
if os.path.exists(json_file_path):
    if not os.path.exists(torch_path):
        Embedding.json_embeddings(MERGE, json_file_path, torch_path, schema_ex_path, model, device, DATA_KEY, EMBE_KEY, batches = False)
    else: 
        print(f"{torch_path} alredy existed")
else:
    print(f"{json_file_path} does not exist")

In [None]:
if os.path.exists(torch_path):
    CheckConstruct.print_json(DATA_KEY, torch_path)

In [None]:
if os.path.exists(torch_path):
    if not os.path.exists(faiss_path):
        FaissConvert.convert_pt_to_faiss(torch_path, faiss_path, mapping_path, mapping_data, DATA_KEY, nlist = 100, use_pickle = False)
    else: 
        print(f"{faiss_path} alredy existed")
else:
    print(f"{torch_path} does not exist")

In [None]:
""" MAIN """

with open(f"Prompts/Docs_Prompt.txt", "r", encoding="utf-8") as file1:
    docs_prompt = file1.read()

with open(f"Prompts/Docs_Prompt.txt", "r", encoding="utf-8") as file2:
    natr_prompt = file2.read()

print("<< Enter 'exit', 'quit', 'escape', 'bye' or Press ESC to exit >>")
print("Chatbot: Hello there! I'm here to help you!")

user_input = "Quy định về đào tạo đại học tại trường Thủ đô Hà Nội"

while True:
    try:
        # user_input = input("You: ")
        user_question = Embedding.preprocess_text(user_input)
        if user_input.strip().lower() in ["exit", "quit", "escape", "bye", ""]:
            print("Chatbot: Goodbye!")
            break

        print(f"Query: {user_question}")

        #Bước 1: Search
        preliminary_results = Search.search_faiss_index(
            query= user_question,
            embedd_model=EMBEDD_MODEL,
            faiss_path=faiss_path,
            mapping_path=mapping_path,
            mapping_data=mapping_data,
            device=device,
            k=10,
            batches = False,
        )
        print(preliminary_results)

        # Bước 2: Rerank
        reranked_results = Rerank.rerank_results(
            query= user_question,
            results=preliminary_results,
            reranker_model=RERANK_MODEL,
            device=device,
            k=5,
            batches = False,
        )
        print(reranked_results)

        if (reranked_results):
            system_prompt = docs_prompt
            doc = True
        else:
            system_prompt = natr_prompt
            doc = False

        # Bước 3: Generate Response
        response, filtered_results = Respond.respond_naturally(
            user_question = user_question,
            results=reranked_results,
            system_prompt = system_prompt,
            responser_model=RESPON_MODEL,
            score_threshold=0.85,
            max_results=3,
            doc = doc,
            gemini_api_key=API_KEY,
        )

        print(f"\nYou: {user_question}")
        print(f"Chatbot: {response}")
        user_input = "exit"

    except KeyboardInterrupt:
        print("\nChatbot: Goodbye!")
        break