In [29]:
import os
from langchain_community.vectorstores import FAISS
from langchain.text_splitter import MarkdownTextSplitter
import pickle

In [30]:
model_path = 'BAAI/bge-large-zh-v1.5'
output_dir = os.path.join('output', f'v1')

In [31]:
from langchain.document_loaders import TextLoader
loader=TextLoader("D:\desktop\code\data\宏利環球貨幣保障計劃 保單條款_2022_05.md", encoding='utf-8')
documents=loader.load()

In [32]:
from dataclasses import dataclass
from typing import List

#配置类 SplitConfig，主要用于文本分割相关的参数配置
@dataclass
class SplitConfig:
    chunk_size: int = 400
    chunk_overlap: int = 40
    separators: List[str] = ('\n\n\n', '\n\n')
    force_split: bool = False
    output_format: str = 'json' 
    cache_dir: str = './cache'

In [33]:
import json
import uuid
import logging
logging.basicConfig(level=logging.INFO, format='[%(levelname)s] %(message)s')

##将文本分块结果保存为json
def save_chunks_as_json(chunks, filepath):
    data = [
        {
            "uuid": chunk.metadata.get('uuid', str(uuid.uuid4())),
            "content": chunk.page_content,
            "metadata": chunk.metadata
        }
        for chunk in chunks
    ]
    with open(filepath, 'w', encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False, indent=2)

##加载分块
def load_chunks_from_json(filepath):
    from langchain.docstore.document import Document
    with open(filepath, 'r', encoding='utf-8') as f:
        data = json.load(f)
    chunks = []
    for item in data:
        doc = Document(
            page_content=item['content'],
            metadata=item['metadata']
        )
        chunks.append(doc)
    logging.info(f"Loaded {len(chunks)} chunks from {filepath}")
    return chunks


In [34]:

def split_docs_with_config(documents, config: SplitConfig, cache_name="all_docs"):
    os.makedirs(config.cache_dir, exist_ok=True)
    filename = f"split_{config.chunk_size}_{config.chunk_overlap}.{config.output_format}"
    filepath = os.path.join(config.cache_dir, filename)
    
    if os.path.exists(filepath) and not config.force_split:
        logging.info("Found existing cache. Loading...")
        if config.output_format == 'json':
            return load_chunks_from_json(filepath)
    
    splitter=MarkdownTextSplitter(
        chunk_size=config.chunk_size,
        chunk_overlap=config.chunk_overlap
    )
    chunks = splitter.split_documents(documents)
    for chunk in chunks:
        chunk.metadata['uuid'] = str(uuid.uuid4())

    if config.output_format == 'json':
        save_chunks_as_json(chunks, filepath)
    return chunks


In [35]:
from langchain.docstore.document import Document

def load_multiple_documents_from_dir(directory: str, encoding='utf-8') -> List[Document]:
    docs = []
    for filename in os.listdir(directory):
        if filename.endswith(".md"):
            file_path = os.path.join(directory, filename)
            loader = TextLoader(file_path, encoding=encoding)
            file_docs = loader.load()
            for doc in file_docs:
                doc.metadata['source_file'] = filename
            docs.extend(file_docs)
    logging.info(f"Loaded {len(docs)} documents from {directory}")
    return docs

input_dir = "D:/desktop/code/data"  
documents = load_multiple_documents_from_dir(input_dir)


[INFO] Loaded 7 documents from D:/desktop/code/data


In [36]:
split_config = SplitConfig(
    chunk_size=400,
    chunk_overlap=40,
    separators=['\n\n\n', '\n\n'],
    force_split=False,
    output_format='json',
    cache_dir='./cache'
)
logging.info(f"Splitting documents with chunk_size={split_config.chunk_size}, overlap={split_config.chunk_overlap}")
split_docs = split_docs_with_config(documents, split_config)

[INFO] Splitting documents with chunk_size=400, overlap=40
[INFO] Found existing cache. Loading...
[INFO] Loaded 39 chunks from ./cache\split_400_40.json


In [37]:
from langchain.embeddings import HuggingFaceBgeEmbeddings
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'device: {device}')

embeddings = HuggingFaceBgeEmbeddings(
    model_name=model_path,
    model_kwargs={'device': device},
    encode_kwargs={'normalize_embeddings': True}
)

[INFO] Load pretrained SentenceTransformer: BAAI/bge-large-zh-v1.5


device: cpu


