训练词向量的尝试（
-----
使用skip-gram模型，和论文中noice contrastive sampling的目标函数

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as tud

from collections import Counter
import numpy as np
import random
import math

import pandas as pd
import scipy
import sklearn
from sklearn.metrics.pairwise import cosine_similarity

USE_CUDA = torch.cuda.is_available()

# fix seed
random.seed(1)
np.random.seed(1)
torch.manual_seed(1)
if USE_CUDA:
    torch.cuda.manual_seed(1)
    
# hyperparameters
C = 3 # context window
K = 100 # number of negative samples
NUM_EPOCHS = 2
MAX_VOCAB_SIZE = 30000
BATCH_SIZE = 128
LEARNING_RATE = 0.2
EMBEDDING_SIZE = 100


- 读取文本文件中文字，创建辞典
- 用UNK表示所有不常见单词
- one-hot表示，word2index & index2word

In [2]:
with open("text8.train.txt", "r") as fin:
    text = fin.read()
    
text = text.split()
vocab = dict(Counter(text).most_common(MAX_VOCAB_SIZE - 1))
vocab["<unk>"] = len(text) - np.sum(list(vocab.values()))

idx_to_word = [word for word in vocab.keys()]
word_to_idx = {word:i for i, word in enumerate(idx_to_word)}

In [3]:
word_counts = np.array([count for count in vocab.values()], dtype=np.float32)
word_freqs = word_counts / np.sum(word_counts)
word_freqs = word_freqs ** (3./4.)
word_freqs = word_freqs / np.sum(word_freqs)
VOCAB_SIZE = len(idx_to_word)
VOCAB_SIZE

30000

dataloader
-------
用dataloader拿到batch-wise数据，打乱数据集等等
- 把之前text转换成对应index
- 每个iter sample一个中心词
- 返回该中心词的context window的词
- 返回该中心词的负采样词

dataloader组件：
- __len__数据集中有多少item
- __get__根据给定idx返回item

In [4]:
class WordEmbeddingDataset(tud.Dataset):
    def __init__(self, text, word_to_idx, idx_to_word, word_freqs):
        super(WordEmbeddingDataset, self).__init__()
        self.text_encoded = [word_to_idx.get(word, word_to_idx["<unk>"]) for word in text]
        self.text_encoded = torch.LongTensor(self.text_encoded)
        self.word_to_idx = word_to_idx
        self.idx_to_word = idx_to_word
        self.word_freqs = torch.Tensor(word_freqs)
        
    def __len__(self, ):
        # 共多少item
        return len(self.text_encoded)
        
    def __getitem__(self, idx):
        center_word = self.text_encoded[idx]
        pos_indices = list(range(idx-C, idx)) + list(range(idx+1, idx+C+1))
        pos_indices = [i % len(self.text_encoded) for i in pos_indices]
        pos_words = self.text_encoded[pos_indices]
        neg_words = torch.multinomial(self.word_freqs, K * pos_words.shape[0], True)
        
        return center_word, pos_words, neg_words

In [5]:
dataset = WordEmbeddingDataset(text, word_to_idx, idx_to_word, word_freqs)
dataloader = tud.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)



In [9]:
class EmbeddingModel(nn.Module):
    def __init__(self, vocab_size, embed_size):
        super(EmbeddingModel, self).__init__()
        
        self.vocab_size = vocab_size
        self.embed_size = embed_size
        
        initrange = 0.5 / self.embed_size
        self.in_embed = nn.Embedding(self.vocab_size, self.embed_size)
        self.in_embed.weight.data.uniform_(-initrange, initrange)
        self.out_embed = nn.Embedding(self.vocab_size, self.embed_size)
        self.out_embed.weight.data.uniform_(-initrange, initrange)
    
    def forward(self, input_labels, pos_labels, neg_labels):
        # input_labels: [batch_size]
        # pos_labels: [batch_size, (window_size*2)]
        # neg_labels: [batch_size, (window_size*2*K)]
        
        input_embedding = self.in_embed(input_labels) # [batch_size, embed_size]
        pos_embedding = self.out_embed(pos_labels) # [batch_size, (window_size*2), embed_size]
        neg_embedding = self.out_embed(neg_labels)
        
        input_embedding = input_embedding.unsqueeze(2) # [batch_size, embed_size, 1]
        pos_dot = torch.bmm(pos_embedding, input_embedding).squeeze(2) # [batch_size, (window_size*2)]
        neg_dot = torch.bmm(neg_embedding, -input_embedding).squeeze(2)
        
        log_pos = F.logsigmoid(pos_dot)
        log_neg = F.logsigmoid(neg_dot)
        
        loss = log_pos.sum(1) + log_neg.sum(1)
        
        return -loss

    def input_embeddings(self):
        return self.in_embed.weight.data.cpu().numpy()

