Skip to content

Commit

Permalink
openai besed RAG uses text-embedding-3-large now.
Browse files Browse the repository at this point in the history
  • Loading branch information
GaiZhenbiao committed Apr 10, 2024
1 parent 0c4dc56 commit fccd3de
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 14 deletions.
56 changes: 44 additions & 12 deletions modules/index_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
def get_documents(file_src):
from langchain.schema import Document
from langchain.text_splitter import TokenTextSplitter

text_splitter = TokenTextSplitter(chunk_size=500, chunk_overlap=30)

documents = []
Expand All @@ -43,41 +44,61 @@ def get_documents(file_src):
pdfReader = PyPDF2.PdfReader(pdfFileObj)
for page in tqdm(pdfReader.pages):
pdftext += page.extract_text()
texts = [Document(page_content=pdftext,
metadata={"source": filepath})]
texts = [Document(page_content=pdftext, metadata={"source": filepath})]
elif file_type == ".docx":
logging.debug("Loading Word...")
from langchain.document_loaders import \
UnstructuredWordDocumentLoader

loader = UnstructuredWordDocumentLoader(filepath)
texts = loader.load()
elif file_type == ".pptx":
logging.debug("Loading PowerPoint...")
from langchain.document_loaders import \
UnstructuredPowerPointLoader

loader = UnstructuredPowerPointLoader(filepath)
texts = loader.load()
elif file_type == ".epub":
logging.debug("Loading EPUB...")
from langchain.document_loaders import UnstructuredEPubLoader

loader = UnstructuredEPubLoader(filepath)
texts = loader.load()
elif file_type == ".xlsx":
logging.debug("Loading Excel...")
text_list = excel_to_string(filepath)
texts = []
for elem in text_list:
texts.append(Document(page_content=elem,
metadata={"source": filepath}))
elif file_type in [".jpg", ".jpeg", ".png", ".heif", ".heic", ".webp", ".bmp", ".gif", ".tiff", ".tif"]:
raise gr.Warning(i18n("不支持的文件: ") + filename + i18n(",请使用 .pdf, .docx, .pptx, .epub, .xlsx 等文档。"))
texts.append(
Document(page_content=elem, metadata={"source": filepath})
)
elif file_type in [
".jpg",
".jpeg",
".png",
".heif",
".heic",
".webp",
".bmp",
".gif",
".tiff",
".tif",
]:
raise gr.Warning(
i18n("不支持的文件: ")
+ filename
+ i18n(",请使用 .pdf, .docx, .pptx, .epub, .xlsx 等文档。")
)
else:
logging.debug("Loading text file...")
from langchain.document_loaders import TextLoader

loader = TextLoader(filepath, "utf8")
texts = loader.load()
except Exception as e:
import traceback

logging.error(f"Error loading file: {filename}")
traceback.print_exc()

Expand Down Expand Up @@ -113,17 +134,28 @@ def construct_index(
index_path = f"./index/{index_name}"
if local_embedding:
embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/distiluse-base-multilingual-cased-v2")
model_name="sentence-transformers/distiluse-base-multilingual-cased-v2"
)
else:
if os.environ.get("OPENAI_API_TYPE", "openai") == "openai":
embeddings = OpenAIEmbeddings(openai_api_base=os.environ.get(
"OPENAI_API_BASE", None), openai_api_key=os.environ.get("OPENAI_EMBEDDING_API_KEY", api_key))
embeddings = OpenAIEmbeddings(
openai_api_base=os.environ.get("OPENAI_API_BASE", None),
openai_api_key=os.environ.get("OPENAI_EMBEDDING_API_KEY", api_key),
model="text-embedding-3-large",
)
else:
embeddings = OpenAIEmbeddings(deployment=os.environ["AZURE_EMBEDDING_DEPLOYMENT_NAME"], openai_api_key=os.environ["AZURE_OPENAI_API_KEY"],
model=os.environ["AZURE_EMBEDDING_MODEL_NAME"], openai_api_base=os.environ["AZURE_OPENAI_API_BASE_URL"], openai_api_type="azure")
embeddings = OpenAIEmbeddings(
deployment=os.environ["AZURE_EMBEDDING_DEPLOYMENT_NAME"],
openai_api_key=os.environ["AZURE_OPENAI_API_KEY"],
model=os.environ["AZURE_EMBEDDING_MODEL_NAME"],
openai_api_base=os.environ["AZURE_OPENAI_API_BASE_URL"],
openai_api_type="azure",
)
if os.path.exists(index_path) and load_from_cache_if_possible:
logging.info(i18n("找到了缓存的索引文件,加载中……"))
return FAISS.load_local(index_path, embeddings, allow_dangerous_deserialization=True)
return FAISS.load_local(
index_path, embeddings, allow_dangerous_deserialization=True
)
else:
documents = get_documents(file_src)
logging.debug(i18n("构建索引中……"))
Expand Down
4 changes: 2 additions & 2 deletions modules/models/ChuanhuAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@
from langchain.text_splitter import TokenTextSplitter
from langchain.tools import StructuredTool, Tool
from langchain.vectorstores.base import VectorStoreRetriever
from langchain_community.embeddings import OpenAIEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_core.messages.ai import AIMessage
from langchain_core.messages.human import HumanMessage
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
from langchain_openai import ChatOpenAI
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from pydantic.v1 import BaseModel, Field

from ..index_func import construct_index
Expand Down Expand Up @@ -217,6 +216,7 @@ def ask_url(self, url, question):
embeddings = OpenAIEmbeddings(
openai_api_key=self.api_key,
openai_api_base=os.environ.get("OPENAI_API_BASE", None),
model="text-embedding-3-large",
)

# create vectorstore
Expand Down

0 comments on commit fccd3de

Please sign in to comment.