In [22]:
import os
import numpy as np
import pandas as pd
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.decomposition import LatentDirichletAllocation
from sklearn.model_selection import cross_val_score
from sklearn.naive_bayes import MultinomialNB
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score
import jieba
import random
from sklearn.model_selection import ShuffleSplit
base_path = r"D:\\Time_base\\NLPwork\\work2"  # 修改为实际路径
    
    # 获取目录下所有txt文件
novel_files = [
        os.path.join(base_path, fname) for fname in os.listdir(base_path) if fname.endswith('.txt')
    ]
    
    # 实验参数
K_values = [20, 100, 500, 1000, 3000 ]  # 段落 token 数
T_values = [ 8,24,48]                 # LDA 主题数
units = ['word', 'char']   

In [None]:
def clean_and_tokenize(input_text, segmentation_mode='word'):
    punctuation_chars = set(string.punctuation + '（）【】…—\n\r\t ')
    unwanted_phrases = ['www', 'cr173', 'com', '下载站', '电子书', '免费', 'txt', '.']
    
    processed_text = input_text
    for phrase in unwanted_phrases:
        processed_text = processed_text.replace(phrase, '')
    
    result_tokens = []
    if segmentation_mode == 'word':
        segmented_words = jieba.lcut(processed_text)
        result_tokens = [w for w in segmented_words if w not in punctuation_chars]
    elif segmentation_mode == 'char':
        result_tokens = [c for c in processed_text 
                        if c not in punctuation_chars 
                        and c.strip() 
                        and '\u4e00' <= c <= '\u9fff']
    
    return result_tokens

In [26]:
def create_samples_from_files(file_list, segment_length, text_unit='word', sample_count=1000):
    sample_data = []
    label_list = []
    num_files = len(file_list)
    base_sample_per_file = sample_count // num_files
    remainder_samples = sample_count % num_files

    for idx, filename in enumerate(file_list):
        # File reading with multiple encoding attempts
        file_content = None
        for encoding in ['utf-8', 'gb18030']:
            try:
                with open(filename, 'r', encoding=encoding) as file_obj:
                    file_content = file_obj.read()
                break
            except Exception as read_error:
                if encoding == 'gb18030':
                    print(f"Failed to read {filename}: {read_error}")
                continue

        if file_content is None:
            continue

        # Token processing
        token_list = clean_and_tokenize(file_content, text_unit)
        if len(token_list) < segment_length:
            print(f"Insufficient tokens in {filename} (needed: {segment_length})")
            continue

        # Sample calculation
        required_samples = base_sample_per_file + (1 if idx < remainder_samples else 0)

        # Segment extraction
        extracted_segments = []
        sampling_interval = max(1, (len(token_list) - segment_length) // required_samples)
        
        for pos in range(0, len(token_list) - segment_length + 1, sampling_interval):
            current_segment = ' '.join(token_list[pos:pos + segment_length])
            if current_segment.strip():
                extracted_segments.append(current_segment)
                if len(extracted_segments) >= required_samples:
                    break

        # Handle insufficient segments
        while len(extracted_segments) < required_samples and extracted_segments:
            extracted_segments.append(random.choice(extracted_segments))

        # Update collections
        sample_data.extend(extracted_segments)
        base_name = os.path.basename(filename).replace('.txt', '')
        label_list.extend([base_name] * len(extracted_segments))

    return sample_data, label_list

In [24]:

from sklearn.svm import SVC  # 添加这行代码
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt

# 设置全局字体为支持中文的字体
plt.rcParams['font.sans-serif'] = ['SimHei']  # Windows 系统
# plt.rcParams['font.sans-serif'] = ['Arial Unicode MS']  # macOS 系统
plt.rcParams['axes.unicode_minus'] = False
def evaluate_lda_rbf(X_topics, labels, n_splits=10, test_size=100):
    
    scores = []
    # 初始化交叉验证对象
    cv = ShuffleSplit(n_splits=n_splits, test_size=test_size, random_state=42)
    labels = np.array(labels)

    # 进行交叉验证
    for train_idx, test_idx in cv.split(X_topics):
        # 初始化高斯核 SVM 分类器
        clf = SVC(kernel='rbf', random_state=42)
        # 训练分类器
        clf.fit(X_topics[train_idx], labels[train_idx])
        # 预测测试集
        preds = clf.predict(X_topics[test_idx])
        # 计算准确率并存储
        scores.append(accuracy_score(labels[test_idx], preds))

    # 返回平均准确率和标准差
    return np.mean(scores), np.std(scores)

In [None]:
def run_experiment(novel_files, K_values, T_values, units, total=1000, n_splits=10, test_size=100):
    
    # 获取目录下所有txt文件
   

    results = []
    acc_data = {}

    for unit in units:
        for K in K_values:
            print(f"\nUNIT {unit},LENGTH K={K}")
            dataset, labels = create_samples_from_files(novel_files, K, unit, total=total)

            # 打印前 5 个样本作为案例
            novel_count = {}
            for label in labels:
                novel_count[label] = novel_count.get(label, 0) + 1
            print("段落数量：", novel_count)

            print("样例段落及其标签：")
            for i in range(5):
                print(f"标签：{labels[i]}")
                print(f"内容：{dataset[i]}")
                print("=" * 40)

            if not dataset or len(dataset) < total:
                print(f"警告：生成的样本数量不足，当前样本数为 {len(dataset)}")
                continue

            # 特征提取
            if unit == 'char':
                vectorizer = CountVectorizer(token_pattern=r'(?u)\b\w+\b')
            else:
                vectorizer = CountVectorizer()
            X = vectorizer.fit_transform(dataset)
            print(f"CountVectorizer 特征矩阵大小：{X.shape}")

            for T in T_values:
                # 训练 LDA 模型
                lda = LatentDirichletAllocation(n_components=T, random_state=42)
                X_topics = lda.fit_transform(X)

                # 评估分类性能
                mean_acc, std_acc = evaluate_lda_rbf(X_topics, labels, n_splits=n_splits, test_size=test_size)
                results.append((unit, K, T, mean_acc, std_acc))

                # 存储结果到结构化字典
                if unit not in acc_data:
                    acc_data[unit] = {}
                if K not in acc_data[unit]:
                    acc_data[unit][K] = {}
                acc_data[unit][K][T] = mean_acc

                print(f" K: {K}, T: {T} ACC {mean_acc:.4f}")

    # 绘制实验结果图
    for current_unit in units:
        plt.figure(figsize=(9, 7))
        sorted_K_list = sorted(acc_data.get(current_unit, {}).keys())

        for K_val in sorted_K_list:
            mean_acc_list = [acc_data[current_unit][K_val].get(T_val, 0) for T_val in T_values]
            plt.plot(T_values, mean_acc_list, marker='D', label=f'K={K_val}',linestyle="--")

        # 设置标题、坐标轴和图例
        plt.title(f"Acc unit='{current_unit}'")
        plt.xlabel("主题数量")
        plt.ylabel("分类准确率")
        plt.ylim([0, 1.0])
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.show()

    return results, acc_data

In [None]:
import string
base_path = r"D:\\Time_base\\NLPwork\\work2"  # 修改为实际路径
    
    # 获取目录下所有txt文件
novel_files = [
        os.path.join(base_path, fname) for fname in os.listdir(base_path) if fname.endswith('.txt')
    ]
    
    # 实验参数
K_values = [20, 100, 500, 1000, 3000 ]  # 段落 token 数
T_values = [ 8,24,48]                 # LDA 主题数
units = ['word', 'char']  
results, acc_data = run_experiment(novel_files, K_values, T_values, units)