In [None]:
import os
import logging
import pandas as pd
import numpy as np
import json
import pickle
import signal
import time
import sys
from pathlib import Path
from concurrent.futures import ProcessPoolExecutor
from paddlenlp import Taskflow
from tqdm import tqdm
from collections import Counter
import matplotlib.pyplot as plt
from wordcloud import WordCloud
import seaborn as sns
import jieba
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics.pairwise import cosine_similarity

# —— 配置区 —— 
CSV_PATH    = r"D:\CODE_WORLD\No.15CDR_workspace\data\texts\无人驾驶网约车评论数据.csv"
OUTPUT_DIR  = "D:\\CODE_WORLD\\No.15CDR_workspace\\results\\uie_analysis"
OUTPUT_FILE_PREFIX = "autonomous_vehicle_analysis"
DEVICE_ID   = 0      # 如果没有GPU，设为None 或删除 device_id
BATCH_SIZE  = 8
NUM_WORKERS = 2

# —— 断点续传配置 ——
CHECKPOINT_DIR = os.path.join(OUTPUT_DIR, "checkpoints")
PROGRESS_FILE = os.path.join(CHECKPOINT_DIR, "progress.json")
# 可以通过键盘中断安全退出的标志
INTERRUPT_FLAG = False

# —— 我们扩展的"维度"（schema）如下 —— 
SCHEMA = [
    "主题关键词",     # 核心话题词，如"安全"、"定价"等
    "情感倾向",       # 正面/负面/中性
    "安全关注",       # 用户对行程安全的具体评论
    "舒适度",         # 座椅、车内体验等
    "价格评价",       # 价格合理性、优惠诉求
    "服务体验",       # 车辆清洁、接送速度、司机交互
    "技术可靠性",     # 自动驾驶的稳定性、故障抱怨
    "法规合规",       # 对地方政策、法规遵守的关注
    "隐私保护"        # 个人数据或车内监控的担忧
]

# 定义每个维度的特征词，用于增强提取能力
FEATURE_WORDS = {
    "安全关注": ["安全", "刹车", "紧急", "事故", "风险", "躲避", "危险", "撞车", "路况", "应急"],
    "舒适度": ["舒适", "座椅", "空间", "噪音", "平稳", "颠簸", "温度", "空调", "气味", "环境"],
    "价格评价": ["价格", "费用", "贵", "便宜", "经济", "优惠", "划算", "收费", "性价比", "成本"],
    "服务体验": ["服务", "态度", "等待", "准时", "接送", "卫生", "整洁", "礼貌", "响应", "交互"],
    "技术可靠性": ["技术", "故障", "系统", "稳定", "错误", "反应", "延迟", "精准", "导航", "识别"],
    "法规合规": ["法规", "合规", "政策", "许可", "牌照", "合法", "监管", "规定", "标准", "要求"],
    "隐私保护": ["隐私", "数据", "监控", "摄像头", "记录", "个人信息", "追踪", "保密", "共享", "安全"]
}

logging.basicConfig(level=logging.INFO, format="%(levelname)s | %(message)s")

# 设置键盘中断信号处理函数
def handle_interrupt(signum, frame):
    global INTERRUPT_FLAG
    if not INTERRUPT_FLAG:
        logging.warning("\n收到中断信号，将在当前批次处理完成后安全退出...")
        logging.warning("再次按Ctrl+C将强制退出（不推荐）")
        INTERRUPT_FLAG = True
    else:
        logging.error("强制退出！可能会丢失部分数据")
        sys.exit(1)

# 注册信号处理函数
signal.signal(signal.SIGINT, handle_interrupt)
signal.signal(signal.SIGTERM, handle_interrupt)

class ProgressManager:
    """管理分析进度的类，用于实现断点续传"""
    
    # 定义处理阶段
    STAGES = [
        "load_data",           # 加载数据
        "uie_extraction",      # UIE信息抽取
        "comment_level_data",  # 评论级别数据处理
        "dimension_stats",     # 维度统计
        "topic_document",      # 主题-文档矩阵
        "sentiment_analysis",  # 情感分析
        "keyword_analysis",    # 关键词分析
        "visualization_data",  # 可视化数据准备
        "summary"              # 生成摘要
    ]
    
    def __init__(self, checkpoint_dir=CHECKPOINT_DIR, progress_file=PROGRESS_FILE):
        """初始化进度管理器"""
        self.checkpoint_dir = checkpoint_dir
        self.progress_file = progress_file
        self.current_progress = self._load_progress()
        
        # 确保checkpoint目录存在
        os.makedirs(self.checkpoint_dir, exist_ok=True)
    
    def _load_progress(self):
        """加载进度信息"""
        if os.path.exists(self.progress_file):
            try:
                with open(self.progress_file, "r", encoding="utf-8") as f:
                    progress = json.load(f)
                logging.info(f"从 {self.progress_file} 加载进度信息")
                return progress
            except Exception as e:
                logging.warning(f"加载进度信息失败: {e}")
        
        # 如果没有进度文件或加载失败，初始化一个新的进度
        return {
            "stages": {stage: False for stage in self.STAGES},
            "batch_progress": {
                "total_batches": 0,
                "completed_batches": 0
            },
            "last_run": None,
            "processed_ids": []
        }
    
    def save_progress(self):
        """保存当前进度"""
        self.current_progress["last_run"] = pd.Timestamp.now().isoformat()
        with open(self.progress_file, "w", encoding="utf-8") as f:
            json.dump(self.current_progress, f, ensure_ascii=False, indent=2)
    
    def is_stage_completed(self, stage):
        """检查指定阶段是否已完成"""
        return self.current_progress["stages"].get(stage, False)
    
    def mark_stage_complete(self, stage):
        """标记阶段为已完成"""
        if stage in self.STAGES:
            self.current_progress["stages"][stage] = True
            self.save_progress()
            logging.info(f"阶段 [{stage}] 已完成并保存进度")
    
    def update_batch_progress(self, total_batches, completed_batches):
        """更新批处理进度"""
        self.current_progress["batch_progress"] = {
            "total_batches": total_batches,
            "completed_batches": completed_batches
        }
        # 每10个批次保存一次进度，避免频繁IO
        if completed_batches % 10 == 0 or completed_batches == total_batches:
            self.save_progress()
    
    def add_processed_ids(self, ids):
        """添加已处理的评论ID"""
        self.current_progress["processed_ids"].extend(ids)
        
    def get_processed_ids(self):
        """获取已处理的评论ID列表"""
        return set(self.current_progress.get("processed_ids", []))
    
    def reset_progress(self, confirm=False):
        """重置进度（需确认）"""
        if confirm:
            self.current_progress = {
                "stages": {stage: False for stage in self.STAGES},
                "batch_progress": {
                    "total_batches": 0,
                    "completed_batches": 0
                },
                "last_run": None,
                "processed_ids": []
            }
            self.save_progress()
            logging.info("进度已重置")
        else:
            logging.warning("需要确认才能重置进度")

