In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import torch.nn.functional as F
import nltk
import random
import numpy as np
from collections import Counter

In [2]:
flatten = lambda l:[item for sublist in l for item in sublist]
random.seed(1024)

In [3]:
USE_CUDA = torch.cuda.is_available()
gpus=[0]
torch.cuda.set_device(gpus[0])

FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor
ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor

In [4]:
def getBatch(batch_size, train_data):
    random.shuffle(train_data)
    sindex=0
    eindex = batch_size
    
    while eindex < len(train_data):
        batch = train_data[sindex:eindex]
        temp = eindex
        eindex = eindex + batch_size
        sindex = temp
        yield batch
    
    if eindex >= len(train_data):
        batch = train_data[sindex:]
        yield batch

### Data load and Preprocessing

In [7]:
corpus = list(nltk.corpus.gutenberg.sents('melville-moby_dick.txt'))[:500]
corpus = [[word.lower() for word in sent] for sent in corpus] 
# 소문자 변형 

### Exclude sparse words

In [47]:
word_count = Counter(flatten(corpus))
print(len(word_count))

2607


In [48]:
MIN_COUNT = 3
exclude = []

In [49]:
for w, c in word_count.items(): #word_count dict에서 key, value값
    if c < MIN_COUNT:     
        exclude.append(w)      #value값이 3개(MIN_COUNT) 미만인 word_count의 key값 리스트
print(exclude[:5])
print(len(exclude))

['[', 'moby', 'dick', 'herman', 'melville']
2129


### Prepare train data

In [50]:
vocab = list(set(flatten(corpus))-set(exclude))
# 단어 빈도 수가 3개 이상인 단어만 추출(3개 미만 제거)
print(vocab[:5])

['quantity', 'fixed', 'soul', 'gathered', 'world']


In [51]:
word2index = {}
for vo in vocab:    #vocab 단어들로 word2index 단어사전 만들기(word:index)
    if word2index.get(vo) is None:
        word2index[vo] = len(word2index)
        
index2word = {v:k for k,v in word2index.items()}  # (index:word)

In [52]:
WINDOW_SIZE = 5
windows = flatten([list(nltk.ngrams(['<DUMMY>']*WINDOW_SIZE + c + ['<DUMMY>']*WINDOW_SIZE,
                                   WINDOW_SIZE*2+1)) for c in corpus])

In [53]:
windows[0]

('<DUMMY>',
 '<DUMMY>',
 '<DUMMY>',
 '<DUMMY>',
 '<DUMMY>',
 '[',
 'moby',
 'dick',
 'by',
 'herman',
 'melville')

In [55]:
train_data = []

for window in windows:
    for i in range(WINDOW_SIZE*2+1): # 양쪽 window_size + 중심단어
        if window[i] in exclude or window[WINDOW_SIZE] in exclude:
            continue  #min_count
        if i == WINDOW_SIZE or window[i] == '<DUMMY>':
            continue #중심단어 이거나 dummy일 경우 
        train_data.append((window[WINDOW_SIZE], window[i])) #(중심단어, 주변단어)

In [57]:
print(train_data[:5])

[('(', 'supplied'), ('(', 'by'), ('(', 'a'), ('(', 'late'), ('supplied', '(')]


In [58]:
def prepare_sequence(seq, word2index):
    idxs = list(map(lambda w:word2index[w] 
                    if word2index.get(w) is not None else word2index['<UNK>'],seq))
    return Variable(LongTensor(idxs))

In [59]:
def prepare_word(word, word2index):
    return Variable(LongTensor([word2index[word]]) 
                   if word2index.get(word) is not None else LongTensor([word2index['<UNK>']]))

In [60]:
X_p=[]
y_p=[]

In [61]:
for tr in train_data:
    X_p.append(prepare_word(tr[0], word2index).view(1,-1))
    y_p.append(prepare_word(tr[1], word2index).view(1,-1))

In [68]:
print(train_data[0])
print(X_p[0], "'",train_data[0][0],"'")
print(y_p[0], "'",train_data[0][1],"'")

('(', 'supplied')
Variable containing:
 277
[torch.cuda.LongTensor of size 1x1 (GPU 0)]
 ' ( '
Variable containing:
 32
[torch.cuda.LongTensor of size 1x1 (GPU 0)]
 ' supplied '


In [72]:
train_data = list(zip(X_p, y_p))
print(len(train_data))

50242


### Build Unigram Distribution**0.75

In [73]:
Z = 0.001

In [82]:
word_count = Counter(flatten(corpus))
num_total_words = sum([c for w,c in word_count.items() if w not in exclude])
# 단어 빈도 수가 3개 미만이 아닌 단어들의 빈도 합
print(num_total_words)

7798


In [97]:
unigram_table = []

for vo in vocab:
    unigram_table.extend([vo] * int(((word_count[vo]/num_total_words)**0.75)/Z))
