In [1]:
import os
import sys
import time
import math
import random
import numpy as np
import collections#内置容器 例如数组等
import torch
from torch import nn
import torch.utils.data as Data

with open('HarryPotter.txt', 'r') as f:#打开文件然后关闭
    lines = f.readlines() # 该数据集中句子以换行符为分割
    raw_dataset = [st.split() for st in lines] # st是sentence的缩写，单词以空格为分割

counter = collections.Counter([tk for st in raw_dataset for tk in st]) # tk是token的缩写  统计整个数据集中每个词出现的次数
counter = dict(filter(lambda x: x[1] >= 5, counter.items())) # 只保留在数据集中至少出现5次的词  counter.items()：返回(词, 频次)的元组列表

idx_to_token = [tk for tk, _ in counter.items()]
token_to_idx = {tk: idx for idx, tk in enumerate(idx_to_token)}
dataset = [[token_to_idx[tk] for tk in st if tk in token_to_idx]
           for st in raw_dataset] # raw_dataset中的单词在这一步被转换为对应的idx
num_tokens = sum([len(st) for st in dataset])

#二次采样操作。越高频率的词一般意义不大，根据公式高频词越容易被过滤。准确来说，应该是降频操作。既不希望超高频被完全过滤，又希望减少高频词对训练的影响。
def discard(idx):
    return random.uniform(0, 1) < 1 - math.sqrt(
        1e-4 / counter[idx_to_token[idx]] * num_tokens)

subsampled_dataset = [[tk for tk in st if not discard(tk)] for st in dataset]#每一行句子进行一次 统计的也是每一行的token对应的idx数目

def get_centers_and_contexts(dataset, max_window_size):
    centers, contexts = [], []
    for st in dataset:#遍历每个句子
        if len(st) < 2:
            continue#句子太短则跳过
        centers += st
        for center_i in range(len(st)):
            window_size = random.randint(1, max_window_size)
            indices = list(range(max(0, center_i - window_size),
                                 min(len(st), center_i + 1 + window_size)))
            indices.remove(center_i)
            contexts.append([st[idx] for idx in indices])
    return centers, contexts
#意思是用来得到每个单词前后的单词 用来展示它们之间的关系  中心词和周围词
all_centers, all_contexts = get_centers_and_contexts(subsampled_dataset, 5)

#负采样近似加快程序运行时间
def get_negatives(all_contexts, sampling_weights, K):
    #all_contexts: 所有中心词的上下文列表
    # 例如：[[45, 23], [12, 23, 67], ...]
    # contexts[0] = 第一个中心词的上下文词索引列表
    #sampling_weights: 每个词的采样权重列表
    #长度 = 词汇表大小
    #权重 ∝ 词频^0.75（论文建议）
    #K: 每个上下文词对应的负样本数
    #   通常K=5
    all_negatives, neg_candidates, i = [], [], 0#存放所有负样本，预生成的候选负样本列表
    population = list(range(len(sampling_weights)))
    for contexts in all_contexts:
        negatives = []
        while len(negatives) < len(contexts) * K:
            if i == len(neg_candidates):
                i, neg_candidates = 0, random.choices(
                    population, sampling_weights, k=int(1e5))
            neg, i = neg_candidates[i], i + 1
            
            if neg not in set(contexts):
                negatives.append(neg)
        all_negatives.append(negatives)
    return all_negatives

sampling_weights = [counter[w]**0.75 for w in idx_to_token]
all_negatives = get_negatives(all_contexts, sampling_weights, 5)
#--------------------------------------------------------------------------------------------------------------------------------------------------
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, centers, contexts, negatives):
        assert len(centers) == len(contexts) == len(negatives)#检查三个长度是否相等  也就是对应生成  中心词生成上下文 然后生成负样本
        self.centers = centers
        self.contexts = contexts
        self.negatives = negatives
        
    def __getitem__(self, index):
        return (self.centers[index], self.contexts[index], self.negatives[index])
#返回单个样本的中心词 上下问和负样本
    def __len__(self):
        return len(self.centers)
 #返回数据集大小   
def batchify(data):
    max_len = max(len(c) + len(n) for _, c, n in data)
    centers, contexts_negatives, masks, labels = [], [], [], []
    for center, context, negative in data:
        cur_len = len(context) + len(negative)
        centers += [center]
        contexts_negatives += [context + negative + [0] * (max_len - cur_len)]# 合并上下文和负样本，填充到max_len
        masks += [[1] * cur_len + [0] * (max_len - cur_len)]#有效标记 其中负样本会被标记为1
        labels += [[1] * len(context) + [0] * (max_len - len(context))]#正确答案标记 其中只有上下文被标记为1
        batch = (torch.tensor(centers).view(-1, 1), torch.tensor(contexts_negatives),
            torch.tensor(masks), torch.tensor(labels))
    return batch

batch_size = 256
num_workers = 0 if sys.platform.startswith('win32') else -1

