In [3]:
import torch
import torch.nn as nn
import torchviz
import sys; sys.path.insert(0, '../')
from exp import nb_d2l_utils

In [4]:
### For EN Text

In [5]:
import collections
class Vocab(object):  # This class is saved in d2l.
    def __init__(self, tokens, min_freq=0, use_special_tokens=False):
        # sort by frequency and token
        counter = collections.Counter(tokens)
        token_freqs = sorted(counter.items(), key=lambda x: x[0])
        token_freqs.sort(key=lambda x: x[1], reverse=True)
        if use_special_tokens:
            # padding, begin of sentence, end of sentence, unknown
            self.pad, self.bos, self.eos, self.unk = (0, 1, 2, 3)
            tokens = ['<pad>', '<bos>', '<eos>', '<unk>']
        else:
            self.unk = 0
            tokens = ['<unk>']
        tokens +=  [token for token, freq in token_freqs if freq >= min_freq]
        self.idx_to_token = []
        self.token_to_idx = dict()
        for token in tokens:
            self.idx_to_token.append(token)
            self.token_to_idx[token] = len(self.idx_to_token) - 1

    def __len__(self):
        return len(self.idx_to_token)

    def __getitem__(self, tokens):
        if not isinstance(tokens, (list, tuple)):
            return self.token_to_idx.get(tokens, self.unk)
        else:
            return [self.__getitem__(token) for token in tokens]

    def to_tokens(self, indices):
        if not isinstance(indices, (list, tuple)):
            return self.idx_to_token[indices]
        else:
            return [self.idx_to_token[index] for index in indices]

In [6]:
def en_pre_process(raw_text):
    lines = raw_text.split('\n')
    text = ' '.join(' '.join(lines).lower().split())
    vocab = Vocab(text)
    return [vocab[i] for i in text], vocab

In [7]:
# 连续采样
import random
def data_iter_consecutive(corpus_indices, batch_size, num_steps, device=None):
    # Offset for the iterator over the data for uniform starts
    offset = int(random.uniform(0,num_steps))
    # Slice out data - ignore num_steps and just wrap around
    num_indices = ((len(corpus_indices) - offset) // batch_size) * batch_size
    indices = torch.Tensor(corpus_indices[offset:(offset + num_indices)], device=device)
    indices = indices.reshape((batch_size,-1))
    # Need to leave one last token since targets are shifted by 1
    num_epochs = (num_indices // batch_size - 1) // num_steps

    for i in range(0, num_epochs * num_steps, num_steps):
        X = indices[:,i:(i+num_steps)]
        Y = indices[:,(i+1):(i+1+num_steps)]
        yield X, Y

In [10]:
with open('./data/timemachine.txt', 'r') as f:
    raw_text = f.read()

corpus_indices, vocab = en_pre_process(raw_text)
dataloader = data_iter_consecutive(corpus_indices, 3, num_steps)

In [27]:
for X, Y in dataloader:
    print(X)
    print(Y)
    for i in X:
        print(vocab.to_tokens([int(i) for i in i]))
    break

tensor([[24.,  2., 12., 12.,  2., 10.,  1., 38., 16.,  7.],
        [24.,  5.,  3.,  4., 12.,  5.,  3., 19.,  1.,  3.],
        [15.,  5.,  2.,  8., 21.,  1.,  4.,  6., 11.,  1.]])
tensor([[ 2., 12., 12.,  2., 10.,  1., 38., 16.,  7., 10.],
        [ 5.,  3.,  4., 12.,  5.,  3., 19.,  1.,  3.,  7.],
        [ 5.,  2.,  8., 21.,  1.,  4.,  6., 11.,  1., 16.]])
['v', 'e', 'l', 'l', 'e', 'r', ' ', '(', 'f', 'o']
['v', 'i', 't', 'a', 'l', 'i', 't', 'y', ' ', 't']
['c', 'i', 'e', 's', ',', ' ', 'a', 'n', 'd', ' ']


In [28]:
### For CN Text

In [44]:
def cn_pre_process(raw_text):
    text = raw_text.replace('\n', ' ')
    vocab = Vocab(text, min_freq=5)
    return [vocab[i] for i in text], vocab

In [45]:
with open('./data/jaychou_lyrics.txt', 'r') as f:
    raw_text = f.read()

In [46]:
corpus_indices, vocab = cn_pre_process(raw_text)
dataloader = data_iter_consecutive(corpus_indices, 3, num_steps)

In [47]:
for X, Y in dataloader:
    print(X)
    print(Y)
    for i in X:
        print(vocab.to_tokens([int(i) for i in i]))
    break

tensor([[1.4000e+01, 1.9000e+01, 2.3400e+02, 4.0000e+00, 2.0000e+02, 2.4000e+01,
         1.0050e+03, 1.0060e+03, 4.2000e+01, 1.0000e+00],
        [1.1000e+01, 6.5000e+01, 1.0290e+03, 0.0000e+00, 6.7000e+01, 2.8700e+02,
         1.0000e+00, 4.1500e+02, 1.1270e+03, 2.0000e+00],
        [2.6100e+02, 7.7000e+02, 8.2000e+01, 4.7000e+01, 2.1000e+01, 2.2000e+01,
         9.0000e+00, 7.2200e+02, 9.8000e+01, 4.3600e+02]])
tensor([[1.9000e+01, 2.3400e+02, 4.0000e+00, 2.0000e+02, 2.4000e+01, 1.0050e+03,
         1.0060e+03, 4.2000e+01, 1.0000e+00, 1.4000e+01],
        [6.5000e+01, 1.0290e+03, 0.0000e+00, 6.7000e+01, 2.8700e+02, 1.0000e+00,
         4.1500e+02, 1.1270e+03, 2.0000e+00, 1.9800e+02],
        [7.7000e+02, 8.2000e+01, 4.7000e+01, 2.1000e+01, 2.2000e+01, 9.0000e+00,
         7.2200e+02, 9.8000e+01, 4.3600e+02, 0.0000e+00]])
['想', '要', '和', '你', '飞', '到', '宇', '宙', '去', ' ']
['\u3000', '像', '欧', '<unk>', '情', '调', ' ', '书', '框', '的']
['两', '块', '空', '地', '那', '就', '是', '勇', '气', '与']
