## word2vec_skip-gram

In [5]:
'''Importing the required packages'''
import random
import collections
import math
import os
import zipfile
import time
import re
import numpy as np
import tensorflow as tf

from matplotlib import pylab
%matplotlib inline

from six.moves import range
from six.moves.urllib.request import urlretrieve

'''Make sure the dataset link is copied correctly'''
dataset_link = 'http://mattmahoney.net/dc/'
zip_file = 'text8.zip'

In [6]:
# 下载有Matt Mahoney 收集和清理的危机百科文章数据集，并将其存储为当前工作目录下的单独文件
def data_download(zip_file):
    """Download the required file"""
    if not os.path.exists(zip_file):
        zip_file, _ = urlretrieve(dataset_link + zip_file, zip_file)
        print('File downloaded successfully!')
    return None
data_download(zip_file)

In [7]:
# 压缩的文本数据集在内部文件夹数据集中提取，稍后将用于训练模型
"""Extracting the dataset in separate folder"""
extracted_folder = 'dataset'

if not os.path.isdir(extracted_folder):
    with zipfile.ZipFile(zip_file) as zf:
        zf.extractall(extracted_folder)
        
with open('dataset/text8') as ft_:
    full_text = ft_.read()

In [8]:
# 由于输入数据的文本中有多个标点符号和其他符号，相同的符号将被替换为带有标点符号名称和符号类型的相应字符
# 有助于让模型单独识别每个标点符号和其他符号并生成向量
def text_processing(ft8_text):
    """Replacing punctuation marks with tokens"""
    ft8_text = ft8_text.lower()
    ft8_text = ft8_text.replace('.', '<period>')
    ft8_text = ft8_text.replace(',', '<comma>')
    ft8_text = ft8_text.replace('"', '<quotation>')
    ft8_text = ft8_text.replace(';', '<semicolon>')
    ft8_text = ft8_text.replace('!', '<exclamation>')
    ft8_text = ft8_text.replace('?', '<question>')
    ft8_text = ft8_text.replace('(', '<paren_l>')
    ft8_text = ft8_text.replace(')', '<paren_r>')
    ft8_text = ft8_text.replace('--', '<hyphen>')
    ft8_text = ft8_text.replace(':', '<colon>')
    ft8_text_tokens = ft8_text.split()
    return ft8_text_tokens

ft_tokens = text_processing(full_text)

In [15]:
# 为了提高所产生的向量表示的质量，建议去除与单词相关的噪音，即输入数据集中词频小于7的单词，因为这些单词没有足够的信息来提供它们的上下文
# 可以通过检查单词数和数据集中的分布来调整此阈值，在此处设为7
"""Shortlisting words with frequency more than 7"""
word_cnt = collections.Counter(ft_tokens)
shortlisted_words = [w for w in ft_tokens if word_cnt[w] > 7]

# 列出数据集中词频最高的几个单词
print(shortlisted_words[:15])

['anarchism', 'originated', 'as', 'a', 'term', 'of', 'abuse', 'first', 'used', 'against', 'early', 'working', 'class', 'radicals', 'including']


In [16]:
# 检查数据集中所有单词的统计信息
print("Total number of shortlisted words : ", len(shortlisted_words))
print("Unique number of shortlisted_words: ", len(set(shortlisted_words)))

Total number of shortlisted words :  16616688
Unique number of shortlisted_words:  53721


In [32]:
# 为了处理语料库中存在的独特单词，我们制作了一组单词和它们在训练数据集中的词频
# 创建一个字典并将单词转换为整数，反之，将整数转换为单词
# 词频最高的单词被赋予最小值0， 其他单词也通过相似方式被赋予数值，从单词转换而来的整数倍存储在一个单独的数组中
def dict_creation(shortlisted_words):
    """The function creates a dictionary of the words present in dataset along with their frequency order"""
    counts = collections.Counter(shortlisted_words)
    vocabulary = sorted(counts, key=counts.get, reverse=True)
    rev_dictionary_ = {ii: word for ii, word in enumerate(vocabulary)}
    # print(rev_dictionary_)
    dictionary_ = {word: ii for ii, word in rev_dictionary_.items()}
    # print(dictionary_)
    return dictionary_, rev_dictionary_

dictionary_, rev_dictionary_ = dict_creation(shortlisted_words)
words_cnt = [dictionary_[word] for word in shortlisted_words]

### skip-gram 模型采用子采样的方法来处理文本中的停止词
### 通过在词频上设置阈值，可以消除所有那些词频较高且中心词周围没有任何重要上下文的单词，这带来了更快的训练速度和更好的词向量表示

### skip-gram论文中给出的概率分数函数，对于训练集中的每个单词，我们将根据以下公式给定的概率来决定是否将其移除
$$ P(w_{i}) = 1-\left( \sqrt\frac{t}{f(w_{i})}\right)$$
### 其中, t是阈值参数，$f(w_{i})$是单词$w_i$在总数据集中的词频

In [38]:
import random
"""Creaing the threshold and performing the subsampling"""
thresh = 0.00005
word_counts = collections.Counter(word_cnt)
total_count = len(words_cnt)
freqs = {word: count/total_count for word, count in word_counts.items()}
p_drop = {word: 1-np.sqrt(thresh/freqs[word]) for word in word_counts}
train_words = [word for word in words_cnt if p_drop[word]<random.random()]

KeyError: 5233