def init_worker(device_id):
    """在每个子进程中初始化 UIE 模型"""
    global uie
    uie = Taskflow(
        "information_extraction",
        schema=SCHEMA,
        model="uie-mini",  # 使用更大的模型提升提取效果
        device_id=device_id,
        static_mode=True,
        position_prob=0.4,  # 降低阈值，提高召回率
        use_fast=True,
        use_fp16_decoding=True
    )
    logging.info(f"Worker ready on device {device_id}")

def process_batch(batch_data):
    """
    批量处理一组评论：
    batch_data 是元组 (batch_ids, batch_texts)
    返回一个 dict: {
        "ids": [...], 
        "raw": {...},  # 原始UIE结果
        "enhanced": {...},  # 增强后结果
        "by_id": {}  # 按ID组织的详细结果
    }
    """
    batch_ids, batch = batch_data
    # 原始UIE提取
    results = uie(batch)
    
    # 整合结果
    merged = {k: [] for k in SCHEMA}
    result_by_id = {}
    
    for i, item in enumerate(results):
        comment_id = batch_ids[i]
        result_by_id[comment_id] = {
            "text": batch[i],  # 原始评论文本
            "uie_results": item,  # UIE提取结果
            "dimensions": {}  # 各维度的分析结果
        }
        
        for field, ents in item.items():
            merged[field].extend(e["text"] for e in ents)
            result_by_id[comment_id]["dimensions"][field] = [e["text"] for e in ents]
            
    # 增强提取 - 基于特征词和上下文
    enhanced = {k: [] for k in SCHEMA}
    for i, text in enumerate(batch):
        comment_id = batch_ids[i]
        # 对每个维度进行特征词增强
        result_by_id[comment_id]["enhanced_dimensions"] = {}
        
        for dimension, keywords in FEATURE_WORDS.items():
            result_by_id[comment_id]["enhanced_dimensions"][dimension] = []
            
            for keyword in keywords:
                if keyword in text:
                    # 找出关键词所在的上下文
                    start_pos = text.find(keyword)
                    # 提取关键词周围的上下文（前后15个字符）
                    context_start = max(0, start_pos - 15)
                    context_end = min(len(text), start_pos + len(keyword) + 15)
                    context = text[context_start:context_end]
                    enhanced[dimension].append(context)
                    
                    # 记录特征词的位置和上下文
                    result_by_id[comment_id]["enhanced_dimensions"][dimension].append({
                        "keyword": keyword,
                        "position": start_pos,
                        "context": context
                    })
    
    return {
        "ids": batch_ids,
        "raw": merged,
        "enhanced": enhanced,
        "by_id": result_by_id
    }

def generate_topic_document_matrix(merged_results, comments_df):
    """
    生成类似LDA的主题-文档矩阵
    """
    # 将评论ID与主题维度的出现频率关联
    topic_doc_matrix = {}
    
    # 对每个评论，计算各主题维度的相关度
    for comment_id, comment_data in merged_results["by_id"].items():
        topic_doc_matrix[comment_id] = {dimension: 0 for dimension in SCHEMA[2:]}  # 跳过主题关键词和情感倾向
        
        # 基于UIE提取结果计算相关度
        for dimension in SCHEMA[2:]:
            # 考虑原始UIE提取结果
            if dimension in comment_data.get("dimensions", {}) and comment_data["dimensions"][dimension]:
                topic_doc_matrix[comment_id][dimension] += len(comment_data["dimensions"][dimension]) * 1.5  # 原始UIE结果权重更高
            
            # 考虑增强维度
            if dimension in comment_data.get("enhanced_dimensions", {}) and comment_data["enhanced_dimensions"][dimension]:
                topic_doc_matrix[comment_id][dimension] += len(comment_data["enhanced_dimensions"][dimension])
    
    # 计算每个文档的主导主题
    dominant_topics = {}
    for comment_id, topics in topic_doc_matrix.items():
        if any(topics.values()):  # 确保至少有一个主题得分不为0
            dominant_topic = max(topics.items(), key=lambda x: x[1])
            dominant_topics[comment_id] = {
                "topic": dominant_topic[0],
                "score": dominant_topic[1]
            }
        else:
            dominant_topics[comment_id] = {"topic": "未分类", "score": 0}
            
    # 转换为DataFrame
    topic_doc_df = pd.DataFrame.from_dict(topic_doc_matrix, orient="index")
    
    # 合并原始评论文本
    topic_doc_df = pd.DataFrame.from_dict(topic_doc_matrix, orient="index")
    
    # 添加主导主题列
    topic_doc_df["dominant_topic"] = [dominant_topics[idx]["topic"] for idx in topic_doc_df.index]
    topic_doc_df["topic_score"] = [dominant_topics[idx]["score"] for idx in topic_doc_df.index]
    
    # 合并原始评论文本
    if "评论ID" in comments_df.columns:
        topic_doc_df = topic_doc_df.merge(
            comments_df[["评论ID", "评论"]], 
            left_index=True, 
            right_on="评论ID",
            how="left"
        )
    
    # 计算主题间的相似度矩阵
    topic_similarity = pd.DataFrame(
        cosine_similarity(topic_doc_df[SCHEMA[2:]].T),
        index=SCHEMA[2:],
        columns=SCHEMA[2:]
    )
    
    return topic_doc_df, topic_similarity, dominant_topics

def extract_sentiment_analysis(all_results):
    """提取情感分析结果并计算统计信息"""
    sentiment_data = {}
    sentiment_by_topic = {topic: {"正面": 0, "负面": 0, "中性": 0} for topic in SCHEMA[2:]}
    
    # 遍历每条评论，提取情感信息
    for comment_id, data in all_results["by_id"].items():
        # 获取情感倾向
        sentiments = data.get("dimensions", {}).get("情感倾向", [])
        dominant_sentiment = "中性"  # 默认中性
        
        # 判断主要情感倾向
        if sentiments:
            pos_count = sentiments.count("正面") + sentiments.count("积极") + sentiments.count("好评")
            neg_count = sentiments.count("负面") + sentiments.count("消极") + sentiments.count("差评")
            
            if pos_count > neg_count:
                dominant_sentiment = "正面"
            elif neg_count > pos_count:
                dominant_sentiment = "负面"
        
        # 记录该评论的主要情感
        sentiment_data[comment_id] = dominant_sentiment
        
        # 统计不同主题下的情感分布
        # 找到该评论的主导主题
        max_score = 0
        dominant_topic = None
        
        for topic in SCHEMA[2:]:
            topic_score = 0
            if topic in data.get("dimensions", {}):
                topic_score += len(data["dimensions"][topic]) * 1.5
            if topic in data.get("enhanced_dimensions", {}):
                topic_score += len(data["enhanced_dimensions"][topic])
                
            if topic_score > max_score:
                max_score = topic_score
                dominant_topic = topic
        
        # 只有当有主导主题时才统计
        if dominant_topic and max_score > 0:
            sentiment_by_topic[dominant_topic][dominant_sentiment] += 1
    
    return sentiment_data, sentiment_by_topic

