Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions ragflows/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,29 +124,35 @@ def parse_chunks(doc_ids, run=1):
}

@timeutils.monitor
def parse_chunks_with_check(filename):
def parse_chunks_with_check(filename, doc_id=None):
"""解析文档,并仅当文档解析完毕后才返回

Args:
filename (str): 文件名,非文件路径
doc_id (str): 文档id

Returns:
bool: 是否已上传并解析完毕
"""
doc_item = ragflowdb.get_doc_item_by_name(filename)
if not doc_item:
timeutils.print_log(f'找不到{filename}对应的数据库记录,跳过')
return False

doc_id = doc_item.get('id')
if not doc_id:
timeutils.print_log(f'根据文件名[{filename}]从数据库获取文档id')
doc_item = ragflowdb.get_doc_item_by_name(filename, max_retries=configs.SQL_RETRIES)
if not doc_item:
timeutils.print_log(f'找不到{filename}对应的数据库记录,跳过')
return False

doc_id = doc_item.get('id')

# 开始解析文档
r = parse_chunks(doc_ids=[doc_id], run=1)

if not is_succeed(r):
timeutils.print_log(F'失败 parse_chunks_with_check = {doc_item.get("id")}')
return False

while True:
doc_item = ragflowdb.get_doc_item(doc_id)
doc_item = ragflowdb.get_doc_item_by_id(doc_id, max_retries=configs.SQL_RETRIES)
if not doc_item:
return False

Expand All @@ -159,7 +165,7 @@ def parse_chunks_with_check(filename):

if configs.ENABLE_PROGRESS_LOG:
progress_percent = round(progress * 100, 2)
timeutils.print_log(f"[{filename}]解析进度为:{progress_percent}%")
timeutils.print_log(f"{filename}解析进度为:{progress_percent}%")

if progress == 1:
return True
Expand Down
2 changes: 2 additions & 0 deletions ragflows/configs.demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
# 切片进度查询间隔时间(秒)
PROGRESS_CHECK_INTERVAL = 1

# 查数据库重试次数(单次重试间隔为1秒)
SQL_RETRIES = 0

def get_header():
return {'authorization': AUTHORIZATION}
29 changes: 24 additions & 5 deletions ragflows/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import glob
import os
from ragflows import api, configs, ragflowdb
from utils import timeutils
from utils import fileutils, timeutils
from pathlib import Path


def get_docs_files() -> list:
Expand Down Expand Up @@ -64,7 +65,6 @@ def get_file_lines(file_path) -> int:
timeutils.print_log(f"打开文件 {file_path} 时出错,错误信息:{e}")
return 0


def main():
"""主函数,处理文档上传和解析"""

Expand All @@ -78,22 +78,40 @@ def main():
if not status:
raise Exception(msg)

# 使用 glob 模块获取所有 .md 文件
# 获取起始文件序号,从1开始计数,更符合非编程用户习惯
user_config_dir = os.path.join(Path.home(), '.ragflow_upload')
os.makedirs(user_config_dir, exist_ok=True)
index_filepath = f"{user_config_dir}/index_{configs.DIFY_DOC_KB_ID}_{configs.KB_NAME}.txt".replace(os.sep, "/")
start_index = int(fileutils.read(index_filepath) or 1)
if start_index < 1:
raise ValueError(f"【起始文件序号】值不能小于1,请改为大于等于1的值,或者删除序号缓存文件:{index_filepath}")

# 使用 glob 模块获取所有文件
doc_files = get_docs_files() or []

file_total = len(doc_files)
if file_total == 0:
raise ValueError(f"在 {configs.DOC_DIR} 目录下没有找到符合要求文档文件")

# 检查start_index是否超过文件总数
if start_index > file_total:
raise ValueError(f"起始文件序号 {start_index} > 文件总数 {file_total},请修改为正确的序号值,或者删除序号缓存文件:{index_filepath}")

# 打印找到的所有 .md 文件
for i in range(file_total):

if i < start_index - 1:
continue

file_path = doc_files[i]
file_path = file_path.replace(os.sep, '/')
filename = os.path.basename(file_path)

timeutils.print_log(f"【{i+1}/{file_total}】正在处理:{file_path}")

