## 1. 数据预处理

In [19]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
from collections import Counter

# 固定随机种子，保证演示可复现
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)

# 一点简短的 toy corpus
corpus = [
    "i like to eat apples and bananas",
    "i like to watch movies and cartoons",
    "the cat likes to eat fish",
    "john loves to read books about python"
]

# 1) 分词
tokenized_sentences = [sent.lower().split() for sent in corpus]

# 2) 汇总到一个 list，得到所有单词序列
all_tokens = []
for tokens in tokenized_sentences:
    all_tokens.extend(tokens)

# 3) 构造词表
word_counter = Counter(all_tokens)
vocab = sorted(word_counter.keys())  
word2idx = {w: i for i, w in enumerate(vocab)}
idx2word = {i: w for w, i in word2idx.items()}
vocab_size = len(vocab)

print(f"Vocabulary size = {vocab_size}")
print("Vocab =", vocab)


Vocabulary size = 20
Vocab = ['about', 'and', 'apples', 'bananas', 'books', 'cartoons', 'cat', 'eat', 'fish', 'i', 'john', 'like', 'likes', 'loves', 'movies', 'python', 'read', 'the', 'to', 'watch']


## 2. 生成 (center, outside) Skip-gram 数据

In [20]:
def make_skipgram_data(tokenized_sentences, word2idx, window_size=2):
    """
    根据 tokenized_sentences，生成 (center, outside) 对儿列表。
    window_size 表示上下文窗口大小 (左右各 window_size 个).
    """
    pairs = []
    for tokens in tokenized_sentences:
        token_ids = [word2idx[w] for w in tokens]
        length = len(token_ids)
        for i, center_id in enumerate(token_ids):
            start = max(i - window_size, 0)
            end = min(i + window_size + 1, length)
            for j in range(start, end):
                if j != i:
                    outside_id = token_ids[j]
                    pairs.append((center_id, outside_id))
    return pairs

window_size = 2
skipgram_pairs = make_skipgram_data(tokenized_sentences, word2idx, window_size)
print(f"Total skip-gram pairs: {len(skipgram_pairs)}")
print("Example pairs (center_idx, outside_idx):", skipgram_pairs[:10])


Total skip-gram pairs: 84
Example pairs (center_idx, outside_idx): [(9, 11), (9, 18), (11, 9), (11, 18), (11, 7), (18, 9), (18, 11), (18, 7), (18, 2), (7, 11)]


其中, 这个 pair (center, outside) 是我们的训练数据, center 是我们的输入, outside 是我们的输出. 在训练的时候, 我们希望通过输入 center 来预测输出 outside.

## 3. 定义模型：SkipGramFullSoftmax


1. **输入 (center_word)**：从 `in_embed` 查到中心词向量 $\mathbf{v}_c$，形状是 $(B, d)$（如果一次喂 batch_size=B）。  
2. **输出 logits**：对全部 `out_embed.weight` 做一次矩阵乘法。  
   - `out_embed.weight` 的形状是 $(V, d)$  
   - $\mathbf{v}_c$ 的形状是 $(B, d)$  
   - 矩阵乘法后得到 $(B, V)$ 的 logits 矩阵。  
3. 对这 $(B, V)$ 矩阵做 **CrossEntropyLoss**，目标是 `outside_word_ids`（形状 $(B,)$），表示“正解在词表中哪个词的位置”。

In [21]:
class SkipGramFullSoftmax(nn.Module):
    def __init__(self, vocab_size, embed_dim):
        super().__init__()
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        
        # 中心词 embedding
        self.in_embed = nn.Embedding(vocab_size, embed_dim)
        # 上下文词 embedding
        self.out_embed = nn.Embedding(vocab_size, embed_dim)
        # 这里的embedding就类似于one-hot编码.
        
        # 可以自行初始化，也可以用默认的
        nn.init.uniform_(self.in_embed.weight, a=-0.5, b=0.5)
        nn.init.uniform_(self.out_embed.weight, a=-0.5, b=0.5)
        
    def forward(self, center_word_ids, outside_word_ids):
        """
        输入:
            center_word_ids: (batch_size,) 里面是中心词的索引
            outside_word_ids: (batch_size,) 里面是真实上下文词(正样本)的索引
        返回:
            loss: 这批样本的平均损失 (标量)
        """
        # 1) 查表得到中心词向量: 形状 (batch_size, embed_dim)
        center_embed = self.in_embed(center_word_ids)  # (B, d)
        
        # 2) 计算与 out_embed.weight 的内积，得到 logits: 形状 (B, V)
        #    center_embed: (B, d)
        #    out_embed.weight: (V, d)
        #    => logits = center_embed @ out_embed.weight.T => (B, V)
        logits = torch.matmul(center_embed, self.out_embed.weight.t())
        
        # 3) 用 CrossEntropyLoss 计算损失
        #    CrossEntropyLoss = log_softmax + NLL
        #    这里 outside_word_ids.shape=(B,), logits.shape=(B, V)
        loss_fn = nn.CrossEntropyLoss(reduction="mean")
        loss = loss_fn(logits, outside_word_ids)
        
        return loss


