diff --git a/ragflows/api.py b/ragflows/api.py index c6da75e..ce500f5 100644 --- a/ragflows/api.py +++ b/ragflows/api.py @@ -124,21 +124,27 @@ def parse_chunks(doc_ids, run=1): } @timeutils.monitor -def parse_chunks_with_check(filename): +def parse_chunks_with_check(filename, doc_id=None): """解析文档,并仅当文档解析完毕后才返回 Args: filename (str): 文件名,非文件路径 + doc_id (str): 文档id Returns: bool: 是否已上传并解析完毕 """ - doc_item = ragflowdb.get_doc_item_by_name(filename) - if not doc_item: - timeutils.print_log(f'找不到{filename}对应的数据库记录,跳过') - return False - doc_id = doc_item.get('id') + if not doc_id: + timeutils.print_log(f'根据文件名[{filename}]从数据库获取文档id') + doc_item = ragflowdb.get_doc_item_by_name(filename, max_retries=configs.SQL_RETRIES) + if not doc_item: + timeutils.print_log(f'找不到{filename}对应的数据库记录,跳过') + return False + + doc_id = doc_item.get('id') + + # 开始解析文档 r = parse_chunks(doc_ids=[doc_id], run=1) if not is_succeed(r): @@ -146,7 +152,7 @@ def parse_chunks_with_check(filename): return False while True: - doc_item = ragflowdb.get_doc_item(doc_id) + doc_item = ragflowdb.get_doc_item_by_id(doc_id, max_retries=configs.SQL_RETRIES) if not doc_item: return False @@ -159,7 +165,7 @@ def parse_chunks_with_check(filename): if configs.ENABLE_PROGRESS_LOG: progress_percent = round(progress * 100, 2) - timeutils.print_log(f"[{filename}]解析进度为:{progress_percent}%") + timeutils.print_log(f"{filename}解析进度为:{progress_percent}%") if progress == 1: return True diff --git a/ragflows/configs.demo.py b/ragflows/configs.demo.py index 825068b..6e8704b 100644 --- a/ragflows/configs.demo.py +++ b/ragflows/configs.demo.py @@ -31,6 +31,8 @@ # 切片进度查询间隔时间(秒) PROGRESS_CHECK_INTERVAL = 1 +# 查数据库重试次数(单次重试间隔为1秒) +SQL_RETRIES = 0 def get_header(): return {'authorization': AUTHORIZATION} \ No newline at end of file diff --git a/ragflows/main.py b/ragflows/main.py index 6d2b208..6004911 100644 --- a/ragflows/main.py +++ b/ragflows/main.py @@ -1,7 +1,8 @@ import glob import os from ragflows import api, configs, ragflowdb -from utils import timeutils +from utils import fileutils, timeutils +from pathlib import Path def get_docs_files() -> list: @@ -64,7 +65,6 @@ def get_file_lines(file_path) -> int: timeutils.print_log(f"打开文件 {file_path} 时出错,错误信息:{e}") return 0 - def main(): """主函数,处理文档上传和解析""" @@ -78,22 +78,40 @@ def main(): if not status: raise Exception(msg) - # 使用 glob 模块获取所有 .md 文件 + # 获取起始文件序号,从1开始计数,更符合非编程用户习惯 + user_config_dir = os.path.join(Path.home(), '.ragflow_upload') + os.makedirs(user_config_dir, exist_ok=True) + index_filepath = f"{user_config_dir}/index_{configs.DIFY_DOC_KB_ID}_{configs.KB_NAME}.txt".replace(os.sep, "/") + start_index = int(fileutils.read(index_filepath) or 1) + if start_index < 1: + raise ValueError(f"【起始文件序号】值不能小于1,请改为大于等于1的值,或者删除序号缓存文件:{index_filepath}") + + # 使用 glob 模块获取所有文件 doc_files = get_docs_files() or [] - + file_total = len(doc_files) if file_total == 0: raise ValueError(f"在 {configs.DOC_DIR} 目录下没有找到符合要求文档文件") + # 检查start_index是否超过文件总数 + if start_index > file_total: + raise ValueError(f"起始文件序号 {start_index} > 文件总数 {file_total},请修改为正确的序号值,或者删除序号缓存文件:{index_filepath}") + # 打印找到的所有 .md 文件 for i in range(file_total): + if i < start_index - 1: + continue + file_path = doc_files[i] file_path = file_path.replace(os.sep, '/') filename = os.path.basename(file_path) timeutils.print_log(f"【{i+1}/{file_total}】正在处理:{file_path}") + # 记录文件序号,从1开始计数 + fileutils.save(index_filepath, str(i+1)) + # 判断文件行数是否小于 目标值 if need_calculate_lines(file_path): file_lines = get_file_lines(file_path) @@ -133,7 +151,8 @@ def main(): # 上传成功,开始切片 timeutils.print_log(f'{file_path},开始切片并等待解析完毕') - status = api.parse_chunks_with_check(filename) + doc_id = response.get('data')[0].get('id') if response.get('data') else None + status = api.parse_chunks_with_check(filename, doc_id) timeutils.print_log(file_path, "切片状态:", status, "\n") timeutils.print_log('all done') diff --git a/ragflows/ragflowdb.py b/ragflows/ragflowdb.py index a59e6ba..2da0105 100644 --- a/ragflows/ragflowdb.py +++ b/ragflows/ragflowdb.py @@ -7,6 +7,7 @@ from ragflows import configs from utils.mysqlutils import BaseMySql from utils import timeutils +import time rag_db = None @@ -17,8 +18,8 @@ def reset_connection(): if rag_db: try: rag_db.close_connect() - except: - pass + except Exception as e: + timeutils.print_log(f'reset_connection error: {e}') rag_db = None def get_db(): @@ -35,34 +36,77 @@ def get_db(): @timeutils.monitor def get_doc_list(kb_id): + """ + 根据知识库id获取文档列表 + + :param kb_id: 知识库id + :return: 文档列表 + """ db = get_db() sql = f"select id,name,progress from document where kb_id = '{kb_id}'" doc_ids = db.query_list(sql) return doc_ids -# @timeutils.monitor -def get_doc_item(doc_id): - db = get_db() - sql = f"select id,name,progress from document where id = '{doc_id}'" - results = db.query_list(sql) - return results[0] if results else None +def get_doc_item_by_id(doc_id, max_retries=0, retry_interval=1): + """根据文档id获取文档信息,支持重试""" + sql_str = f"select id,name,progress from document where id = '{doc_id}'" + return _query_doc_item_with_try( + sql_str=sql_str, + max_retries=max_retries, + retry_interval=retry_interval + ) -# @timeutils.monitor -def get_doc_item_by_name(name): - db = get_db() +def get_doc_item_by_name(name, max_retries=0, retry_interval=1): + """根据文档名称获取文档信息,支持重试""" kb_id = configs.DIFY_DOC_KB_ID + if kb_id: - # 这里同时查询kb_id和name,如果document表中的数据量很大,需要增加kb_id和name的组合索引:CREATE INDEX document_kb_id_name ON document(kb_id, name); - sql = f"select id,name,progress from document where kb_id = '{kb_id}' and name = '{name}'" + sql_str = f"select id,name,progress from document where kb_id = '{kb_id}' and name = '{name}'" else: - sql = f"select id,name,progress from document where name = '{name}'" - results = db.query_list(sql) - return results[0] if results else None + sql_str = f"select id,name,progress from document where name = '{name}'" + + return _query_doc_item_with_try( + sql_str=sql_str, + max_retries=max_retries, + retry_interval=retry_interval + ) + +def _query_doc_item_with_try(sql_str, max_retries=0, retry_interval=1): + """ + 根据文档名称获取文档信息,支持重试 + + :param sql_str: 查询语句 + :param max_retries: 最大重试次数,0表示不重试 + :param retry_interval: 重试间隔(秒) + + :return: 文档信息 或 None + """ + db = get_db() + + results = db.query_list(sql_str) + + # 如果有值或者 max_retries为<=0,直接返回查询结果 + if results or max_retries <= 0: + return results[0] if results else None + + # 否则执行重试逻辑 + for attempt in range(1, max_retries + 1): + where_str = sql_str.split('where ')[1] if 'where' in sql_str else "" + timeutils.print_log(f"查询{where_str}无结果,第{attempt}次重试...") + time.sleep(retry_interval) + + results = db.query_list(sql_str) + if results: + return results[0] + + return None def exist(doc_id): - return get_doc_item(doc_id) is not None + """根据文档id判断文档是否存在""" + return get_doc_item_by_id(doc_id) is not None def exist_name(name): + """根据文档名称判断文档是否存在""" return get_doc_item_by_name(name) is not None diff --git a/requirements.txt b/requirements.txt index 31a34b1..5401f40 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ -requests==2.32.3 +requests==2.32.4 pytz==2024.1 pymysql==1.1.1 \ No newline at end of file diff --git a/scripts/launcher.py b/scripts/launcher.py index 8383c9b..efa52aa 100644 --- a/scripts/launcher.py +++ b/scripts/launcher.py @@ -74,13 +74,14 @@ def __init__(self): self.current_thread = None # 添加线程跟踪变量 self.is_running = False # 添加运行状态标志 self.should_stop = False # 添加停止标志 + self.is_stopping = False # 添加正在停止标志 self.log_handlers = [] # 添加日志处理器列表 self.original_print_log = None # 保存原始的日志打印函数 self.title("RagFlow Upload") - self.geometry("800x660") + self.geometry("800x700") # 版本和仓库信息 - self.version = "v1.0.2" # 版本号 + self.version = "v1.0.3" # 版本号 self.github_repo = "https://github.com/Samge0/ragflow-upload" # GitHub仓库地址 # 自定义图标 @@ -99,6 +100,7 @@ def __init__(self): "DOC_DIR": {"type": str, "label": "文档目录", "default": "your doc dir"}, "DOC_SUFFIX": {"type": str, "label": "文档后缀", "default": "md,txt,pdf,docx"}, "PROGRESS_CHECK_INTERVAL": {"type": int, "label": "切片进度查询间隔", "default": "1"}, + "SQL_RETRIES": {"type": int, "label": "SQL查询重试次数", "default": "1"}, "MYSQL_HOST": {"type": str, "label": "MySQL主机", "default": "localhost"}, "MYSQL_PORT": {"type": int, "label": "MySQL端口", "default": "5455"}, @@ -108,10 +110,12 @@ def __init__(self): "DOC_MIN_LINES": {"type": int, "label": "最小行数", "default": "1"}, "ONLY_UPLOAD": {"type": bool, "label": "仅上传文件", "default": "False"}, "ENABLE_PROGRESS_LOG": {"type": bool, "label": "打印切片进度日志", "default": "True"}, + "UI_START_INDEX": {"type": int, "label": "起始文件序号", "default": "1"}, # 从1开始计数,更符合非编程用户习惯 } self.create_ui() self.load_config() + self.load_index_from_cache() # 加载缓存中的序号 def create_ui(self): # 主框架 @@ -231,12 +235,20 @@ def toggle_run(self): if self.current_thread and self.current_thread.is_alive(): self.log("上一个任务还在运行中,请等待完成或点击停止") return + + # 运行前时将滚动条设置到底部 + self.log_text.see("end") + self.start_run() else: - self.stop_run() + if not self.is_stopping: # 防止重复点击 + self.stop_run() def start_run(self): """开始运行""" + # 保存当前序号到缓存 + self.save_index_to_cache() + self.is_running = True self.run_button.configure( text="停止", @@ -250,10 +262,14 @@ def start_run(self): def stop_run(self): """停止运行""" if self.current_thread and self.current_thread.is_alive(): + self.is_stopping = True # 设置正在停止标志 self.is_running = False self.should_stop = True # 设置停止标志 self.log("正在停止运行...") + # 禁用停止按钮,防止重复点击 + self.run_button.configure(state="disabled") + # 尝试终止线程 try: import ctypes @@ -263,19 +279,30 @@ def stop_run(self): ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(thread_id), exc) except Exception as e: self.log(f"停止线程时出错: {str(e)}") - finally: - # 无论是否成功停止线程,都更新UI状态 + + # 等待线程真正结束 + def wait_thread_end(): + if self.current_thread: + self.current_thread.join() + # 线程结束后更新UI状态 self.current_thread = None + self.is_stopping = False self.run_button.configure( text="运行", fg_color=["#3B8ED0", "#1F6AA5"], # 默认蓝色 hover_color=["#36719F", "#144870"], # 深蓝色 - text_color="white" # 白色文字 + text_color="white", # 白色文字 + state="normal" # 恢复按钮状态 ) self.set_config_entries_state("normal") + # 重新加载缓存中的序号 + self.load_index_from_cache() self.log("已停止运行") # 清理日志处理器 self.cleanup_log_handlers() + + # 在新线程中等待原线程结束 + threading.Thread(target=wait_thread_end, daemon=True).start() def set_config_entries_state(self, state): """设置配置项的启用/禁用状态""" @@ -361,7 +388,7 @@ def run(): self.log("数据库连接已重置") except Exception as e: self.log(f"数据库连接失败: {str(e)},请检查数据库配置后重试") - self.stop_run() + self.should_stop = True # 设置停止标志 return # 中断执行 # 动态导入主程序 @@ -394,15 +421,31 @@ def check_stop(): if not self.should_stop: # 只有在非停止状态下才显示完成消息 self.log("程序运行完成") - self.stop_run() except Exception as e: self.log(f"运行失败: {str(e)}") self.log("详细错误信息:") self.log(traceback.format_exc()) - self.stop_run() finally: # 确保在任何情况下都清理日志处理器 self.cleanup_log_handlers() + # 确保在任何情况下都更新停止状态 + if self.is_running: # 如果还在运行状态,说明是异常导致的停止 + self.is_running = False + self.should_stop = True + # 更新UI状态 + self.current_thread = None + self.is_stopping = False + self.run_button.configure( + text="运行", + fg_color=["#3B8ED0", "#1F6AA5"], # 默认蓝色 + hover_color=["#36719F", "#144870"], # 深蓝色 + text_color="white", # 白色文字 + state="normal" # 恢复按钮状态 + ) + self.set_config_entries_state("normal") + # 重新加载缓存中的序号 + self.load_index_from_cache() + self.log("已停止运行") # 在新线程中运行上传任务 self.current_thread = threading.Thread(target=run, daemon=True) @@ -421,11 +464,23 @@ def open_config_dir(self): subprocess.run(['open', config_dir]) else: # Linux subprocess.run(['xdg-open', config_dir]) + + def is_scrollbar_at_bottom(self): + """检查滚动条是否在底部""" + current_position = self.log_text.yview()[1] + # 添加一个小的容差值(0.9)来判断是否在底部 + is_at_bottom = current_position >= 0.9 + return is_at_bottom + def log(self, message): # 输出到GUI self.log_text.configure(state="normal") self.log_text.insert("end", f"{get_now_str()} {message}\n") - self.log_text.see("end") + + # 只有当滚动条在底部时才自动滚动 + if self.is_scrollbar_at_bottom(): + self.log_text.see("end") + self.log_text.configure(state="disabled") # 保存到日志文件 log_save_handler.log(message) @@ -466,6 +521,8 @@ def save_config(self): with open(config_path, "w", encoding="utf-8") as f: f.write("# 配置文件(注意:若是手动修改该配置文件,需要重新运行程序才能生效)\n") for key, entry in self.config_entries.items(): + if key.startswith("UI_"): # UI_前缀的配置项表示仅用于ui界面,不不需要保存到configs.py配置文件 + continue if isinstance(entry, ctk.CTkCheckBox): value = bool(entry.get()) else: @@ -486,6 +543,40 @@ def clear_log(self): self.log_text.configure(state="disabled") self.log("日志已清理") + def get_index_cache_path(self): + """获取序号缓存文件路径""" + kb_id = self.config_entries["DIFY_DOC_KB_ID"].get() + kb_name = self.config_entries["KB_NAME"].get() + return os.path.join(get_config_dir(), f"index_{kb_id}_{kb_name}.txt") + + def load_index_from_cache(self): + """从缓存文件加载序号""" + try: + index_path = self.get_index_cache_path() + if os.path.exists(index_path): + with open(index_path, 'r', encoding='utf-8') as f: + index = f.read().strip() + if index: + self.config_entries["UI_START_INDEX"].delete(0, "end") + self.config_entries["UI_START_INDEX"].insert(0, str(index)) + self.log(f"从 {index_path} 文件中读取文件序号: {index}") + except Exception as e: + self.log(f"加载序号缓存失败: {str(e)}") + + def save_index_to_cache(self): + """保存序号到缓存文件""" + try: + index = self.config_entries["UI_START_INDEX"].get() + if index: + index = int(index) + index_path = self.get_index_cache_path() + os.makedirs(os.path.dirname(index_path), exist_ok=True) + with open(index_path, 'w', encoding='utf-8') as f: + f.write(str(index)) + self.log(f"保存文件序号({index})到: {index_path}") + except Exception as e: + self.log(f"保存序号缓存失败: {str(e)}") + if __name__ == "__main__": ctk.set_appearance_mode("dark") ctk.set_default_color_theme("blue") diff --git a/utils/mysqlutils.py b/utils/mysqlutils.py index b95debf..1ced729 100644 --- a/utils/mysqlutils.py +++ b/utils/mysqlutils.py @@ -6,6 +6,8 @@ import pymysql import logging +from utils import timeutils + class BaseMySql(object): conn = None @@ -34,7 +36,7 @@ def __init__(self, host=None, user=None, password=None, database=None, port=None self.conn.commit() except Exception as e: - self.e(e) + timeutils.print_log(f'连接数据库异常: {e}') pass def query_list(self, sql: str) -> list: @@ -51,7 +53,7 @@ def query_list(self, sql: str) -> list: columns = [col[0] for col in cur.description] return [dict(zip(columns, self.parse_encoding(row))) for row in cur.fetchall()] except Exception as e: - self.e(e) + timeutils.print_log(f'query_list 查询数据异常: {e}') return [] def execute(self, sql: str) -> bool: @@ -67,7 +69,7 @@ def execute(self, sql: str) -> bool: self.conn.commit() return True except Exception as e: - self.e(e) + timeutils.print_log(f'execute 执行sql异常,sql = {sql}\n error: {e}') return False def parse_encoding(self, row) -> list: @@ -89,43 +91,15 @@ def close_connect(self) -> None: self.cursor.close() self.conn.close() self.child_close() - self.i('释放数据库连接') + timeutils.print_log(f'close_connect 已关闭数据库连接') except Exception as e: - self.e(e) + timeutils.print_log(f'close_connect 关闭数据库异常: {e}') def child_close(self) -> None: """ 提供给子类处理的关闭操作 """ pass - - def _need_update(self, spider) -> bool: - """ - 判断该爬虫是否需要进行更新操作 - :param spider: - :return: - """ - try: - if not spider or not hasattr(spider, 'NEED_UPDATE'): - return False - self.i(f"是否需要进行更新 spider.NEED_UPDATE={spider.NEED_UPDATE}") - return spider.NEED_UPDATE - except: - return False - - def _get_update_field_list(self, spider) -> list: - """ - 获取需要指定更新的字段 - :param spider: - :return: - """ - try: - if not spider or not hasattr(spider, 'UPDATE_FIELD_LIST'): - return [] - self.i(f"指定更新字段 spider.UPDATE_FIELD_LIST={spider.UPDATE_FIELD_LIST}") - return spider.UPDATE_FIELD_LIST - except: - return [] def i(self, msg): self.logger.info(msg)