In [None]:
import os
import json
import requests
import traceback
from datetime import datetime
from pathlib import Path
from typing import Dict, Optional
from collections import defaultdict
from Arxiv_Parser.paper_parser import parse_html
from Arxiv_Parser.paper_storage import save_paper_data
from LLM.llm import MultiLLM
from Task_Conductor.prompts import RelevanceTask


class ProcessingConfig:
    """配置中心（简化版）"""
    DEFAULT_SAVE_DIR = "papers"
    FILENAME_TEMPLATE = "paper_{arxiv_id}.json"
    
    def __init__(self, root_dir: str = ".", custom_output: Optional[str] = None):
        self.root_dir = Path(root_dir).expanduser().resolve()
        self.custom_output = Path(custom_output) if custom_output else None
        self._init_paths()
    
    def _init_paths(self):
        """初始化路径（去除HTML缓存）"""
        self.STATE_DIR = self.root_dir / ".status"
        self.OUTPUT_DIR = self.root_dir / self.DEFAULT_SAVE_DIR
        
        self.STATE_DIR.mkdir(parents=True, exist_ok=True)
        self.OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
    
    def get_output_path(self, arxiv_id: str) -> Path:
        """解析最终输出路径"""
        if self.custom_output:
            if self.custom_output.is_dir():
                return self.custom_output / self.FILENAME_TEMPLATE.format(arxiv_id=arxiv_id)
            return self.custom_output
        return self.OUTPUT_DIR / self.FILENAME_TEMPLATE.format(arxiv_id=arxiv_id)

