In [None]:
import os
import re
import json
import markdown2
from bs4 import BeautifulSoup
from sentence_transformers import SentenceTransformer, util
import numpy as np
from typing import List, Dict, Any, Tuple, Optional
import torch
from collections import defaultdict

# 增强的配置参数
SIM_THRESHOLD = {
    "责任免除": {"base": 0.92, "adjust": 0.02},  # 基础阈值+动态调整范围
    "术语解释": {"base": 0.88, "adjust": 0.03},
    "赔付规则": {"base": 0.95, "adjust": 0.01},  # 数值类规则更严格
    "保障相关时间": {"base": 0.95, "adjust": 0.01}
}

# 增强的正则表达式模式
FIELD_PATTERNS = {
    "责任免除": r"(?:责任免除|免责说明|免责情形)[:：]?\s*([\s\S]*?)(?=(?:保障责任|赔付规则|\n##|\n#|$))",
    "赔付比例": r"(?:赔付比例|给付比例)[:：]?\s*(\d+%|按[\d一二三四五六七八九十]+成)",
    "等待期": r"(?:等待期|观察期)[:：]?\s*(\d+)\s*天",
    "投保年龄": r"(?:投保年龄|承保年龄)[:：]?\s*([\d\-～]+)\s*岁?",
    "保险金额": r"(?:保险金额|基本保额)[:：]?\s*([\d,]+)\s*元",
    "产品名称": r"(?:产品名称|保险产品)[:：]\s*([^\n]+)",
    "保险期间": r"(?:保险期间|保障期限)[:：]\s*([^\n]+)",
    "交费期间": r"(?:交费期间|缴费期限)[:：]\s*([^\n]+)",
    "犹豫期": r"(?:犹豫期|冷静期)[:：]\s*(\d+)\s*天",
    "免赔额": r"(?:免赔额|自付额)[:：]\s*([^\n]+)",
    "赔付次数": r"(?:赔付次数|给付次数)[:：]\s*([^\n]+)"
}

# 文本预处理正则
PREPROCESS_PATTERNS = [
    (r'[\u3000\s]+', ' '),  # 替换全角空格和连续空格
    (r'[【】]', ''),        # 移除中文括号
    (r'[（）]', '()'),      # 统一括号格式
    (r'[:：]\s*', ': ')     # 统一冒号格式
]

# 素材类型映射
TYPE_MAPPING = {
    "CLAUSE": "条款",
    "HEAD_IMG": "头图",
    "INSURE_NOTICE": "投保须知",
    "INTRODUCE_IMG": "图文说明",
    "LIABILITY_EXCLUSION": "免责说明"
}

# 规则类型映射
RULE_TYPE_MAPPING = {
    "基础产品销售信息": "销售信息",
    "责任免除": "免责条款",
    "赔付规则": "赔付规则",
    "保障相关时间": "保障时间",
    "投保条款": "投保条件",
    "术语解释": "术语定义"
}

