### 引入必要的库 ###

In [None]:
import os
import re
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import Counter
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import chardet
from tqdm import tqdm
from adjustText import adjust_text

### 配置参数，包含数据路径、模型参数等 ###

In [None]:
class Config:

    data_folder = "D:\\大三下\\Word2Vec\\Novel_Database"  # 小说文件夹
    martial_arts_file = "D:\\大三下\\Word2Vec\\stop\\金庸小说全武功.txt"  # 武功
    factions_file = "D:\\大三下\\Word2Vec\\stop\\金庸小说全门派.txt"  # 门派
    characters_file = "D:\\大三下\\Word2Vec\\stop\\金庸小说全人物.txt"  # 人物
    stopwords_file = "D:\\大三下\\Word2Vec\\stop\\stop_words.txt"  # 停用词


    embedding_dim = 200  # 词向量维度
    window_size = 5  # 上下文窗口
    batch_size = 512  # CPU 上减小 batch_size
    num_epochs = 5  # 训练轮次
    min_count = 5  # 最低词频
    neg_samples = 5  # 负采样数量

### 检测文件编码 ###

In [None]:
def detect_encoding(file_path):
    try:
        with open(file_path, 'rb') as f:
            result = chardet.detect(f.read())
        return result['encoding']
    except Exception as e:
        print(f"Error detecting encoding for {file_path}: {e}")
        return 'utf-8'  # 默认使用 utf-8

### 加载停用词 ###

In [None]:
def load_stopwords(file_path):
    #使用gbk编码
    try:
        with open(file_path, 'r', encoding='gbk') as f:
            return set(line.strip() for line in f)
    except UnicodeDecodeError:
        # 若 gbk 失败，尝试使用 chardet 检测的编码
        encoding = detect_encoding(file_path)
        try:
            with open(file_path, 'r', encoding=encoding) as f:
                return set(line.strip() for line in f)
        except Exception as e:
            print(f"Error loading stopwords from {file_path}: {e}")
    return set()

### 加载特殊词汇（人物、武功、门派） ###

In [None]:
def load_special_words(file_path):
    # 使用gbk编码
    try:
        with open(file_path, 'r', encoding='gbk') as f:
            return set(line.strip() for line in f if line.strip())
    except UnicodeDecodeError:
        # 若 utf-8 失败，尝试使用 chardet 检测的编码
        encoding = detect_encoding(file_path)
        try:
            with open(file_path, 'r', encoding=encoding) as f:
                return set(line.strip() for line in f if line.strip())
        except Exception as e:
            print(f"Error loading special words from {file_path}: {e}")
    return set()

### 读取小说并预处理 ###

In [None]:
def load_and_preprocess_novels(folder_path, characters, martial_arts, factions, stopwords):
    all_words = []
    for filename in os.listdir(folder_path):
        if filename.endswith('.txt'):
            file_path = os.path.join(folder_path, filename)
            encoding = detect_encoding(file_path)
            try:
                with open(file_path, 'r', encoding=encoding) as f:
                    text = f.read()
                    # 合并特殊词汇（人物、武功、门派）
                    special_words = characters | martial_arts | factions
                    pattern = '|'.join(map(re.escape, sorted(special_words, key=len, reverse=True)))
                    text = re.sub(f'({pattern})', r' \1 ', text)
                    # 分词
                    words = re.findall(r'[\w\u4e00-\u9fff]+', text)
                    words = [word for word in words if word not in stopwords]
                    all_words.extend(words)
            except Exception as e:
                print(f"Error processing {file_path}: {e}")
    return all_words

### 构建词汇表（确保索引从0连续） ###

In [None]:
def build_vocab(all_words, min_count):
    word_counts = Counter(all_words)
    filtered_words = {word: count for word, count in word_counts.items() if count >= min_count}
    vocab = {word: idx for idx, word in enumerate(filtered_words)}
    idx_to_word = {idx: word for word, idx in vocab.items()}
    print(f"Vocab size: {len(vocab)} (min_count={min_count})")
    return vocab, idx_to_word

### 负采样 Word2Vec 模型 ###