class PaperProcessor:
    """增强型论文处理器（无HTML缓存）"""
    
    VERSION = "2.1"
    ENCODING = 'utf-8'
    
    def __init__(self, config: ProcessingConfig):
        self.config = config
        self.llm = MultiLLM('deepseek-coder')
        self._state = {
            "current_step": None,
            "metadata": {},
            "stats": {
                "sections": 0,
                "references": 0
            }
        }
    
    def process(self, url: str, keyword: Optional[str] = None) -> Path:
        """核心处理流程"""
        try:
            # 初始化元数据
            arxiv_id = self._extract_arxiv_id(url)
            output_path = self.config.get_output_path(arxiv_id)
            
            self._update_state({
                "metadata": {
                    "source_url": url,
                    "arxiv_id": arxiv_id,
                    "keyword": keyword,
                    "output_path": str(output_path),
                    "timestamp": datetime.now().isoformat()
                }
            })
            
            # 获取并处理内容
            print("🔄 获取论文内容...", end='', flush=True)
            response = requests.get(url)
            response.encoding = self.ENCODING
            response.raise_for_status()
            print("✅")
            
            # 解析内容
            print("🔍 解析论文结构...", end='', flush=True)
            paper_data = parse_html(response.text)
            self._update_state({
                "stats": {
                    "sections": len(paper_data.get("sections", [])),
                    "references": len(paper_data.get("references", []))
                }
            })
            print(f"✅ 找到 {self._state['stats']['sections']} 个章节")
            
            # 保存结果
            print(f"💾 保存到：{output_path}...", end='', flush=True)
            save_paper_data(
                paper_data, 
                str(output_path),
                encoding=self.ENCODING
            )
            print(f"✅ ({output_path.stat().st_size / 1024:.1f} KB)")
            
            # 关键词分析
            if keyword:
                print(f"🔎 分析关键词 '{keyword}' 相关性...")
                score = self._analyze_relevance(paper_data["abstract"], keyword)
                print(f"⭐ 相关性评分：{score}/1.0")
            
            return output_path
            
        except Exception as e:
            error_info = {
                "step": self._state.get("current_step"),
                "error_type": type(e).__name__,
                "message": str(e),
                "traceback": traceback.format_exc()
            }
            print(f"\n❌ 处理失败：{error_info['message']}")
            raise RuntimeError(json.dumps(error_info, ensure_ascii=False)) from e
    
    def _analyze_relevance(self, abstract: str, keyword: str) -> float:
        """执行相关性分析"""
        task = RelevanceTask(abstract, keyword)
        response = self.llm.ask(task.generate_prompt())
        return task.parse_model_output(response)
    
    @staticmethod
    def _extract_arxiv_id(url: str) -> str:
        """增强型ID提取"""
        base_id = url.split("/")[-1]
        for substr in ["v1", "html/", "pdf/"]:
            base_id = base_id.replace(substr, "")
        return base_id.strip("/")
    
    def _update_state(self, update_dict: dict):
        """更新状态"""
        self._state.update(update_dict)

    def process_sections_with_buffer(self, paper_data: dict) -> list:
        """修正后的章节处理方法"""
        buffer = []
        section_counter = [0]  # 使用可变对象保持计数状态

        def _recursive_processor(sections: list, parent: dict = None, depth: int = 0) -> None:
            for idx, section in enumerate(sections, start=1):
                # 更新路径计数器
                if depth >= len(section_counter):
                    section_counter.append(0)
                section_counter[depth] = idx
                current_path = '.'.join(map(str, section_counter[:depth+1]))
                
                # 仅收集当前章节的独立文本（不含子章节）
                buffer.append({
                    "section_path": current_path,
                    "self_text": section.get('text', ''),
                    "depth": depth,
                    "title_chain": _get_title_chain(section, parent),
                    "parent_path": parent['section_path'] if parent else None,
                    "children": []
                })
                
                # 递归处理子章节
                if section.get('subsections'):
                    _recursive_processor(
                        section['subsections'],
                        parent=section,
                        depth=depth + 1
                    )
                
                # 重置深层计数器
                for i in range(depth+1, len(section_counter)):
                    section_counter[i] = 0

        def _get_title_chain(section: dict, parent: dict) -> str:
            """生成标题层级链"""
            chain = []
            if parent and 'title_chain' in parent:
                chain.append(parent['title_chain'])
            chain.append(section.get('title', ''))
            return ' > '.join(chain)

        _recursive_processor(paper_data.get('sections', []))
        return self._build_hierarchy(buffer)

    def _build_hierarchy(self, flat_list: list) -> list:
        """将扁平列表转换为树形结构"""
        node_map = {item['section_path']: item for item in flat_list}
        for item in flat_list:
            if item['parent_path']:
                parent = node_map[item['parent_path']]
                parent['children'].append(item)
        return [item for item in flat_list if item['parent_path'] is None]

    def generate_char_stats(self, buffer: list) -> dict:
        """修正后的统计方法"""
        stats = {
            "individual_stats": [],
            "cumulative_stats": defaultdict(int),
            "summary": defaultdict(int)
        }

        # 预计算独立字符数
        for item in buffer:
            self_chars = len(item['self_text'])
            stats["individual_stats"].append({
                "path": item["section_path"],
                "title_chain": item["title_chain"],
                "self_chars": self_chars,
                "cumulative_chars": self_chars + sum(len(c['self_text']) for c in item['children'])
            })
            stats["cumulative_stats"][item["section_path"]] = self_chars

        # 后序遍历计算累计值
        def _post_order(node):
            total = len(node['self_text'])
            for child in node['children']:
                total += _post_order(child)
            stats["cumulative_stats"][node['section_path']] = total
            return total

        for root in [item for item in buffer if not item['parent_path']]:
            _post_order(root)

        # 生成摘要
        total_self = sum(len(item['self_text']) for item in buffer)
        stats["summary"] = {
            "total_sections": len(buffer),
            "total_self_chars": total_self,
            "total_cumulative_chars": stats["cumulative_stats"].get('1', 0),
            "avg_self_per_section": total_self // len(buffer) if buffer else 0,
            "max_depth": max(item['depth'] for item in buffer),
            "longest_self_section": max(buffer, key=lambda x: len(x['self_text']))['section_path'] if buffer else None,
            "longest_cumulative_section": max(stats["cumulative_stats"].items(), key=lambda x: x[1])[0]
        }

        return stats





In [None]:
config = ProcessingConfig(
    root_dir=r"C:\Users\Inuyasha\Programs\Python\AIGC\Arxiv_Reviewer/research_papers"
)

processor = PaperProcessor(config)
result_path = processor.process(
    url="https://arxiv.org/html/2501.00092v1",
    keyword="AI"
)


In [None]:
# 在paper_storage.py中增强保存函数
def save_paper_data(data: dict, filename: str, encoding='utf-8'):
    """增强保存函数"""
    Path(filename).parent.mkdir(parents=True, exist_ok=True)
    with open(filename, 'w', encoding=encoding, errors='replace') as f:
        json.dump(data, f, ensure_ascii=False, indent=2)

In [None]:
# 初始化配置和处理器（使用用户提供的路径）
config = ProcessingConfig(
    root_dir=r"C:\Users\Inuyasha\Programs\Python\AIGC\Arxiv_Reviewer/research_papers"
)
processor = PaperProcessor(config)

# 加载已处理的论文数据
paper_path = config.get_output_path("2501.00092")  # 根据实际arxiv_id修改
with open(paper_path, 'r', encoding='utf-8') as f:
    paper_data = json.load(f)

# 执行缓冲处理
buffer = processor.process_sections_with_buffer(paper_data)

# 查看处理结果
for item in buffer[:2]:  # 查看前两项示例
    print(f"路径：{item['section_path']}")
    print(f"标题链：{item['title_chain']}")
    print(f"文本长度：{len(item['full_text'])}字符")
    print("="*50)

