In [None]:
import os
import pandas as pd
import jieba
import re
from bertopic import BERTopic
from sklearn.feature_extraction.text import CountVectorizer
import matplotlib.pyplot as plt
import numpy as np
import logging
from datetime import datetime
import plotly.graph_objects as go
import plotly.express as px
from collections import defaultdict
from sentence_transformers import SentenceTransformer

plt.rcParams['font.sans-serif'] = ['SimHei'] # 中文显示

# 停用词（自定义和文件）
custom_stopwords = [
    "豆包", "DS", "下载", "文心", "百度", "一言", "内容", "东西", "偷笑", "感觉", "真的", "模型",
    "女套", "男套", "玫瑰", "捂脸", "哈哈", "排位", "正赛", "量子", "男鞋", "雷诺", "男裤", "永久",
    "哈", "哈哈哈", "哈哈哈哈"
]
STOPWORDS_PATH = "stopwords1893.txt"

save_dir = "E:"
output_dir = os.path.join(save_dir, "sankey_outputs")
os.makedirs(output_dir, exist_ok=True)

platforms = ['dy', 'xhs', 'tieba']
comment_column = 'content'
time_column_map = {
    'dy': 'time',
    'xhs': 'time',
    'tieba': 'time'
}
optimal_k_dict = {'dy': 4, 'xhs': 6, 'tieba': 6}  # 主题数
time_freq = 'M'
min_docs_per_period = 20
similarity_threshold = 0.05

def load_chinese_stopwords(filepath, custom_list):
    stopwords = set(custom_list)
    try:
        with open(filepath, 'r', encoding='utf-8') as f:
            for line in f:
                word = line.strip()
                if word:
                    stopwords.add(word)
    except FileNotFoundError:
        print(f"警告: 标准停用词文件未找到: '{filepath}'。仅使用自定义停用词。")
    except Exception as e:
        print(f"加载停用词文件 '{filepath}' 时出错: {e}。仅使用自定义停用词。")
    return stopwords

def preprocess_text(text, stop_words_set):
    if not isinstance(text, str):
        return []
    text = re.sub(r"http[s]?://\S+", "", text)
    text = re.sub(r"[^\u4e00-\u9fa5]", " ", text)
    text = re.sub(r"\s+", " ", text).strip()
    if not text:
        return []
    words = jieba.lcut(text)
    # 过滤掉空格和空字符串
    words = [word for word in words if word and word.strip() and word not in stop_words_set]
    return words


def parse_time(time_str):
    if pd.isna(time_str):
        return pd.NaT
    if isinstance(time_str, (int, float)):
        try:
            if 631152000 < time_str < datetime.now().timestamp() + 31536000 * 5:
                return pd.to_datetime(time_str, unit='s')
            else:
                return pd.NaT
        except (ValueError, OSError):
            return pd.NaT
    if isinstance(time_str, str):
        time_str = time_str.strip()
        formats = [
            "%Y-%m-%d %H:%M:%S", "%Y/%m/%d %H:%M:%S", "%Y-%m-%d %H:%M",
            "%Y/%m/%d %H:%M", "%Y-%m-%d", "%Y/%m/%d", "%Y年%m月%d日 %H:%M",
            "%Y年%m月%d日", "%m-%d %H:%M", "%m月%d日 %H:%M"
        ]
        for fmt in formats:
            try:
                if "%Y" not in fmt and ("%m-%d" in fmt or "%m月%d日" in fmt):
                    continue
                dt_obj = datetime.strptime(time_str, fmt)
                return dt_obj
            except ValueError:
                continue
        try:
            dt_obj = pd.to_datetime(time_str)
            return dt_obj
        except (ValueError, TypeError):
            return pd.NaT
    return pd.NaT

# --- JSD 相似度函数
def kl_divergence(p, q):
    p = np.asarray(p, dtype=np.float64)
    q = np.asarray(q, dtype=np.float64)
    epsilon = 1e-10
    p = np.where(p == 0, epsilon, p)
    q = np.where(q == 0, epsilon, q)
    division = np.divide(p, q, out=np.zeros_like(p), where=q!=0)
    log_division = np.log(np.where(division > 0, division, epsilon))
    kl_div = np.sum(p * log_division)
    return kl_div

