In [2]:
import os
from langchain.chat_models import ChatOpenAI

from langchain.schema import SystemMessage, HumanMessage

from langchain.document_loaders import PDFPlumberLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter

from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import Chroma

os.environ["OPENAI_API_KEY"] = "sk-LzwgVgu5xvNPpwoqCdeeVcAt7Tu7ZoZICXzzkheldIbXA60h"
os.environ["OPENAI_API_BASE"] = "https://api.chatanywhere.com.cn/v1"

dataDir = "../data/"
dataName = "Deep Learning.pdf"

# 1. 载入书本数据

In [3]:
loader = PDFPlumberLoader(dataDir + dataName)

pages = loader.load_and_split()

KeyboardInterrupt: 

# 1.1. 分段

In [None]:
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=500,
    chunk_overlap=50,
)

docs = text_splitter.split_documents(pages)

# 1.2. 向量化并存入向量数据库

In [None]:
embed_model = OpenAIEmbeddings(
    openai_api_base=os.environ["OPENAI_API_BASE"],
    openai_api_key=os.environ["OPENAI_API_KEY"],
    request_timeout=60,
)
vectorstore = Chroma.from_documents(
    documents=docs, embedding=embed_model, collection_name="openai_embed"
)

In [1]:
vectorstore

NameError: name 'vectorstore' is not defined

# 2. 定义Chat与System Prompt

In [6]:
chat = ChatOpenAI(
    openai_api_key=os.environ["OPENAI_API_KEY"],
    openai_api_base=os.environ["OPENAI_API_BASE"],
    model="gpt-3.5-turbo",
)

In [9]:
relations = "'based on', 'facet of', 'part of', 'instance of', 'subject of', 'subclass of', 'use'"
system_prompt = f"""
角色：
你是一个深度学习领域的关系判断专员

任务：
给定两个实体，给定关系列表{relations}，请判断两个实体间可能存在什么关系，返回前3个最有可能的关系

格式：
请以以下格式返回：
(relation1, relation2, relation3)

注意事项：
1. 除返回结果外，不要返回任何其他内容
2. 以关系列表中的关系为准，不要返回其他关系

"""

# 3. 利用检索数据增强Prompt

In [7]:
def augment_prompt(query: str, topk=3):
    # 获取topk的文本片段
    results = vectorstore.similarity_search(query, k=topk)
    source_knowledge = "\n".join([x.page_content for x in results])
    # 构建prompt
    augmented_prompt = f"""Using the contexts below, answer the query.

  contexts:
  {source_knowledge}

  query: {query}"""
    return augmented_prompt

# 4. 测试与输出

In [12]:
def chat_RAG(RAG=True, query=NotImplemented):
    messages = [
    ]
    if RAG:
        prompt = HumanMessage(content=augment_prompt(query))
    else:
        prompt = HumanMessage(content=query)
    messages.append(prompt)
    res = chat(messages)

    return res

In [14]:
query = "what is pool"
print("without RAG: ", chat_RAG(RAG=False, query=query))
print("with RAG: ", chat_RAG(RAG=True, query=query))

without RAG:  content='A pool is a small to large artificial body of water, typically rectangular in shape, that is designed for swimming, recreational activities, or water sports. Pools can be found in private residences, hotels, resorts, and public recreational facilities. They are often filled with treated or chlorinated water to maintain cleanliness and safety. Pools can range in size and depth, with some featuring additional amenities such as diving boards, water slides, or hot tubs.'
with RAG:  content='Pooling is a form of randomized pooling for building ensembles of convolutional networks with each convolutional network attending to different spatial locations of each feature map.'