## 4. 训练优化

In [25]:
embed_dim = 8  # 词向量维度
model = SkipGramFullSoftmax(vocab_size, embed_dim)

optimizer = optim.SGD(model.parameters(), lr=0.01)

num_epochs = 500
pairs_list = skipgram_pairs[:]  # 拷贝一份
for epoch in range(num_epochs):
    random.shuffle(pairs_list)  # 每个 epoch 打乱
    
    total_loss = 0.0
    for (center_id, outside_id) in pairs_list:
        center_tensor = torch.LongTensor([center_id])
        outside_tensor = torch.LongTensor([outside_id])
        
        loss = model(center_tensor, outside_tensor)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
    avg_loss = total_loss / len(pairs_list)
    print(f"Epoch {epoch+1}/{num_epochs}, Avg Loss = {avg_loss:.4f}")


Epoch 1/500, Avg Loss = 3.0109
Epoch 2/500, Avg Loss = 2.9995
Epoch 3/500, Avg Loss = 2.9884
Epoch 4/500, Avg Loss = 2.9778
Epoch 5/500, Avg Loss = 2.9674
Epoch 6/500, Avg Loss = 2.9573
Epoch 7/500, Avg Loss = 2.9474
Epoch 8/500, Avg Loss = 2.9377
Epoch 9/500, Avg Loss = 2.9281
Epoch 10/500, Avg Loss = 2.9187
Epoch 11/500, Avg Loss = 2.9094
Epoch 12/500, Avg Loss = 2.9002
Epoch 13/500, Avg Loss = 2.8910
Epoch 14/500, Avg Loss = 2.8818
Epoch 15/500, Avg Loss = 2.8726
Epoch 16/500, Avg Loss = 2.8635
Epoch 17/500, Avg Loss = 2.8543
Epoch 18/500, Avg Loss = 2.8450
Epoch 19/500, Avg Loss = 2.8357
Epoch 20/500, Avg Loss = 2.8263
Epoch 21/500, Avg Loss = 2.8168
Epoch 22/500, Avg Loss = 2.8071
Epoch 23/500, Avg Loss = 2.7974
Epoch 24/500, Avg Loss = 2.7876
Epoch 25/500, Avg Loss = 2.7776
Epoch 26/500, Avg Loss = 2.7674
Epoch 27/500, Avg Loss = 2.7571
Epoch 28/500, Avg Loss = 2.7466
Epoch 29/500, Avg Loss = 2.7360
Epoch 30/500, Avg Loss = 2.7251
Epoch 31/500, Avg Loss = 2.7140
Epoch 32/500, Avg

In [26]:
def get_embedding(model, word):
    """获取某个word在 in_embed 中的向量。"""
    idx = word2idx[word]
    emb = model.in_embed.weight[idx].detach().numpy()
    return emb

def cosine_sim(a, b):
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-9)

def most_similar_words(model, query_word, top_k=3):
    """找与 query_word 最相似的 top_k 词。"""
    if query_word not in word2idx:
        return []
    query_emb = get_embedding(model, query_word)
    sims = []
    for w in vocab:
        if w == query_word:
            continue
        emb = get_embedding(model, w)
        sim_score = cosine_sim(query_emb, emb)
        sims.append((w, sim_score))
    sims.sort(key=lambda x: x[1], reverse=True)
    return sims[:top_k]

test_words = ["i", "eat", "movies", "python", "cat"]
for w in test_words:
    if w not in word2idx:
        print(f"Word '{w}' not in vocab, skip.")
        continue
    print(f"\n[Top similar words to '{w}']")
    for candidate, score in most_similar_words(model, w):
        print(f"   {candidate:<8} cos_sim = {score:.4f}")



[Top similar words to 'i']
   watch    cos_sim = 0.7528
   eat      cos_sim = 0.6275
   john     cos_sim = 0.5144

[Top similar words to 'eat']
   bananas  cos_sim = 0.6609
   i        cos_sim = 0.6275
   cat      cos_sim = 0.4998

[Top similar words to 'movies']
   like     cos_sim = 0.4661
   bananas  cos_sim = 0.3990
   apples   cos_sim = 0.3901

[Top similar words to 'python']
   read     cos_sim = 0.6595
   about    cos_sim = 0.5029
   books    cos_sim = 0.4482

[Top similar words to 'cat']
   eat      cos_sim = 0.4998
   likes    cos_sim = 0.4874
   the      cos_sim = 0.2974