In [None]:
class SkipGramNeg(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super().__init__()
        self.vocab_size = vocab_size
        self.input_embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.output_embeddings = nn.Embedding(vocab_size, embedding_dim)
        # 初始化权重
        self.input_embeddings.weight.data.uniform_(-0.5 / embedding_dim, 0.5 / embedding_dim)
        self.output_embeddings.weight.data.uniform_(-0.5 / embedding_dim, 0.5 / embedding_dim)

    def forward(self, target, context, neg_samples):
        # 检查索引范围
        if (target >= self.vocab_size).any() or (context >= self.vocab_size).any() or (
                neg_samples >= self.vocab_size).any():
            raise ValueError(f"Index out of range (vocab_size={self.vocab_size})")

        # 正样本
        emb_target = self.input_embeddings(target)  # (batch_size, emb_dim)
        emb_context = self.output_embeddings(context)  # (batch_size, emb_dim)
        pos_score = torch.sum(emb_target * emb_context, dim=1)  # (batch_size)
        pos_loss = F.logsigmoid(pos_score)

        # 负样本
        neg_emb = self.output_embeddings(neg_samples)  # (batch_size, neg_samples, emb_dim)
        neg_score = torch.bmm(neg_emb, emb_target.unsqueeze(2)).squeeze()  # (batch_size, neg_samples)
        neg_loss = F.logsigmoid(-neg_score).mean(dim=1)

        return -(pos_loss + neg_loss).mean()

### 准备训练数据（严格检查索引范围） ###

In [None]:
def prepare_training_data(all_words, vocab, window_size):
    data = []
    vocab_size = len(vocab)
    for i, target_word in enumerate(tqdm(all_words, desc="Preparing data")):
        if target_word not in vocab:
            continue
        target_idx = vocab[target_word]
        start = max(0, i - window_size)
        end = min(len(all_words), i + window_size + 1)
        context_words = all_words[start:i] + all_words[i + 1:end]

        for context_word in context_words:
            if context_word in vocab:
                context_idx = vocab[context_word]
                if context_idx < vocab_size:  # 双重检查
                    data.append((target_idx, context_idx))

    print(f"Generated {len(data)} training pairs")
    return np.array(data, dtype=np.int32)

### 生成负样本 ###

In [None]:
def generate_neg_samples(context, vocab_size, neg_samples):
    neg = []
    for c in context:
        while True:
            sample = np.random.randint(0, vocab_size, neg_samples)
            if not np.any(sample == c):
                neg.append(sample)
                break
    return np.array(neg, dtype=np.int32)

### 训练函数 ###

In [None]:
def train_word2vec(model, train_data, vocab_size, batch_size, num_epochs, device):
    model.to(device)
    optimizer = optim.Adam(model.parameters())

    for epoch in range(num_epochs):
        total_loss = 0
        np.random.shuffle(train_data)

        for i in tqdm(range(0, len(train_data), batch_size), desc=f"Epoch {epoch + 1}/{num_epochs}"):
            batch = train_data[i:i + batch_size]
            targets = torch.tensor(batch[:, 0], dtype=torch.long).to(device)
            contexts = torch.tensor(batch[:, 1], dtype=torch.long).to(device)
            neg_samples = torch.tensor(generate_neg_samples(contexts.cpu().numpy(), vocab_size, Config.neg_samples)).to(
                device)

            optimizer.zero_grad()
            loss = model(targets, contexts, neg_samples)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch + 1}, Avg Loss: {total_loss / len(train_data):.4f}")

### 可视化（自动调整标签密度） ###

In [None]:
def visualize_embeddings(embeddings, words, title, dim=2):
    # 设置支持中文的字体
    plt.rcParams['font.sans-serif'] = ['SimHei']  # 使用黑体字体
    plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题

    pca = PCA(n_components=dim)
    # 确保 embeddings 是二维数组
    if embeddings.ndim == 1:
        embeddings = embeddings.reshape(1, -1)
    reduced = pca.fit_transform(embeddings)

    plt.figure(figsize=(15, 12))
    if dim == 2:
        # 绘制散点
        scatter = plt.scatter(reduced[:, 0], reduced[:, 1], alpha=0.6, s=50)

        # 智能标注（避免重叠）
        texts = []
        for i, (x, y) in enumerate(zip(reduced[:, 0], reduced[:, 1])):
            texts.append(plt.text(x, y, words[i], fontsize=9,
                                  bbox=dict(facecolor='white', alpha=0.7, edgecolor='none')))
        adjust_text(texts, arrowprops=dict(arrowstyle='->', color='red', lw=0.5), lim=50,  # 减少迭代次数，这里是，体量太大，内存不够了
                    force_text=0.1,  # 减少文本之间的排斥力
                    force_points=0.1)

    elif dim == 3:
        ax = plt.axes(projection='3d')
        scatter = ax.scatter3D(reduced[:, 0], reduced[:, 1], reduced[:, 2], alpha=0.6, s=50)

        # 3D标注（仅显示部分避免混乱）
        for i in range(0, len(words),10):  
            ax.text(reduced[i, 0], reduced[i, 1], reduced[i, 2], words[i],
                    fontsize=9, bbox=dict(facecolor='white', alpha=0.5))

    plt.title(title, fontsize=14, pad=20)
    plt.tight_layout()
    plt.savefig(f"word2vec_{dim}d.png", dpi=300, bbox_inches='tight')
    plt.show()


if __name__ == "__main__":
    # 加载停用词和特殊词汇
    stopwords = load_stopwords(Config.stopwords_file)
    characters = load_special_words(Config.characters_file)
    martial_arts = load_special_words(Config.martial_arts_file)
    factions = load_special_words(Config.factions_file)

    # 读取并预处理小说
    all_words = load_and_preprocess_novels(Config.data_folder, characters, martial_arts, factions, stopwords)

    # 构建词汇表
    vocab, idx_to_word = build_vocab(all_words, Config.min_count)

    # 准备训练数据
    train_data = prepare_training_data(all_words, vocab, Config.window_size)

    # 初始化模型
    vocab_size = len(vocab)
    model = SkipGramNeg(vocab_size, Config.embedding_dim)

    # 训练模型
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_word2vec(model, train_data, vocab_size, Config.batch_size, Config.num_epochs, device)

    # 获取所有人名的嵌入
    all_character_embeddings = []
    all_character_names = []
    selected_characters = list(characters)[:1000]  # 只选择前1000个人名，由于内存不够，所以只取了前一千
    for character in selected_characters:
        if character in vocab:
            idx = vocab[character]
            embedding = model.input_embeddings.weight[idx].detach().cpu().numpy()
            all_character_embeddings.append(embedding)
            all_character_names.append(character)
    all_character_embeddings = np.array(all_character_embeddings)

    # 可视化人名嵌入（二维）
    visualize_embeddings(all_character_embeddings, all_character_names, "Character Embeddings (2D)", dim=2)

    # 可视化人名嵌入（三维）
    visualize_embeddings(all_character_embeddings, all_character_names, "Character Embeddings (3D)", dim=3)