class EnhancedInsuranceRiskDetector:
    def __init__(self, model_path: str):
        """初始化增强版保险素材风险检测器"""
        cache_dir = os.path.join(os.getcwd(), "model_cache")
        os.makedirs(cache_dir, exist_ok=True)
        
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"使用设备: {self.device}")
        
        # 加载模型并设置设备
        self.model = SentenceTransformer(
            model_path, 
            cache_folder=cache_dir,
            device=self.device
        )
        self.model.to(self.device)
        
        # 注册规则检查处理器
        self.rule_handlers = {
            "销售信息": self.check_sales_info,
            "免责条款": self.check_liability_exclusion,
            "赔付规则": self.check_payment_rule,
            "保障时间": self.check_time_rule,
            "投保条件": self.check_insurance_terms,
            "术语定义": self.check_term_definition
        }
        
        # 编译预处理正则
        self.preprocess_regex = [(re.compile(pat), repl) for pat, repl in PREPROCESS_PATTERNS]
    
    def preprocess_text(self, text: str) -> str:
        """文本预处理：清洗和标准化"""
        for regex, repl in self.preprocess_regex:
            text = regex.sub(repl, text)
        return text.strip()
    
    def detect_risk(self, material_data: Dict[str, Any], rule: str) -> bool:
        """增强版风险检测"""
        rule_type = self.infer_rule_type(rule)
        handler = self.rule_handlers.get(rule_type)
        
        if not handler:
            print(f"警告：未找到匹配的规则处理器，规则类型: {rule_type}，规则内容: {rule}")
            return True
        
        # 筛选并预处理相关素材
        relevant_docs = []
        for doc in material_data["documents"]:
            if doc["type"] in TYPE_MAPPING.values():
                processed_doc = doc.copy()
                processed_doc["text"] = self.preprocess_text(doc["text"])
                processed_doc["fields"] = {k: self.preprocess_text(v) for k, v in doc["fields"].items()}
                relevant_docs.append(processed_doc)
        
        if not relevant_docs:
            print(f"警告：素材包 {material_data['material_id']} 中缺少关键素材类型")
            return False
        
        return handler(relevant_docs)
    
    def check_sales_info(self, docs: List[Dict[str, Any]]) -> bool:
        """增强版销售信息检查"""
        sales_fields = ["产品名称", "保险期间", "交费期间", "保险金额"]
        field_values = defaultdict(set)
        
        for doc in docs:
            for field in sales_fields:
                value = doc["fields"].get(field, "")
                if value:
                    # 标准化处理
                    norm_value = self.normalize_field(field, value)
                    if norm_value:
                        field_values[field].add(norm_value)
        
        # 检查一致性
        conflicts = []
        for field, values in field_values.items():
            if len(values) > 1:
                conflicts.append((field, values))
        
        if conflicts:
            print("销售信息冲突发现:")
            for field, values in conflicts:
                print(f"  {field}: {values}")
            return False
            
        return True
    
    def check_liability_exclusion(self, docs: List[Dict[str, Any]]) -> bool:
        """增强版责任免除检查"""
        exclusion_texts = []
        for doc in docs:
            if "责任免除" in doc["fields"]:
                text = doc["fields"]["责任免除"]
                if text.strip():
                    exclusion_texts.append(text)
        
        if len(exclusion_texts) < 2:
            return True
            
        # 分段处理长文本
        segmented_texts = []
        for text in exclusion_texts:
            if len(text) > 200:  # 长文本分段
                segments = self.split_text(text, max_length=200)
                segmented_texts.extend(segments)
            else:
                segmented_texts.append(text)
        
        # 计算动态阈值
        avg_length = sum(len(t) for t in segmented_texts) / len(segmented_texts)
        threshold = self.get_dynamic_threshold("责任免除", avg_length)
        
        # 计算相似度
        embeddings = self.model.encode(
            segmented_texts,
            convert_to_tensor=True,
            device=self.device
        )
        cos_scores = util.cos_sim(embeddings, embeddings)
        cos_scores = cos_scores.cpu().numpy()
        np.fill_diagonal(cos_scores, 1.0)
        
        min_similarity = np.min(cos_scores)
        print(f"责任免除内容最小相似度: {min_similarity:.4f}，动态阈值: {threshold:.4f}")
        
        return min_similarity >= threshold
    
    def check_payment_rule(self, docs: List[Dict[str, Any]]) -> bool:
        """增强版赔付规则检查"""
        payment_fields = ["赔付比例", "免赔额", "赔付次数"]
        field_values = defaultdict(set)
        
        for doc in docs:
            for field in payment_fields:
                value = doc["fields"].get(field, "")
                if value:
                    norm_value = self.normalize_field(field, value)
                    if norm_value:
                        field_values[field].add(norm_value)
        
        # 检查一致性
        conflicts = []
        for field, values in field_values.items():
            if len(values) > 1:
                conflicts.append((field, values))
        
        if conflicts:
            print("赔付规则冲突发现:")
            for field, values in conflicts:
                print(f"  {field}: {values}")
            return False
            
        return True
    
    def check_time_rule(self, docs: List[Dict[str, Any]]) -> bool:
        """增强版保障时间检查"""
        time_fields = ["等待期", "犹豫期"]
        field_values = defaultdict(set)
        
        for doc in docs:
            for field in time_fields:
                value = doc["fields"].get(field, "")
                if value:
                    # 提取数值部分
                    match = re.search(r'(\d+)', value)
                    if match:
                        field_values[field].add(match.group(1))
        
        # 检查一致性
        conflicts = []
        for field, values in field_values.items():
            if len(values) > 1:
                conflicts.append((field, values))
        
        if conflicts:
            print("保障时间冲突发现:")
            for field, values in conflicts:
                print(f"  {field}: {values}")
            return False
            
        return True
    
    def check_insurance_terms(self, docs: List[Dict[str, Any]]) -> bool:
        """增强版投保条件检查"""
        age_values = set()
        health_values = set()
        
        for doc in docs:
            # 检查投保年龄
            age_str = doc["fields"].get("投保年龄", "")
            if age_str:
                # 提取年龄范围
                match = re.search(r'(\d+)\s*[-～至]?\s*(\d+)?\s*岁?', age_str)
                if match:
                    min_age = match.group(1)
                    max_age = match.group(2) if match.group(2) else min_age
                    age_values.add(f"{min_age}-{max_age}")
            
            # 检查健康告知要求
            health_text = doc["text"]
            health_matches = re.finditer(r'(健康告知|健康状况)[:：](.+?)(?=\n##|\n#|$)', health_text)
            for match in health_matches:
                if match.group(2).strip():
                    health_values.add(match.group(2).strip())
        
        # 检查一致性
        conflicts = []
        if len(age_values) > 1:
            conflicts.append(("投保年龄", age_values))
        if len(health_values) > 1:
            conflicts.append(("健康告知", health_values))
        
        if conflicts:
            print("投保条件冲突发现:")
            for field, values in conflicts:
                print(f"  {field}: {values}")
            return False
            
        return True
    
    def check_term_definition(self, docs: List[Dict[str, Any]]) -> bool:
        """增强版术语定义检查"""
        term_texts = []
        for doc in docs:
            text = doc["text"]
            # 提取术语定义段落
            term_paragraphs = re.finditer(r'(?:^|\n)([^\n：:]+?)\s*[:：]\s*([^\n]+)', text)
            for match in term_paragraphs:
                term, definition = match.groups()
                if len(definition) > 10:  # 只考虑较长的定义
                    term_texts.append(f"{term}: {definition}")
        
        if len(term_texts) < 2:
            return True
            
        # 计算动态阈值
        avg_length = sum(len(t) for t in term_texts) / len(term_texts)
        threshold = self.get_dynamic_threshold("术语解释", avg_length)
        
        # 计算相似度
        embeddings = self.model.encode(
            term_texts,
            convert_to_tensor=True,
            device=self.device
        )
        cos_scores = util.cos_sim(embeddings, embeddings)
        cos_scores = cos_scores.cpu().numpy()
        np.fill_diagonal(cos_scores, 1.0)
        
        min_similarity = np.min(cos_scores)
        print(f"术语定义最小相似度: {min_similarity:.4f}，动态阈值: {threshold:.4f}")
        
        return min_similarity >= threshold
    
    def split_text(self, text: str, max_length: int = 200) -> List[str]:
        """智能分段文本"""
        sentences = re.split(r'(?<=[。！？；])', text)
        segments = []
        current_segment = ""
        
        for sent in sentences:
            if len(current_segment) + len(sent) <= max_length:
                current_segment += sent
            else:
                if current_segment:
                    segments.append(current_segment)
                current_segment = sent
        
        if current_segment:
            segments.append(current_segment)
        
        return segments
    
    def get_dynamic_threshold(self, field: str, avg_length: float) -> float:
        """获取动态调整的相似度阈值"""
        base = SIM_THRESHOLD[field]["base"]
        adjust = SIM_THRESHOLD[field]["adjust"]
        
        # 长度越长，允许的阈值调整范围越大
        length_factor = min(1.0, avg_length / 300)  # 300字符为基准
        dynamic_adjust = adjust * length_factor
        
        return base + dynamic_adjust
    
    def normalize_field(self, field: str, value: str) -> str:
        """标准化字段值"""
        if field in ["保险金额", "免赔额"]:
            # 统一金额格式，如"10,000元" -> "10000元"
            return re.sub(r'[^\d元]', '', value)
        elif field == "投保年龄":
            # 统一年龄格式，如"30-50岁" -> "30-50"
            return re.sub(r'[^\d\-～]', '', value)
        elif field == "赔付比例":
            # 统一比例格式，如"80%" -> "80%"
            return re.sub(r'[^\d%成]', '', value)
        return value
    
    def infer_rule_type(self, rule: str) -> str:
        """从规则文本推断规则类型"""
        for keyword, rule_type in RULE_TYPE_MAPPING.items():
            if keyword in rule:
                return rule_type
        return "其他"
    
    def parse_material(self, material_path: str) -> Dict[str, Any]:
        """增强版素材解析"""
        material_data = {
            "material_id": os.path.basename(material_path),
            "documents": []
        }
        
        print(f"正在解析素材包: {material_data['material_id']}")
        
        for sub_dir in os.listdir(material_path):
            sub_dir_full = os.path.join(material_path, sub_dir)
            if os.path.isdir(sub_dir_full) and sub_dir in TYPE_MAPPING:
                doc_type = TYPE_MAPPING[sub_dir]
                print(f"  发现文档类型: {doc_type}")
                
                for root, dirs, files in os.walk(sub_dir_full):
                    for file in files:
                        if file.endswith('.md'):
                            md_path = os.path.join(root, file)
                            try:
                                text = self.parse_markdown(md_path)
                                fields = self.extract_fields(text, doc_type)
                                material_data["documents"].append({
                                    "type": doc_type,
                                    "path": md_path,
                                    "text": text,
                                    "fields": fields
                                })
                            except Exception as e:
                                print(f"    解析文件失败: {file}, 错误: {e}")
        
        return material_data
    
    def parse_markdown(self, file_path: str) -> str:
        """将 Markdown 文件转为纯文本"""
        with open(file_path, 'r', encoding='utf-8') as f:
            md_content = f.read()
            html = markdown2.markdown(md_content)
            return ''.join(BeautifulSoup(html, 'lxml').stripped_strings)
    
    def extract_fields(self, text: str, doc_type: str) -> Dict[str, str]:
        """增强版字段提取"""
        fields = {}
        text = self.preprocess_text(text)
        
        for field, pattern in FIELD_PATTERNS.items():
            try:
                matches = re.finditer(pattern, text, re.IGNORECASE)
                for match in matches:
                    if match.groups():
                        # 取第一个非空捕获组
                        value = next((g for g in match.groups() if g), "")
                        if value.strip():
                            fields[field] = value.strip()
                            break
            except Exception as e:
                print(f"字段提取错误 - {field}: {e}")
        
        return fields