In [10]:
model = EmbeddingModel(VOCAB_SIZE, EMBEDDING_SIZE)
if USE_CUDA:
    model = model.cuda()

In [11]:
optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)

for epoch in range(NUM_EPOCHS):
    for i, (input_labels, pos_labels, neg_labels) in enumerate(dataloader):
        input_labels = input_labels.long()
        pos_labels = pos_labels.long()
        neg_labels = neg_labels.long()
        
        if USE_CUDA:
            input_labels = input_labels.cuda()
            pos_labels = pos_labels.cuda()
            neg_labels = neg_labels.cuda()
            
            optimizer.zero_grad()
            loss = model(input_labels, pos_labels, neg_labels).mean()
            loss.backward()
            optimizer.step()
            
            if i % 100 == 0:
                print("epoch", epoch, "iteration", i, loss.item())

epoch 0 iteration 0 420.04730224609375
epoch 0 iteration 100 275.59674072265625
epoch 0 iteration 200 221.04335021972656
epoch 0 iteration 300 177.56753540039062
epoch 0 iteration 400 160.03225708007812
epoch 0 iteration 500 148.35157775878906
epoch 0 iteration 600 125.90690612792969
epoch 0 iteration 700 104.45404815673828
epoch 0 iteration 800 114.70223236083984
epoch 0 iteration 900 92.68739318847656
epoch 0 iteration 1000 95.55662536621094
epoch 0 iteration 1100 85.60282135009766
epoch 0 iteration 1200 79.15052795410156
epoch 0 iteration 1300 74.55094146728516
epoch 0 iteration 1400 90.07608795166016
epoch 0 iteration 1500 83.55937194824219
epoch 0 iteration 1600 72.25748443603516
epoch 0 iteration 1700 61.23986053466797
epoch 0 iteration 1800 63.942298889160156
epoch 0 iteration 1900 67.76444244384766
epoch 0 iteration 2000 76.70926666259766
epoch 0 iteration 2100 60.37493896484375
epoch 0 iteration 2200 57.94870376586914
epoch 0 iteration 2300 68.48345947265625
epoch 0 iteration 

epoch 0 iteration 19700 31.429759979248047
epoch 0 iteration 19800 31.950511932373047
epoch 0 iteration 19900 33.415313720703125
epoch 0 iteration 20000 31.508468627929688
epoch 0 iteration 20100 32.12083053588867
epoch 0 iteration 20200 33.324302673339844
epoch 0 iteration 20300 32.7859001159668
epoch 0 iteration 20400 32.64169692993164
epoch 0 iteration 20500 31.902429580688477
epoch 0 iteration 20600 31.655282974243164
epoch 0 iteration 20700 32.06268310546875
epoch 0 iteration 20800 32.45343780517578
epoch 0 iteration 20900 32.062889099121094
epoch 0 iteration 21000 32.72632598876953
epoch 0 iteration 21100 32.754676818847656
epoch 0 iteration 21200 32.35794448852539
epoch 0 iteration 21300 31.97311019897461
epoch 0 iteration 21400 31.646499633789062
epoch 0 iteration 21500 32.70805358886719
epoch 0 iteration 21600 31.517066955566406
epoch 0 iteration 21700 32.05213928222656
epoch 0 iteration 21800 32.231590270996094
epoch 0 iteration 21900 32.795021057128906
epoch 0 iteration 2200

