In [1]:
import collections
import math
import random
import sys
import time
import os
import numpy as np
import torch
from torch import nn
import torch.utils.data as Data

sys.path.append('../code/')
import d2lzh_pytorch as d2l
devic=torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
assert 'ptb.train.txt' in os.listdir('../data/ptb/')

with open('../data/ptb/ptb.train.txt','r') as f:
    lines=f.readlines()
    raw_dataset=[st.split() for st in lines]
    
'# sentenses: %d' % len(raw_dataset)

'# sentenses: 42068'

In [3]:
for st in raw_dataset[:3]:
    print('# tokens:',len(st),st[:5])

# tokens: 24 ['aer', 'banknote', 'berlitz', 'calloway', 'centrust']
# tokens: 15 ['pierre', '<unk>', 'N', 'years', 'old']
# tokens: 11 ['mr.', '<unk>', 'is', 'chairman', 'of']


In [4]:
counter=collections.Counter([tk for st in raw_dataset for tk in st])
counter=dict(filter(lambda x:x[1]>=5,counter.items()))
len(counter)

9858

In [9]:
def discard(idx):
    return random.uniform(0, 1) < 1 - math.sqrt(
        1e-4 / counter[idx_to_token[idx]] * num_tokens)

subsampled_dataset = [[tk for tk in st if not discard(tk)] for st in dataset]
'# tokens: %d' % sum([len(st) for st in subsampled_dataset]) # '# tokens: 375875'


'# tokens: 376061'

In [10]:
idx_to_token = [tk for tk,_ in counter.items()]
token_to_idx={tk:idx for idx,tk in enumerate(idx_to_token)}
dataset=[[token_to_idx[tk] for tk in st if tk in token_to_idx] for st in raw_dataset]
num_tokens = sum([len(st) for st in dataset])
num_tokens

887100

In [11]:
def compare_counts(token):
    return '# %s:before=%d,after=%d' %(token,sum([st.count(token_to_idx[token]) for st in dataset]),
            sum([st.count(token_to_idx[token]) for st in subsampled_dataset]))

compare_counts('the')

'# the:before=50770,after=2098'

In [12]:
compare_counts('join')

'# join:before=45,after=45'

In [13]:
def get_centers_and_contexts(dataset,max_window_size):
    centers,contexts=[],[]
    for st in dataset:
        if len(st)<2:
            continue
        centers+=st
        for center_i in range(len(st)):
            window_size=random.randint(1,max_window_size)
            indices=list(range(max(0,center_i-window_size),min(len(st),center_i+1+window_size)))
            indices.remove(center_i)
            contexts.append([st[idx] for idx in indices])
    return centers,contexts

In [14]:
tiny_dataset = [list(range(7)), list(range(7, 10))]
print('dataset',tiny_dataset)
for center,context in zip(*get_centers_and_contexts(tiny_dataset,2)):
    print('center',center,'has contexts',context)

dataset [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]]
center 0 has contexts [1]
center 1 has contexts [0, 2, 3]
center 2 has contexts [0, 1, 3, 4]
center 3 has contexts [1, 2, 4, 5]
center 4 has contexts [3, 5]
center 5 has contexts [3, 4, 6]
center 6 has contexts [5]
center 7 has contexts [8]
center 8 has contexts [7, 9]
center 9 has contexts [8]


In [15]:
all_centers, all_contexts = get_centers_and_contexts(subsampled_dataset, 5)

In [16]:
def get_negatives(all_contexts,sampling_weights,K):
    all_negatives,neg_candidates,i=[],[],0
    population=list(range(len(sampling_weights)))
    for contexts in all_contexts:
        negatives=[]
        while len(negatives)<len(contexts)*K:
            if i==len(neg_candidates):
                i,neg_candidates=0,random.choices(population,sampling_weights,k=int(1e5))
            neg,i=neg_candidates[i],i+1
            
            if neg not in set(contexts):
                negatives.append(neg)
        all_negatives.append(negatives)
    return all_negatives

