Skip to content

Commit

Permalink
fixed file-based RAG
Browse files Browse the repository at this point in the history
  • Loading branch information
GaiZhenbiao committed Apr 10, 2024
1 parent bb92452 commit 3f4dde0
Showing 1 changed file with 13 additions and 12 deletions.
25 changes: 13 additions & 12 deletions modules/index_func.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import os
import hashlib
import logging
import os

import hashlib
import PyPDF2
from langchain_community.chat_models import ChatOpenAI
from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
from tqdm import tqdm

from modules.config import local_embedding
from modules.presets import *
from modules.utils import *
from modules.config import local_embedding


def get_documents(file_src):
Expand All @@ -28,8 +32,8 @@ def get_documents(file_src):
if file_type == ".pdf":
logging.debug("Loading PDF...")
try:
from modules.pdf_func import parse_pdf
from modules.config import advance_docs
from modules.pdf_func import parse_pdf

two_column = advance_docs["pdf"].get("two_column", False)
pdftext = parse_pdf(filepath, two_column).text
Expand All @@ -43,12 +47,14 @@ def get_documents(file_src):
metadata={"source": filepath})]
elif file_type == ".docx":
logging.debug("Loading Word...")
from langchain.document_loaders import UnstructuredWordDocumentLoader
from langchain.document_loaders import \
UnstructuredWordDocumentLoader
loader = UnstructuredWordDocumentLoader(filepath)
texts = loader.load()
elif file_type == ".pptx":
logging.debug("Loading PowerPoint...")
from langchain.document_loaders import UnstructuredPowerPointLoader
from langchain.document_loaders import \
UnstructuredPowerPointLoader
loader = UnstructuredPowerPointLoader(filepath)
texts = loader.load()
elif file_type == ".epub":
Expand Down Expand Up @@ -93,9 +99,6 @@ def construct_index(
separator=" ",
load_from_cache_if_possible=True,
):
from langchain.chat_models import ChatOpenAI
from langchain.vectorstores import FAISS

if api_key:
os.environ["OPENAI_API_KEY"] = api_key
else:
Expand All @@ -109,11 +112,9 @@ def construct_index(
index_name = get_file_hash(file_src)
index_path = f"./index/{index_name}"
if local_embedding:
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/distiluse-base-multilingual-cased-v2")
else:
from langchain.embeddings import OpenAIEmbeddings
if os.environ.get("OPENAI_API_TYPE", "openai") == "openai":
embeddings = OpenAIEmbeddings(openai_api_base=os.environ.get(
"OPENAI_API_BASE", None), openai_api_key=os.environ.get("OPENAI_EMBEDDING_API_KEY", api_key))
Expand All @@ -122,7 +123,7 @@ def construct_index(
model=os.environ["AZURE_EMBEDDING_MODEL_NAME"], openai_api_base=os.environ["AZURE_OPENAI_API_BASE_URL"], openai_api_type="azure")
if os.path.exists(index_path) and load_from_cache_if_possible:
logging.info(i18n("找到了缓存的索引文件,加载中……"))
return FAISS.load_local(index_path, embeddings)
return FAISS.load_local(index_path, embeddings, allow_dangerous_deserialization=True)
else:
documents = get_documents(file_src)
logging.debug(i18n("构建索引中……"))
Expand Down

0 comments on commit 3f4dde0

Please sign in to comment.