In [1]:
from docs_preprocess import DocumentProcessor
from VecStore import VectorStore
from langchain.llms import HuggingFacePipeline
from pathlib import Path


In [2]:
from typing import List, Optional, Any
from dataclasses import dataclass, field

import torch
from langchain_core.language_models.llms import LLM
from transformers import AutoTokenizer, AutoModelForCausalLM

class DeepSeekLLM(LLM):
    model: Any
    tokenizer: Any

    # ==== 通用生成超参 ====
    max_new_tokens: int = 512
    temperature: float  = 0.7
    do_sample: bool     = True
    top_p: float        = 0.9
    top_k: int          = 50

    # 让 IDE 能补全；LangChain 会用
    @property
    def _llm_type(self) -> str:
        return "deepseek_hf"

    # --- 自动补 pad_token_id，避免警告 ---
    def __post_init__(self):
        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
        if getattr(self.model.config, "pad_token_id", None) is None:
            self.model.config.pad_token_id = self.tokenizer.pad_token_id

    # -------------------------------------------------------------
    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        **kwargs,
    ) -> str:

        # 1) prompt → input_ids
        messages = [{"role": "user", "content": prompt}]
        input_ids = self.tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,
            return_tensors="pt",
        ).to(self.model.device)

        attention_mask = (input_ids != self.tokenizer.pad_token_id).long()

        # 2) 组装超参（kwargs > 实例字段）
        gen_kwargs = dict(
            max_new_tokens = kwargs.get("max_new_tokens", self.max_new_tokens),
            do_sample      = kwargs.get("do_sample",      self.do_sample),
            temperature    = kwargs.get("temperature",    self.temperature),
            top_p          = kwargs.get("top_p",          self.top_p),
            top_k          = kwargs.get("top_k",          self.top_k),
            pad_token_id   = self.tokenizer.pad_token_id,
            attention_mask = attention_mask,
        )

        with torch.no_grad():
            outputs = self.model.generate(input_ids, **gen_kwargs)

        text = self.tokenizer.decode(
            outputs[0][input_ids.shape[1]:],
            skip_special_tokens=True,
        )

        # 3) stop words 手动截断
        if stop:
            for s in stop:
                if s in text:
                    text = text.split(s)[0]

        return text.strip()


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
model_dir = "/home/lyus4/yuheng/All_in_LLM/deepseek-llm-7b-chat"
tok   = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_dir,
                                             trust_remote_code=True,
                                             torch_dtype="auto",
                                             device_map="auto",
                                            )

llm = DeepSeekLLM(model=model, tokenizer=tok, max_new_tokens=512, temperature=0.2, do_sample = True, top_p = 0.95, top_k = 40)

[2025-06-14 18:07:18,421] [INFO] [real_accelerator.py:254:get_accelerator] Setting ds_accelerator to cuda (auto detect)


/home/lyus4/anaconda3/envs/rag_env/compiler_compat/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status
/home/lyus4/anaconda3/envs/rag_env/compiler_compat/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status
Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.44s/it]
We've detected an older driver with an RTX 4000 series GPU. These drivers have issues with P2P. This can affect the multi-gpu inference when using accelerate device_map.Please make sure to update your driver to the latest version which resolves this.


In [4]:
# 路径配置
MODEL_PATH =  "/home/lyus4/yuheng/All_in_LLM/all-MiniLM-L6-v2"
INDEX_PATH = Path("../vector_store/faiss_ivfflat_100")  # 无需加 .index 后缀
VECTORS_PATH = Path("../index/vectors.npy")
DOC_JSON_PATH = Path("output_chunks.json")

In [5]:
vs = VectorStore(model_path=MODEL_PATH, db_path=INDEX_PATH)
vs.load_documents_and_metadata(json_path=DOC_JSON_PATH)
vs.describe()


[INFO] 初始化 VectorStore -> ../vector_store/faiss_ivfflat_100
[INFO] 索引加载成功: ../vector_store/faiss_ivfflat_100
[INFO] 类型: IndexIVFFlat, 维度: 384, 数量: 18343
[INFO] 加载文档 18343 条