#(단어 빈도수/전체 단어개수)^0.75/0.001
print(len(vocab), len(unigram_table))

478 3500


### Negative Sampling

In [98]:
def negative_sampling(targets, unigram_table, k):
    batch_size = targets.size(0)
    neg_samples = []
    for i in range(batch_size):
        nsample = []
        target_index = targets[i].data.cpu().tolist()[0] if USE_CUDA else targets[i].data.tolist()[0]
        while len(nsample) < k: #num of sampling
            neg = random.choice(unigram_table)
            if word2index[neg] == target_index: #랜덤으로 뽑은 단어 타겟단어 같으면
                continue
            nsample.append(neg)
        neg_samples.append(prepare_sequence(nsample, word2index).view(1,-1))
        
    return torch.cat(neg_samples)

### Modeling

In [99]:
class SkipgramNegSampling(nn.Module):
    
    def __init__(self, vocab_size, projection_dim):
        super(SkipgramNegSampling, self).__init__()
        self.embedding_v = nn.Embedding(vocab_size, projection_dim) # center embedding
        self.embedding_u = nn.Embedding(vocab_size, projection_dim) # out embedding
        self.logsigmoid = nn.LogSigmoid()
        
        initrange = (2.0 / (vocab_size + projection_dim))**0.5 #Xaver init
        self.embedding_v.weight.data.uniform_(-initrange, initrange) #init
        self.embedding_u.weight.data.uniform_(-0.0,0.0) #init
        
    def forward(self, center_words, target_words, negative_words):
        center_embeds = self.embedding_v(center_words) #B x 1 x D
        target_embeds = self. embedding_u(target_words) #B x 1 x D
        neg_embeds = -self.embedding_u(negative_words) #B x K x D
        
        positive_score = target_embeds.bmm(center_embeds.transpose(1,2)).squeeze(2)
        negative_score = torch.sum(neg_embeds.bmm(center_embeds.transpose(1,2)).
                                   squeeze(2),1).view(negs.size(0),-1) # B x K -> B x 1
        
        loss = self.logsigmoid(positive_score) + self.logsigmoid(negative_score)
        
        return -torch.mean(loss)
    
    def prediction(self, inputs):
        embeds = self.embedding_v(inputs)
        
        return embeds

### Train

In [100]:
EMBEDDING_SIZE = 30
BATCH_SIZE = 256
EPOCH = 100
NEG = 10 #num of negative sampling

In [91]:
losses = []
model = SkipgramNegSampling(len(word2index), EMBEDDING_SIZE)
if USE_CUDA:
    model = model.cuda()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [93]:
for epoch in range(EPOCH):
    for i, batch in enumerate(getBatch(BATCH_SIZE, train_data)):
        
        inputs, targets = zip(*batch)
        
        inputs = torch.cat(inputs) # B x 1
        targets = torch.cat(targets) # B X 1
        negs = negative_sampling(targets, unigram_table, NEG)
        model.zero_grad()
        
        loss = model(inputs, targets, negs)
        
        loss.backward()
        optimizer.step()
        
        losses.append(loss.data.tolist()[0])
    if epoch % 10 == 0:
        print('Epoch: %d, mean_loss: %0.2f' %(epoch, np.mean(losses)))
        losses=[]

Epoch: 0, mean_loss: 1.06
Epoch: 10, mean_loss: 0.86
Epoch: 20, mean_loss: 0.79
Epoch: 30, mean_loss: 0.74
Epoch: 40, mean_loss: 0.71
Epoch: 50, mean_loss: 0.69
Epoch: 60, mean_loss: 0.67
Epoch: 70, mean_loss: 0.65
Epoch: 80, mean_loss: 0.64
Epoch: 90, mean_loss: 0.63


### Test

In [94]:
def word_similarity(target, vocab):
    target_V = model.prediction(prepare_word(target, word2index))
    similarities = []
    for i in range(len(vocab)):
        if vocab[i] == target:
            continue
            
        vector = model.prediction(prepare_word(list(vocab)[i], word2index))
        cosine_sim = F.cosine_similarity(target_V, vector).data.tolist()[0]
        similarities.append([vocab[i], cosine_sim])
    return sorted(similarities, key=lambda x:x[1], reverse=True)[:10] #상위 10개 

In [95]:
test = random.choice(list(vocab))
test

'may'

In [96]:
word_similarity(test, vocab)

[['seen', 0.6772878170013428],
 ['however', 0.6718274354934692],
 ['order', 0.6715513467788696],
 ['must', 0.5928581357002258],
 ['passengers', 0.5621551871299744],
 ['passage', 0.5597257018089294],
 ['rather', 0.5570704340934753],
 ['themselves', 0.5388443470001221],
 ['never', 0.5291361808776855],
 ['nigh', 0.5260406732559204]]