epoch 0 iteration 39000 31.167442321777344
epoch 0 iteration 39100 31.624662399291992
epoch 0 iteration 39200 30.875795364379883
epoch 0 iteration 39300 31.25849151611328
epoch 0 iteration 39400 31.413291931152344
epoch 0 iteration 39500 31.875438690185547
epoch 0 iteration 39600 30.957775115966797
epoch 0 iteration 39700 30.921634674072266
epoch 0 iteration 39800 30.707500457763672
epoch 0 iteration 39900 31.445392608642578
epoch 0 iteration 40000 31.45154571533203
epoch 0 iteration 40100 31.226329803466797
epoch 0 iteration 40200 31.53829002380371
epoch 0 iteration 40300 31.153043746948242
epoch 0 iteration 40400 31.606853485107422
epoch 0 iteration 40500 31.278135299682617
epoch 0 iteration 40600 31.116233825683594
epoch 0 iteration 40700 31.54392433166504
epoch 0 iteration 40800 30.91817855834961
epoch 0 iteration 40900 31.233844757080078
epoch 0 iteration 41000 31.472246170043945
epoch 0 iteration 41100 31.23771095275879
epoch 0 iteration 41200 31.34368896484375
epoch 0 iteration 

epoch 0 iteration 58200 30.706558227539062
epoch 0 iteration 58300 31.13748550415039
epoch 0 iteration 58400 31.113882064819336
epoch 0 iteration 58500 30.739486694335938
epoch 0 iteration 58600 31.544513702392578
epoch 0 iteration 58700 31.581371307373047
epoch 0 iteration 58800 30.84398078918457
epoch 0 iteration 58900 31.27242660522461
epoch 0 iteration 59000 31.35492706298828
epoch 0 iteration 59100 30.84623146057129
epoch 0 iteration 59200 31.046512603759766
epoch 0 iteration 59300 30.528730392456055
epoch 0 iteration 59400 30.640151977539062
epoch 0 iteration 59500 31.05780029296875
epoch 0 iteration 59600 30.754901885986328
epoch 0 iteration 59700 31.187000274658203
epoch 0 iteration 59800 31.03858184814453
epoch 0 iteration 59900 30.965194702148438
epoch 0 iteration 60000 31.16712188720703
epoch 0 iteration 60100 31.34661102294922
epoch 0 iteration 60200 31.07785987854004
epoch 0 iteration 60300 31.223705291748047
epoch 0 iteration 60400 31.130245208740234
epoch 0 iteration 605

epoch 0 iteration 77500 30.781715393066406
epoch 0 iteration 77600 30.806690216064453
epoch 0 iteration 77700 30.592628479003906
epoch 0 iteration 77800 30.855663299560547
epoch 0 iteration 77900 30.90161895751953
epoch 0 iteration 78000 31.25902557373047
epoch 0 iteration 78100 30.31134605407715
epoch 0 iteration 78200 30.42319107055664
epoch 0 iteration 78300 30.322107315063477
epoch 0 iteration 78400 30.592681884765625
epoch 0 iteration 78500 31.467742919921875
epoch 0 iteration 78600 31.068496704101562
epoch 0 iteration 78700 30.674373626708984
epoch 0 iteration 78800 30.815509796142578
epoch 0 iteration 78900 30.679763793945312
epoch 0 iteration 79000 30.54802131652832
epoch 0 iteration 79100 31.039348602294922
epoch 0 iteration 79200 30.228870391845703
epoch 0 iteration 79300 30.901844024658203
epoch 0 iteration 79400 31.085819244384766
epoch 0 iteration 79500 30.969383239746094
epoch 0 iteration 79600 30.572956085205078
epoch 0 iteration 79700 30.885210037231445
epoch 0 iteratio

epoch 0 iteration 96800 30.806224822998047
epoch 0 iteration 96900 30.312118530273438
epoch 0 iteration 97000 30.5284366607666
epoch 0 iteration 97100 31.10184097290039
epoch 0 iteration 97200 30.65213394165039
epoch 0 iteration 97300 31.05136489868164
epoch 0 iteration 97400 30.781587600708008
epoch 0 iteration 97500 31.06682586669922
epoch 0 iteration 97600 30.30845069885254
epoch 0 iteration 97700 30.90546417236328
epoch 0 iteration 97800 30.45550537109375
epoch 0 iteration 97900 30.450990676879883
epoch 0 iteration 98000 30.513946533203125
epoch 0 iteration 98100 30.738269805908203
epoch 0 iteration 98200 30.61899185180664
epoch 0 iteration 98300 30.453258514404297
epoch 0 iteration 98400 30.665693283081055
epoch 0 iteration 98500 31.04178237915039
epoch 0 iteration 98600 30.640182495117188
epoch 0 iteration 98700 30.78666114807129
epoch 0 iteration 98800 30.541728973388672
epoch 0 iteration 98900 30.233760833740234
epoch 0 iteration 99000 30.829368591308594
epoch 0 iteration 99100