# 记录文件序号,从1开始计数
fileutils.save(index_filepath, str(i+1))

# 判断文件行数是否小于 目标值
if need_calculate_lines(file_path):
file_lines = get_file_lines(file_path)
Expand Down Expand Up @@ -133,7 +151,8 @@ def main():

# 上传成功,开始切片
timeutils.print_log(f'{file_path},开始切片并等待解析完毕')
status = api.parse_chunks_with_check(filename)
doc_id = response.get('data')[0].get('id') if response.get('data') else None
status = api.parse_chunks_with_check(filename, doc_id)
timeutils.print_log(file_path, "切片状态:", status, "\n")

timeutils.print_log('all done')
Expand Down
78 changes: 61 additions & 17 deletions ragflows/ragflowdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ragflows import configs
from utils.mysqlutils import BaseMySql
from utils import timeutils
import time


rag_db = None
Expand All @@ -17,8 +18,8 @@ def reset_connection():
if rag_db:
try:
rag_db.close_connect()
except:
pass
except Exception as e:
timeutils.print_log(f'reset_connection error: {e}')
rag_db = None

def get_db():
Expand All @@ -35,34 +36,77 @@ def get_db():

@timeutils.monitor
def get_doc_list(kb_id):
"""
根据知识库id获取文档列表

:param kb_id: 知识库id
:return: 文档列表
"""
db = get_db()
sql = f"select id,name,progress from document where kb_id = '{kb_id}'"
doc_ids = db.query_list(sql)
return doc_ids

# @timeutils.monitor
def get_doc_item(doc_id):
db = get_db()
sql = f"select id,name,progress from document where id = '{doc_id}'"
results = db.query_list(sql)
return results[0] if results else None
def get_doc_item_by_id(doc_id, max_retries=0, retry_interval=1):
"""根据文档id获取文档信息,支持重试"""
sql_str = f"select id,name,progress from document where id = '{doc_id}'"
return _query_doc_item_with_try(
sql_str=sql_str,
max_retries=max_retries,
retry_interval=retry_interval
)

# @timeutils.monitor
def get_doc_item_by_name(name):
db = get_db()
def get_doc_item_by_name(name, max_retries=0, retry_interval=1):
"""根据文档名称获取文档信息,支持重试"""
kb_id = configs.DIFY_DOC_KB_ID

if kb_id:
# 这里同时查询kb_id和name,如果document表中的数据量很大,需要增加kb_id和name的组合索引:CREATE INDEX document_kb_id_name ON document(kb_id, name);
sql = f"select id,name,progress from document where kb_id = '{kb_id}' and name = '{name}'"
sql_str = f"select id,name,progress from document where kb_id = '{kb_id}' and name = '{name}'"
else:
sql = f"select id,name,progress from document where name = '{name}'"
results = db.query_list(sql)
return results[0] if results else None
sql_str = f"select id,name,progress from document where name = '{name}'"

return _query_doc_item_with_try(
sql_str=sql_str,
max_retries=max_retries,
retry_interval=retry_interval
)

def _query_doc_item_with_try(sql_str, max_retries=0, retry_interval=1):
"""
根据文档名称获取文档信息,支持重试

:param sql_str: 查询语句
:param max_retries: 最大重试次数,0表示不重试
:param retry_interval: 重试间隔(秒)

:return: 文档信息 或 None
"""
db = get_db()

results = db.query_list(sql_str)

# 如果有值或者 max_retries为<=0,直接返回查询结果
if results or max_retries <= 0:
return results[0] if results else None

# 否则执行重试逻辑
for attempt in range(1, max_retries + 1):
where_str = sql_str.split('where ')[1] if 'where' in sql_str else ""
timeutils.print_log(f"查询{where_str}无结果,第{attempt}次重试...")
time.sleep(retry_interval)

results = db.query_list(sql_str)
if results:
return results[0]

return None

def exist(doc_id):
return get_doc_item(doc_id) is not None
"""根据文档id判断文档是否存在"""
return get_doc_item_by_id(doc_id) is not None

def exist_name(name):
"""根据文档名称判断文档是否存在"""
return get_doc_item_by_name(name) is not None


Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
requests==2.32.3
requests==2.32.4
pytz==2024.1
pymysql==1.1.1
Loading