In [None]:
# Standard library imports
from typing_extensions import List, TypedDict
import faiss
import torch
from langchain import HuggingFacePipeline
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain_core.documents import Document
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.docstore.in_memory import InMemoryDocstore
from langchain_core.prompts import PromptTemplate
from langgraph.graph import START, StateGraph
from langchain_huggingface.llms import HuggingFacePipeline
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    pipeline
)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(
    r"D:\Backend_insurance\Algorithm\Fine_tuning\merged_16bit", 
    trust_remote_code=True
)

model = AutoModelForCausalLM.from_pretrained(
    r"D:\Backend_insurance\Algorithm\Fine_tuning\merged_16bit", 
    trust_remote_code=True,
    torch_dtype=torch.float16,
    device_map="auto",
    load_in_4bit=True
)


pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=200)
llm = HuggingFacePipeline(pipeline=pipe)

In [None]:
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=200)
llm = HuggingFacePipeline(pipeline=pipe)

In [None]:
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")

In [None]:
embedding_dim = len(embeddings.embed_query("hello world"))
index = faiss.IndexFlatL2(embedding_dim)

vector_store = FAISS(
    embedding_function=embeddings,
    index=index,
    docstore=InMemoryDocstore(),
    index_to_docstore_id={},
)

In [None]:
loader = PyPDFLoader(r'D:\jincheng_project\RAG\pdf\t2.pdf')
documents = loader.load()

In [None]:
all_splits  = RecursiveCharacterTextSplitter(
        chunk_size=1000,
        chunk_overlap=200,
        add_start_index=True
    ).split_documents(documents)

In [None]:
# Index chunks
_ = vector_store.add_documents(documents=all_splits)

In [None]:
template = """
请根据以下规则回答问题：

### 规则：
1. **如果问题涉及电话号码或地址**，必须严格使用提供的上下文信息。  
   - 若上下文中有答案，可参考但用自己的话回答。  
   - 若上下文中无相关信息，可根据自身知识补充。  

### 上下文：
{context}

### 问题：
{question}

### 回答：
"""

prompt = PromptTemplate.from_template(template)

In [None]:
# Define state for application
class State(TypedDict):
    question: str
    context: List[Document]
    answer: str


def retrieve(state: State):
    retrieved_docs = vector_store.similarity_search(state["question"], k=1)
    if not retrieved_docs:
        return {"context": [Document(page_content="无相关信息")]}
    return {"context": retrieved_docs}


def generate(state: State):
    docs_content = "\n\n".join(doc.page_content for doc in state["context"])
    formatted_prompt = prompt.format(question=state["question"], context=docs_content)
    
    # Get just the generated text without the prompt
    response = llm(formatted_prompt)
    
    # Extract just the answer part (you may need to adjust this based on your model's output format)
    if isinstance(response, dict):
        answer = response.get("generated_text", "").replace(formatted_prompt, "").strip()
    else:
        answer = response.strip()
    
    return {"answer": answer}

# Compile application and test
graph_builder = StateGraph(State).add_sequence([retrieve, generate])
graph_builder.add_edge(START, "retrieve")
graph = graph_builder.compile()

In [None]:
response = graph.invoke({"question": "我遇到保险问题，我该联系谁以及如何联系"})
print((response['answer'].split("### 专业答案")[1]).split("###")[0])