In [17]:
sampling_weights=[counter[w]**0.75 for w in idx_to_token]
all_negatives=get_negatives(all_contexts,sampling_weights,5)

In [18]:
print(all_centers[:2])
print(all_contexts[:2])
print(all_negatives[:2])

[0, 3]
[[3, 5], [0, 5, 6, 8, 10]]
[[53, 4269, 3327, 8436, 6194, 3032, 2091, 9239, 5615, 1607], [482, 1360, 2913, 815, 119, 1742, 3237, 7637, 2726, 7, 1416, 78, 436, 530, 103, 1136, 3154, 3271, 314, 1720, 4752, 84, 828, 2778, 793]]


In [19]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self,centers,contexts,negatives):
        assert len(centers)==len(contexts)==len(negatives)
        self.centers=centers
        self.contexts=contexts
        self.negatives=negatives
        
    def __getitem__(self,index):
        return (self.centers[index],self.contexts[index],self.negatives[index])
    
    def __len__(self):
        return len(self.centers)

In [21]:
def batchify(data):
    max_len=max(len(c)+len(n) for _,c,n in data)
    centers,contexts_negatives,masks,labels=[],[],[],[]
    for center,context,negative in data:
        cur_len=len(context)+len(negative)
        centers+=[center]
        contexts_negatives+=[context+negative+[0]* (max_len - cur_len)]
        masks += [[1] * cur_len + [0] * (max_len - cur_len)]
        labels += [[1] * len(context) + [0] * (max_len - len(context))]
    return (torch.tensor(centers).view(-1, 1), torch.tensor(contexts_negatives),
            torch.tensor(masks), torch.tensor(labels))

In [22]:
batch_size = 512
num_workers = 0 if sys.platform.startswith('win32') else 4

dataset = MyDataset(all_centers, 
                    all_contexts, 
                    all_negatives)
data_iter = Data.DataLoader(dataset, batch_size, shuffle=True,
                            collate_fn=batchify, 
                            num_workers=num_workers)
for batch in data_iter:
    for name, data in zip(['centers', 'contexts_negatives', 'masks',
                           'labels'], batch):
        print(name, 'shape:', data.shape)
    break


centers shape: torch.Size([512, 1])
contexts_negatives shape: torch.Size([512, 60])
masks shape: torch.Size([512, 60])
labels shape: torch.Size([512, 60])


In [23]:
embed=nn.Embedding(num_embeddings=20,embedding_dim=4)
embed.weight

Parameter containing:
tensor([[-1.4923, -0.9218, -0.3741, -0.7244],
        [-1.0465,  0.8376,  1.7640, -0.0449],
        [-0.9930, -0.9658,  0.4450, -0.9252],
        [-1.0616, -0.5010, -0.2011, -0.9076],
        [ 1.2408, -0.4449, -0.2709, -1.9615],
        [ 1.9716, -1.2924,  0.9997,  0.6505],
        [-0.7347,  1.4951, -1.1393,  0.7423],
        [-1.5388,  0.8921,  0.6642, -0.5877],
        [-1.2955, -0.2329,  1.9294, -0.7238],
        [ 0.0502, -0.3816, -2.1170, -1.0962],
        [-0.9966, -0.2542, -0.1587, -0.6545],
        [-1.4093,  0.2142,  0.6566,  1.3255],
        [-0.4407, -0.7742,  0.9616,  0.2935],
        [-0.5589,  1.6793, -1.3027, -1.3612],
        [-1.2708,  0.6169, -0.6204, -1.2055],
        [-1.7165,  0.9141,  0.4607,  0.2788],
        [ 1.8599, -0.7802,  1.0194,  0.4894],
        [ 0.0054,  1.0828, -0.8563,  0.5429],
        [-0.3988, -0.3144,  0.4465,  0.5345],
        [ 0.0470, -0.4758, -0.2673,  1.0637]], requires_grad=True)

In [24]:
x=torch.tensor([[1,2,3],[4,5,6]],dtype=torch.long)
embed(x).shape

torch.Size([2, 3, 4])

