In [None]:
import os
import torch
import faiss
import logging
from transformers import AutoTokenizer, AutoModel
from sentence_transformers import SentenceTransformer
from Libraries import A0_MyUtils as A0, A1_TextProcess as A1, A2_PdfProcess as A2
from Libraries import B1_ExtractData as B1, B2_MergeData as B2, B3_GetStructures as B3
from Libraries import B4_ChunkMaster as B4, B5_ChunkFlex as B5, B6_ChunkFixed as B6
from Libraries import C1_CreateSchema as C1, C2_Embedding as C2, C3_CheckConstruct as C3
from Libraries import D0_FaissConvert as D0, D1_Search as D1, D2_Rerank as D2, D3_Respond as D3
from Config import Widgets, Configs

In [None]:
widgets_list = Widgets.create_name_form()

In [None]:
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["TORCH_USE_CUDA_DSA"] = "1"
force_download = True

In [None]:
config = Configs.WidgetValues(widgets_list)

data_foler = config["data_folder"]
dcmt_path = config["dcmt_path"]
base_folder = config["base_folder"]
base_path = config["base_path"]
extracted_path = config["extracted_path"]
merged_path = config["merged_path"]
struct_path = config["struct_path"]
chunks_base = config["chunks_base"]
chunks_segment = config["chunks_segment"]
schema_ex_path = config["schema_ex_path"]
embedding_path = config["embedding_path"]
torch_path = config["torch_path"]
faiss_path = config["faiss_path"]
mapping_path = config["mapping_path"]
mapping_data = config["mapping_data"]

FILE_TYPE = config["FILE_TYPE"]
DATA_KEY = config["DATA_KEY"]
EMBE_KEY = config["EMBE_KEY"]
SWITCH = config["SWITCH"]
EMBEDD_MODEL = config["EMBEDD_MODEL"]
SEARCH_EGINE = config["SEARCH_EGINE"]
RERANK_MODEL = config["RERANK_MODEL"]
RESPON_MODEL = config["RESPON_MODEL"]
MERGE = config["MERGE"]
API_KEY = config["API_KEY"]

WORD_LIMIT = config["WORD_LIMIT"]
LEVEL_INPUT = config["LEVEL_INPUT"]
LEVEL_VALUES = config["LEVEL_VALUES"]

Contents = LEVEL_VALUES[-1] if LEVEL_VALUES else None

SEARCH_ENGINE = faiss.IndexFlatIP

## Prepare

### Device

In [None]:
print("CUDA supported:", torch.cuda.is_available())
print("Number of GPUs:", torch.cuda.device_count())
if torch.cuda.is_available():
    print("Current GPU name:", torch.cuda.get_device_name(0))
    print("CUDA device capability:", torch.cuda.get_device_capability(0))
    print("CUDA version (PyTorch):", torch.version.cuda)
    print("cuDNN version:", torch.backends.cudnn.version())
else:
    print("CUDA not available.")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cached_path = "../Models"

def load_auto_model(path_or_name, device):
    try:
        tokenizer = AutoTokenizer.from_pretrained(path_or_name)
        model = AutoModel.from_pretrained(path_or_name).to(device)
        return tokenizer, model
    except (OSError, FileNotFoundError) as e:
        print("❌ Model files missing:", e)
        return None, None
    except RuntimeError as e:
        print("⚠️ GPU issue, fallback to CPU:", e)
        tokenizer = AutoTokenizer.from_pretrained(path_or_name)
        model = AutoModel.from_pretrained(path_or_name).to("cpu")
        return tokenizer, model
    except Exception as e:
        print("❌ Unexpected error:", e)
        raise

def load_sentence_model(path_or_name, device):
    try:
        return SentenceTransformer(path_or_name, device=str(device))
    except (OSError, FileNotFoundError) as e:
        print("❌ Model files missing:", e)
        return None
    except RuntimeError as e:
        print("⚠️ GPU issue, fallback to CPU:", e)
        return SentenceTransformer(path_or_name, device="cpu")
    except Exception as e:
        print("❌ Unexpected error:", e)
        raise

# Select Model
if SWITCH == "Auto Model":
    if os.path.exists(cached_path):
        tokenizer, model = load_auto_model(cached_path, device)
        print(f"ℹ️ Auto Model: {cached_path}")
        if model is None:
            tokenizer, model = load_auto_model(EMBEDD_MODEL, device)
    else:
        print(f"ℹ️ Auto Model: {EMBEDD_MODEL}")
        tokenizer, model = load_auto_model(EMBEDD_MODEL, device)

elif SWITCH == "Sentence Transformer":
    if os.path.exists(cached_path):
        model = load_sentence_model(cached_path, device)
        print(f"ℹ️ Sentece Transformer: {cached_path}")

        if model is None:
            model = load_sentence_model(EMBEDD_MODEL, device)
    else:
        print(f"ℹ️ Sentece Transformer: {EMBEDD_MODEL}")
        model = load_sentence_model(EMBEDD_MODEL, device)

print(f"✅ Using: {device}")

### Assets

In [None]:
assets = "../Assets/"
exceptions_path = f"{assets}ex.exceptions.json"
markers_path = f"{assets}ex.markers.json"
status_path = f"{assets}ex.status.json"

### Extract Data

In [None]:
dataExtractor = B1.B1Extractor(
    exceptions_path, 
    markers_path, 
    status_path, 
    proper_name_min_count=10)

In [None]:
extracted_data = dataExtractor.extract(dcmt_path)
A0.write_json(extracted_data, extracted_path, indent=1)