def js_divergence(p, q):
    p = np.asarray(p, dtype=np.float64)
    q = np.asarray(q, dtype=np.float64)
    if len(p) != len(q):
        max_len = max(len(p), len(q))
        p = np.pad(p, (0, max_len - len(p)))
        q = np.pad(q, (0, max_len - len(q)))
    p_sum = np.sum(p)
    q_sum = np.sum(q)
    if p_sum > 1e-9: p /= p_sum
    else: p = np.ones_like(p) / len(p)
    if q_sum > 1e-9: q /= q_sum
    else: q = np.ones_like(q) / len(q)
    m = 0.5 * (p + q)
    jsd = 0.5 * (kl_divergence(p, m) + kl_divergence(q, m))
    jsd_bits = jsd / np.log(2)
    return jsd_bits

def get_topic_word_dist(topic_model, topic_id, combined_vocab):
    topic = topic_model.get_topic(topic_id)
    word_dist = np.full(len(combined_vocab), 1e-10)
    if topic:
        total_weight = sum([w for word, w in topic])
        if total_weight == 0:
            return word_dist
        for word, weight in topic:
            if word in combined_vocab:
                idx = combined_vocab.index(word)
                word_dist[idx] = weight / total_weight
    return word_dist

def calculate_topic_similarity_bertopic(model1, model2, k1, k2):
    # 获取所有主题出现过的词
    vocab1 = set(word for tid in range(k1) for word, _ in (model1.get_topic(tid) or []))
    vocab2 = set(word for tid in range(k2) for word, _ in (model2.get_topic(tid) or []))
    combined_vocab = sorted(list(vocab1 | vocab2))
    similarity_matrix_js = np.zeros((k1, k2))
    for i in range(k1):
        dist1 = get_topic_word_dist(model1, i, combined_vocab)
        for j in range(k2):
            dist2 = get_topic_word_dist(model2, j, combined_vocab)
            jsd = js_divergence(dist1, dist2)
            similarity = 1.0 - jsd
            similarity_matrix_js[i, j] = max(0, min(1, similarity))
    return similarity_matrix_js

def plot_sankey(platform_name, sankey_data, all_nodes, period_labels, k_value, similarity_threshold, output_dir):
    if sankey_data.empty:
        print(f"平台 {platform_name} 没有足够的数据绘制Sankey图。")
        return
    node_map = {node_label: i for i, node_label in enumerate(all_nodes)}
    topic_color_palette = px.colors.qualitative.Vivid
    if k_value > len(topic_color_palette):
        print(f"警告: 平台 {platform_name} 的 K 值 ({k_value}) 大于预选颜色板大小。颜色将会循环使用。")
    node_colors = []
    node_display_labels = []
    node_hover_labels = []
    for node_label in all_nodes:
        try:
            parts = node_label.split('_T')
            period = parts[0]
            topic_index = int(parts[-1])
            color = topic_color_palette[topic_index % len(topic_color_palette)]
            node_colors.append(color)
            node_display_labels.append(str(topic_index))
            node_hover_labels.append(f"Period: {period}<br>Topic: {topic_index}")
        except (ValueError, IndexError, KeyError):
            node_colors.append('grey')
            node_display_labels.append('?')
            node_hover_labels.append(node_label)
    link = dict(
        source=sankey_data['source_idx'].tolist(),
        target=sankey_data['target_idx'].tolist(),
        value=sankey_data['value'].tolist(),
        color='rgba(180, 180, 180, 0.35)',
        hovertemplate='Similarity: %{value:.3f}<extra></extra>'
    )
    node = dict(
        pad=12,
        thickness=18,
        line=dict(color="black", width=0.6),
        label=node_display_labels,
        color=node_colors,
        customdata=node_hover_labels,
        hovertemplate='<b>%{customdata}</b><extra></extra>',
    )
    fig = go.Figure()
    fig.add_trace(go.Sankey(
        arrangement='snap',
        node=node,
        link=link,
    ))
    fig.update_layout(
        title=dict(
            text=f"<b>{platform_name.upper()} Platform: Topic Evolution (K={k_value}, Min Similarity={similarity_threshold})</b>",
            font=dict(size=16, family="Arial, sans-serif", color='black'),
            x=0.5,
            xanchor='center'
        ),
        font=dict(size=11, family="Arial, sans-serif"),
        margin=dict(l=30, r=30, t=70, b=30),
        plot_bgcolor='white',
        height=600
    )
    fig.show()
    svg_filename = f"{platform_name}_k{k_value}_sim{similarity_threshold}_bertopic_evolution.svg"
    svg_filepath = os.path.join(output_dir, svg_filename)
    try:
        # print(f"  正在展示图像到: {svg_filepath}")
        # fig.write_image(svg_filepath, format='svg', width=1200, height=600) #可能存在bug保存不了
        # fig.write_image(f"{platform_name}.png", width=1200, height=800)

        print(f"  成功保存 SVG 图像。")
    except (ImportError, ValueError) as e:
        print(f"\n[!] 错误：无法保存 SVG 图像。")
        print(f"    请确保已安装 'kaleido' 包 (pip install -U kaleido)。")
        print(f"    具体错误: {e}\n")
    except Exception as e:
        print(f"\n[!] 保存 SVG 时发生未知错误: {e}\n")

