In [1]:
!pip install -U langchain==0.2.11 openai==1.37.0 ragas==0.1.11 arxiv==2.1.3 pymupdf==1.24.9 chromadb==0.5.5 wandb==0.17.5 tiktoken==0.7.0 pypdf==4.3.1 sentence_transformers==2.7.0
!pip install rank_bm25

Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
[0mLooking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
[0m

In [2]:
from langchain.document_loaders import DirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings, SentenceTransformerEmbeddings
from langchain.vectorstores import Chroma
from langchain.retrievers import ParentDocumentRetriever, BM25Retriever, EnsembleRetriever
from langchain.storage import InMemoryStore
from langchain.llms import ChatGLM
from langchain.chains import RetrievalQA
import pandas as pd

In [3]:
import os
from langchain.document_loaders import pdf, PyPDFLoader

def load_pdf_doucuments(pdf_folder_path: str) -> list:
    base_docs = []
    
    if not os.path.exists(pdf_folder_path):
        raise FileNotFoundError(f"The folder '{pdf_folder_path}' does not exist.")

    for file in os.listdir(pdf_folder_path):
        if file.endswith('.pdf'):
            pdf_path = os.path.join(pdf_folder_path, file)
            print(f"Processing: {pdf_path}")
            try:
                loader = PyPDFLoader(pdf_path)
                pages = loader.load()
                base_docs.extend(pages)
            except Exception as e:
                print(f"Error processing {pdf_path}: {str(e)}")

    return base_docs

In [4]:
from langchain.llms.base import LLM
from typing import Any, List, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

class ChatGLM3_LLM(LLM):
    # 基于本地 ChatGLM3 自定义 LLM 类
    tokenizer : AutoTokenizer = None
    model: AutoModelForCausalLM = None

    def __init__(self, model_path :str):
        # model_path: ChatGLM3 模型路径
        # 从本地初始化模型
        super().__init__()
        print("正在从本地加载模型...")
        # 从本地加载一个预训练的分词器（tokenizer）
        self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_path,
            trust_remote_code=True,
            device_map="auto",
            torch_dtype=torch.float16).eval().to("cuda:0") 

        # 从本地加载一个预训练的生成式语言模型
        # self.model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).half().cuda()
        # 将模型设置为评估模式
        # self.model = self.model.eval()
        print("完成本地模型的加载")

    def _call(self, prompt: str, stop: Optional[List[str]] = None,
                  run_manager: Optional[CallbackManagerForLLMRun] = None,
                  **kwargs: Any):
            # 将输入 prompt 编码为 tokens
            inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
            
            # 设置生成参数，如最大生成 token 数量
            generate_kwargs = {
                "max_new_tokens": kwargs.get("max_new_tokens", 150),
                "temperature": kwargs.get("temperature", 0.7),
                "top_p": kwargs.get("top_p", 0.9),
                "do_sample": True,
                "eos_token_id": self.tokenizer.eos_token_id
            }
    
            # 使用模型生成响应
            output = self.model.generate(**inputs, **generate_kwargs)
            
            # 解码生成的 tokens 为文本
            response = self.tokenizer.decode(output[0], skip_special_tokens=True)
            return response
        
    @property
    def _llm_type(self) -> str:
        return "ChatGLM3-6B"

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
pdf_folder_path = 'dataset'
base_docs = load_pdf_doucuments(pdf_folder_path)
len(base_docs)

Processing: dataset/内向者优势.pdf
Processing: dataset/天才在左疯子在右.pdf
Processing: dataset/爱的艺术.pdf


EOF marker not found


Error processing dataset/爱的艺术.pdf: Stream has ended unexpectedly
Processing: dataset/自卑与超越.pdf
Processing: dataset/路西法效应.pdf


1281

In [6]:
# 初始化语言模型
model_path = os.path.expandvars("$GEMINI_PRETRAIN2/")
primary_qa_llm = ChatGLM3_LLM(model_path)

正在从本地加载模型...


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Loading checkpoint shards: 100%|██████████| 4/4 [00:50<00:00, 12.60s/it]

完成本地模型的加载





In [7]:
# 初始化嵌入模型
EMBEDDING_PATH = os.path.expandvars('$GEMINI_PRETRAIN3/bge-m3')
embeddings = SentenceTransformerEmbeddings(model_name=EMBEDDING_PATH)

  warn_deprecated(


In [8]:
base_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
parent_splitter = RecursiveCharacterTextSplitter(chunk_size=1500)
child_splitter = RecursiveCharacterTextSplitter(chunk_size=200)

# 1. Base Retriever
base_docs_split = base_splitter.split_documents(base_docs)
base_vectorstore = Chroma.from_documents(base_docs_split, embeddings)
base_retriever = base_vectorstore.as_retriever(search_kwargs={"k": 2})

# 2. Parent Document Retriever (PDR)
vectorstore = Chroma(collection_name="split_parents", embedding_function=embeddings)
store = InMemoryStore()
pdr = ParentDocumentRetriever(
    vectorstore=vectorstore,
    docstore=store,
    child_splitter=child_splitter,
    parent_splitter=parent_splitter,
)
pdr.add_documents(base_docs)

# 3. Ensemble Retriever (ER)
bm25_retriever = BM25Retriever.from_documents(base_docs_split)
bm25_retriever.k = 3
chroma_retriever = base_vectorstore.as_retriever(search_kwargs={"k": 3})
er = EnsembleRetriever(
    retrievers=[bm25_retriever, chroma_retriever],
    weights=[0.75, 0.25]
)

  warn_deprecated(


In [9]:
from langchain.prompts import ChatPromptTemplate

# Create QA chains
template = """你是一个专业心理咨询师，请结合你查阅到的知识回复以下问题。

### Context Information
{context}

### Question
Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template)

In [10]:
from langchain.schema.runnable import RunnableLambda, RunnablePassthrough
from operator import itemgetter

def create_qa_chain(retriever):
    return (
        {"context": itemgetter("question") | retriever,
         "question": itemgetter("question")
        }
        | RunnablePassthrough.assign(
            context=itemgetter("context")
          )
        | {
             "response": prompt | primary_qa_llm,
             "context": itemgetter("context"),
          }
    )

base_qa = create_qa_chain(base_retriever)
pdr_qa = create_qa_chain(pdr)
er_qa = create_qa_chain(er)

In [11]:
questions_and_answers = [
    {
        "question": "心理咨询师，我觉得我的胸闷症状越来越严重了，这让我很害怕。",
        "ground_truth": "我能理解你的感受，首先我们要明确你的症状并不是生理问题，而是心理问题。我们可以尝试找出引发你胸闷的心理原因。"
    },
    {
        "question": "您好，最近我总是因为一些小事睡不着，心里着急，第二天感觉疲乏无力。而且我还经常担心自己生病。",
        "ground_truth": "我建议您尝试服用抗焦虑药物，并密切关注病情变化。同时，在咨询过程中，我们可以学习如何有效地应对焦虑情绪，减轻躯体症状。"
    },
    {
        "question": "你好，心理咨询师，我觉得我最近的情绪很不好，总是容易发脾气。",
        "ground_truth": "你很关心父母的感受，这是一个很好的品质。但我们都知道，失败是成功的垫脚石，每个人都会经历失败。你觉得，如果你失败了，父母会怎么样呢？"
    },
    {
        "question": "您好，最近我一直想减肥，但总是没有行动力。我在美容院花了几万块钱，但就是不去。我希望通过催眠提高我的行动力。",
        "ground_truth": "您好，感谢您来咨询。我了解到您想通过催眠来提高行动力以实现减肥的目标。请您谈谈您的生活和工作状况，以及您认为行动力不足的原因。"
    }
]

In [12]:
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Helper function to safely get chunks
def safe_get_chunks(retriever, question):
    try:
        chunks = retriever.invoke(question)
        if not chunks:
            logger.warning(f"No chunks returned for question: {question}")
            return []
        return chunks
    except Exception as e:
        logger.error(f"Error retrieving chunks for question '{question}': {str(e)}")
        return []

# Helper function to safely execute QA
def safe_qa_invoke(qa_chain, question):
    try:
        result = qa_chain.invoke({"question": question})
        return result.get("response", "No response generated")
    except Exception as e:
        logger.error(f"Error in QA chain for question '{question}': {str(e)}")
        return "Error in generating response"

# Execute queries and save results
results = []
for idx, qa in enumerate(questions_and_answers, 1):
    question = qa["question"]
    ground_truth = qa["ground_truth"]
    
    logger.info(f"Processing question {idx}: {question}")

    # For base_retriever
    base_chunks = safe_get_chunks(base_retriever, question)
    base_result = safe_qa_invoke(base_qa, question)
    
    # For PDR
    pdr_chunks = safe_get_chunks(pdr, question)
    pdr_result = safe_qa_invoke(pdr_qa, question)
    
    # For ER
    er_chunks = safe_get_chunks(er, question)
    er_result = safe_qa_invoke(er_qa, question)
    
    results.append({
        "index": idx,
        "question": question,
        "ground_truth_answer": ground_truth,
        "base_retriever_chunks_size1000_overlap100_k2": str(base_chunks),
        "base_retriever_answer_size1000_overlap100_k2": base_result,
        "PDR_chunks_psize1500_csize200": str(pdr_chunks),
        "PDR_answer_psize1500_csize200": pdr_result,
        "ER_chunks_size1000_overlap100_k3_w75": str(er_chunks),
        "ER_answer_size1000_overlap100_k3_w75": er_result,
    })

    logger.info(f"Completed processing question {idx}")

# Save results to CSV
df = pd.DataFrame(results)
df.to_csv("rag-result.csv", index=False)
logger.info("Results saved to rag-result.csv")

INFO:__main__:Processing question 1: 心理咨询师，我觉得我的胸闷症状越来越严重了，这让我很害怕。
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
INFO:__main__:Completed processing question 1
INFO:__main__:Processing question 2: 您好，最近我总是因为一些小事睡不着，心里着急，第二天感觉疲乏无力。而且我还经常担心自己生病。
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
INFO:__main__:Completed processing question 2
INFO:__main__:Processing question 3: 你好，心理咨询师，我觉得我最近的情绪很不好，总是容易发脾气。
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
INFO:__main__:Completed processi