dataset = MyDataset(all_centers, all_contexts, all_negatives)
data_iter = Data.DataLoader(dataset, batch_size, shuffle=True,
                            collate_fn=batchify, #使用批处理函数
                            num_workers=num_workers)
for batch in data_iter:
    for name, data in zip(['centers', 'contexts_negatives', 'masks', 'labels'], batch):
        print(name, 'shape:', data.shape)
    break

centers shape: torch.Size([256, 1])
contexts_negatives shape: torch.Size([256, 60])
masks shape: torch.Size([256, 60])
labels shape: torch.Size([256, 60])


In [7]:
#采用交叉熵损失函数
class SigmoidBinaryCrossEntropyLoss(nn.Module):
    def __init__(self):
        super(SigmoidBinaryCrossEntropyLoss, self).__init__()
    def forward(self, inputs, targets, mask=None):
        inputs, targets, mask = inputs.float(), targets.float(), mask.float()
        res = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction="none", weight=mask)
        # 1. 计算带权重的二分类交叉熵损失
        res = res.sum(dim=1) / mask.float().sum(dim=1)
        # 2. 按样本平均（考虑有效部分）
        return res

loss = SigmoidBinaryCrossEntropyLoss()

def sigmd(x):
    return - math.log(1 / (1 + math.exp(-x)))

embed_size = 200# 词向量维度
net = nn.Sequential(nn.Embedding(num_embeddings=len(idx_to_token), embedding_dim=embed_size),# 中心词嵌入
                    nn.Embedding(num_embeddings=len(idx_to_token), embedding_dim=embed_size)) # 目标词嵌入

#skip_gram向前计算
def skip_gram(center, contexts_and_negatives, embed_v, embed_u):
    v = embed_v(center)
    u = embed_u(contexts_and_negatives)
    pred = torch.bmm(v, u.permute(0, 2, 1))# permute(0, 2, 1) = 交换维度1和维度2 bmm矩阵乘法
    return pred

def train(net, lr, num_epochs):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("train on", device)
    net = net.to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    for epoch in range(num_epochs):
        start, l_sum, n = time.time(), 0.0, 0
        for batch in data_iter:
            center, context_negative, mask, label = [d.to(device) for d in batch]
            pred = skip_gram(center, context_negative, net[0], net[1])
            l = loss(pred.view(label.shape), label, mask).mean()
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            l_sum += l.cpu().item()
            n += 1
        print('epoch %d, loss %.2f, time %.2fs'
              % (epoch + 1, l_sum / n, time.time() - start))

train(net, 0.02, 50)



train on cuda
epoch 1, loss 4.77, time 23.27s
epoch 2, loss 1.78, time 23.85s
epoch 3, loss 0.83, time 28.94s
epoch 4, loss 0.48, time 27.75s
epoch 5, loss 0.32, time 26.47s
epoch 6, loss 0.24, time 25.35s
epoch 7, loss 0.20, time 21.84s
epoch 8, loss 0.18, time 28.71s
epoch 9, loss 0.17, time 29.38s
epoch 10, loss 0.17, time 24.25s
epoch 11, loss 0.17, time 25.66s
epoch 12, loss 0.17, time 24.50s
epoch 13, loss 0.16, time 24.99s
epoch 14, loss 0.17, time 21.73s
epoch 15, loss 0.17, time 21.90s
epoch 16, loss 0.17, time 21.81s
epoch 17, loss 0.17, time 23.18s
epoch 18, loss 0.17, time 21.62s
epoch 19, loss 0.17, time 22.07s
epoch 20, loss 0.17, time 21.75s
epoch 21, loss 0.17, time 22.04s
epoch 22, loss 0.17, time 21.81s
epoch 23, loss 0.17, time 21.62s
epoch 24, loss 0.17, time 21.42s
epoch 25, loss 0.17, time 21.96s
epoch 26, loss 0.17, time 21.14s
epoch 27, loss 0.17, time 22.19s
epoch 28, loss 0.17, time 22.37s
epoch 29, loss 0.17, time 21.77s
epoch 30, loss 0.17, time 22.17s
epoch

In [9]:
#测试模型
def get_similar_tokens(query_token, k, embed):
    W = embed.weight.data
    x = W[token_to_idx[query_token]]

    cos = torch.matmul(W, x) / (torch.sum(W * W, dim=1) * torch.sum(x * x) + 1e-9).sqrt()
    _, topk = torch.topk(cos, k=k+1)
    topk = topk.cpu().numpy()
    for i in topk[1:]:
        print('余弦相似度 = %.3f: %s' % (cos[i], (idx_to_token[i])))
        
get_similar_tokens('Jordan', 5, net[0])

余弦相似度 = 0.284: and
余弦相似度 = 0.271: silver
余弦相似度 = 0.258: small
余弦相似度 = 0.256: History
余弦相似度 = 0.246: Goyle
