In [2]:
import collections
import re
import os
import torch
import d2l.torch as d2l
from typing import List, Tuple

# 文本预处理

In [5]:
# 加载文本
d2l.DATA_HUB['fra-eng'] = (d2l.DATA_URL + 'fra-eng.zip',
                           '94646ad1522d915e7b0f9296181140edcf86a4f5')

def read_data_nmt():
    data_dir = d2l.download_extract('fra-eng')
    with open(os.path.join(data_dir, 'fra.txt'), 'r', encoding='utf-8') as f:
        return f.read()

raw_text = read_data_nmt()
print(raw_text[:75])

Go.	Va !
Hi.	Salut !
Run!	Cours !
Run!	Courez !
Who?	Qui ?
Wow!	Ça alors !



In [11]:
# 预处理文本
def preprocess_nmt(raw_text: str) -> str:
    # 使用空格替换不间断空格
    text = raw_text.replace('\u202f', ' ').replace('\xa0', ' ')
    # 使用小写字母替换大写字母
    text = text.lower()
    # 在单词和标点符号之间插入空格
    marks = set(',.?!')
    text = [' ' + char if char in marks and i > 0 and text[i-1] != ' ' else char
            for i, char in enumerate(text)]
    return ''.join(text)

text = preprocess_nmt(raw_text)
print(text[:80])

go .	va !
hi .	salut !
run !	cours !
run !	courez !
who ?	qui ?
wow !	ça alors !


# 文本分词

In [17]:
def tokenize_nmt(text: str, num_examples: int) -> Tuple[List[int], List[int]]:
    source, target = [], []
    for i, line in enumerate(text.split('\n')):
        if 0 <= num_examples == len(source):
            return source, target
        parts = line.split('\t')
        if len(parts) < 2:
            continue
        source.append(parts[0].split(' '))
        target.append(parts[1].split(' '))
    return source, target

source, target = tokenize_nmt(text, num_examples=-1)
source[:6], target[:6]

([['go', '.'],
  ['hi', '.'],
  ['run', '!'],
  ['run', '!'],
  ['who', '?'],
  ['wow', '!']],
 [['va', '!'],
  ['salut', '!'],
  ['cours', '!'],
  ['courez', '!'],
  ['qui', '?'],
  ['ça', 'alors', '!']])

# 建立词表

In [18]:
src_vocab = d2l.Vocab(source, min_freq=2,
                                  reserved_tokens=['<pad>', '<bos>', '<eos>'])
tgt_vocab = d2l.Vocab(target, min_freq=2,
                                  reserved_tokens=['<pad>', '<bos>', '<eos>'])

# 文本编码

In [19]:
src_corpus = src_vocab[source]
tgt_corpus = tgt_vocab[target]
src_corpus[:6], tgt_corpus[:6]

([[47, 4], [2944, 4], [435, 126], [435, 126], [90, 9], [3664, 126]],
 [[124, 34], [4183, 34], [579, 34], [5850, 34], [39, 7], [35, 386, 34]])

# 构建数据集

In [20]:
# 截断或填充文本序列
def truncate_pad(line: List[int], num_steps: int, padding_token: int) -> List[int]:
    if len(line) > num_steps:
        return line[:num_steps]

SyntaxError: unexpected EOF while parsing (2391291822.py, line 3)