## Negative Sampling
### 빈도가 작은 단어를 좀 더 많이 sampling하기 위한 기법

출처: http://dalpo0814.tistory.com/6 [deeep]

In [30]:
import torch 
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim 
import torch.nn.functional as F

import nltk
import random
import numpy as np
from collections import Counter

# Appending the item to a flat list.
flatten = lambda l: [item for sublist in l for item in sublist]
random.seed(1024)

In [31]:
print(torch.__version__)
print(nltk.__version__)

0.4.1
3.2.3


In [32]:
USE_CUDA = torch.cuda.is_available()

FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor
ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor

In [50]:
def getBatch(batch_size, train_data):
    random.shuffle(train_data)
    sindex = 0
    eindex = batch_size
    while eindex < len(train_data):
        batch = train_data[sindex: eindex]
        temp = eindex
        eindex = eindex + batch_size
        sindex = temp
        yield batch
    
    if eindex >= len(train_data):
        batch = train_data[sindex:]
        yield batch

In [33]:
# 단어를 index값으로 변경해서 내뱉어줌 
def prepare_sequence(seq, word2index):
    idxs = list(map(lambda w: word2index[w] 
                    if word2index.get(w) is not None 
                    else word2index["<UNK>"], seq))
    return Variable(LongTensor(idxs))

def prepare_word(word, word2index):
    return Variable(LongTensor([word2index[word]]) 
                    if word2index.get(word) is not None 
                    else LongTensor([word2index["<UNK>"]]))

## Data load and Preprocessing

In [34]:
corpus = list(nltk.corpus.gutenberg.sents('melville-moby_dick.txt'))[:500]
corpus = [[word.lower() for word in sent] for sent in corpus]

### Exclude sparse words

In [35]:
word_count= Counter(flatten(corpus))

In [36]:
MIN_COUNT = 3
exclude =[]

In [37]:
for w,c in word_count.items():
    if c < MIN_COUNT:
        exclude.append(w)

### Prepare train data

In [38]:
vocab = list(set(flatten(corpus)) - set(exclude))

In [39]:
word2index = {}

# word2index - word : index 형태 
for vo in vocab:
    # get(vo): vo라는 key에 대응되는 value를 돌려줌. 
    if word2index.get(vo) is None:  
        word2index[vo] = len(word2index)

# index2word - index : word 형태로 변경
index2word = {v:k for k, v in word2index.items()}

In [51]:
WINDOW_SIZE = 5

# window의 길이를 지정 
windows = flatten([list(nltk.ngrams(['<DUMMY>'] * WINDOW_SIZE + c + ['<DUMMY>'] * WINDOW_SIZE, 
                                  WINDOW_SIZE * 2 + 1)) for c in corpus])
train_data = []
for window in windows:
    for i in range(WINDOW_SIZE * 2 + 1):
        if window[i] in exclude or window[WINDOW_SIZE] in exclude:
            continue  # for문의 처음으로 
        if i == WINDOW_SIZE or window[i] == '<DUMMY>':
            continue
        train_data.append((window[WINDOW_SIZE], window[i]))
        
X_p = []
y_p = []

for tr in train_data:
    X_p.append(prepare_word(tr[0], word2index).view(1, -1))
    y_p.append(prepare_word(tr[1], word2index).view(1, -1))
    
train_data = list(zip(X_p, y_p))

In [52]:
len(train_data)

50242

### Build Unigram Distribution**0.75

\begin{equation*}
P(w) = U(w)^{3/4} /Z
\end{equation*}

In [42]:
Z = 0.001 

In [43]:
word_count = Counter(flatten(corpus))
num_total_words = sum([c for w,c in word_count.items() 
                       if w not in exclude])

In [44]:
unigram_table = []

# Unigram_table: (단어가 나온 횟수/전체 단어 개수)^0.75 / Z
for vo in vocab:
    unigram_table.extend([vo] * int(((word_count[vo]/num_total_words)**0.75)/Z))
    

In [45]:
print(len(vocab), len(unigram_table))

478 3500


### Negative Sampling

In [46]:
def negative_sampling(targets, unigram_table, k):
    batch_size = targets.size(0)
    neg_samples = []
    for i in range(batch_size):
        nsample = []
        target_index = targets[i].data.cpu().tolist()[0] if USE_CUDA else targets[i].data.tolist()[0]
        
        while len(nsample) < k:  # num of sampling
            neg = random.choice(unigram_table)
            if word2index[neg] == target_index:  
                continue
            nsample.append(neg)
        neg_samples.append(prepare_sequence(nsample, word2index).view(1, -1))
        
    return torch.cat(neg_samples)

## Modeling 

