# 加1 word2vec

每个句子中，枚举所有的二元组(word1, word2)，作为(target, context)，训练一个通过target预测context的模型，最终得到每个词的词嵌入表示。

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from collections import Counter
import numpy as np
from torch.utils.data import Dataset, DataLoader


class SkipGramDataset(Dataset):
    def __init__(self, sentences, window_size):
        self.sentences = sentences
        self.window_size = window_size
        self.word_to_idx, self.idx_to_word, self.word_count = self.build_vocab()
        self.vocab_size = len(self.word_to_idx)
        self.data = self.build_data()

    def build_vocab(self):
        """
        """
        word_count = Counter()
        for sentence in self.sentences:
            word_count.update(sentence)
        word_to_idx = {word: idx for idx, word in enumerate(word_count.keys())}
        idx_to_word = {idx: word for word, idx in word_to_idx.items()}
        return word_to_idx, idx_to_word, word_count

    def build_data(self):
        """
        对每个句子, 枚举每个词, 将其视为中心词, 其它作为上下文词, 这个二元组是一个data。依此类推构建整个数据集。
        """
        data = []
        for sentence in self.sentences:
            for i, target_word in enumerate(sentence):
                target_idx = self.word_to_idx[target_word]
                for j in range(max(0, i - self.window_size), min(len(sentence), i + self.window_size + 1)):
                    if i!= j:
                        context_word = sentence[j]
                        context_idx = self.word_to_idx[context_word]
                        data.append((target_idx, context_idx))
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


class SkipGramModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(SkipGramModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.linear = nn.Linear(embedding_dim, vocab_size)

    def forward(self, x):
        """
        :param x: 目标词的索引(而非one-hot向量!)
        :return: 输出是上下文词的预测概率
        """
        x = self.embedding(x)
        x = self.linear(x)
        return x


# 示例数据
sentences = [["the", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog"],
             ["I", "like", "to", "play", "football"]]
window_size = 2
embedding_dim = 100
batch_size = 64
num_epochs = 10
learning_rate = 0.001

dataset = SkipGramDataset(sentences, window_size)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

model = SkipGramModel(dataset.vocab_size, embedding_dim)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(num_epochs):
    for i, (targets, contexts) in enumerate(dataloader):
        optimizer.zero_grad()
        outputs = model(targets)
        loss = criterion(outputs, contexts)
        loss.backward()
        optimizer.step()
        print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(dataloader)}], Loss: {loss.item():.4f}')

# 获取词向量
word_embeddings = model.embedding.weight.data.numpy()
for word, idx in dataset.word_to_idx.items():
    print(f'Word: {word}, Embedding: {word_embeddings[idx]}')

Epoch [1/10], Step [1/1], Loss: 2.7484
Epoch [2/10], Step [1/1], Loss: 2.7079
Epoch [3/10], Step [1/1], Loss: 2.6681
Epoch [4/10], Step [1/1], Loss: 2.6289
Epoch [5/10], Step [1/1], Loss: 2.5905
Epoch [6/10], Step [1/1], Loss: 2.5529
Epoch [7/10], Step [1/1], Loss: 2.5159
Epoch [8/10], Step [1/1], Loss: 2.4797
Epoch [9/10], Step [1/1], Loss: 2.4442
Epoch [10/10], Step [1/1], Loss: 2.4095
Word: the, Embedding: [-1.3962811   0.40131408 -1.2093109  -0.57528406 -0.6966133  -1.5688841
 -0.6770535  -1.2728134   0.6957413   0.4542178  -0.448228   -0.21715783
  1.2702729   1.182503   -1.2620773  -1.3823662   1.1253415   1.4812647
 -1.094116    0.69203585 -0.7800522   0.8037099   0.26197556 -0.13500202
 -1.3093109  -0.17332064 -0.24429609 -0.7622649  -1.2912277   0.82824856
  0.52058303 -0.31685656  0.79374754  0.75090057 -0.17067195  1.2342457
  1.4350717  -0.62355816  1.8235614  -0.94761777  0.3556419   0.25830632
 -0.45455575  0.3597888  -0.8148145  -0.25150365  1.2074031   0.37156823
 -0.71