def enhanced_process_all_materials(materials_dir: str, data_jsonl_path: str, model_path: str) -> None:
    """增强版批量处理函数"""
    detector = EnhancedInsuranceRiskDetector(model_path)
    
    with open(data_jsonl_path, 'r', encoding='utf-8') as f:
        tasks = [json.loads(line) for line in f]
    
    print(f"已加载 {len(tasks)} 条检测规则")
    
    material_ids = []
    for item in os.listdir(materials_dir):
        item_path = os.path.join(materials_dir, item)
        if os.path.isdir(item_path) and item.startswith("m_"):
            material_ids.append(item)
    
    print(f"发现 {len(material_ids)} 个素材包")
    
    results = []
    for task in tasks:
        material_id = task["material_id"]
        matched_materials = [m for m in material_ids if m.startswith(material_id.rstrip('as'))]
        
        if not matched_materials:
            print(f"警告：未找到匹配的素材包: {material_id}")
            task["result"] = False
            task["reason"] = "未找到匹配的素材包"
            results.append(task)
            continue
        
        material_path = os.path.join(materials_dir, matched_materials[0])
        
        try:
            print(f"\n处理任务: {task.get('rule_id', '无规则ID')} - {task['rule']}")
            print(f"素材包: {material_id}，实际路径: {material_path}")
            
            material_data = detector.parse_material(material_path)
            result = detector.detect_risk(material_data, task["rule"])
            
            task["result"] = bool(result)
            task["reason"] = "素材内容一致" if result else "素材内容存在冲突"
            results.append(task)
            
            print(f"检测结果: {'合规' if result else '违规'}")
            
        except Exception as e:
            print(f"处理素材包失败: {material_id}, 错误: {e}")
            task["result"] = False
            task["reason"] = f"处理失败: {str(e)}"
            results.append(task)
    
    output_path = os.path.join(os.path.dirname(data_jsonl_path), "enhanced_检测结果.jsonl")
    with open(output_path, 'w', encoding='utf-8') as f_out:
        for res in results:
            output_data = {
                "material_id": res["material_id"],
                "rule_id": res.get("rule_id"),
                "rule": res["rule"],
                "result": res["result"],
                "reason": res.get("reason", "")
            }
            json_line = json.dumps(output_data, ensure_ascii=False)
            f_out.write(json_line + "\n")
    
    print(f"\n检测完成，结果已保存至: {output_path}")

if __name__ == "__main__":
    materials_dir = r"C:\Users\27782\Desktop\测试 A 集\materials"
    data_jsonl = r"C:\Users\27782\Desktop\测试 A 集\data.jsonl"
    model_path = r"C:\Users\27782\Desktop\BERT\models--google-bert--bert-base-chinese\snapshots\c30a6ed22ab4564dc1e3b2ecbf6e766b0611a33f"
    
    print("开始加载语义模型...")
    try:
        enhanced_process_all_materials(materials_dir, data_jsonl, model_path)
    except Exception as e:
        print(f"程序执行失败: {e}")