<a href="https://colab.research.google.com/github/PunDin0827/EQ_RAG/blob/main/EQchat3_0.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import re
import json
import pickle
import chromadb
import unicodedata
import numpy as np
import pandas as pd
import regex as re_u
from datetime import datetime
from rank_bm25 import BM25Okapi
from collections import defaultdict
from transformers import BitsAndBytesConfig
from transformers import AutoModel , AutoTokenizer
from llama_index.llms.huggingface import HuggingFaceLLM
from llama_index.core.node_parser import SimpleNodeParser
from llama_index.core.memory import ChatSummaryMemoryBuffer
from llama_index.core.llms import ChatMessage , MessageRole
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core.vector_stores import FilterCondition, FilterOperator
from llama_index.core.response_synthesizers import get_response_synthesizer
from llama_index.core import SimpleDirectoryReader, VectorStoreIndex, StorageContext , Document
from llama_index.core.vector_stores import MetadataFilter, MetadataFilters, FilterCondition, FilterOperator

## LLM model ##

In [None]:
from llama_index.llms.llama_cpp import LlamaCPP
llm = LlamaCPP(
        model_path = "Qwen3-14B-Q4_K_M.gguf",
        model_kwargs = {
            "n_ctx" : 4096 ,
            "n_gpu_layers" : -1,
            "n_batch" : 192,
            "top_k" : 0,
            "top_p" : 0.9,
            "n_threads" : 32,
            "temperature" : 0.6,
            "max_tokens" : 512
        }
)

## vector DB ##

In [None]:

# vectorDB
# 1. 準備 chroma
chroma_client = chromadb.PersistentClient("chromadb")
chroma_collection = chroma_client.get_or_create_collection("RAG_EQ")
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
storage_context = StorageContext.from_defaults(vector_store=vector_store)

## Data embedding ##

