In [2]:
import re
import os
import numpy as np
from collections import Counter
from sklearn.datasets import fetch_20newsgroups

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

In [3]:
newsgroups = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))
docs = newsgroups.data

In [4]:
print(newsgroups.keys())

dict_keys(['data', 'filenames', 'target_names', 'target', 'DESCR'])


In [5]:
# 顯示前 3 篇文章的內容
for i in range(3):
    print(f"文章 {i+1}:\n")
    print(docs[i])
    print("=" * 80)  # 分隔線

文章 1:



I am sure some bashers of Pens fans are pretty confused about the lack
of any kind of posts about the recent Pens massacre of the Devils. Actually,
I am  bit puzzled too and a bit relieved. However, I am going to put an end
to non-PIttsburghers' relief with a bit of praise for the Pens. Man, they
are killing those Devils worse than I thought. Jagr just showed you why
he is much better than his regular season stats. He is also a lot
fo fun to watch in the playoffs. Bowman should let JAgr have a lot of
fun in the next couple of games since the Pens are going to beat the pulp out of Jersey anyway. I was very disappointed not to see the Islanders lose the final
regular season game.          PENS RULE!!!


文章 2:

My brother is in the market for a high-performance video card that supports
VESA local bus with 1-2MB RAM.  Does anyone have suggestions/ideas on:

  - Diamond Stealth Pro Local Bus

  - Orchid Farenheit 1280

  - ATI Graphics Ultra Pro

  - Any other high-performance VLB 

In [6]:
# 顯示前 3 篇文章的類別索引
print("前 3 篇文章的類別索引:", newsgroups.target[:3])

# 顯示對應的類別名稱
print("對應的新興論壇分類:", [newsgroups.target_names[idx] for idx in newsgroups.target[:3]])

前 3 篇文章的類別索引: [10  3 17]
對應的新興論壇分類: ['rec.sport.hockey', 'comp.sys.ibm.pc.hardware', 'talk.politics.mideast']


In [7]:
print(newsgroups.target_names)

['alt.atheism', 'comp.graphics', 'comp.os.ms-windows.misc', 'comp.sys.ibm.pc.hardware', 'comp.sys.mac.hardware', 'comp.windows.x', 'misc.forsale', 'rec.autos', 'rec.motorcycles', 'rec.sport.baseball', 'rec.sport.hockey', 'sci.crypt', 'sci.electronics', 'sci.med', 'sci.space', 'soc.religion.christian', 'talk.politics.guns', 'talk.politics.mideast', 'talk.politics.misc', 'talk.religion.misc']


In [8]:
def preprocess(doc):
    doc = doc.lower()
    doc = re.sub(r"[^a-z\s]", "", doc)
    tokens = doc.split()
    return tokens

docs_tokens = [preprocess(doc) for doc in docs]

In [10]:
print(docs_tokens[0])