In [18]:
# from xinference.client import Client

# # 连接到 Xinference 服务
# client = Client("http://localhost:9997")
# # 加载嵌入模型
# model_uid = client.launch_model(
#     model_name="bge-large-zh-v1.5",
#     model_size_in_billions=None,  
#     quantization=None,
#     model_type="embedding"
# )

# model = client.get_model(model_uid)

# # 定义一个自定义的嵌入类，使用 Xinference 模型生成嵌入
# class XinferenceEmbeddings:
#     def __init__(self, model):
#         self.model = model

#     def embed_documents(self, texts):
#         embeddings = []
#         for text in texts:
#             result = self.model.create_embedding(input=[text])
#             embeddings.append(result["data"][0]["embedding"])
#         return embeddings

#     def embed_query(self, text):
#         result = self.model.create_embedding(input=[text])
#         return result["data"][0]["embedding"]

# class CallableXinferenceEmbeddings:
#     def __init__(self, xinference_embedder):
#         self.embedder = xinference_embedder

#     def __call__(self, text):
#         return self.embedder.embed_query(text)

# # 使用方式
# embeddings = CallableXinferenceEmbeddings(XinferenceEmbeddings(model))

# # embeddings = XinferenceEmbeddings(model)

In [38]:
def attach_metadata_to_faiss_db(vector_db, metadata_path):
    with open(metadata_path, 'rb') as f:
        metadatas = pickle.load(f)
    for i, doc in enumerate(vector_db.docstore._dict.values()):
        doc.metadata = metadatas[i]

In [39]:

def get_vector_db(docs, store_path, force_rebuild=False):
    index_path = os.path.join(store_path, "faiss_index")
    metadata_path = os.path.join(store_path, "faiss_metadata.pkl")

    if not os.path.exists(index_path) or not os.path.exists(metadata_path):
        force_rebuild = True

    if force_rebuild:
        os.makedirs(store_path, exist_ok=True)
        vector_db = FAISS.from_documents(docs, embedding=embeddings)
        vector_db.save_local(index_path)

        # 保存 metadatas（Langchain 的 FAISS 默认不会持久化 metadata）
        with open(metadata_path, 'wb') as f:
            pickle.dump([doc.metadata for doc in docs], f)
        
    else:
        vector_db = FAISS.load_local(index_path, embeddings, allow_dangerous_deserialization=True)
        attach_metadata_to_faiss_db(vector_db, metadata_path)
    return vector_db


In [41]:
vector_db = get_vector_db(split_docs, store_path=os.path.join(output_dir, 'FAISS', 'bge_large_v1.5'))

In [42]:

def update_faiss_vector_db(new_docs, store_path, embeddings):
    """
    向已有 FAISS 向量数据库中添加新文档并更新持久化索引和 metadata。
    
    参数：
        new_docs (List[Document]): 新文档列表
        store_path (str): 向量数据库的持久化路径
        embeddings: 已初始化的 embedding 模型（如 HuggingFaceEmbeddings）
    """
    index_path = os.path.join(store_path, "faiss_index")
    metadata_path = os.path.join(store_path, "faiss_metadata.pkl")

    # 1. 加载现有向量数据库或初始化新库
    if os.path.exists(index_path):
        vector_db = FAISS.load_local(index_path, embeddings,allow_dangerous_deserialization=True)
    else:
        os.makedirs(store_path, exist_ok=True)
        vector_db = FAISS.from_documents([], embedding=embeddings)

    # 2. 添加新文档并保存
    vector_db.add_documents(new_docs)
    vector_db.save_local(index_path)

    # 3. metadata 更新
    if os.path.exists(metadata_path):
        with open(metadata_path, 'rb') as f:
            metadatas = pickle.load(f)
    else:
        metadatas = []

    metadatas.extend([doc.metadata for doc in new_docs])
    with open(metadata_path, 'wb') as f:
        pickle.dump(metadatas, f)

    print(f"[INFO] 添加了 {len(new_docs)} 条新文档，向量库已更新并保存。")


In [43]:
vector_db = update_faiss_vector_db(split_docs, store_path=os.path.join(output_dir, 'FAISS', 'bge_large_v1.5'), embeddings=embeddings)

[INFO] 添加了 39 条新文档，向量库已更新并保存。


