In [1]:
import torch
from torch import nn
from d2l import torch as d2l
import os
import collections

读取数据集

In [8]:
def read_data_nmt():
    """Load the English-French dataset.

    Defined in :numref:`sec_utils`"""
    data_dir = d2l.download_extract('fra-eng')
    with open(os.path.join(data_dir, 'fra.txt'), 'r', encoding='utf-8') as f:
        return f.read()
    
data_mnt = read_data_nmt()
print(data_mnt[0:75])

Go.	Va !
Hi.	Salut !
Run!	Cours !
Run!	Courez !
Who?	Qui ?
Wow!	Ça alors !



数据集预处理

In [9]:
def preprocess_nmt(text):
    """Preprocess the English-French dataset.

    Defined in :numref:`sec_utils`"""
    def no_space(char, prev_char):
        return char in set(',.!?') and prev_char != ' '

    # Replace non-breaking space with space, and convert uppercase letters to
    # lowercase ones
    text = text.replace('\u202f', ' ').replace('\xa0', ' ').lower()
    # Insert space between words and punctuation marks
    out = [' ' + char if i > 0 and no_space(char, text[i - 1]) else char
           for i, char in enumerate(text)]
    return ''.join(out)

pre_data = preprocess_nmt(data_mnt)
print(pre_data[0:75])

go .	va !
hi .	salut !
run !	cours !
run !	courez !
who ?	qui ?
wow !	ça al


In [22]:
def tokenize_nmt(text, num_examples=None):
    """Tokenize the English-French dataset.

    Defined in :numref:`sec_utils`"""
    source, target = [], []
    for i, line in enumerate(text.split('\n')):
        if num_examples and i > num_examples:
            break
        parts = line.split('\t')
        if len(parts) == 2:
            source.append(parts[0].split(' '))
            target.append(parts[1].split(' '))
    return source, target

source, target = tokenize_nmt(pre_data, 600)
print(len(source), len(target))
print(source[0:5])
print(target[0:5])

601 601
[['go', '.'], ['hi', '.'], ['run', '!'], ['run', '!'], ['who', '?']]
[['va', '!'], ['salut', '!'], ['cours', '!'], ['courez', '!'], ['qui', '?']]


定义词表

In [41]:
class Vocab:
    """Vocabulary for text."""
    def __init__(self, tokens=[], min_freq=0, reserved_tokens=[]):
        """Defined in :numref:`sec_text-sequence`"""
        # Flatten a 2D list if needed
        if tokens and isinstance(tokens[0], list):
            tokens = [token for line in tokens for token in line]
        # Count token frequencies
        counter = collections.Counter(tokens)
        self.token_freqs = sorted(counter.items(), key=lambda x: x[1],
                                  reverse=True)
        # The list of unique tokens
        self.idx_to_token = list(sorted(set(['<unk>'] + reserved_tokens + [
            token for token, freq in self.token_freqs if freq >= min_freq])))
        self.token_to_idx = {token: idx
                             for idx, token in enumerate(self.idx_to_token)}

    def __len__(self):
        return len(self.idx_to_token)

    def __getitem__(self, tokens):
        if not isinstance(tokens, (list, tuple)):
            return self.token_to_idx.get(tokens, self.unk)
        return [self.__getitem__(token) for token in tokens]

    def to_tokens(self, indices):
        if hasattr(indices, '__len__') and len(indices) > 1:
            return [self.idx_to_token[int(index)] for index in indices]
        return self.idx_to_token[indices]

    @property
    def unk(self):  # Index for the unknown token
        return self.token_to_idx['<unk>']
    
src_vocab = Vocab(source, min_freq=2,
                          reserved_tokens=['<pad>', '<bos>', '<eos>'])
tgt_vocab = Vocab(target, min_freq=2,
                          reserved_tokens=['<pad>', '<bos>', '<eos>'])
print(src_vocab.to_tokens(1))

,


In [47]:
def truncate_pad(line, num_steps, padding_token):
    """Truncate or pad sequences.

    Defined in :numref:`sec_utils`"""
    if len(line) > num_steps:
        return line[:num_steps]  # Truncate
    return line + [padding_token] * (num_steps - len(line))  # Pad

def build_array_nmt(lines, vocab, num_steps):
    """Transform text sequences of machine translation into minibatches.

    Defined in :numref:`sec_utils`"""
    lines = [vocab[l] for l in lines]
    lines = [l + [vocab['<eos>']] for l in lines]
    array = d2l.tensor([truncate_pad(
        l, num_steps, vocab['<pad>']) for l in lines])
    # print(array[0])
    valid_len = d2l.reduce_sum(
        d2l.astype(array != vocab['<pad>'], d2l.int32), 1)
    return array, valid_len

num_steps = 10
src_array, src_valid_len = build_array_nmt(source, src_vocab, num_steps)
tgt_array, tgt_valid_len = build_array_nmt(target, tgt_vocab, num_steps)
print(src_array[0], src_valid_len[0])

tensor([59,  2,  4,  5,  5,  5,  5,  5,  5,  5])
tensor([188,   0,   4,   5,   5,   5,   5,   5,   5,   5])
tensor([59,  2,  4,  5,  5,  5,  5,  5,  5,  5]) tensor(3)


In [48]:
batch_size, num_steps = 64, 10
train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)

In [52]:
src, src_valid_len, tgt, tgt_valid_len = next(iter(train_iter))
print(src[0], src_valid_len[0])
print(tgt[0], tgt_valid_len[0])

tensor([ 81, 144,   2,   4,   5,   5,   5,   5,   5,   5]) tensor(4)
tensor([100, 117, 171,   6,  56,   2,   4,   5,   5,   5]) tensor(7)


In [73]:
print(" ".join(src_vocab.to_tokens(src[0][:src_valid_len[0] - 1])))
print(" ".join(tgt_vocab.to_tokens(tgt[0][:tgt_valid_len[0] - 1])))

i stood .
je me suis <unk> debout .


In [74]:
state = ([1, 2, 3], [4, 5, 6])
state

([1, 2, 3], [4, 5, 6])

In [75]:
l1, l2 = state
l1[0] = 100
state

([100, 2, 3], [4, 5, 6])

In [2]:
torch.arange((10), dtype=torch.float32)[None, :]

tensor([[0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]])