[INFO] VectorStore 状态描述：
- 文档数: 18343
- 向量数: 0
- 索引类型: IndexIVFFlat
- 向量维度: 384
- 向量总数: 18343


In [6]:
# ── 依赖 ─────────────────────────────────────────────────────────────
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import (
    Runnable, RunnableLambda, RunnableWithMessageHistory
)
from langchain_core.output_parsers import StrOutputParser
from langchain_core.chat_history import InMemoryChatMessageHistory
import textwrap
from utils import format_context_grouped, expand_acronyms

# 一个极简摘要 Prompt
summary_prompt = ChatPromptTemplate.from_messages([
    ("system", 
     "You are a note-taking assistant.\n"
     "Rewrite the following conversation so far into 1-2 concise sentences, "
     "focusing only on facts or user preferences we should remember.\n\n"
     "{conversation}")
])
summary_chain = summary_prompt | llm | StrOutputParser()

history_store      = {}   # 原有
summary_store      = {}   # 新增：session_id -> str

MAX_HISTORY_EVENTS = 4        # 窗口内原始消息 <= 6 条
MAX_TURNS_PER_SESSION = 4    # 方案 A：30 个 Q/A 后重置
MAX_SUMMARY_CHARS = 100       # 方案 B：摘要长度阈值（字符）
RESET_PRESERVE_SUMMARY  = False

# 你的llm和vs初始化略




In [7]:
def history_factory(session_id: str) -> InMemoryChatMessageHistory:
    if session_id not in history_store:
        history_store[session_id] = InMemoryChatMessageHistory()
    return history_store[session_id]

# ---------- A. 每轮都更新长期摘要 ---------------------------------  # ← CHANGED
def update_summary(session_id: str, new_messages):
    """把旧摘要 + 本轮内容 → 新摘要；若过长再重压缩。"""
    prev = summary_store.get(session_id, "")
    delta = "\n".join(m.content for m in new_messages).strip()
    conversation = f"{prev}\n{delta}".strip() if prev else delta

    try:
        summary = summary_chain.invoke({"conversation": conversation})
    except Exception:
        return  # 失败保持旧摘要

    # ---- 若摘要太长，再压成一句话 ----
    if len(summary) > MAX_SUMMARY_CHARS:
        try:
            summary = summary_chain.invoke({
                "conversation": summary,
                # 也可以让 prompt 变成“再浓缩成一句”
            })
        except Exception:
            pass    # 兜底仍用超长版本

    summary_store[session_id] = summary


# ---------- B. 仅裁剪短期窗口 ------------------------------------  # ← CHANGED
def _msg_role(msg) -> str:
    """return 'human' / 'ai' / 'system'... 兼容 .role 和 .type"""
    return getattr(msg, "role", getattr(msg, "type", "")).lower()

turn_counter: dict[str, int] = {}

def trim_messages(session_id: str):
    hist_obj = history_factory(session_id)
    msgs     = hist_obj.messages

    # ------- 1. 回合计数（在裁剪之前） -------
    turns = turn_counter.get(session_id, 0)
    # 当且仅当“用户 + AI”各到位才算一轮
    if len(msgs) >= 2 and _msg_role(msgs[-1]) in ("ai", "assistant"):
        turns += 1
        turn_counter[session_id] = turns

    if turns >= MAX_TURNS_PER_SESSION:
        print(f"[INFO] session {session_id} reaches {turns} turns, resetting…")
        history_store[session_id] = InMemoryChatMessageHistory()
        turn_counter[session_id]  = 0
        if not RESET_PRESERVE_SUMMARY:
            summary_store[session_id] = ""
        return                      # 直接返回，下面裁剪已无意义

    # ------- 2. 窗口裁剪（保证成对） -------
    if len(msgs) > MAX_HISTORY_EVENTS:
        start = len(msgs) - MAX_HISTORY_EVENTS
        if _msg_role(msgs[start]) not in ("human", "user"):
            start += 1
        msgs[:] = msgs[start:]


