네거티브 샘플링을 이용한 Word2Vec 구현(Skip-Gram with Negative Sampling, SGNS) [Reference](https://wikidocs.net/69141)

In [1]:
import pandas as pd
import numpy as np
import nltk
from nltk.corpus import stopwords
from sklearn.datasets import fetch_20newsgroups
from tensorflow.keras.preprocessing.text import Tokenizer

In [2]:
dataset = fetch_20newsgroups(shuffle=True, random_state=1, remove=('headers', 'footers', 'quotes'))
documents = dataset.data
print('총 샘플 수 :',len(documents))

총 샘플 수 : 11314


In [3]:
news_df = pd.DataFrame({'document':documents})
# 특수 문자 제거
news_df['clean_doc'] = news_df['document'].str.replace("[^a-zA-Z]", " ")
# 길이가 3이하인 단어는 제거 (길이가 짧은 단어 제거)
news_df['clean_doc'] = news_df['clean_doc'].apply(lambda x: ' '.join([w for w in x.split() if len(w)>3]))
# 전체 단어에 대한 소문자 변환
news_df['clean_doc'] = news_df['clean_doc'].apply(lambda x: x.lower())

  news_df['clean_doc'] = news_df['document'].str.replace("[^a-zA-Z]", " ")


In [4]:
news_df.isnull().values.any()

False

In [5]:
news_df.replace("", float("NaN"), inplace=True)
news_df.isnull().values.any()

True

In [6]:
news_df.dropna(inplace=True)
print('총 샘플 수 :',len(news_df))

총 샘플 수 : 10995


In [7]:
# 불용어를 제거
stop_words = stopwords.words('english')
tokenized_doc = news_df['clean_doc'].apply(lambda x: x.split())
tokenized_doc = tokenized_doc.apply(lambda x: [item for item in x if item not in stop_words])
tokenized_doc = tokenized_doc.to_list()

In [8]:
# 단어가 1개 이하인 샘플의 인덱스를 찾아서 저장하고, 해당 샘플들은 제거.
drop_train = [index for index, sentence in enumerate(tokenized_doc) if len(sentence) <= 1]
tokenized_doc = np.delete(tokenized_doc, drop_train, axis=0)
print('총 샘플 수 :',len(tokenized_doc))

총 샘플 수 : 10940


  arr = asarray(arr)


In [9]:
tokenizer = Tokenizer()
tokenizer.fit_on_texts(tokenized_doc)

word2idx = tokenizer.word_index
idx2word = {value : key for key, value in word2idx.items()}
encoded = tokenizer.texts_to_sequences(tokenized_doc)

In [10]:
print(encoded[:2])

[[9, 59, 603, 207, 3278, 1495, 474, 702, 9470, 13686, 5533, 15227, 702, 442, 702, 70, 1148, 1095, 1036, 20294, 984, 705, 4294, 702, 217, 207, 1979, 15228, 13686, 4865, 4520, 87, 1530, 6, 52, 149, 581, 661, 4406, 4988, 4866, 1920, 755, 10668, 1102, 7837, 442, 957, 10669, 634, 51, 228, 2669, 4989, 178, 66, 222, 4521, 6066, 68, 4295], [1026, 532, 2, 60, 98, 582, 107, 800, 23, 79, 4522, 333, 7838, 864, 421, 3825, 458, 6488, 458, 2700, 4730, 333, 23, 9, 4731, 7262, 186, 310, 146, 170, 642, 1260, 107, 33568, 13, 985, 33569, 33570, 9471, 11491]]


In [11]:
vocab_size = len(word2idx) + 1 
print('단어 집합의 크기 :', vocab_size)

단어 집합의 크기 : 64277


In [12]:
from tensorflow.keras.preprocessing.sequence import skipgrams
# 네거티브 샘플링
skip_grams = [skipgrams(sample, vocabulary_size=vocab_size, window_size=10) for sample in encoded[:10]]

In [13]:
# 첫번째 샘플인 skip_grams[0] 내 skipgrams로 형성된 데이터셋 확인
pairs, labels = skip_grams[0][0], skip_grams[0][1]
for i in range(5):
    print("({:s} ({:d}), {:s} ({:d})) -> {:d}".format(
          idx2word[pairs[i][0]], pairs[i][0], 
          idx2word[pairs[i][1]], pairs[i][1], 
          labels[i]))

(media (702), ignore (1979)) -> 1
(ruin (9470), ejected (10778)) -> 0
(clearly (661), emailing (14645)) -> 0
(media (702), israels (13686)) -> 1
(israeli (442), biased (3278)) -> 1


In [14]:
print('전체 샘플 수 :',len(skip_grams))

전체 샘플 수 : 10


In [15]:
# 첫번째 뉴스그룹 샘플에 대해서 생긴 pairs와 labels의 개수
print(len(pairs))
print(len(labels))

2220
2220


In [16]:
skip_grams = [skipgrams(sample, vocabulary_size=vocab_size, window_size=10) for sample in encoded]

In [17]:
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Embedding, Reshape, Activation, Input
from tensorflow.keras.layers import Dot
from tensorflow.keras.utils import plot_model
from IPython.display import SVG

In [18]:
embedding_dim = 100

# 중심 단어를 위한 임베딩 테이블
w_inputs = Input(shape=(1, ), dtype='int32')
word_embedding = Embedding(vocab_size, embedding_dim)(w_inputs)

# 주변 단어를 위한 임베딩 테이블
c_inputs = Input(shape=(1, ), dtype='int32')
context_embedding  = Embedding(vocab_size, embedding_dim)(c_inputs)

In [19]:
dot_product = Dot(axes=2)([word_embedding, context_embedding])
dot_product = Reshape((1,), input_shape=(1, 1))(dot_product)
output = Activation('sigmoid')(dot_product)

model = Model(inputs=[w_inputs, c_inputs], outputs=output)
model.summary()
model.compile(loss='binary_crossentropy', optimizer='adam')
plot_model(model, to_file='model3.png', show_shapes=True, show_layer_names=True, rankdir='TB')

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 1)]          0           []                               
                                                                                                  
 input_2 (InputLayer)           [(None, 1)]          0           []                               
                                                                                                  
 embedding (Embedding)          (None, 1, 100)       6427700     ['input_1[0][0]']                
                                                                                                  
 embedding_1 (Embedding)        (None, 1, 100)       6427700     ['input_2[0][0]']                
                                                                                              