if __name__ == "__main__":
    jieba.setLogLevel(logging.ERROR)
    print("正在加载停用词...")
    stop_words = load_chinese_stopwords(STOPWORDS_PATH, custom_stopwords)
    all_platforms_sankey_data = {}

    for platform in platforms:
        print(f"\n{'='*20} 开始处理平台: {platform} {'='*20}")
        file_path = os.path.join(save_dir, f"{platform}_combined.csv")
        time_col = time_column_map[platform]
        optimal_k = optimal_k_dict.get(platform)
        if optimal_k is None:
            print(f"错误: 平台 {platform} 未在 optimal_k_dict 中找到对应的 K 值。跳过此平台。")
            continue
        print(f"平台 {platform} 使用 K = {optimal_k}")

        # 1. 加载数据
        try:
            print("正在加载数据文件:")
            df = pd.read_csv(file_path, encoding='utf-8', low_memory=False)
            print(f"成功加载数据，共 {len(df)} 条记录。")
            if comment_column not in df.columns or time_col not in df.columns:
                print(f"错误: 文件 '{file_path}' 缺少必需的列 ('{comment_column}' 或 '{time_col}')。")
                continue
            df.dropna(subset=[comment_column], inplace=True)
            print(f"移除空评论后剩余: {len(df)} 条记录。")
            if df.empty: continue
            print(f"正在解析时间列 '{time_col}'...")
            original_time_count = len(df)
            df['datetime'] = df[time_col].apply(parse_time)
            df.dropna(subset=['datetime'], inplace=True)
            valid_time_count = len(df)
            print(f"时间解析完成。有效时间记录: {valid_time_count} (移除了 {original_time_count - valid_time_count} 条无效时间)")
            if df.empty:
                print(f"平台 {platform} 没有有效的带时间戳的评论数据。")
                continue
            df.sort_values('datetime', inplace=True)
            df.reset_index(drop=True, inplace=True)
            # 时间范围限制
            start_date = pd.to_datetime("2024-02-01")
            end_date = pd.to_datetime("2025-04-01")
            df = df[(df['datetime'] >= start_date) & (df['datetime'] < end_date)].copy()
            if df.empty:
                print(f"平台 {platform} 在指定时间范围 ({start_date.date()} to {end_date.date()}) 内没有数据，跳过。")
                continue
            print(f"时间范围筛选 ({start_date.date()} to {end_date.date()}) 后剩余: {len(df)} 条记录。")
        except FileNotFoundError:
            print(f"错误: 文件未找到 '{file_path}'。")
            continue
        except pd.errors.EmptyDataError:
             print(f"错误: 文件 '{file_path}' 为空。")
             continue
        except Exception as e:
            print(f"加载或初步处理文件 '{file_path}' 时发生错误: {e}。")
            continue

        # 2. 文本预处理
        print("开始进行文本预处理...")
        df['processed_text'] = df[comment_column].apply(lambda x: preprocess_text(x, stop_words))
        df = df[df['processed_text'].apply(len) > 0]
        print(f"文本预处理完成。有效文档数量: {len(df)}")
        if df.empty:
             print(f"平台 {platform} 在预处理后没有有效的文本数据。")
             continue

        # 3. 按时间段分组
        print(f"按时间频率 '{time_freq}' 对数据进行分组...")
        df['time_period'] = df['datetime'].dt.to_period(time_freq)
        grouped_data = df.groupby('time_period')
        valid_periods = {period: group for period, group in grouped_data if len(group) >= min_docs_per_period}
        print(f"共有 {len(grouped_data)} 个原始时间段，过滤后剩余 {len(valid_periods)} 个有效时间段 (文档数 >= {min_docs_per_period})。")
        if len(valid_periods) < 2:
            print(f"平台 {platform} 的有效时间段少于 2 个，无法进行演化分析。")
            continue

        # 4. 每时间段 BERTopic
        print(f"为每个有效时间段训练 BERTopic 模型 (K={optimal_k})...")
        period_models = {}
        sorted_periods = sorted(valid_periods.keys())
        for period in sorted_periods:
            period_df = valid_periods[period]
            print(f"\n====== 调试：时间段 {period} ======")
            # 打印 period 原始评论样例
            print(f"{period} 原始评论前5条：")
            print(period_df[comment_column].head(5).tolist())
            
            # 分词前的文本数量
            print(f"{period} period 原始评论数量: {len(period_df)}")
            
            # 分词，允许所有长度词，并打印部分分词结果
            seg_list = period_df[comment_column].apply(lambda x: preprocess_text(x, stop_words)).tolist()
            seg_list = [words for words in seg_list if words and any([w.strip() for w in words])]
            valid_texts = [" ".join(words) for words in seg_list]
            print(f"{period} period 分词后样例前5条：")
            print(seg_list[:5])
            
            # 检查分词后是否都是空list
            empty_cnt = sum([1 for words in seg_list if not words])
            print(f"{period} period 分词后全空文本数: {empty_cnt}")
            
            # 分词拼接回字符串，做给vectorizer用
            valid_texts = [" ".join(words) for words in seg_list if words]
            print(f"{period} 分词拼接后非空条数: {len(valid_texts)}")
            print(f"{period} 分词拼接后样例前5条: {valid_texts[:5]}")
            
            # 如果全部空，跳过，不训练
            if not valid_texts or all([not txt.strip() for txt in valid_texts]):
                print(f"  跳过时间段 {period}: 分词拼接后无有效文本（全被过滤或全是停用词）")
                continue
            try:
                vectorizer_model = CountVectorizer(tokenizer=lambda x: x.split(), token_pattern=None)
                embedding_model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
                topic_model = BERTopic(
                    embedding_model=embedding_model,
                    nr_topics=optimal_k,
                    vectorizer_model=vectorizer_model,
                    calculate_probabilities=True,
                    verbose=False,
                    min_topic_size=5
                )
                topics, probs = topic_model.fit_transform(valid_texts)
                # 有效主题数
                actual_k = len([tid for tid in topic_model.get_topic_info()['Topic'] if tid != -1])
                if topic_model.get_topics():
                    period_models[period] = {
                        'model': topic_model,
                        'topics': topics,
                        'probs': probs,
                        'actual_k': actual_k
                    }
                    print(f"    时间段 {period} 的BERTopic模型训练完成。主题数: {actual_k}")
                else:
                    print(f"    警告: {period} 没有学出有效主题。")
            except Exception as e:
                print(f"  处理时间段 {period} 时出错: {e}")
                if period in period_models: del period_models[period]

        # 5. 计算相邻时间段主题相似度
        print("计算相邻时间段的主题相似度...")
        sankey_links = []
        valid_trained_periods = sorted([p for p in sorted_periods if p in period_models])
        if len(valid_trained_periods) < 2:
             print(f"平台 {platform} 训练成功的模型不足 2 个时间段，无法计算演化。")
             continue
        for i in range(len(valid_trained_periods) - 1):
            period1_key = valid_trained_periods[i]
            period2_key = valid_trained_periods[i+1]
            print(f"  比较: {period1_key} -> {period2_key}")
            model1_data = period_models[period1_key]
            model2_data = period_models[period2_key]
            k1 = optimal_k
            k2 = optimal_k
            try:
                similarity_matrix = calculate_topic_similarity_bertopic(
                    model1_data['model'], model2_data['model'], k1, k2
                )
                for src_topic_idx in range(similarity_matrix.shape[0]):
                    for tgt_topic_idx in range(similarity_matrix.shape[1]):
                        similarity = similarity_matrix[src_topic_idx, tgt_topic_idx]
                        if similarity >= similarity_threshold:
                            source_node = f"{period1_key}_T{src_topic_idx}"
                            target_node = f"{period2_key}_T{tgt_topic_idx}"
                            sankey_links.append({
                                'source': source_node,
                                'target': target_node,
                                'value': similarity
                            })
            except Exception as e:
                print(f"    计算 {period1_key} 和 {period2_key} 之间相似度时出错: {e}")

        if not sankey_links:
             print(f"平台 {platform} 没有计算出高于阈值 {similarity_threshold} 的主题相似度链接。无法生成Sankey图。")
             continue

        # 6. 准备 Sankey 图数据
        sankey_df = pd.DataFrame(sankey_links)
        all_nodes = sorted(list(set(sankey_df['source'].tolist() + sankey_df['target'].tolist())),
                           key=lambda x: (str(x.split('_T')[0]), int(x.split('_T')[1])))
        period_labels_obj = sorted(list(set([pd.Period(node.split('_T')[0], freq=time_freq) for node in all_nodes])))
        period_labels_str = [str(p) for p in period_labels_obj]
        node_map = {node_label: i for i, node_label in enumerate(all_nodes)}
        sankey_df['source_idx'] = sankey_df['source'].map(node_map)
        sankey_df['target_idx'] = sankey_df['target'].map(node_map)
        all_platforms_sankey_data[platform] = {
            'sankey_df': sankey_df,
            'all_nodes': all_nodes,
            'period_labels': period_labels_str,
            'k_value': optimal_k
        }
        print(f"平台 {platform} 的 Sankey 数据准备完成。共 {len(all_nodes)} 个节点，{len(sankey_df)} 个链接。")

    print(f"\n{'='*20} 开始绘制 Sankey 图 {'='*20}")
    if not all_platforms_sankey_data:
        print("没有可用于绘制 Sankey 图的数据。")
    else:
        for platform, data in all_platforms_sankey_data.items():
            print(f"\n正在绘制平台 {platform} 的 Sankey 图...")
            plot_sankey(platform,
                        data['sankey_df'],
                        data['all_nodes'],
                        data['period_labels'],
                        data['k_value'],
                        similarity_threshold,
                        output_dir)

    print("\n所有平台处理完毕。")