epoch 0 iteration 115700 30.725175857543945
epoch 0 iteration 115800 30.898651123046875
epoch 0 iteration 115900 30.19784927368164
epoch 0 iteration 116000 30.74707794189453
epoch 0 iteration 116100 30.566856384277344
epoch 0 iteration 116200 31.202136993408203
epoch 0 iteration 116300 31.210840225219727
epoch 0 iteration 116400 30.528684616088867
epoch 0 iteration 116500 30.58077049255371
epoch 0 iteration 116600 30.700977325439453
epoch 0 iteration 116700 30.027191162109375
epoch 0 iteration 116800 30.449188232421875
epoch 0 iteration 116900 30.654565811157227
epoch 0 iteration 117000 30.960594177246094
epoch 0 iteration 117100 30.54000473022461
epoch 0 iteration 117200 29.96695899963379
epoch 0 iteration 117300 30.61639404296875
epoch 0 iteration 117400 30.572998046875
epoch 0 iteration 117500 30.526063919067383
epoch 0 iteration 117600 30.73918342590332
epoch 0 iteration 117700 30.959842681884766
epoch 0 iteration 117800 31.137895584106445
epoch 0 iteration 117900 30.72686386108398

epoch 1 iteration 15500 30.691686630249023
epoch 1 iteration 15600 30.660526275634766
epoch 1 iteration 15700 30.516407012939453
epoch 1 iteration 15800 30.598569869995117
epoch 1 iteration 15900 30.459779739379883
epoch 1 iteration 16000 30.54584503173828
epoch 1 iteration 16100 30.620586395263672
epoch 1 iteration 16200 31.110332489013672
epoch 1 iteration 16300 30.346738815307617
epoch 1 iteration 16400 30.346363067626953
epoch 1 iteration 16500 30.441162109375
epoch 1 iteration 16600 30.77596664428711
epoch 1 iteration 16700 30.506441116333008
epoch 1 iteration 16800 30.362079620361328
epoch 1 iteration 16900 30.28512191772461
epoch 1 iteration 17000 30.721858978271484
epoch 1 iteration 17100 30.419780731201172
epoch 1 iteration 17200 29.95083999633789
epoch 1 iteration 17300 30.846927642822266
epoch 1 iteration 17400 30.3343563079834
epoch 1 iteration 17500 30.721464157104492
epoch 1 iteration 17600 30.98794174194336
epoch 1 iteration 17700 30.546293258666992
epoch 1 iteration 178

epoch 1 iteration 34800 30.273649215698242
epoch 1 iteration 34900 30.85088348388672
epoch 1 iteration 35000 30.548595428466797
epoch 1 iteration 35100 30.697668075561523
epoch 1 iteration 35200 29.96678924560547
epoch 1 iteration 35300 30.598512649536133
epoch 1 iteration 35400 30.41854476928711
epoch 1 iteration 35500 30.663429260253906
epoch 1 iteration 35600 30.718162536621094
epoch 1 iteration 35700 30.74629020690918
epoch 1 iteration 35800 30.94312286376953
epoch 1 iteration 35900 30.772201538085938
epoch 1 iteration 36000 30.57705307006836
epoch 1 iteration 36100 29.932506561279297
epoch 1 iteration 36200 30.25347900390625
epoch 1 iteration 36300 30.224102020263672
epoch 1 iteration 36400 30.32636070251465
epoch 1 iteration 36500 30.79202651977539
epoch 1 iteration 36600 30.83405876159668
epoch 1 iteration 36700 30.941354751586914
epoch 1 iteration 36800 30.731544494628906
epoch 1 iteration 36900 30.432418823242188
epoch 1 iteration 37000 30.082162857055664
epoch 1 iteration 371