['i', 'am', 'sure', 'some', 'bashers', 'of', 'pens', 'fans', 'are', 'pretty', 'confused', 'about', 'the', 'lack', 'of', 'any', 'kind', 'of', 'posts', 'about', 'the', 'recent', 'pens', 'massacre', 'of', 'the', 'devils', 'actually', 'i', 'am', 'bit', 'puzzled', 'too', 'and', 'a', 'bit', 'relieved', 'however', 'i', 'am', 'going', 'to', 'put', 'an', 'end', 'to', 'nonpittsburghers', 'relief', 'with', 'a', 'bit', 'of', 'praise', 'for', 'the', 'pens', 'man', 'they', 'are', 'killing', 'those', 'devils', 'worse', 'than', 'i', 'thought', 'jagr', 'just', 'showed', 'you', 'why', 'he', 'is', 'much', 'better', 'than', 'his', 'regular', 'season', 'stats', 'he', 'is', 'also', 'a', 'lot', 'fo', 'fun', 'to', 'watch', 'in', 'the', 'playoffs', 'bowman', 'should', 'let', 'jagr', 'have', 'a', 'lot', 'of', 'fun', 'in', 'the', 'next', 'couple', 'of', 'games', 'since', 'the', 'pens', 'are', 'going', 'to', 'beat', 'the', 'pulp', 'out', 'of', 'jersey', 'anyway', 'i', 'was', 'very', 'disappointed', 'not', 'to', '

建立詞彙表

In [15]:
# 建立詞彙表
all_tokens = [token for doc in docs_tokens for token in doc]
vocab_counter = Counter(all_tokens)
vocab = list(vocab_counter.keys())
vocab_size = len(vocab)

# 建立 word2idx 和 idx2word 對應
word2idx = {word: i for i, word in enumerate(vocab)}
idx2word = {i: word for word, i in word2idx.items()}

print(f"詞彙表大小: {vocab_size}")

詞彙表大小: 121462


建構skip-gram訓練數據

In [16]:
import random

def generate_skipgram_pairs(docs_tokens, window_size=2, num_neg_samples=5):
    positive_pairs = []
    negative_pairs = []
    
    for tokens in docs_tokens:
        indices = [word2idx[word] for word in tokens if word in word2idx]
        
        for i, target in enumerate(indices):
            window_start = max(i - window_size, 0)
            window_end = min(i + window_size + 1, len(indices))
            
            # 取得 context 單詞
            context_words = indices[window_start:i] + indices[i+1:window_end]
            for ctx in context_words:
                positive_pairs.append((target, ctx))
                
                # 生成負樣本
                for _ in range(num_neg_samples):
                    neg_word = random.randint(0, vocab_size - 1)
                    negative_pairs.append((target, neg_word))
    
    return positive_pairs, negative_pairs

# 產生 skip-gram 訓練數據
positive_pairs, negative_pairs = generate_skipgram_pairs(docs_tokens, window_size=2, num_neg_samples=5)
print(f"正樣本數量: {len(positive_pairs)}, 負樣本數量: {len(negative_pairs)}")

正樣本數量: 12893026, 負樣本數量: 64465130


建立DataLoader

In [18]:
# 轉換成 PyTorch 張量
positive_pairs_tensor = torch.tensor(positive_pairs, dtype=torch.long)
negative_pairs_tensor = torch.tensor(negative_pairs, dtype=torch.long)

# 使用 TensorDataset 和 DataLoader
batch_size = 1024
# 正樣本 DataLoader
positive_dataset = TensorDataset(positive_pairs_tensor)
positive_data_loader = DataLoader(positive_dataset, batch_size=batch_size, shuffle=True)

# 負樣本 DataLoader
negative_dataset = TensorDataset(negative_pairs_tensor)
negative_data_loader = DataLoader(negative_dataset, batch_size=batch_size, shuffle=True)

定義 Word2Vec Skip-Gram模型

In [19]:
class SkipGramModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(SkipGramModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
    
    def forward(self, center_word, context_word):
        v_pvt = self.embedding(center_word)  # 主要單詞向量
        v_ctx = self.embedding(context_word)  # 背景單詞向量
        return torch.sum(v_pvt * v_ctx, dim=1)  # 內積計算相似度

定義損失函數

In [20]:
def negative_sampling_loss(pos_scores, neg_scores):
    """
    pos_scores: 正樣本的點積結果
    neg_scores: 負樣本的點積結果
    """
    pos_loss = -F.logsigmoid(pos_scores).mean()  # 第一項
    neg_loss = -F.logsigmoid(-neg_scores).mean()  # 第二項
    return pos_loss + neg_loss

訓練模型

In [21]:
# 設定超參數
embedding_dim = 100
learning_rate = 0.01
num_epochs = 5

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SkipGramModel(vocab_size, embedding_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 訓練迴圈
for epoch in range(num_epochs):
    total_loss = 0.0

    # 使用 zip() 同時迭代兩個 DataLoader
    for (pos_batch,), (neg_batch,) in zip(positive_data_loader, negative_data_loader):
        pos_batch, neg_batch = pos_batch.to(device), neg_batch.to(device)
        center_word_pos, context_word_pos = pos_batch[:, 0], pos_batch[:, 1]
        center_word_neg, context_word_neg = neg_batch[:, 0], neg_batch[:, 1]

        # 正樣本分數
        pos_scores = model(center_word_pos, context_word_pos)

        # 負樣本分數
        neg_scores = model(center_word_neg, context_word_neg)

        # 計算損失
        loss = negative_sampling_loss(pos_scores, neg_scores)

        # 反向傳播與優化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss:.4f}")

print("訓練完成！")

Epoch 1/5, Loss: 20232.0747
Epoch 2/5, Loss: 14660.2711
Epoch 3/5, Loss: 14233.7238
Epoch 4/5, Loss: 14030.1474
Epoch 5/5, Loss: 13931.2454
訓練完成！


測試嵌入向量效果

In [22]:
# 取得單詞的嵌入向量
word_embeddings = model.embedding.weight.data.cpu().numpy()

# 隨機選 10 個單詞，顯示其對應的嵌入向量
sample_words = random.sample(vocab, 10)
for word in sample_words:
    idx = word2idx[word]
    print(f"{word}: {word_embeddings[idx][:5]} ...")  # 顯示前 5 個數值

tibetian: [-1.4375777  -0.68293995  2.1688416   0.4099418   0.43794662] ...
psyhtjsmipscccnottinghamacuk: [-0.4609     -0.74916387 -0.06740743 -2.334864   -1.9351561 ] ...
separatist: [-1.0169411  -2.9384568  -0.52474445 -1.0964781  -0.19436616] ...
academics: [-1.014023    0.35530466  0.8775246   0.08616684  0.8812692 ] ...
mjglzprscqdgaqvpgtitcvcnnj: [-0.25967306  0.90539557  0.7625903  -1.0720445   0.6960263 ] ...
blank: [ 0.1894634  -0.39105284  0.04489758 -0.11747424  0.15941003] ...
repainted: [ 0.5745938  -1.3974309  -0.7704798  -0.08036944 -1.07997   ] ...
cbts: [-3.7467105   1.158229   -2.5395732   0.55880266  0.28207058] ...
futility: [ 2.3441868  1.1798515  3.284471  -1.591921   0.4799493] ...
qatar: [-0.819624  -1.1575304  3.0840986 -1.0137874  1.5982919] ...
