In [1]:
import os
from langchain.chat_models import ChatOpenAI

from langchain.schema import SystemMessage, HumanMessage

from langchain.document_loaders import PDFPlumberLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter

from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import Chroma

os.environ["OPENAI_API_KEY"] = "sk-LzwgVgu5xvNPpwoqCdeeVcAt7Tu7ZoZICXzzkheldIbXA60h"
os.environ["OPENAI_API_BASE"] = "https://api.chatanywhere.com.cn/v1"

dataDir = "../data/"
dataName = "Deep Learning.pdf"

# 1. 载入书本数据

In [2]:
loader = PDFPlumberLoader(dataDir + dataName)

pages = loader.load_and_split()

# 1.1. 分段

In [3]:
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=500,
    chunk_overlap=50,
)

docs = text_splitter.split_documents(pages)

# 1.2. 向量化并存入向量数据库

In [4]:
embed_model = OpenAIEmbeddings(
    openai_api_base=os.environ["OPENAI_API_BASE"],
    openai_api_key=os.environ["OPENAI_API_KEY"],
    request_timeout=60,
)
vectorstore = Chroma.from_documents(
    documents=docs, embedding=embed_model, collection_name="openai_embed"
)

# 2. 定义Chat与System Prompt

In [5]:
chat = ChatOpenAI(
    openai_api_key=os.environ["OPENAI_API_KEY"],
    openai_api_base=os.environ["OPENAI_API_BASE"],
    model="gpt-3.5-turbo",
)

In [6]:
relations = "'based on', 'facet of', 'part of', 'instance of', 'subject of', 'subclass of', 'use'"
system_prompt = f"""
Role:
You are a relationship judgment specialist in the field of deep learning

Missions:
Given two entities, given the list of relations {relations}, determine what relationships are possible between the two entities, returning the top 3 most likely relationships

Format:
Please return in the following format:
(relation1, relation2, relation3)

Note:
1. Do not return anything other than the result
2. Use the relationship in the relationship list. Do not return other relationships

"""

# 3. 利用检索数据增强Prompt

In [7]:
def augment_prompt(query: str, topk=3):
    # 获取topk的文本片段
    results = vectorstore.similarity_search(query, k=topk)
    source_knowledge = "\n".join([x.page_content for x in results])
    # 构建prompt
    augmented_prompt = f"""Using the contexts below, answer the query.

  contexts:
  {source_knowledge}

  query: {query}"""
    return augmented_prompt

# 4. 测试与输出

In [9]:
def chat_RAG(RAG=True, query=NotImplemented):
    messages = [
    ]
    if RAG:
        prompt = HumanMessage(content=augment_prompt(query))
    else:
        prompt = HumanMessage(content=query)
    messages.append(prompt)
    res = chat(messages)

    return res

In [10]:
query = "what is MLP"
print("without RAG: ", chat_RAG(RAG=False, query=query))
print("with RAG: ", chat_RAG(RAG=True, query=query))

without RAG:  content='MLP can refer to several different things:\n\n1. MLP can stand for "My Little Pony," a franchise of toys, TV shows, and movies featuring colorful pony characters.\n\n2. MLP can also refer to "Master Limited Partnership," a type of business structure in the United States that combines the tax benefits of a partnership with the liquidity of publicly traded securities.\n\n3. MLP can also mean "Multilayer Perceptron," which is a type of artificial neural network used in machine learning and deep learning algorithms.\n\nIt\'s important to specify the context when using the term MLP to avoid confusion.'
with RAG:  content='MLP stands for Multi-Layer Perceptron.'