epoch 1 iteration 54100 30.679401397705078
epoch 1 iteration 54200 29.977523803710938
epoch 1 iteration 54300 30.19614028930664
epoch 1 iteration 54400 30.79931640625
epoch 1 iteration 54500 30.525028228759766
epoch 1 iteration 54600 30.587718963623047
epoch 1 iteration 54700 30.139379501342773
epoch 1 iteration 54800 30.284883499145508
epoch 1 iteration 54900 30.578536987304688
epoch 1 iteration 55000 30.146982192993164
epoch 1 iteration 55100 30.95645523071289
epoch 1 iteration 55200 30.547496795654297
epoch 1 iteration 55300 30.520761489868164
epoch 1 iteration 55400 30.312183380126953
epoch 1 iteration 55500 30.89084243774414
epoch 1 iteration 55600 29.842571258544922
epoch 1 iteration 55700 31.106464385986328
epoch 1 iteration 55800 30.435546875
epoch 1 iteration 55900 30.039281845092773
epoch 1 iteration 56000 30.13681411743164
epoch 1 iteration 56100 30.70728874206543
epoch 1 iteration 56200 30.588886260986328
epoch 1 iteration 56300 30.336261749267578
epoch 1 iteration 56400 30

epoch 1 iteration 75700 30.676776885986328
epoch 1 iteration 75800 30.405399322509766
epoch 1 iteration 75900 30.471755981445312
epoch 1 iteration 76000 30.297996520996094
epoch 1 iteration 76100 30.530921936035156
epoch 1 iteration 76200 30.477127075195312
epoch 1 iteration 76300 30.28673553466797
epoch 1 iteration 76400 30.314481735229492
epoch 1 iteration 76500 30.695531845092773
epoch 1 iteration 76600 30.20577621459961
epoch 1 iteration 76700 30.06511688232422
epoch 1 iteration 76800 30.604860305786133
epoch 1 iteration 76900 30.53544807434082
epoch 1 iteration 77000 29.977458953857422
epoch 1 iteration 77100 30.225908279418945
epoch 1 iteration 77200 29.954730987548828
epoch 1 iteration 77300 30.36444854736328
epoch 1 iteration 77400 30.576398849487305
epoch 1 iteration 77500 30.207794189453125
epoch 1 iteration 77600 30.48060417175293
epoch 1 iteration 77700 30.225421905517578
epoch 1 iteration 77800 30.443702697753906
epoch 1 iteration 77900 30.80235481262207
epoch 1 iteration 

epoch 1 iteration 95000 30.479736328125
epoch 1 iteration 95100 30.245372772216797
epoch 1 iteration 95200 30.21898651123047
epoch 1 iteration 95300 30.707273483276367
epoch 1 iteration 95400 30.35696792602539
epoch 1 iteration 95500 29.892318725585938
epoch 1 iteration 95600 30.292282104492188
epoch 1 iteration 95700 30.73837661743164
epoch 1 iteration 95800 30.345216751098633
epoch 1 iteration 95900 30.231830596923828
epoch 1 iteration 96000 30.928693771362305
epoch 1 iteration 96100 30.184249877929688
epoch 1 iteration 96200 30.063932418823242
epoch 1 iteration 96300 30.209854125976562
epoch 1 iteration 96400 29.79553985595703
epoch 1 iteration 96500 30.398820877075195
epoch 1 iteration 96600 30.50404167175293
epoch 1 iteration 96700 30.49762725830078
epoch 1 iteration 96800 30.507919311523438
epoch 1 iteration 96900 30.805419921875
epoch 1 iteration 97000 29.88398551940918
epoch 1 iteration 97100 30.358407974243164
epoch 1 iteration 97200 30.439685821533203
epoch 1 iteration 97300 

epoch 1 iteration 114000 30.010465621948242
epoch 1 iteration 114100 30.24431610107422
epoch 1 iteration 114200 30.412555694580078
epoch 1 iteration 114300 30.221567153930664
epoch 1 iteration 114400 30.525554656982422
epoch 1 iteration 114500 30.471879959106445
epoch 1 iteration 114600 30.350749969482422
epoch 1 iteration 114700 30.045495986938477
epoch 1 iteration 114800 29.996543884277344
epoch 1 iteration 114900 30.635860443115234
epoch 1 iteration 115000 30.08211898803711
epoch 1 iteration 115100 30.46127700805664
epoch 1 iteration 115200 29.834745407104492
epoch 1 iteration 115300 30.24089813232422
epoch 1 iteration 115400 30.071285247802734
epoch 1 iteration 115500 30.577749252319336
epoch 1 iteration 115600 30.71854019165039
epoch 1 iteration 115700 30.53062629699707
epoch 1 iteration 115800 30.590423583984375
epoch 1 iteration 115900 29.96701431274414
epoch 1 iteration 116000 30.693218231201172
epoch 1 iteration 116100 30.141908645629883
epoch 1 iteration 116200 29.62376785278