正在加载停用词...

平台 dy 使用 K = 4
正在加载数据文件:
成功加载数据，共 22431 条记录。
移除空评论后剩余: 22213 条记录。
正在解析时间列 'time'...
时间解析完成。有效时间记录: 22213 (移除了 0 条无效时间)
时间范围筛选 (2024-02-01 to 2025-04-01) 后剩余: 19800 条记录。
开始进行文本预处理...
文本预处理完成。有效文档数量: 17632
按时间频率 'M' 对数据进行分组...
共有 14 个原始时间段，过滤后剩余 14 个有效时间段 (文档数 >= 20)。
为每个有效时间段训练 BERTopic 模型 (K=4)...

2024-02 原始评论前5条：
['Ai不能和第五人格比[捂脸]', '华为自带的在手机里干嘛非得搞这个', '它能算出双色球开奖结果吗[呲牙][呲牙][呲牙]', '刚刚接到通知，手机被淘汰了[微笑]', '能算出今晚排例五的开奖结果吗？']
2024-02 period 原始评论数量: 370
2024-02 period 分词后样例前5条：
[['第五', '人格'], ['华为', '自带', '手机', '里', '干嘛', '搞'], ['能算出', '双色球', '开奖', '呲', '牙', '呲', '牙', '呲', '牙'], ['刚刚', '接到', '通知', '手机', '淘汰', '微笑'], ['能算出', '今晚', '排例', '开奖']]
2024-02 period 分词后全空文本数: 0
2024-02 分词拼接后非空条数: 370
2024-02 分词拼接后样例前5条: ['第五 人格', '华为 自带 手机 里 干嘛 搞', '能算出 双色球 开奖 呲 牙 呲 牙 呲 牙', '刚刚 接到 通知 手机 淘汰 微笑', '能算出 今晚 排例 开奖']
    时间段 2024-02 的BERTopic模型训练完成。主题数: 3

2024-03 原始评论前5条：
['[烟花][烟花][烟花][烟花][烟花]👿👿👿👿', '与其说是AI，不如说是人类网络现有资料整合工具[尬笑]', '三星A100是一款较早期的手机型号，通常不支持直接刷入DOS操作系统。如果你想在三星A100手机上运行中文DOS程序，可以尝试使

  成功保存 SVG 图像。

正在绘制平台 xhs 的 Sankey 图...


  成功保存 SVG 图像。

正在绘制平台 tieba 的 Sankey 图...


  成功保存 SVG 图像。

所有平台处理完毕。
