python prepare_datasets.py --train_source=data/example/raw/src-train.txt --train_target=data/example/raw/tgt-train.txt --val_source=data/example/raw/src-val.txt --val_target=data/example/raw/tgt-val.txt --save_data_dir=data/example/processed

In [1]:
de_path=r"data/de-en/europarl-v7.de-en.de"
en_path=r"data/de-en/europarl-v7.de-en.en"

In [2]:
#需要把这里的数据划分为数据集和验证集。按1/10来吧
import re

def read_file(filepath):
    with open(filepath, encoding='utf-8') as file:
        lines = file.readlines() # 按行读取文件内容
    data = []
    for line in lines:
        # 使用正则表达式把特殊字符加上空格
        line = re.sub(r'([^\w\s])', r' \1 ', line)
        # 把多个空格合并成一个
        line = re.sub(r'\s+', ' ', line)
        data.append(line.strip()) # 前后去掉空格并添加到列表中
    return data

In [3]:
de_raw=read_file(de_path)
en_raw=read_file(en_path)

print(len(de_raw),len(en_raw))

1920209 1920209


In [4]:
en_raw[21]

''

In [5]:
en_raw[-1]

'( The sitting was closed at 10 . 50 a . m . )'

In [6]:
de_raw[1]

'Ich erkläre die am Freitag , dem 17 . Dezember unterbrochene Sitzungsperiode des Europäischen Parlaments für wiederaufgenommen , wünsche Ihnen nochmals alles Gute zum Jahreswechsel und hoffe , daß Sie schöne Ferien hatten .'

In [7]:
import random

def train_val_split(data,val_rate=0.1):
    n=int(len(data) * val_rate)
    val_data = data[:n]
    tra_data = data[n:]
    return tra_data,val_data

In [8]:
en_tra,en_val=train_val_split(en_raw,val_rate=0.1)
de_tra,de_val=train_val_split(de_raw,val_rate=0.1)

print(len(en_tra),len(en_val))
print(len(de_tra),len(de_val))

de_tra, en_tra = zip(*[(de, en) for de, en in zip(de_tra, en_tra) if de and en])#清掉任意项为空的项
de_val, en_val = zip(*[(de, en) for de, en in zip(de_val, en_val) if de and en])

print(len(en_tra),len(en_val))
print(len(de_tra),len(de_val))

1728189 192020
1728189 192020
1717885 191035
1717885 191035


In [9]:
import os
def write_to_txt_file(file_path, file_name, str_list):
    """将字符串列表输出到指定的txt文件中"""

    # 如果指定的目录不存在，则创建目录
    if not os.path.exists(file_path):
        os.makedirs(file_path)

    # 构造完整的文件路径
    full_path = os.path.join(file_path, file_name)

    # 使用with语句自动关闭文件流
    with open(full_path, mode='w', encoding='utf-8') as f:
        # 把字符串列表中的每个元素写入文件中
        for s in str_list:
            f.write(s + '\n')

In [10]:
folder="data/de-en/raw"
write_to_txt_file(folder,"tra_en.txt",en_tra)
write_to_txt_file(folder,"val_en.txt",en_val)

write_to_txt_file(folder,"tra_de.txt",de_tra)
write_to_txt_file(folder,"val_de.txt",de_val)

# 开始搭建语料库，做字典咯

In [11]:
PAD_TOKEN = '<PAD>'#变成编码后，应该填充是0
UNK_TOKEN = '<UNK>'#未知是1
START_TOKEN = '<StartSent>'#开始是2
END_TOKEN = '<EndSent>'#结束是3

from os.path import dirname, abspath, join, exists
import os
BASE_DIR = os.getcwd() # 等价于这个，那种表示适用于脚本中获取当前的路径的

train_source="data/de-en/raw/tra_en.txt"
train_target="data/de-en/raw/tra_de.txt"
val_source="data/de-en/raw/val_en.txt"
val_target="data/de-en/raw/val_de.txt"
save_data_dir="data/de-en/processed"

from datasets import TranslationDataset,TokenizedTranslationDataset,IndexedInputTargetTranslationDataset

TranslationDataset.prepare(train_source, train_target, val_source, val_target, save_data_dir)#处理生成目标文件
translation_dataset = TranslationDataset(save_data_dir, 'train')#读取train的数据集
# translation_dataset_on_the_fly = TranslationDatasetOnTheFly('train')#一样，读取train的数据集
# share_dictionary=False


from dictionaries import IndexDictionary
from utils.pipe import source_tokens_generator,target_tokens_generator

tokenized_dataset = TokenizedTranslationDataset(save_data_dir, 'train')#
source_generator = source_tokens_generator(tokenized_dataset)
source_dictionary = IndexDictionary(source_generator, mode='source')
target_generator = target_tokens_generator(tokenized_dataset)
target_dictionary = IndexDictionary(target_generator, mode='target')

source_dictionary.save(save_data_dir)
target_dictionary.save(save_data_dir)

In [12]:
source_dictionary.vocabulary_size,target_dictionary.vocabulary_size

(92349, 326269)

In [13]:
#读取，其实只要这么读就可以了，处理完后只需要运行这一个就行了
source_dictionary = IndexDictionary.load(save_data_dir, mode='source',vocabulary_size=38000)
target_dictionary = IndexDictionary.load(save_data_dir, mode='target',vocabulary_size=38000)

In [14]:
IndexedInputTargetTranslationDataset.prepare(save_data_dir, source_dictionary, target_dictionary)

# 但实际上这些都没啥用，乐，具体有效的应该是vocabulary- .txt的文件。有这个就行了
## 顺着跑到这里就行了

```
python train.py --data_dir=data/de-en/processed --save_config=checkpoints/de-en_config.json --save_checkpoint=checkpoints/de-en_model.pth --save_log=logs/de-en.log --positional_encoding --layers_count=4 --heads_count=4 --epochs=300 --batch_size=32 --vocabulary_size=38000


```