In [25]:
X=torch.ones((2,1,4))
Y=torch.ones((2,4,6))
torch.bmm(X,Y).shape

torch.Size([2, 1, 6])

In [33]:
def skip_gram(center,context_and_negatives,embed_v,embed_u):
    v=embed_v(center)
    u=embed_u(context_and_negatives)
    pred=torch.bmm(v,u.permute(0,2,1))
    return pred

In [34]:
class SigmoidBinaryCrossEntropyLoss(nn.Module):
    def __init__(self):
        super(SigmoidBinaryCrossEntropyLoss,self).__init__()
    def forward(self,inputs,targets,mask=None):
        inputs, targets, mask = inputs.float(), targets.float(), mask.float()
        res = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction="none", weight=mask)
        return res.mean(dim=1)
    
loss=SigmoidBinaryCrossEntropyLoss()

In [35]:
pred=torch.tensor([[1.5,0.3,-1,2],[1.1,-0.6,2.2,0.4]])
label=torch.tensor([[1,0,0,0],[1,1,0,0]])
mask=torch.tensor([[1,1,1,1],[1,1,1,0]])
loss(pred,label,mask)*mask.shape[1]/mask.float().sum(dim=1)

tensor([0.8740, 1.2100])

In [36]:
def sigmd(x):
    return -math.log(1/(1+math.exp(-x)))
print('%.4f' % ((sigmd(1.5) + sigmd(-0.3) + sigmd(1) + sigmd(-2)) / 4)) # 注意1-sigmoid(x) = sigmoid(-x)
print('%.4f' % ((sigmd(1.1) + sigmd(-0.6) + sigmd(-2.2)) / 3))

0.8740
1.2100


In [37]:
embed_size = 100
net = nn.Sequential(
    nn.Embedding(num_embeddings=len(idx_to_token), embedding_dim=embed_size),
    nn.Embedding(num_embeddings=len(idx_to_token), embedding_dim=embed_size)
)


In [38]:
def train(net, lr, num_epochs):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("train on", device)
    net = net.to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    for epoch in range(num_epochs):
        start, l_sum, n = time.time(), 0.0, 0
        for batch in data_iter:
            center, context_negative, mask, label = [d.to(device) for d in batch]
            
            pred = skip_gram(center, context_negative, net[0], net[1])

            # 使用掩码变量mask来避免填充项对损失函数计算的影响
            l = (loss(pred.view(label.shape), label, mask) *
                 mask.shape[1] / mask.float().sum(dim=1)).mean() # 一个batch的平均loss
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            l_sum += l.cpu().item()
            n += 1
        print('epoch %d, loss %.2f, time %.2fs'
              % (epoch + 1, l_sum / n, time.time() - start))


In [51]:
train(net, 0.001, 10)

train on cuda
epoch 1, loss 0.25, time 7.99s
epoch 2, loss 0.24, time 8.00s
epoch 3, loss 0.24, time 8.00s
epoch 4, loss 0.24, time 7.99s
epoch 5, loss 0.23, time 7.99s
epoch 6, loss 0.23, time 7.98s
epoch 7, loss 0.23, time 7.99s
epoch 8, loss 0.23, time 8.01s
epoch 9, loss 0.23, time 7.98s
epoch 10, loss 0.23, time 7.99s


In [53]:
def get_similar_tokens(query_token, k, embed):
    W = embed.weight.data
    print(W.shape)
    x = W[token_to_idx[query_token]]
    print(x.shape)
    # 添加的1e-9是为了数值稳定性
    cos = torch.matmul(W, x) / (torch.sum(W * W, dim=1) * torch.sum(x * x) + 1e-9).sqrt()
    _, topk = torch.topk(cos, k=k+1)
    topk = topk.cpu().numpy()
    for i in topk[1:]:  # 除去输入词
        print('cosine sim=%.3f: %s' % (cos[i], (idx_to_token[i])))

get_similar_tokens('chip', 3, net[0])


torch.Size([9858, 100])
torch.Size([100])
cosine sim=0.481: intel
cosine sim=0.451: nec
cosine sim=0.428: oliver