In [None]:
# Alarm code
PAT_EQ = re.compile(###)
PAT_EQ_type = re.compile(###)
PAT_MR_code = re.compile(###)
PAT_Err = re.compile(###)
PAT_AI_COMBO = re.compile(###)
PAT_AI_SINGLE = re.compile(###)
PAT_FANUC_code = re.compile(###)
PAT_alarm = re.compile(###)
def norm(s: str) -> str:
    return(s.replace('x','-').replace('X','-')
              .replace('．','.')
              .replace('–','-').replace('—','-')
              .upper().strip())

def Alarm_tittle(meta:dict):
    codes = meta.get("AlarmCode" , []) or []
    meta["has_alarm"] = bool(codes)
    if codes:
        fam = {c.split("#", 1)[0] for c in codes if "#" in c}
        meta["alarm_family"] = sorted(fam)
    else:
        meta["alarm_family"] = []
    return meta


def Alarm_metadata(filepath):
    with open(filepath, encoding="utf8") as f:
        raw = f.read()
    chunks = [c.strip() for c in raw.split("---") if c.strip()]
    nodes2 = []

    for idx, chunk in enumerate(chunks):
        # ---------- 機台/機型 ----------
        m_eq = PAT_EQ.search(chunk)
        eq_name = m_eq.group(1).upper() if m_eq else ""
        etypes = [mt.group(1) for mt in PAT_EQ_type.finditer(chunk)]
        eq_type = etypes[0] if etypes else ""

        # ---------- Alarm Codes ----------
        codes = set()
        for m in PAT_MR_code.finditer(chunk):
            codes.add(f"MR#{m.group(1).upper()}")
        for m in PAT_Err.finditer(chunk):
            codes.add(f"ERR#{norm(m.group(1))}")
        for m in PAT_AI_COMBO.finditer(chunk):
            for n in re.split(r'[,\s、/＋\+]+', m.group(1)):
                if n.isdigit():
                    codes.add(f"AI#{n}")
        for m in PAT_AI_SINGLE.finditer(chunk):
            codes.add(f"AI#{m.group(1)}")
        for m in PAT_alarm.finditer(chunk):
            codes.add(f"Alarm#{m.group(1)}")
        for m in PAT_FANUC_code.finditer(chunk):
            alpha = m.group(1).upper()
            num = m.group(2)
            codes.add(f"FANUC#{alpha}-{num}")

        # ---------- 產生節點：一碼一 node ----------
        if codes:
            for code in sorted(codes):
                meta = {
                    "eq_name" : eq_name,            # str
                    "eq_type" : eq_type,            # str
                    "has_alarm" : 1,
                    "alarm_code" : code,
                    "code_family" : code.split("#", 1)[0],  # AI/MR/ERR/FANUC（也單值）
                }
                nodes2.append(Document(text=chunk, metadata=meta))
        else:
            # 沒碼也可建一筆
            meta = {
                "eq_name" : eq_name,
                "eq_type" : eq_type,
                "has_alarm" : 0,                    # int
                "alarm_code" : "",
                "code_family" : "",
            }
            nodes2.append(Document(text=chunk, metadata=meta))

    return nodes2


In [None]:
Alarm_metadata("AlarmCode.txt")

In [None]:
nodes = Alarm_metadata("alldata.txt")

## Embedding ##

In [None]:

# # 3. embedding model
# embed_model = HuggingFaceEmbedding(model_name = "QWEN3_4B_embedding",
#                            device = "cuda" ,
#                            embed_batch_size=32,
#                            model_kwargs={"torch_dtype": "float16"},)
# # 4. 建索引
# index = VectorStoreIndex(nodes, embed_model = embed_model, storage_context = storage_context)

# index.storage_context.persist()

# # 5. 查詢資料筆數
# ids = chroma_collection.get()["ids"]
# print(f"RAG_EQ collection 筆數: {len(ids)}")

## Inference ##

In [None]:
EQ_dict = {
   ###
}



def get_EQ(input):
    key = input.upper().replace("-","").replace("_","").replace(" ","")
    for main , list in EQ_dict.items():
        full_list = [main] + list
        if key in [list.upper().replace("_","").replace("_","").replace(" ","") for list in full_list]:
            return main # 主名字
    return input # 沒找到就回傳原輸入

In [None]:
print(get_EQ("###"))
print(get_EQ("###"))
print(get_EQ("###"))
print(get_EQ("###"))
print(get_EQ("###"))

In [None]:
PAT_EQ = re.compile(###)
PAT_EQ_type = re.compile(###)
PAT_MR = re.compile(###)
PAT_ERR_FAMILY = ###
PAT_ERR_CODE = re.compile(###)
PAT_AI = re.compile(###)
PAT_alarm = re.compile(###)
PAT_FANUC = re.compile(###)

def norm(s:str)->str:
    return (s.replace('x','-').replace('X','-')
             .replace('．','.')
             .replace('–','-').replace('—','-')
             .upper().strip())

def in_filter(key: str , values:list[str]) ->MetadataFilter|None:
    vals = [v for v in values if v]
    if not vals:
        return None
    return MetadataFilter(key = key , value = vals , operator = FilterOperator.IN)

def parse_filters(question: str) -> MetadataFilters | None:
    flat_filters: list[MetadataFilter] = []     # 只加入「有命中」

    # === eq_name 可能多個 -> OR 群組（EQ） ===
    eq_names = []
    for m in PAT_EQ.finditer(question):
        eq = get_EQ(m.group(1))
        if eq:
            eq_names.append(eq.upper())
    f_eq = in_filter("eq_name", eq_names)
    if f_eq:
        flat_filters.append(f_eq)

    # === eq_type 可能多個 -> OR 群組
    etypes = [mt.group(1) for mt in PAT_EQ_type.finditer(question)]
    f_type = in_filter("eq_type", etypes)
    if f_type:
        flat_filters.append(f_type)

    # === alarm_code 把各種碼正規化成「型別#數值」 -> OR 群組（EQ） ===
    codes = []

    for mm in PAT_MR.finditer(question):
        codes.append(f"MR#{mm.group(1).upper()}")

    for mc in PAT_ERR_CODE.finditer(question):
         codes.append(f"ERR#{norm(mc.group(1))}")

    for ma in PAT_AI.finditer(question):
        codes.append(f"AI#{ma.group(1)}")

    for mf in PAT_FANUC.finditer(question):
        alpha = mf.group(1).upper()
        num = mf.group(2)
        codes.append(f"FANUC#{alpha}-{num}")

    for m in PAT_alarm.finditer(question):
        code = m.group(1).upper()
        codes.append(f"Alarm#{code}")

    f_code = in_filter("alarm_code", codes)
    if f_code:
        flat_filters.append(f_code)

    # === 組裝：只把「有命中」的群組 AND 起來 若全部都沒有，回 None ===
    if not flat_filters:
        return None
    return MetadataFilters(filters=flat_filters, condition=FilterCondition.AND)

## BM25 ##

In [None]:
RE_TOKEN = re.compile(r"[A-Za-z0-9]+(?:[-_.:/#][A-Za-z0-9]+)*" , re.I)
RE_CJK = re.compile(r"[\u4E00-\u9FFF\u3400-\u4DBF]+")
def nfkc(s:str)->str:
    return unicodedata.normalize("NFKC" , s)

def norm_units(s:str)->str:
    s = re.sub(r"(?<=\d)([A-Za-z%]+)", r" \1", s)
    return s

def cleanup(s:str)->str:
    s = re_u.sub(r"[\p{C}--[\n\t]]+" , "" , s )
    return s

def punctunify(s:str)->str:
    table = str.maketrans("，。；：【】（）％／－～", ",.;:[]()%/-~")
    return s.translate(table)

def code_family_variants(tok: str):
    m = re.match(r"(err|mr|ai|alarm)[-_:# ]?([0-9]+(?:[.x-][0-9]+)*)$", tok, re.I)
    if not m: return []
    fam = m.group(1).lower()
    num = m.group(2).lower()
    core = re.split(r"[.x-]", num)[0]
    return [f"{fam}#{num}", f"{fam}#{core}", fam, num, core]


def tokenize(text:str , add_trigram: bool = False):
    s = nfkc(text)
    s = cleanup(s)
    s = punctunify(s)
    s = norm_units(s)
    s = re.sub(r"\s+" ," " , s).strip()

    tokens , spans = [] , []

    for m in RE_TOKEN.finditer(s):
        t = m.group(0)
        tokens.append(t)
        tokens.extend(code_family_variants(t))
        spans.append(m.span())

    cursor = 0
    for(a,b) in spans + [(len(s) , len(s))]:
        if cursor < a:
            chunk = s[cursor:a]
            for cjkm in RE_CJK.finditer(chunk):
                seg = cjkm.group(0)
                seg = re.sub(r"\s+", "", seg)
                tokens.extend([seg[i:i+2] for i in range(len(seg)-1)])
                if add_trigram and len(seg) >=3:
                    tokens.extend([seg[i:i+3] for i in range(len(seg)-2)])
        cursor = b
    return tokens


corpus_tokens = []
doc_ids = []
corpus   = []
text_map = {}

for n in nodes:
    nid = getattr(n, "node_id", getattr(n, "id_", None))
    text = getattr(n, "text", "")
    doc_ids.append(nid)
    tok = tokenize(text)
    corpus.append(tok)
    text_map[nid] = text

bm25 = BM25Okapi(corpus)

with open("bm25_index.pkl", "wb") as f:
    pickle.dump({"bm25": bm25, "doc_ids": doc_ids, "text_map": text_map}, f)


In [None]:
embed_model = HuggingFaceEmbedding(model_name = "QWEN3_4B_embedding",
                          device="cuda" ,
                          embed_batch_size=32,
                          model_kwargs={"torch_dtype": "float16"},)

chroma_client = chromadb.PersistentClient("chromadb")
chroma_collection = chroma_client.get_or_create_collection("RAG_EQ")
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
index = VectorStoreIndex.from_vector_store(vector_store=vector_store,
                       storage_context=storage_context,
                       embed_model=embed_model,)

In [None]:
txt_path = "llm_log.txt"

def log_txt(user_id, user_query, node_texts, node_scores, llm_response,llm_memory_prompt=None):
    times = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    node_strs = []
    for i, (t, s) in enumerate(zip(node_texts, node_scores)):
        node_strs.append(f"Node {i+1} [Score: {s}]:\n{t}")
    nodes_block = "\n\n".join(node_strs)
    log_str = f"""==== User ID: {user_id} ====
Timestamp: {times}
User Query: {user_query}
=========================
Node Texts:
{nodes_block}
=========================
LLM Full Prompt:
{llm_memory_prompt}
=========================
LLM Response: {llm_response}
----------------------------

"""
    with open(txt_path, 'a', encoding='utf-8') as f:
        f.write(log_str)

In [None]:
try:
    from llama_index.core.schema import TextNode, NodeWithScore
except Exception:
    from llama_index.schema import TextNode, NodeWithScore  # 舊版相容

# 查詢
memory_prompt = (
    """你是一位機械維修工程師，請將以下使用者的對話整理成重點摘要，內容包含\n
    1.明確記錄使用者提到的設備型號與異常代碼\n
    2.明確記錄已經討論與檢查過的問題\n
    3.已經提供的建議與步驟\n
    4.後續追問需求\n
    5.請用條列式摘要\n
    """
)

memory = ChatSummaryMemoryBuffer(
    llm=llm,
    token_limit=2048,
    max_token_limit=4096,
    full_message_token_limit=1024,
    summarize_prompt=memory_prompt
)

prompt = """你是一位機械維修工程師，嚴格禁止輸出：推理步驟、內部分析、重複內容、英文\n
        必須輸出兩段，且都要有內容：
        1.【檢索到的正確內容】：逐條摘錄證據的關鍵事實。
        2.【可能的原因與處理方向】：根據證據條列原因；每條給出檢查→處置步驟。
        """

with open("bm25_index.pkl", "rb") as f:
    bm25_pack = pickle.load(f)
bm25 = bm25_pack["bm25"]
doc_ids = bm25_pack["doc_ids"]
text_map[nid] = text
K = 60
node_cache = {}  # nid -> Node,
user_id = 1

def get_node(obj):
    return getattr(obj, "node", obj)

def get_nid(obj):
    n = get_node(obj)
    return getattr(n, "node_id", getattr(n, "id_", None))

def ensure_node(nid):
    n = node_cache.get(nid)
    if n is not None:
        return n
    try:
        n = index.docstore.get_node(nid)
    except Exception:
        n = None
    if n is None and nid in text_map:
        n = TextNode(text=text_map[nid], id_=nid)
    if n is not None:
        node_cache[nid] = n
    return n

while True:
    user_input = input("請輸入問題（輸入 exit 結束）：")
    if not user_input.strip():
        print("請輸入有效問題")
        continue
    if user_input.strip().lower() in ["exit", "quit"]:
        break

    question = user_input
    meta_filters = parse_filters(question)

    # ===== metadata filter =====
    query_engine_filter = index.as_query_engine(
        llm=llm,
        memory=memory,
        similarity_top_k=50,
        filters=meta_filters if meta_filters else None
    )
    nodes_filter = query_engine_filter.retriever.retrieve(question) if meta_filters else []

    # ===== Embedding =====
    query_engine_emb = index.as_query_engine(
        llm=llm,
        memory=memory,
        similarity_top_k=50,
        filters=None
    )
    nodes_emb = query_engine_emb.retriever.retrieve(question)

    # ===== 統一分數映射 =====
    scores_by_id = defaultdict(float)
    id_to_node = {}

    # filter + emb 的 Node 與分數納入
    for nws in (nodes_filter + nodes_emb):
        nid = get_nid(nws)
        if nid is None:
            continue
        n = get_node(nws)
        id_to_node[nid] = n
        s0 = float(getattr(nws, "score", 0.0) or 0.0)
        if s0 > scores_by_id[nid]:
            scores_by_id[nid] = s0

    # ===== BM25 =====
    qtoks = tokenize(question)  #
    bm25_scores = bm25.get_scores(qtoks)
    bm25_topk_idx = np.argsort(bm25_scores)[-50:][::-1]
    bm25_hits = [(doc_ids[i], float(bm25_scores[i]), rank+1)
                 for rank, i in enumerate(bm25_topk_idx)]

    nodes_bm25 = []
    for nid, s, r in bm25_hits:
        n = ensure_node(nid)
        if n is None:
            continue
        if s > scores_by_id.get(nid, 0.0):
            scores_by_id[nid] = s
        nodes_bm25.append(n)
        id_to_node[nid] = n  #

    node_source_type = {}
    for nws in nodes_filter:
        nid = get_nid(nws)
        if nid:
            node_source_type[nid] = "filter"
    for nws in nodes_emb:
        nid = get_nid(nws)
        if nid:
            node_source_type[nid] = "hybrid" if node_source_type.get(nid) == "filter" else "embedding"
    for n in nodes_bm25:
        nid = get_nid(n)
        if nid:
            prev = node_source_type.get(nid)
            if prev:
                if "bm25" not in prev:
                    node_source_type[nid] = prev + "_bm25"
            else:
                node_source_type[nid] = "bm25"

    # ===== 合併 =====
    base_nodes_by_id = {}
    for nws in (nodes_filter + nodes_emb):
        n = get_node(nws)
        nid = get_nid(nws)
        if nid:
            base_nodes_by_id[nid] = n
    for n in nodes_bm25:
        nid = get_nid(n)
        if nid and nid not in base_nodes_by_id:
            base_nodes_by_id[nid] = n

    # ===== 最終 =====
    def final_score(obj):
        nid = get_nid(obj)
        stype = node_source_type.get(nid, "embedding")
        base = {
            "hybrid_bm25" : 1.6,
            "filter_bm25" : 1.3,
            "embedding_bm25" : 1.2,
            "hybrid" : 1.0,
            "filter" : 0.6,
            "bm25" : 0.5,
            "embedding" : 0.0,
        }.get(stype , 0.0)
        raw = scores_by_id.get(nid , 0.0)
        return raw + base

    all_base_nodes = list(base_nodes_by_id.values())
    if not all_base_nodes:
        print("查無任何相關資料\n查無相關資料，請補充查詢條件，如:機台編碼、異常原因或更細節的錯誤描述")
        continue

    sorted_base_nodes = sorted(all_base_nodes , key=final_score, reverse=True)
    top_base_nodes = sorted_base_nodes[:5]
    top_nodes  = [NodeWithScore(node=n , score=final_score(n)) for n in top_base_nodes]

    # ======================================================================================
    top0_score = final_score(top_nodes[0])
    if top0_score < 0.50:
        response = "查無相關資料，請補充查詢條件，如:機台編碼、異常原因或更細節的錯誤描述"
        print("查無任何相關資料\n", response)
    elif top0_score < 0.70:
        synthesizer = get_response_synthesizer(response_mode="compact", llm=llm)
        response = synthesizer.synthesize(query = question + "\n" + prompt, nodes = top_nodes[:2])
        print("目前查詢條件命中資料的相似度偏低，請補充查詢條件，如:機台編碼、異常原因或更細節的錯誤描述，以下為最相關兩筆資料供參考 :\n", response)
    else:
        synthesizer = get_response_synthesizer(response_mode="compact", llm=llm)
        response = synthesizer.synthesize(query=question + "\n" + prompt, nodes=top_nodes)
        print(response)

    # ===== 記憶與紀錄 =====
    message = [
        ChatMessage(role=MessageRole.USER, content=user_input),
        ChatMessage(role=MessageRole.ASSISTANT, content=response)
    ]
    memory.put_messages(message)

    print(memory.get_all())
    print(memory.chat_store)
    print(memory)

    memory_summary_text = memory.get()
    node_texts  = [getattr(get_node(n), "text", "") for n in (top_nodes if top_base_nodes else [])]
    node_scores = [final_score(n) for n in (top_nodes if top_base_nodes else [])]
    log_txt(
        user_id=user_id,
        user_query=question,
        node_texts=node_texts,
        node_scores=node_scores,
        llm_memory_prompt=memory_summary_text,
        llm_response=response
    )

    # ===== 監控 BM25 缺失 =====
    missing = [nid for nid, _, _ in bm25_hits if nid not in id_to_node]
    if missing:
        print(f"[warn] 有 {len(missing)} 個 BM25 nid 不在 id_to_node（列前 5 個）: {missing[:5]}")
