In [1]:
import nltk
import numpy as np

nltk.download('punkt')

[nltk_data] Downloading package punkt to /Users/tippy/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

In [3]:
from nltk.tokenize import word_tokenize, sent_tokenize
from collections import defaultdict

In [5]:
class TheLittlePrinceDataset:
    def __init__(self, tokenizer=True):
        with open('datas/TheLittlePrince.txt', 'r') as f:
            text = f.read()
            
        if tokenizer:
            self.sentences = sent_tokenize(text.lower())
            self.tokens = [word_tokenize(sent) for sent in self.sentences]
        else:
            self.text = text
            
    def build_vocab(self, min_freq=1):
        frequency = defaultdict(int)
        for sent in self.tokens:
            for word in sent:
                frequency[word] += 1
        self.frequency = frequency
        
        self.word2idx = {'<unk>': 1, '<pad>': 0}
        self.idx2word = {1: '<unk>', 0: '<pad>'}
        
        for token, freq in sorted(frequency.items(), key=lambda x: -x[1]):
            # 丢弃低频词
            if freq > min_freq:
                idx = len(self.word2idx)
                self.word2idx[token] = idx
                self.idx2word[idx] = token
            else:
                break
    
    def convert_tokens_to_ids(self, drop_single_word=True):
        self.token_ids = []
        for sent in self.tokens:
            token_ids = [self.word2idx.get(word, 1) for word in sent]
            
            # 忽略只有一个词元的序列，无法计算损失
            if drop_single_word and len(token_ids) == 1:
                continue
            self.token_ids.append(token_ids)
            
        return self.token_ids
    
dataset = TheLittlePrinceDataset()
dataset.build_vocab()
sentences = dataset.convert_tokens_to_ids()

In [6]:
window_size = 2
data = []
for sent in sentences:
    for i in range(len(sent)):
        for j in range(i-window_size,i+window_size+1):
            if j < 0 or j >= len(sent) or i == j:
                continue
            center_word = sent[i]
            context_word = sent[j]
            data.append([center_word, context_word])

import numpy as np
data = np.array(data)
data.shape, data

((77380, 2),
 array([[  4,  17],
        [  4,  20],
        [ 17,   4],
        ...,
        [127,   3],
        [  3,  84],
        [  3, 127]]))