merged_data = B2.mergeLinesToParagraphs(extracted_data)
A0.write_json(merged_data, merged_path, indent=1)

### Get Struct

In [None]:
analyzer = B3.StructureAnalyzer(merged_path, verbose=True)

In [None]:
markers = analyzer.extract_markers()

structures = analyzer.build_structures(markers)
print(A0.jsonl_convert(structures))

dedup = analyzer.deduplicate(structures)
print(A0.jsonl_convert(dedup))

top = analyzer.select_top(dedup)
topext = analyzer.extend_top(top, dedup)
print(A0.json_convert(topext, pretty=True))

A0.write_json(topext, struct_path, indent=2)

### Chunks

In [None]:
builder = B4.ChunkBuilder(
    struct_path=struct_path,
    merged_path=merged_path,
    output_path="Data_HNMU_Regulations_Chunks.json"
)

In [None]:
chunks = builder.build()
A0.write_json(chunks, chunks_base, indent=2)

In [None]:
A0.write_json(chunks, chunks_segment, indent=2)

### Schema Extract

In [None]:
schemaEx = C1.JSONSchemaExtractor(list_policy="first", verbose=True)

In [None]:
if os.path.exists(chunks_segment):
    if not os.path.exists(schema_ex_path): 
        schemaEx.schemaRun(chunks_segment, schema_path=schema_ex_path)
    chunksSchema = A0.read_json(schema_ex_path)
    print(chunksSchema)
else:
    print(f"{chunks_segment} does not exist")

### Embedding

In [None]:
Embedding = C2.JSONEmbedding(
    model=model,
    device="cuda:0",
    batch_size=32,
    show_progress=False,
    flatten_mode="split"
)

In [None]:
if os.path.exists(chunks_segment):
    if not os.path.exists(torch_path):

        Embedding.embeddingRun(
            json_path = chunks_segment,
            schema_path = schema_ex_path,
            torch_path = torch_path,
            data_key = DATA_KEY,
            embe_key = EMBE_KEY,    
        )
        
    C3.print_json(DATA_KEY, EMBE_KEY, torch_path)
else:
    print(f"{chunks_segment} does not exist")

### Convert to Faiss

In [None]:
if os.path.exists(torch_path):
    if not os.path.exists(faiss_path):
        D0.convert_pt_to_faiss(
            torch_path=torch_path, 
            faiss_path=faiss_path, 
            mapping_path=mapping_path, 
            mapping_data=mapping_data, 
            data_key = DATA_KEY,
            nlist = 100, 
            use_pickle = False)
    else: 
        print(f"{faiss_path} alredy existed")
else:
    print(f"{torch_path} does not exist")

## RAG - Main

### Funcs

In [None]:
def searchRun(user_question):
    preliminary_results = D1.search_faiss_index(
        MERGE = MERGE,
        query= user_question,
        embedd_model=EMBEDD_MODEL,
        faiss_path=faiss_path,
        mapping_path=mapping_path,
        mapping_data=mapping_data,
        device=device,
        k=2,
        min_score = 5,
        batches = False,
    )
    return preliminary_results

def rerankRun(user_question, preliminary_results):
    reranked_results = D2.rerank_results(
        query= user_question,
        results=preliminary_results,
        reranker_model=RERANK_MODEL,
        device=device,
        k=5,
        batches = False,
    )
    context = '\n\n'.join(item['text'] for item in reranked_results)
    return reranked_results, context

def respondRun(user_question, prompt, context="", doc=False):
    response = D3.respond_naturally(
        user_question = user_question,
        context = context,
        prompt = prompt,
        responser_model=RESPON_MODEL,
        score_threshold=0.85,
        max_results=3,
        doc = doc,
        gemini_api_key=API_KEY,
    )
    return response

### Main

In [None]:
prompt_type = "Docs"

queries = [
    "Quy chế này quy định những gì và áp dụng cho đối tượng nào",
    "Sinh viên có thể được thi lại bao nhiêu lần?",
]

def chatRun():

    prompt_path = f"Prompts/{prompt_type}_Prompt.txt"
    with open(prompt_path, "r", encoding="utf-8") as f: prompt = f.read()

    print("<< Enter 'exit', 'quit', 'escape', 'bye' or Press ESC to exit >>")
    print("Chatbot: Hello there! I'm here to help you!\n\n")

    query = -1
    while True:
        try:
            query += 1
            if query >= len(queries):
                user_input = "exit"
            else:
                user_input = queries[query]

            # user_input = input("You: ")

            user_question = A0.preprocess_text(user_input)
            if user_input.strip().lower() in ["exit", "quit", "escape", "bye", ""]:
                print("Chatbot: Goodbye!")
                break

            print(f"Query: {user_question}")

            # Bước 1: Search
            preliminary_results = searchRun(user_question)
            
            # Bước 2: Rerank
            if preliminary_results:
                print(preliminary_results)
                reranked_results, context = rerankRun(user_question, preliminary_results)
            else:
                print("Không tìm thấy thông tin!")

            # Bước 3: Generate Response
            if (reranked_results):
                print(reranked_results)
                print(f"\n Context:\n {context}")
                response = respondRun(user_question, prompt, context="", doc=True)
            else:
                print("Rerank thất bại!")

            # Bước 4: Print Response
            if response:
                print(f"\nYou: {user_question}")
                print(f"Chatbot: {response}\n\n")
            else:
                print("LLM không phản hồi!")

            print("=" * 200)
            print("\n\n")

        except KeyboardInterrupt:
            print("\nChatbot: Goodbye!")
            break

In [None]:
# chatRun()