In [44]:
# 测试查询与正确答案
test_queries = [
    ("保单持有人何时可行使终期红利锁定权益？", "从终期红利锁定周年日（第 15 个保单周年日或其后每个保单周年日 ）起计 31 日内及受保人在世期间，保单持有人可行使终期红利锁定权益"),
    ("保单货币转换需要满足哪些条件？", "保单货币转换申请需满足以下条件：保单持有人必须在货币转换周年日（第 3 个保单周年日或其后每个保单周年日）起计 31 日内递交申请；同一个保单年度内未曾递交过转换申请；申请一旦递交不可撤回或更改；保单的名义金额在转换后必须不少于公司厘定的最低金额；保单持有人必须在公司批准申请前偿还全部欠款"),
]

# 对应的正确答案（也可以是文档列表，与你的检索文档匹配）
correct_answers = [
    "从终期红利锁定周年日（第 15 个保单周年日或其后每个保单周年日 ）起计 31 日内及受保人在世期间，保单持有人可行使终期红利锁定权益",
    "保单货币转换申请需满足以下条件：保单持有人必须在货币转换周年日（第 3 个保单周年日或其后每个保单周年日）起计 31 日内递交申请；同一个保单年度内未曾递交过转换申请；申请一旦递交不可撤回或更改；保单的名义金额在转换后必须不少于公司厘定的最低金额；保单持有人必须在公司批准申请前偿还全部欠款"
]


In [45]:

def test_retrieval_effectiveness(vector_db, test_queries, correct_answers, k=3):
    precision_scores = []
    recall_scores = []

    for query, correct_answer in zip(test_queries, correct_answers):
        # 对每个查询进行相似性检索
        results = vector_db.similarity_search(query[0], k=k)

        # 获取检索到的文档文本
        retrieved_docs = [doc.page_content for doc in results]

        # 计算相关性指标
        precision = calculate_precision(retrieved_docs, correct_answer)
        recall = calculate_recall(retrieved_docs, correct_answer)

        precision_scores.append(precision)
        recall_scores.append(recall)

    # 输出最终的精确度和召回率
    avg_precision = sum(precision_scores) / len(precision_scores)
    avg_recall = sum(recall_scores) / len(recall_scores)

    print(f"Average Precision: {avg_precision:.4f}")
    print(f"Average Recall: {avg_recall:.4f}")

def calculate_precision(retrieved_docs, correct_answer):
    # 判断检索文档中与正确答案的匹配程度（可以设定阈值，简单地判断字符串匹配）
    return int(any(correct_answer in doc for doc in retrieved_docs))

def calculate_recall(retrieved_docs, correct_answer):
    # 计算召回率：检索到的正确答案是否出现在返回的文档中
    return int(any(correct_answer in doc for doc in retrieved_docs))


In [46]:
def calculate_precision(retrieved_docs, correct_answer):
    relevant_count = 0
    for doc in retrieved_docs:
        for keyword in correct_answer.split():
            if keyword in doc:
                relevant_count += 1
                break
    return relevant_count / len(retrieved_docs) if len(retrieved_docs) > 0 else 0

def calculate_recall(retrieved_docs, correct_answer):
    # 标记是否找到相关文档
    found_relevant = False
    for doc in retrieved_docs:
        for keyword in correct_answer.split():
            if keyword in doc:
                found_relevant = True
                break
        if found_relevant:
            break
    return 1 if found_relevant else 0

def test_retrieval_effectiveness(vector_db, test_queries, correct_answers, k=3):
    precision_scores = []
    recall_scores = []

    for query, correct_answer in zip(test_queries, correct_answers):
        results = vector_db.similarity_search(query[0], k=k)
        print(f"Query: {query[0]}")
        print(f"Correct Answer: {correct_answer}")
        print(f"Retrieved Documents: {[doc.page_content[:100] for doc in results]}")
        retrieved_docs = [doc.page_content for doc in results]
        precision = calculate_precision(retrieved_docs, correct_answer)
        recall = calculate_recall(retrieved_docs, correct_answer)
        precision_scores.append(precision)
        recall_scores.append(recall)
    avg_precision = sum(precision_scores) / len(precision_scores)
    avg_recall = sum(recall_scores) / len(recall_scores)
    print(f"Average Precision: {avg_precision:.4f}")
    print(f"Average Recall: {avg_recall:.4f}")

In [47]:
# 执行测试
test_retrieval_effectiveness(vector_db, test_queries, correct_answers, k=3)


AttributeError: 'NoneType' object has no attribute 'similarity_search'