# 生成统计报告
stats = processor.generate_char_stats(buffer)

# 控制台输出
processor.print_stat_report(stats)

# 获取原始数据（如需编程处理）
print("\n独立章节统计示例:")
for item in stats["individual_stats"][:2]:
    print(f"路径 {item['path']} | 标题链: {item['title_chain'][:20]}... | 字符数: {item['chars']:,}")


In [None]:
import re
from typing import List, Dict

# 加载已处理的论文数据
paper_path = config.get_output_path("2501.00092")  # 根据实际arxiv_id修改
with open(paper_path, 'r', encoding='utf-8') as f:
    paper_data = json.load(f)

def extract_content_summary(paper_data: dict) -> List[str]:
    """完整字数统计与内容摘要"""
    result = []
    total_chars = 0
    section_counter = [0]

    # 处理Abstract
    if 'abstract' in paper_data and paper_data['abstract']:
        abstract = paper_data['abstract'].strip()
        count = len(abstract)
        preview = abstract[:100] + ('...' if len(abstract) > 100 else '')
        result.append(f"[Abstract] ({count}字)\n{preview}\n")
        total_chars += count

    # 递归处理章节
    def process_sections(sections: list, depth=0):
        nonlocal total_chars
        if depth >= len(section_counter):
            section_counter.append(0)

        for idx, section in enumerate(sections, start=1):
            section_counter[depth] = idx
            current_label = ".".join(map(str, section_counter[:depth + 1]))
            
            if 'text' in section and section['text']:
                text = section['text'].strip()
                count = len(text)
                preview = text
                result.append(f"[{current_label}] ({count}字)\n{preview}\n")
                total_chars += count

            if 'subsections' in section and section['subsections']:
                process_sections(section['subsections'], depth + 1)
            
            # 重置子级编号
            for i in range(depth + 1, len(section_counter)):
                section_counter[i] = 0

    if 'sections' in paper_data and paper_data['sections']:
        process_sections(paper_data['sections'])

    # 添加总统计
    result.append("=" * 16 + f"\n总字数: {total_chars:,}字")

    return result

def split_long_paragraphs(content_list: List[str], max_length: int = 8000, overlap: int = 500) -> List[Dict]:
    """将长段落分割成更小的部分"""
    def split_text(text: str) -> List[str]:
        # 首先按换行符分割
        lines = text.split('\n')
        chunks = []
        current_chunk = []
        current_length = 0

        for line in lines:
            if current_length + len(line) > max_length:
                if current_chunk:
                    chunks.append('\n'.join(current_chunk))
                    current_chunk = [current_chunk[-1]]
                    current_length = len(current_chunk[-1])
                
                if len(line) > max_length:
                    sentences = re.split('([。！？])', line)
                    temp_chunk = []
                    for i in range(0, len(sentences), 2):
                        sentence = sentences[i] + (sentences[i+1] if i+1 < len(sentences) else '')
                        if current_length + len(sentence) <= max_length:
                            temp_chunk.append(sentence)
                            current_length += len(sentence)
                        else:
                            if temp_chunk:
                                chunks.append(''.join(temp_chunk))
                            temp_chunk = [sentence]
                            current_length = len(sentence)
                    current_chunk.extend(temp_chunk)
                else:
                    current_chunk.append(line)
                    current_length += len(line)
            else:
                current_chunk.append(line)
                current_length += len(line)

        if current_chunk:
            chunks.append('\n'.join(current_chunk))

        return chunks

    result = []
    for item in content_list:
        if len(item) <= max_length:
            result.append({"content": item, "word_count": len(item)})
        else:
            split_parts = split_text(item)
            for idx, part in enumerate(split_parts):
                result.append({
                    "content": part,
                    "word_count": len(part),
                    "part": idx + 1,
                    "total_parts": len(split_parts)
                })

    return result
# 获取内容摘要
content_summary = extract_content_summary(paper_data)

# 打印分割前每个元素的长度
# print("分割前每个元素的长度:")
# for item in content_summary:
#     print(len(item),item)

# 分割长段落
split_result = split_long_paragraphs(content_summary)
print(len(split_result))
# 打印分割后每个元素的长度
print("分割后每个元素的长度:")
for item in split_result:
    print(len(item['content']))

60
分割后每个元素的长度:
123
7551
7597
6051
2543
7969
7178
7526
7958
3748
7389
7702
7582
814
3009
2936
5114
7539
5148
7758
7469
6276
7334
8005
7528
6816
7539
7191
8010
5851
7661
8074
7341
7780
7647
7930
7250
7372
7899
2645
7162
7862
7976
5336
2009
6391
7821
3739
7823
7410
1589
7703
6813
7250
5595
7851
1166
7718
7665
30