![](https://nbviewer.jupyter.org/github/DSKSD/DeepNLP-models-Pytorch/blob/master/images/02.skipgram-objective.png)

borrowed image from http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture3.pdf

### Weight Initialization

###  * _Xavier initialization_
![](https://render.githubusercontent.com/render/math?math=Var%28W%29%20%3D%20%5Cfrac%7B2%7D%7Bn_%7Bin%7D%20%2B%20n_%7Bout%7D%7D&mode=display)

![](https://render.githubusercontent.com/render/math?math=High%28W%29%20%3D%20%5Csqrt%7B%5Cfrac%7B6%7D%7Bn_%7Bin%7D%20%2B%20n_%7Bout%7D%7D%7D%5C%2C%2C%5C%2C%20Low%28W%29%20%3D%20-%5Csqrt%7B%5Cfrac%7B6%7D%7Bn_%7Bin%7D%20%2B%20n_%7Bout%7D%7D%7D&mode=display)

In [47]:
class SkipgramNegSampling(nn.Module):
    
    def __init__(self, vocab_size, projection_dim):
        super(SkipgramNegSampling, self).__init__()
        # embedding matrix 생성 
        self.embedding_v = nn.Embedding(vocab_size, projection_dim)  # center word embedding
        self.embedding_u = nn.Embedding(vocab_size, projection_dim)  # out word embedding 
        self.logsigmoid = nn.LogSigmoid()
        
        # Xavier init
        initrange = (2.0 / (vocab_size + projection_dim))**0.5
        # init 
        self.embedding_v.weight.data.uniform_(-initrange, initrange)
        self.embedding_u.weight.data.uniform_(-0.0, 0.0)
        
    def forward(self, center_words, target_words, negative_words):
        center_embeds = self.embedding_v(center_words) 
        target_embeds = self.embedding_u(target_words)
        neg_embeds = -self.embedding_u(negative_words)
        
        # bmm = Batch matrix multiplication
        positive_score = target_embeds.bmm(center_embeds.transpose(1,2)).squeeze(2)
        negative_score = torch.sum(neg_embeds.bmm(center_embeds.transpose(1, 2)).squeeze(2), 1).view(negs.size(0), -1)
        
        loss = self.logsigmoid(positive_score) + self.logsigmoid(negative_score)
        # negative log likelihood
        return -torch.mean(loss) 
    
    def prediction(self, inputs):
        embeds = self.embedding_v(inputs)
        return embeds

## Train

In [53]:
EMBEDDING_SIZE = 30
BATCH_SIZE = 256
EPOCH = 100
NEG = 10  # num of negative sampling

losses = []
model = SkipgramNegSampling(len(word2index), EMBEDDING_SIZE)
optimizer = optim.Adam(model.parameters(), lr=0.001)

if USE_CUDA:
    model = model.cuda()

In [55]:
for epoch in range(EPOCH):
    for i,batch in enumerate(getBatch(BATCH_SIZE, train_data)):
        
        inputs, targets = zip(*batch)
        
        inputs = torch.cat(inputs) # B x 1
        targets = torch.cat(targets) # B x 1
        negs = negative_sampling(targets, unigram_table, NEG)
        model.zero_grad()

        loss = model(inputs, targets, negs)
        
        loss.backward()
        optimizer.step()
    
        losses.append(loss.data.item())
        
    if epoch % 10 == 0:
        print("Epoch : %d, mean_loss : %.02f" % (epoch, np.mean(losses)))
        losses = []

Epoch : 0, mean_loss : 1.06
Epoch : 10, mean_loss : 0.85
Epoch : 20, mean_loss : 0.79
Epoch : 30, mean_loss : 0.74
Epoch : 40, mean_loss : 0.71
Epoch : 50, mean_loss : 0.69
Epoch : 60, mean_loss : 0.67
Epoch : 70, mean_loss : 0.65
Epoch : 80, mean_loss : 0.64
Epoch : 90, mean_loss : 0.63


### Test

In [59]:
def word_similarity(target, vocab):
    target_V = model.prediction(prepare_word(target, word2index))
    similarities = []
    
    for i in range(len(vocab)):
        if vocab[i] == target:
            continue
            
        vector = model.prediction(prepare_word(list(vocab)[i], word2index))
        cosine_sim = F.cosine_similarity(target_V, vector).data.tolist()[0]
        similarities.append([vocab[i], cosine_sim])
    return sorted(similarities, key=lambda x: x[1], reverse=True)[:10]

In [62]:
test = random.choice(list(vocab))
test

'seas'

In [63]:
word_similarity(test, vocab)

[['those', 0.7124468684196472],
 ['because', 0.6647225618362427],
 ['island', 0.6219084858894348],
 ['strong', 0.6061983704566956],
 ['boats', 0.5999205708503723],
 ['red', 0.5541212558746338],
 ['order', 0.5343049764633179],
 ['seen', 0.5275437235832214],
 ['many', 0.5269707441329956],
 ['hand', 0.5078719258308411]]