def extract_keyword_co_occurrence(all_results):
    """提取关键词共现信息"""
    # 统计各维度关键词出现次数
    keyword_counts = {dimension: Counter() for dimension in SCHEMA[2:]}
    for comment_id, data in all_results["by_id"].items():
        for dimension in SCHEMA[2:]:
            if dimension in data.get("dimensions", {}):
                keywords = data["dimensions"][dimension]
                for kw in keywords:
                    keyword_counts[dimension][kw] += 1
            
            if dimension in data.get("enhanced_dimensions", {}):
                for item in data["enhanced_dimensions"][dimension]:
                    keyword_counts[dimension][item["keyword"]] += 1
    
    # 统计关键词之间的共现关系
    co_occurrence = {}
    
    for comment_id, data in all_results["by_id"].items():
        # 获取评论中出现的所有关键词
        all_keywords = []
        
        for dimension in SCHEMA[2:]:
            if dimension in data.get("dimensions", {}):
                all_keywords.extend([(kw, dimension) for kw in data["dimensions"][dimension]])
            
            if dimension in data.get("enhanced_dimensions", {}):
                all_keywords.extend([(item["keyword"], dimension) for item in data["enhanced_dimensions"][dimension]])
        
        # 统计共现
        for i, (kw1, dim1) in enumerate(all_keywords):
            for kw2, dim2 in all_keywords[i+1:]:
                if kw1 != kw2:  # 避免自己和自己共现
                    key = (kw1, dim1, kw2, dim2)
                    co_occurrence[key] = co_occurrence.get(key, 0) + 1
    
    # 转换为DataFrame格式
    co_occur_records = []
    for (kw1, dim1, kw2, dim2), count in co_occurrence.items():
        if count >= 2:  # 只保留共现次数大于等于2的
            co_occur_records.append({
                "keyword1": kw1,
                "dimension1": dim1,
                "keyword2": kw2,
                "dimension2": dim2,
                "co_occurrence_count": count
            })
    
    co_occur_df = pd.DataFrame(co_occur_records)
    return keyword_counts, co_occur_df

def save_checkpoint(name, data):
    """保存检查点数据"""
    checkpoint_path = os.path.join(CHECKPOINT_DIR, f"{name}.pickle")
    with open(checkpoint_path, "wb") as f:
        pickle.dump(data, f)
    logging.info(f"检查点已保存: {checkpoint_path}")

def load_checkpoint(name):
    """加载检查点数据"""
    checkpoint_path = os.path.join(CHECKPOINT_DIR, f"{name}.pickle")
    if os.path.exists(checkpoint_path):
        with open(checkpoint_path, "rb") as f:
            data = pickle.load(f)
        logging.info(f"已加载检查点: {checkpoint_path}")
        return data
    return None