In [20]:
for epoch in range(1, 6):
    loss = 0
    for _, elem in enumerate(skip_grams):
        first_elem = np.array(list(zip(*elem[0]))[0], dtype='int32')
        second_elem = np.array(list(zip(*elem[0]))[1], dtype='int32')
        labels = np.array(elem[1], dtype='int32')
        X = [first_elem, second_elem]
        Y = labels
        loss += model.train_on_batch(X,Y)  
    print('Epoch :',epoch, 'Loss :',loss)

Epoch : 1 Loss : 4623.914303861558
Epoch : 2 Loss : 3660.281063683331
Epoch : 3 Loss : 3499.7574602030218
Epoch : 4 Loss : 3299.0960647948086
Epoch : 5 Loss : 3076.197873252444


In [21]:
import gensim

embed_size = 100
f = open('vectors.txt' ,'w')
f.write('{} {}\n'.format(vocab_size-1, embed_size))
vectors = model.get_weights()[0]
for word, i in tokenizer.word_index.items():
    f.write('{} {}\n'.format(word, ' '.join(map(str, list(vectors[i, :])))))
f.close()

# 모델 로드
w2v = gensim.models.KeyedVectors.load_word2vec_format('./vectors.txt', binary=False) 

In [22]:
w2v.most_similar(positive=['soldiers'])

[('wounded', 0.8784220814704895),
 ('slaughtered', 0.8647086024284363),
 ('villages', 0.8550965189933777),
 ('moslem', 0.8526031970977783),
 ('murdered', 0.8524329662322998),
 ('azerbaijani', 0.8472632765769958),
 ('moslems', 0.8447192907333374),
 ('enclave', 0.8427597284317017),
 ('fighting', 0.8418922424316406),
 ('seized', 0.8408921360969543)]

In [23]:
w2v.most_similar(positive=['doctor'])

[('pain', 0.7204384803771973),
 ('vaginal', 0.7079140543937683),
 ('therapy', 0.6889781355857849),
 ('physician', 0.685503363609314),
 ('sugar', 0.6791753768920898),
 ('patient', 0.677403450012207),
 ('illness', 0.6750391721725464),
 ('bacteria', 0.6703062057495117),
 ('nuts', 0.665992259979248),
 ('prescription', 0.6636449098587036)]

In [24]:
w2v.most_similar(positive=['police'])

[('officers', 0.7047038674354553),
 ('democratic', 0.6865219473838806),
 ('minorities', 0.6780165433883667),
 ('patriotism', 0.6680333614349365),
 ('constitution', 0.6674482822418213),
 ('defamation', 0.6622008085250854),
 ('democracy', 0.6578359603881836),
 ('meetings', 0.6574252247810364),
 ('escaped', 0.6563798189163208),
 ('damages', 0.6549358367919922)]

In [25]:
w2v.most_similar(positive=['knife'])

[('homicides', 0.7911495566368103),
 ('practicing', 0.7750158905982971),
 ('humanity', 0.771486222743988),
 ('bosnia', 0.769193708896637),
 ('threat', 0.7680944800376892),
 ('democracy', 0.7670363187789917),
 ('questioned', 0.7663816213607788),
 ('raised', 0.7609885931015015),
 ('witness', 0.7584652304649353),
 ('persecution', 0.7578152418136597)]

In [26]:
w2v.most_similar(positive=['engine'])

[('cylinder', 0.6833519339561462),
 ('rear', 0.6611299514770508),
 ('racing', 0.6588989496231079),
 ('fairing', 0.6584413051605225),
 ('tires', 0.6199708580970764),
 ('compartment', 0.6195544600486755),
 ('seat', 0.6151350140571594),
 ('pedals', 0.6054973006248474),
 ('fork', 0.6045674681663513),
 ('inline', 0.6004648208618164)]