In [8]:
template_test = """
<Role>
You are a 5G wireless communication expert.

<Goal>
Answer the question using the information in the context below.
If the context is insufficient, reply exactly: **"I don't know"**.

<Memory>
{summary}

<Context>
{context}

<Question>
{question}

<Instructions>
1. Explain simply and clearly, as if to a non-expert.  
2. Give the reference.

<Answer>
"""
chat_prompt = ChatPromptTemplate.from_messages([
    ("system", template_test),
    MessagesPlaceholder("history"),
    ("user", "<Question>\n{question}")
])


In [9]:
def build_inputs(inputs: dict):
    question       = inputs["input"]
    session_id     = inputs["session_id"]     # 👉 ① 把 session_id 一并传进来
    history_msgs   = inputs.get("history", [])

    # —— A. 检索 query：历史 Human + 当前问题 ——
    hist_text = "\n".join(
        m.content for m in history_msgs        # 最近几条
        if (getattr(m, "type", getattr(m, "role", "")).lower()
            in ("human", "user"))
    )
    combined       = f"{hist_text}\n{question}" if hist_text else question
    expanded_query = expand_acronyms(combined)

    # —— B. VS 检索上下文 ——
    top_k_results  = vs.search(expanded_query, k=3, score_mode="reciprocal")
    ctx            = format_context_grouped(top_k_results,
                                            with_metadata=True,
                                            with_score=True)

    # —— C. 注入摘要 —— 
    summary        = summary_store.get(session_id, "")   # ← 旧记忆
    return {
        **inputs,                 # 传下游的字段都保留
        "context": ctx,
        "question": question,
        "summary": summary
    }


In [10]:
context_retriever = RunnableLambda(build_inputs)

base_chain: Runnable = (
    context_retriever |
    chat_prompt |
    llm |
    StrOutputParser()
)

chatbot = RunnableWithMessageHistory(
    base_chain,
    history_factory,
    input_messages_key="input",      # 本轮问题字段
    history_messages_key="history"   # 多轮消息字段
)


# ── 6. 单轮调用接口 (invoke→更新摘要→裁剪) ─────────────────────────  # ← CHANGED
def ask(session_id: str, user_question: str) -> str:
    resp = chatbot.invoke(
        {"input": user_question, "session_id": session_id},
        config={"configurable": {"session_id": session_id}}
    )

    hist = history_factory(session_id).messages
    update_summary(session_id, hist[-2:])   # 每轮都写摘要
    trim_messages(session_id)               # 窗口&轮数控制
    return resp


# ── 7. 打印辅助 ───────────────────────────────────────────────────
def wrap_text(text, width=120):
    return textwrap.fill(text, width=width)

def print_chat_state(session_id: str):
    msgs = history_factory(session_id).messages
    print("\n==== 最近窗口中的消息 ====")

    for idx in range(0, len(msgs), 2):
        q = wrap_text(msgs[idx].content)
        a = wrap_text(msgs[idx+1].content) if idx+1 < len(msgs) else "(pending)"
        turn_no = (len(msgs) - idx) // 2  # 倒序标号，更直观
        print(f"\n◉ Q-{turn_no}: {q}\n◎ A-{turn_no}: {a}")

    print("\n---- 长期摘要 ----")
    print(wrap_text(summary_store.get(session_id, '(empty)')))



In [None]:
# 第 1 轮
sid = "user_42"

for q in [
"What is beam management?",
"And why is it important?",
"Which 3GPP spec defines it?",
"what is ofdma",
"what is RB"
]:
    print(f"\n🧑‍💬 {q}")
    print(f"🤖 {wrap_text(ask(sid, q))}")

    print_chat_state(sid)
    print("="*100)


🧑‍💬 What is beam management?
🤖 Beam management refers to the process of controlling and optimizing the transmission and reception of data in a wireless
network by directing radio signals towards specific directions or points. In 5G wireless communication, beam management
is used to improve the quality and efficiency of data transmission by focusing signals on specific beams or directions,
which can help reduce interference and increase the speed and reliability of data transmission.  Beam management is an
important aspect of 5G wireless communication as it enables the network to better utilize the available frequency
spectrum and improve the overall performance of the network. By directing signals towards specific beams, 5G networks
can reduce interference and improve the overall quality of service for users.  The beam management process involves
several steps, including beamforming, beam alignment, and beam tracking. Beamforming involves using advanced signal
processing techniques to