def main():
    # 创建进度管理器
    progress_mgr = ProgressManager()
    
    # 创建输出目录
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    os.makedirs(CHECKPOINT_DIR, exist_ok=True)
    
    # 1. 读取 CSV 并提取评论列
    df = None
    comments = []
    comment_ids = []
    
    if not progress_mgr.is_stage_completed("load_data"):
        logging.info("开始加载数据...")
        df = pd.read_csv(CSV_PATH, encoding="utf-8")
        if "评论ID" not in df.columns:
            df["评论ID"] = [f"comment_{i}" for i in range(len(df))]
        
        comments = df["评论"].dropna().astype(str).tolist()
        comment_ids = df["评论ID"].tolist()
        logging.info(f"已加载 {len(comments)} 条评论")
        
        # 保存原始数据的副本
        df.to_csv(os.path.join(OUTPUT_DIR, f"{OUTPUT_FILE_PREFIX}_raw_data.csv"), index=False, encoding="utf-8")
        
        # 保存数据检查点
        save_checkpoint("raw_data", {"df": df, "comments": comments, "comment_ids": comment_ids})
        progress_mgr.mark_stage_complete("load_data")
    else:
        logging.info("加载数据阶段已完成，从检查点恢复...")
        checkpoint_data = load_checkpoint("raw_data")
        if checkpoint_data:
            df = checkpoint_data["df"]
            comments = checkpoint_data["comments"]
            comment_ids = checkpoint_data["comment_ids"]
            logging.info(f"已从检查点恢复 {len(comments)} 条评论数据")
        else:
            logging.error("无法加载数据检查点，重新开始加载数据...")
            # 重置进度并重新运行该阶段
            progress_mgr.current_progress["stages"]["load_data"] = False
            return main()

    # 2. 切分成多个批次并进行UIE抽取
    all_results = {
        "ids": [],
        "raw": {k: [] for k in SCHEMA},
        "enhanced": {k: [] for k in SCHEMA},
        "by_id": {}
    }
    
    if not progress_mgr.is_stage_completed("uie_extraction"):
        # 检查是否有缓存结果
        cache_path = os.path.join(OUTPUT_DIR, f"{OUTPUT_FILE_PREFIX}_uie_results.pickle")
        if os.path.exists(cache_path):
            logging.info(f"找到缓存结果，从 {cache_path} 加载")
            with open(cache_path, "rb") as f:
                all_results = pickle.load(f)
            logging.info(f"已加载 {len(all_results['by_id'])} 条评论的分析结果")
            
            # 检查是否所有评论都已处理
            processed_ids = set(all_results["by_id"].keys())
            all_ids = set(comment_ids)
            unprocessed_ids = all_ids - processed_ids
            
            # 如果有未处理的评论，继续处理
            if unprocessed_ids:
                logging.info(f"发现 {len(unprocessed_ids)} 条评论尚未处理，继续处理...")
                unprocessed_indices = [i for i, cid in enumerate(comment_ids) if cid in unprocessed_ids]
                unprocessed_comments = [comments[i] for i in unprocessed_indices]
                unprocessed_comment_ids = [comment_ids[i] for i in unprocessed_indices]
                
                # 切分未处理的部分
                batches = [
                    (unprocessed_comment_ids[i:i+BATCH_SIZE], unprocessed_comments[i:i+BATCH_SIZE])
                    for i in range(0, len(unprocessed_comments), BATCH_SIZE)
                ]
                
                # 继续处理未处理的批次
                with ProcessPoolExecutor(
                    max_workers=NUM_WORKERS,
                    initializer=init_worker,
                    initargs=(DEVICE_ID,)
                ) as executor:
                    # 使用tqdm显示进度
                    completed_batches = 0
                    total_batches = len(batches)
                    progress_mgr.update_batch_progress(total_batches, completed_batches)
                    
                    batch_iter = executor.map(process_batch, batches)
                    for result in tqdm(batch_iter, total=total_batches, desc="继续UIE提取"):
                        all_results["ids"].extend(result["ids"])
                        progress_mgr.add_processed_ids(result["ids"])
                        
                        for k in SCHEMA:
                            all_results["raw"][k].extend(result["raw"].get(k, []))
                            all_results["enhanced"][k].extend(result["enhanced"].get(k, []))
                        all_results["by_id"].update(result["by_id"])
                        
                        # 更新进度
                        completed_batches += 1
                        progress_mgr.update_batch_progress(total_batches, completed_batches)
                        
                        # 每处理10个批次保存一次中间结果
                        if completed_batches % 10 == 0 or completed_batches == total_batches:
                            # 保存UIE结果
                            with open(cache_path, "wb") as f:
                                pickle.dump(all_results, f)
                            logging.info(f"UIE分析中间结果已保存到缓存：{cache_path}")
                        
                        # 检查中断标志
                        if INTERRUPT_FLAG:
                            logging.warning("检测到中断请求，安全退出处理...")
                            # 保存当前进度
                            with open(cache_path, "wb") as f:
                                pickle.dump(all_results, f)
                            logging.info(f"当前进度已保存，下次可继续处理")
                            return
            
            # 所有评论处理完成，标记阶段完成
            progress_mgr.mark_stage_complete("uie_extraction")
        else:
            logging.info("未找到缓存结果，开始运行UIE分析")
            # 切分成多个批次
            batches = [
                (comment_ids[i:i+BATCH_SIZE], comments[i:i+BATCH_SIZE])
                for i in range(0, len(comments), BATCH_SIZE)
            ]

            # 并行抽取
            with ProcessPoolExecutor(
                max_workers=NUM_WORKERS,
                initializer=init_worker,
                initargs=(DEVICE_ID,)
            ) as executor:
                # 使用tqdm显示进度
                completed_batches = 0
                total_batches = len(batches)
                progress_mgr.update_batch_progress(total_batches, completed_batches)
                
                batch_iter = executor.map(process_batch, batches)
                for result in tqdm(batch_iter, total=total_batches, desc="UIE提取"):
                    all_results["ids"].extend(result["ids"])
                    progress_mgr.add_processed_ids(result["ids"])
                    
                    for k in SCHEMA:
                        all_results["raw"][k].extend(result["raw"].get(k, []))
                        all_results["enhanced"][k].extend(result["enhanced"].get(k, []))
                    all_results["by_id"].update(result["by_id"])
                    
                    # 更新进度
                    completed_batches += 1
                    progress_mgr.update_batch_progress(total_batches, completed_batches)
                    
                    # 每处理10个批次保存一次中间结果
                    if completed_batches % 10 == 0 or completed_batches == total_batches:
                        # 保存UIE结果
                        with open(cache_path, "wb") as f:
                            pickle.dump(all_results, f)
                        logging.info(f"UIE分析中间结果已保存到缓存：{cache_path}")
                    
                    # 检查中断标志
                    if INTERRUPT_FLAG:
                        logging.warning("检测到中断请求，安全退出处理...")
                        # 保存当前进度
                        with open(cache_path, "wb") as f:
                            pickle.dump(all_results, f)
                        logging.info(f"当前进度已保存，下次可继续处理")
                        return
                
                # UIE分析完成，保存结果并标记阶段完成
                with open(cache_path, "wb") as f:
                    pickle.dump(all_results, f)
                logging.info(f"UIE分析结果已保存到缓存：{cache_path}")
                progress_mgr.mark_stage_complete("uie_extraction")
    else:
        # 从缓存加载UIE结果
        cache_path = os.path.join(OUTPUT_DIR, f"{OUTPUT_FILE_PREFIX}_uie_results.pickle")
        if os.path.exists(cache_path):
            logging.info(f"加载UIE分析结果从缓存: {cache_path}")
            with open(cache_path, "rb") as f:
                all_results = pickle.load(f)
            logging.info(f"已加载 {len(all_results['by_id'])} 条评论的分析结果")
        else:
            logging.error("无法找到UIE分析结果缓存，但进度显示已完成，将重置进度")
            # 重置进度并重新运行该阶段
            progress_mgr.current_progress["stages"]["uie_extraction"] = False
            return main()

    # 4.1 提取详细的评论级别数据
    if not progress_mgr.is_stage_completed("comment_level_data"):
        logging.info("开始生成评论级别详细数据...")
        try:
            comment_level_data = []
            for comment_id, data in all_results["by_id"].items():
                row = {
                    "评论ID": comment_id,
                    "评论文本": data.get("text", ""),
                }
                
                # 添加各维度的提取结果
                for dimension in SCHEMA:
                    extracted = data.get("dimensions", {}).get(dimension, [])
                    row[f"{dimension}_提取结果"] = "|".join(extracted) if extracted else ""
                    
                    # 添加增强特征的信息
                    enhanced = data.get("enhanced_dimensions", {}).get(dimension, [])
                    row[f"{dimension}_增强特征"] = "|".join([item.get("keyword", "") for item in enhanced]) if enhanced else ""
                
                comment_level_data.append(row)
            
            # 保存评论级别的详细数据
            comment_level_df = pd.DataFrame(comment_level_data)
            comment_level_file = os.path.join(OUTPUT_DIR, f"{OUTPUT_FILE_PREFIX}_comment_level_data.csv")
            comment_level_df.to_csv(comment_level_file, index=False, encoding="utf-8")
            logging.info(f"评论级别数据已保存到: {comment_level_file}")
            
            # 保存检查点和标记阶段完成
            save_checkpoint("comment_level_data", comment_level_df)
            progress_mgr.mark_stage_complete("comment_level_data")
            
            # 检查中断标志
            if INTERRUPT_FLAG:
                logging.warning("检测到中断请求，安全退出处理...")
                return
        except Exception as e:
            logging.error(f"生成评论级别数据时出错: {e}")
    else:
        logging.info("评论级别数据生成阶段已完成")

    # 4.2 生成维度频次统计
    if not progress_mgr.is_stage_completed("dimension_stats"):
        logging.info("开始生成维度频次统计...")
        try:
            dimension_stats = {}
            for dimension in SCHEMA:
                # 统计原始UIE提取
                raw_count = Counter(all_results["raw"].get(dimension, []))
                # 统计增强特征
                enhanced_count = Counter()
                for comment_id, data in all_results["by_id"].items():
                    for item in data.get("enhanced_dimensions", {}).get(dimension, []):
                        enhanced_count[item.get("keyword", "")] += 1
                
                # 合并两种统计
                combined_count = raw_count.copy()
                for k, v in enhanced_count.items():
                    combined_count[k] = combined_count.get(k, 0) + v
                
                dimension_stats[dimension] = {
                    "raw": dict(raw_count.most_common(20)),
                    "enhanced": dict(enhanced_count.most_common(20)),
                    "combined": dict(combined_count.most_common(20))
                }
            
            # 保存维度统计
            dimension_stats_file = os.path.join(OUTPUT_DIR, f"{OUTPUT_FILE_PREFIX}_dimension_stats.json")
            with open(dimension_stats_file, "w", encoding="utf-8") as f:
                json.dump(dimension_stats, f, ensure_ascii=False, indent=2)
            logging.info(f"维度统计数据已保存到: {dimension_stats_file}")
            
            # 保存检查点和标记阶段完成
            save_checkpoint("dimension_stats", dimension_stats)
            progress_mgr.mark_stage_complete("dimension_stats")
            
            # 检查中断标志
            if INTERRUPT_FLAG:
                logging.warning("检测到中断请求，安全退出处理...")
                return
        except Exception as e:
            logging.error(f"生成维度统计时出错: {e}")
    else:
        logging.info("维度统计生成阶段已完成")

    # 4.3 生成类似于LDA的主题-文档矩阵
    topic_doc_df = None
    topic_similarity = None
    dominant_topics = None
    if not progress_mgr.is_stage_completed("topic_document"):
        logging.info("开始生成主题-文档矩阵...")
        try:
            # 生成类似于LDA的主题-文档矩阵
            topic_doc_df, topic_similarity, dominant_topics = generate_topic_document_matrix(all_results, df)
            
            # 保存主题-文档矩阵
            topic_doc_file = os.path.join(OUTPUT_DIR, f"{OUTPUT_FILE_PREFIX}_topic_document_matrix.csv")
            topic_doc_df.to_csv(topic_doc_file, index=True, encoding="utf-8")
            logging.info(f"主题-文档矩阵已保存到: {topic_doc_file}")
            
            # 保存主题相似度矩阵
            similarity_file = os.path.join(OUTPUT_DIR, f"{OUTPUT_FILE_PREFIX}_topic_similarity.csv")
            topic_similarity.to_csv(similarity_file, encoding="utf-8")
            logging.info(f"主题相似度矩阵已保存到: {similarity_file}")
            
            # 主题分布统计
            topic_distribution = Counter([info["topic"] for info in dominant_topics.values()])
            topic_distribution_df = pd.DataFrame({
                "主题": list(topic_distribution.keys()),
                "评论数量": list(topic_distribution.values())
            })
            distribution_file = os.path.join(OUTPUT_DIR, f"{OUTPUT_FILE_PREFIX}_topic_distribution.csv")
            topic_distribution_df.to_csv(distribution_file, index=False, encoding="utf-8")
            logging.info(f"主题分布统计已保存到: {distribution_file}")
            
            # 保存检查点和标记阶段完成
            save_checkpoint("topic_document", {
                "topic_doc_df": topic_doc_df,
                "topic_similarity": topic_similarity,
                "dominant_topics": dominant_topics,
                "topic_distribution": topic_distribution
            })
            progress_mgr.mark_stage_complete("topic_document")
            
            # 检查中断标志
            if INTERRUPT_FLAG:
                logging.warning("检测到中断请求，安全退出处理...")
                return
        except Exception as e:
            logging.error(f"生成主题-文档矩阵时出错: {e}")
    else:
        logging.info("主题-文档矩阵生成阶段已完成")
        # 从检查点加载数据
        topic_doc_data = load_checkpoint("topic_document")
        if topic_doc_data:
            topic_doc_df = topic_doc_data.get("topic_doc_df")
            topic_similarity = topic_doc_data.get("topic_similarity")
            dominant_topics = topic_doc_data.get("dominant_topics")

    # 4.4 情感分析结果
    sentiment_data = None
    sentiment_by_topic = None
    if not progress_mgr.is_stage_completed("sentiment_analysis"):
        logging.info("开始生成情感分析结果...")
        try:
            # 情感分析结果
            sentiment_data, sentiment_by_topic = extract_sentiment_analysis(all_results)
            
            # 保存评论级别情感分析
            sentiment_df = pd.DataFrame({
                "评论ID": list(sentiment_data.keys()),
                "情感倾向": list(sentiment_data.values())
            })
            sentiment_file = os.path.join(OUTPUT_DIR, f"{OUTPUT_FILE_PREFIX}_sentiment_analysis.csv")
            sentiment_df.to_csv(sentiment_file, index=False, encoding="utf-8")
            logging.info(f"情感分析结果已保存到: {sentiment_file}")
            
            # 保存主题-情感矩阵
            sentiment_topic_df = pd.DataFrame(sentiment_by_topic).T
            sentiment_topic_df.index.name = "主题"
            sentiment_topic_df.reset_index(inplace=True)
            sentiment_topic_file = os.path.join(OUTPUT_DIR, f"{OUTPUT_FILE_PREFIX}_sentiment_by_topic.csv")
            sentiment_topic_df.to_csv(sentiment_topic_file, index=False, encoding="utf-8")
            logging.info(f"主题-情感矩阵已保存到: {sentiment_topic_file}")
            
            # 保存检查点和标记阶段完成
            save_checkpoint("sentiment_analysis", {
                "sentiment_data": sentiment_data,
                "sentiment_by_topic": sentiment_by_topic
            })
            progress_mgr.mark_stage_complete("sentiment_analysis")
            
            # 检查中断标志
            if INTERRUPT_FLAG:
                logging.warning("检测到中断请求，安全退出处理...")
                return
        except Exception as e:
            logging.error(f"生成情感分析结果时出错: {e}")
    else:
        logging.info("情感分析结果生成阶段已完成")
        # 从检查点加载数据
        sentiment_data_obj = load_checkpoint("sentiment_analysis")
        if sentiment_data_obj:
            sentiment_data = sentiment_data_obj.get("sentiment_data")
            sentiment_by_topic = sentiment_data_obj.get("sentiment_by_topic")

    # 4.5 关键词共现分析
    keyword_counts = None
    co_occur_df = None
    if not progress_mgr.is_stage_completed("keyword_analysis"):
        logging.info("开始生成关键词分析结果...")
        try:
            # 关键词共现分析
            keyword_counts, co_occur_df = extract_keyword_co_occurrence(all_results)
            
            # 保存关键词频率
            for dimension, counts in keyword_counts.items():
                counts_df = pd.DataFrame({
                    "关键词": list(counts.keys()),
                    "频次": list(counts.values())
                })
                counts_df.sort_values("频次", ascending=False, inplace=True)
                keyword_file = os.path.join(OUTPUT_DIR, f"{OUTPUT_FILE_PREFIX}_{dimension}_keywords.csv")
                counts_df.to_csv(keyword_file, index=False, encoding="utf-8")
                logging.info(f"维度 '{dimension}' 关键词分析已保存到: {keyword_file}")
            
            # 保存关键词共现
            co_occur_file = os.path.join(OUTPUT_DIR, f"{OUTPUT_FILE_PREFIX}_keyword_co_occurrence.csv")
            co_occur_df.to_csv(co_occur_file, index=False, encoding="utf-8")
            logging.info(f"关键词共现分析已保存到: {co_occur_file}")
            
            # 保存检查点和标记阶段完成
            save_checkpoint("keyword_analysis", {
                "keyword_counts": keyword_counts,
                "co_occur_df": co_occur_df
            })
            progress_mgr.mark_stage_complete("keyword_analysis")
            
            # 检查中断标志
            if INTERRUPT_FLAG:
                logging.warning("检测到中断请求，安全退出处理...")
                return
        except Exception as e:
            logging.error(f"生成关键词分析结果时出错: {e}")
    else:
        logging.info("关键词分析结果生成阶段已完成")
        # 从检查点加载数据
        keyword_data = load_checkpoint("keyword_analysis")
        if keyword_data:
            keyword_counts = keyword_data.get("keyword_counts")
            co_occur_df = keyword_data.get("co_occur_df")

    # 4.6 生成可视化基础数据
    if not progress_mgr.is_stage_completed("visualization_data"):
        logging.info("开始生成可视化基础数据...")
        try:
            # 确保我们有需要的所有数据
            if not all([dominant_topics, sentiment_data, sentiment_by_topic, keyword_counts, topic_similarity]):
                logging.warning("缺少生成可视化数据所需的某些组件，尝试从检查点恢复...")
                
                # 尝试从检查点加载缺失的数据
                topic_doc_data = load_checkpoint("topic_document")
                if topic_doc_data and not dominant_topics:
                    dominant_topics = topic_doc_data.get("dominant_topics")
                    topic_similarity = topic_doc_data.get("topic_similarity")
                
                sentiment_data_obj = load_checkpoint("sentiment_analysis")
                if sentiment_data_obj and not sentiment_data:
                    sentiment_data = sentiment_data_obj.get("sentiment_data")
                    sentiment_by_topic = sentiment_data_obj.get("sentiment_by_topic")
                
                keyword_data = load_checkpoint("keyword_analysis")
                if keyword_data and not keyword_counts:
                    keyword_counts = keyword_data.get("keyword_counts")
            
            # 检查是否已恢复所需的所有数据
            if not all([dominant_topics, sentiment_data, sentiment_by_topic, keyword_counts, topic_similarity]):
                logging.error("无法恢复生成可视化数据所需的所有组件，请重新运行前面的分析阶段")
                # 将前面的阶段标记为未完成
                progress_mgr.current_progress["stages"]["topic_document"] = False
                progress_mgr.current_progress["stages"]["sentiment_analysis"] = False
                progress_mgr.current_progress["stages"]["keyword_analysis"] = False
                progress_mgr.save_progress()
                return main()
            
            # 生成主题分布统计
            topic_distribution = Counter([info["topic"] for info in dominant_topics.values()])
            
            # 生成可视化基础数据
            visualization_data = {
                "topic_distribution": {
                    "labels": list(topic_distribution.keys()),
                    "values": list(topic_distribution.values())
                },
                "sentiment_overall": {
                    "labels": ["正面", "负面", "中性"],
                    "values": [
                        list(sentiment_data.values()).count("正面"),
                        list(sentiment_data.values()).count("负面"),
                        list(sentiment_data.values()).count("中性")
                    ]
                },
                "topic_sentiment": {
                    dim: {
                        "labels": ["正面", "负面", "中性"],
                        "values": [data["正面"], data["负面"], data["中性"]]
                    }
                    for dim, data in sentiment_by_topic.items()
                },
                "topic_similarity": topic_similarity.to_dict() if hasattr(topic_similarity, 'to_dict') else {},
                "top_keywords": {
                    dim: list(counts.most_common(10))
                    for dim, counts in keyword_counts.items()
                },
            }
            
            # 保存可视化数据
            viz_file = os.path.join(OUTPUT_DIR, f"{OUTPUT_FILE_PREFIX}_visualization_data.json")
            with open(viz_file, "w", encoding="utf-8") as f:
                json.dump(visualization_data, f, ensure_ascii=False, indent=2)
            logging.info(f"可视化基础数据已保存到: {viz_file}")
            
            # 保存检查点和标记阶段完成
            save_checkpoint("visualization_data", visualization_data)
            progress_mgr.mark_stage_complete("visualization_data")
            
            # 检查中断标志
            if INTERRUPT_FLAG:
                logging.warning("检测到中断请求，安全退出处理...")
                return
        except Exception as e:
            logging.error(f"生成可视化基础数据时出错: {e}")
    else:
        logging.info("可视化基础数据生成阶段已完成")

    # 5. 创建数据摘要文件
    if not progress_mgr.is_stage_completed("summary"):
        logging.info("开始生成数据摘要文件...")
        try:
            summary_path = os.path.join(OUTPUT_DIR, "README.md")
            with open(summary_path, "w", encoding="utf-8") as f:
                f.write("# 无人驾驶网约车评论数据分析结果\n\n")
                f.write(f"## 分析时间: {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
                f.write(f"## 数据来源: {CSV_PATH}\n")
                f.write(f"## 评论总数: {len(comments)}\n\n")
                
                f.write("## 文件说明\n\n")
                f.write("本目录包含以下数据文件：\n\n")
                f.write(f"1. `{OUTPUT_FILE_PREFIX}_raw_data.csv` - 原始评论数据\n")
                f.write(f"2. `{OUTPUT_FILE_PREFIX}_comment_level_data.csv` - 评论级别的详细分析结果\n")
                f.write(f"3. `{OUTPUT_FILE_PREFIX}_dimension_stats.json` - 各维度关键词统计\n")
                f.write(f"4. `{OUTPUT_FILE_PREFIX}_topic_document_matrix.csv` - 主题-文档矩阵（类LDA结果）\n")
                f.write(f"5. `{OUTPUT_FILE_PREFIX}_topic_similarity.csv` - 主题间相似度矩阵\n")
                f.write(f"6. `{OUTPUT_FILE_PREFIX}_topic_distribution.csv` - 主题分布统计\n")
                f.write(f"7. `{OUTPUT_FILE_PREFIX}_sentiment_analysis.csv` - 评论级别情感分析结果\n")
                f.write(f"8. `{OUTPUT_FILE_PREFIX}_sentiment_by_topic.csv` - 主题-情感分布矩阵\n")
                f.write(f"9. `{OUTPUT_FILE_PREFIX}_*_keywords.csv` - 各维度关键词频率\n")
                f.write(f"10. `{OUTPUT_FILE_PREFIX}_keyword_co_occurrence.csv` - 关键词共现分析\n")
                f.write(f"11. `{OUTPUT_FILE_PREFIX}_visualization_data.json` - 用于可视化的整合数据\n\n")
                
                f.write("## 可视化建议\n\n")
                f.write("基于生成的数据，您可以创建以下可视化图表：\n\n")
                f.write("1. **主题分布图** - 使用饼图或柱状图展示 `topic_distribution.csv`\n")
                f.write("2. **情感分析** - 使用环形图展示整体情感分布，使用堆叠柱状图展示各主题下的情感分布\n")
                f.write("3. **主题相似度热力图** - 使用 `topic_similarity.csv` 创建热力图\n")
                f.write("4. **关键词词云** - 使用各维度的关键词频率数据创建词云\n")
                f.write("5. **关键词共现网络图** - 使用 `keyword_co_occurrence.csv` 创建网络关系图\n")
                f.write("6. **主题-评论矩阵可视化** - 使用 `topic_document_matrix.csv` 创建热力图或散点图\n\n")
                
                f.write("要进行可视化，推荐使用 Python 中的 Matplotlib, Seaborn, Plotly 或 Tableau, Power BI 等工具。\n")
            
            logging.info(f"数据摘要文件已生成: {summary_path}")
            
            # 标记阶段完成
            progress_mgr.mark_stage_complete("summary")
            
            # 检查中断标志
            if INTERRUPT_FLAG:
                logging.warning("检测到中断请求，安全退出处理...")
                return
        except Exception as e:
            logging.error(f"生成数据摘要文件时出错: {e}")
    else:
        logging.info("数据摘要生成阶段已完成")
    
    # 所有阶段都已完成
    logging.info("分析完成！所有结果已保存到: {}".format(OUTPUT_DIR))
    
    # 打印数据文件列表
    logging.info("\n生成的数据文件：")
    for file in os.listdir(OUTPUT_DIR):
        if file.startswith(OUTPUT_FILE_PREFIX):
            logging.info(f" - {file}")
            
    logging.info("\n断点续传检查点文件：")
    for file in os.listdir(CHECKPOINT_DIR):
        logging.info(f" - {file}")

def print_progress_info():
    """打印当前进度信息"""
    if os.path.exists(PROGRESS_FILE):
        with open(PROGRESS_FILE, "r", encoding="utf-8") as f:
            progress = json.load(f)
            
        print("\n当前分析进度:")
        print("="*50)
        
        # 打印阶段完成情况
        for stage, completed in progress["stages"].items():
            status = "✓ 已完成" if completed else "✗ 未完成"
            print(f"{stage:<20}: {status}")
        
        # 打印批处理进度
        if "batch_progress" in progress:
            batch_info = progress["batch_progress"]
            total = batch_info.get("total_batches", 0)
            completed = batch_info.get("completed_batches", 0)
            if total > 0:
                percent = (completed / total) * 100
                print(f"\n批处理进度: {completed}/{total} ({percent:.1f}%)")
        
        # 打印最后运行时间
        if progress.get("last_run"):
            print(f"\n上次运行时间: {progress['last_run']}")
        
        print("="*50)
    else:
        print("\n未找到进度文件，尚未开始分析或进度文件丢失")

if __name__ == "__main__":
    # 如果指定了查看进度参数
    if len(sys.argv) > 1 and sys.argv[1] == "--progress":
        print_progress_info()
    # 如果指定了重置进度参数
    elif len(sys.argv) > 1 and sys.argv[1] == "--reset":
        confirm = input("确定要重置所有分析进度吗？这将导致重新开始所有分析。(y/n): ")
        if confirm.lower() == 'y':
            pm = ProgressManager()
            pm.reset_progress(confirm=True)
            print("分析进度已重置！")
        else:
            print("已取消重置操作")
    else:
        # 正常运行分析
        try:
            main()
            if INTERRUPT_FLAG:
                print("\n分析已暂停，可以通过运行同一脚本继续处理")
                print("查看当前进度: python uie_judge.py --progress")
            else:
                print("\n所有分析已成功完成！")
        except KeyboardInterrupt:
            print("\n程序被中断，下次可继续运行")
        except Exception as e:
            logging.error(f"发生错误: {str(e)}")
            import traceback
            traceback.print_exc()


平台无gpu，使用本地电脑运行  

![Image Name](https://cdn.kesci.com/upload/svqmhm2wnz.jpg?imageView2/0/w/960/h/960)  


In [3]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

# ========== 嵌入数据 ==========
data = {
    '关键词': ['便宜', '价格', '优惠', '贵', '科技',
             '安全', '事故', '危险', '路况', '刹车',
             '技术', '系统', '稳定', '故障', '导航',
             '服务', '礼貌', '态度', '接送', '卫生',
             '牌照', '要求', '法规', '合法', '规定',
             '空调', '空间', '环境', '舒适', '温度',
             '安全', '共享', '数据', '摄像头', '监控'],
    '频率': [1844, 800, 779, 698, 385,
           1757, 998, 478, 90, 81,
           968, 330, 171, 155, 89,
           725, 143, 117, 44, 42,
           260, 135, 107, 93, 58,
           159, 115, 80, 31, 27,
           1757, 330, 290, 119, 100],
    '维度': ['价格评价'] * 5 + ['安全关注'] * 5 + ['技术可靠性'] * 5 +
           ['服务体验'] * 5 + ['法规合规'] * 5 + ['舒适度'] * 5 + ['隐私保护'] * 5
}

df = pd.DataFrame(data)

# ========== 可视化 ==========
plt.figure(figsize=(12, 8))
sns.set(style="whitegrid")

# 用不同颜色区分维度
palette = sns.color_palette("Set2", n_colors=df['维度'].nunique())

sns.barplot(data=df, x='频率', y='关键词', hue='维度', dodge=False, palette=palette)

plt.title("各维度下高频关键词分布（Top 5）", fontsize=16)
plt.xlabel("关键词频率", fontsize=13)
plt.ylabel("关键词", fontsize=13)
plt.legend(title="关键词维度", bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()


  font.set_text(s, 0.0, flags=flags)
  font.set_text(s, 0.0, flags=flags)
  font.set_text(s, 0.0, flags=flags)
  font.set_text(s, 0.0, flags=flags)
  font.set_text(s, 0.0, flags=flags)
  font.set_text(s, 0.0, flags=flags)
  font.set_text(s, 0.0, flags=flags)
  font.set_text(s, 0.0, flags=flags)
  font.set_text(s, 0.0, flags=flags)
  font.set_text(s, 0.0, flags=flags)
  font.set_text(s, 0.0, flags=flags)
  font.set_text(s, 0.0, flags=flags)
  font.set_text(s, 0.0, flags=flags)
  font.set_text(s, 0.0, flags=flags)
  font.set_text(s, 0.0, flags=flags)
  font.set_text(s, 0.0, flags=flags)
  font.set_text(s, 0.0, flags=flags)
  font.set_text(s, 0.0, flags=flags)
  font.set_text(s, 0.0, flags=flags)
  font.set_text(s, 0.0, flags=flags)
  font.set_text(s, 0.0, flags=flags)
  font.set_text(s, 0.0, flags=flags)
  font.set_text(s, 0.0, flags=flags)
  font.set_text(s, 0.0, flags=flags)
  font.set_text(s, 0.0, flags=flags)
  font.set_text(s, 0.0, flags=flags)
  font.set_text(s, 0.0, flags=flags)
 

In [5]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.font_manager import FontProperties

# —— 1. 嵌入数据 ——
data = {
    '关键词': ['便宜','价格','优惠','贵','科技',
             '安全','事故','危险','路况','刹车',
             '技术','系统','稳定','故障','导航',
             '服务','礼貌','态度','接送','卫生',
             '牌照','要求','法规','合法','规定',
             '空调','空间','环境','舒适','温度',
             '安全','共享','数据','摄像头','监控'],
    '频率': [1844,800,779,698,385,
           1757,998,478,90,81,
           968,330,171,155,89,
           725,143,117,44,42,
           260,135,107,93,58,
           159,115,80,31,27,
           1757,330,290,119,100],
    '维度': ['价格评价']*5 + ['安全关注']*5 + ['技术可靠性']*5 +
           ['服务体验']*5 + ['法规合规']*5 + ['舒适度']*5 + ['隐私保护']*5
}
df = pd.DataFrame(data)

# —— 2. 字体设置 ——
font_path = "/home/mw/input/ch491739173/Ubuntu_18.04_SimHei.ttf"
font_prop = FontProperties(fname=font_path)

plt.rcParams['font.sans-serif'] = [font_prop.get_name()]
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['font.size']       = 14
plt.rcParams['axes.titlesize']  = 20
plt.rcParams['axes.labelsize']  = 16
plt.rcParams['legend.fontsize'] = 14

# —— 3. 绘图 —— 
plt.figure(figsize=(14,10))
sns.set(style="whitegrid")
palette = sns.color_palette("Set2", df['维度'].nunique())

sns.barplot(
    data=df, x='频率', y='关键词',
    hue='维度', dodge=False, palette=palette
)

plt.title("各维度下高频关键词分布（Top 5）", fontproperties=font_prop)
plt.xlabel("关键词频率", fontproperties=font_prop)
plt.ylabel("关键词", fontproperties=font_prop)

# 强制坐标刻度用中文字体
plt.xticks(fontproperties=font_prop, fontsize=14)
plt.yticks(fontproperties=font_prop, fontsize=14)
# 强制图例用中文字体
plt.legend(prop=font_prop, title="维度", title_fontsize=16, loc='upper right')

plt.tight_layout()
plt.show()


  font.set_text(s, 0.0, flags=flags)
  font.set_text(s, 0, flags=flags)


In [15]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.font_manager import FontProperties

# ———— 1. 数据嵌入 ————
data = {
    '关键词': [
        '便宜','价格','优惠','贵','科技',
        '安全','事故','危险','路况','刹车',
        '技术','系统','稳定','故障','导航',
        '服务','礼貌','态度','接送','卫生',
        '牌照','要求','法规','合法','规定',
        '空调','空间','环境','舒适','温度',
        '安全','共享','数据','摄像头','监控'
    ],
    '频率': [
        1844, 800, 779, 698, 385,
        1757, 998, 478,  90,  81,
         968, 330, 171, 155,  89,
         725, 143, 117,  44,  42,
         260, 135, 107,  93,  58,
         159, 115,  80,  31,  27,
        1757, 330, 290, 119, 100
    ],
    '维度': (
        ['价格评价'] * 5 +
        ['安全关注'] * 5 +
        ['技术可靠性'] * 5 +
        ['服务体验'] * 5 +
        ['法规合规'] * 5 +
        ['舒适度'] * 5 +
        ['隐私保护'] * 5
    )
}
df = pd.DataFrame(data)

# ———— 2. 中文字体 & 全局样式 ————
font_path = "/home/mw/input/ch491739173/Ubuntu_18.04_SimHei.ttf"
font_prop = FontProperties(fname=font_path)

plt.rcParams['font.sans-serif']    = [font_prop.get_name()]  # 指定中文字体
plt.rcParams['axes.unicode_minus'] = False                   # 解决负号显示问题
plt.rcParams.update({
    'axes.titlesize': 22,
    'axes.labelsize': 22,
    'xtick.labelsize': 22,
    'ytick.labelsize': 22
})

# ———— 3. FacetGrid 分面条形图 ————
sns.set(style="whitegrid")
g = sns.FacetGrid(
    df,
    col='维度',
    col_wrap=4,
    sharex=False,
    sharey=False,
    height=3.5,
    aspect=1.2
)
g.map_dataframe(sns.barplot, x='频率', y='关键词', palette='viridis', orient='h')

# ———— 4. 子图调整：数值标签 + 轴标签 + 强制 y 轴中文 ————
max_freq = df['频率'].max()
for ax in g.axes.flatten():
    # 为每个条形添加数值标签
    for p in ax.patches:
        w = p.get_width()
        ax.text(
            w + max_freq*0.01,
            p.get_y() + p.get_height()/2,
            f"{int(w)}",
            va='center',
            fontproperties=font_prop,
            fontsize=12
        )
    # x 轴标签
    ax.set_xlabel('频率', fontproperties=font_prop, fontsize=16)
    # y 轴刻度标签手动指定中文字体
    y_labels = [t.get_text() for t in ax.get_yticklabels()]
    ax.set_yticklabels(y_labels, fontproperties=font_prop, fontsize=14)

# ———— 5. 子图标题 & 布局 ————
g.set_titles(col_template='{col_name}', fontproperties=font_prop, size=18)
plt.tight_layout(h_pad=2, w_pad=1)
plt.show()


In [5]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import font_manager

# 加载中文字体
font_path = "/home/mw/input/ch491739173/Ubuntu_18.04_SimHei.ttf"
chinese_font = font_manager.FontProperties(fname=font_path, size=14)

# 创建数据
data = {
    '': ['安全关注', '舒适度', '价格评价', '服务体验', '技术可靠性', '法规合规', '隐私保护'],
    '安全关注': [0.9999999999999943, 0.036495967547236004, 0.07490051068157604, 0.05396010202700443, 0.07523861763866395, 0.034797975927453276, 0.5865885533261261],
    '舒适度': [0.036495967547236004, 0.9999999999999969, 0.0698092513502067, 0.055743124472060986, 0.019619050328084912, 0.011802025077377874, 0.028334600756232894],
    '价格评价': [0.07490051068157604, 0.0698092513502067, 0.9999999999999997, 0.10206529180241004, 0.10552236197584401, 0.03140387248548638, 0.06972175771755759],
    '服务体验': [0.05396010202700443, 0.055743124472060986, 0.10206529180241004, 1.0000000000000038, 0.035337001810151196, 0.016291080476171286, 0.033449352934405506],
    '技术可靠性': [0.07523861763866395, 0.019619050328084912, 0.10552236197584401, 0.035337001810151196, 1.0000000000000004, 0.020962386250275113, 0.06286354999654796],
    '法规合规': [0.034797975927453276, 0.011802025077377874, 0.03140387248548638, 0.016291080476171286, 0.020962386250275113, 1.0000000000000056, 0.018799235342177655],
    '隐私保护': [0.5865885533261261, 0.028334600756232894, 0.06972175771755759, 0.033449352934405506, 0.06286354999654796, 0.018799235342177655, 1.0000000000000016]
}

df = pd.DataFrame(data)
df.set_index('', inplace=True)

# 创建上三角掩码
mask = np.triu(np.ones_like(df, dtype=bool))

# 设置图形
plt.figure(figsize=(12, 10), dpi=300)

# 创建热力图
heatmap = sns.heatmap(
    df, 
    annot=True, 
    fmt=".2f", 
    cmap="YlGnBu",
    mask=mask, 
    vmin=0, 
    vmax=1,
    linewidths=0.8,
    linecolor='white',
    annot_kws={
        "size": 14,
        "color": "black",
        "fontproperties": chinese_font  # 为注释文本指定字体
    },
    cbar_kws={
        "shrink": 0.8,
        "label": "相似度"
    }
)

# 自定义标题
title = plt.title('自动驾驶网约车评论主题相似度矩阵', pad=25, fontsize=20, fontweight='bold')
title.set_fontproperties(font_manager.FontProperties(fname=font_path, size=20))

# 自定义坐标轴标签
for label in heatmap.get_xticklabels():
    label.set_fontproperties(chinese_font)
    label.set_size(14)
    
for label in heatmap.get_yticklabels():
    label.set_fontproperties(chinese_font)
    label.set_size(14)

# 自定义颜色条标签
cbar = heatmap.collections[0].colorbar
cbar.ax.yaxis.label.set_fontproperties(chinese_font)
cbar.ax.yaxis.label.set_size(14)
for label in cbar.ax.get_yticklabels():
    label.set_fontproperties(chinese_font)
    label.set_size(12)

# 添加边框
for _, spine in heatmap.spines.items():
    spine.set_visible(True)
    spine.set_linewidth(1.5)

# 调整布局
plt.tight_layout()

# 保存图形
plt.savefig("similarity_matrix_improved.jpg", dpi=300, bbox_inches='tight', quality=95)
plt.show()