In [30]:
import sys
import pdb
import pprint
import logging
import os
import random

In [31]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
import numpy as np
import tqdm.auto as tqdm
from pathlib import Path
from argparse import Namespace

import matplotlib.pyplot as plt


In [32]:
seed = 73
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  
np.random.seed(seed)  
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

## 数据处理

In [33]:
prefix = "F:/Ai/Data/ml2021spring-hw5/DATA/rawdata/ted2020"
prefix = Path(prefix).absolute()

data_dir = "F:/Ai/Data/ml2021spring-hw5/DATA/rawdata"
dataset_name = 'ted2020'


In [34]:
src_lang = 'en'
tgt_lang = 'zh'

data_prefix = f'{prefix}/train_dev.raw'
test_prefix = f'{prefix}/test.raw'

In [35]:
data_prefix

'F:\\Ai\\Data\\ml2021spring-hw5\\DATA\\rawdata\\ted2020/train_dev.raw'

In [36]:

!head {data_prefix+'.'+src_lang} -n 5
!head {data_prefix+'.'+tgt_lang} -n 5

'head' �����ڲ����ⲿ���Ҳ���ǿ����еĳ���
���������ļ���
'head' �����ڲ����ⲿ���Ҳ���ǿ����еĳ���
���������ļ���


In [37]:
data_prefix

'F:\\Ai\\Data\\ml2021spring-hw5\\DATA\\rawdata\\ted2020/train_dev.raw'

In [38]:
import re

In [39]:
# 定义一个函数，将字符串中的全角字符转换为半角字符
def strQ2B(ustring):
    # 创建一个空列表，用于存储转换后的字符串
    ss = []
    # 遍历输入的字符串
    for s in ustring:
        # 创建一个空字符串，用于存储转换后的字符
        rstring = ""
        # 遍历字符串中的每个字符
        for uchar in s:
            # 获取字符的Unicode编码
            inside_code = ord(uchar)
            # 如果字符的Unicode编码为全角空格，则将其转换为半角空格
            if inside_code == 12288:
                inside_code = 32
            # 如果字符的Unicode编码在65281到65374之间，则将其转换为对应的半角字符
            elif (inside_code >= 65281 and inside_code <= 65374):
                inside_code -= 65248
            # 将转换后的字符添加到rstring中
            rstring += chr(inside_code)
        # 将转换后的字符串添加到ss列表中
        ss.append(rstring)
    # 将ss列表中的字符串连接起来，并返回
    return ''.join(ss)


In [40]:
def clean_s(s, lang):
    # 如果语言是英文
    if lang == 'en':
        # 去掉括号及其内容
        s = re.sub(r"\([^()]*\)", "", s)
        # 去掉连字符
        s = s.replace('-', '')
        # 在标点符号前后加空格
        s = re.sub('([.,;!?()\"])', r' \1 ', s)
    # 如果语言是中文
    elif lang == 'zh':
        # 将全角字符转换为半角字符
        s = strQ2B(s)
        # 去掉括号及其内容
        s = re.sub(r"\([^()]*\)", "", s)
        # 去掉空格
        s = s.replace(' ', '')
        # 去掉破折号
        s = s.replace('—', '')
        # 将中文引号替换为英文引号
        s = s.replace('“', '"')
        s = s.replace('”', '"')
        # 去掉下划线
        s = s.replace('_', '')
        # 在标点符号前后加空格
        s = re.sub('([。,;!?()\"~「」])', r' \1 ', s)
    # 去掉多余的空格
    s = ' '.join(s.strip().split())
    # 返回处理后的字符串
    return s

In [41]:
def len_s(s, lang):
    # 判断语言类型，如果是中文，则返回字符串长度，否则返回字符串中单词的个数
    if lang == 'zh':
        return len(s)
    return len(s.split())

In [42]:
def clean_corpus(prefix, l1, l2, ratio=9, max_len=1000, min_len=1):
    # 检查前缀为prefix的l1和l2文件是否存在，如果存在则跳过清理过程
    if Path(f'{prefix}.clean.{l1}').exists() and Path(f'{prefix}.clean.{l2}').exists():
        print(f'{prefix}.clean.{l1} & {l2} exists. skipping clean.')
        return
    # 打开前缀为prefix的l1和l2文件，以及清理后的l1和l2文件
    with open(f'{prefix}.{l1}', 'r') as l1_in_f:
        with open(f'{prefix}.{l2}', 'r') as l2_in_f:
            with open(f'{prefix}.clean.{l1}', 'w') as l1_out_f:
                with open(f'{prefix}.clean.{l2}', 'w') as l2_out_f:
                    # 遍历l1文件中的每一行
                    for s1 in l1_in_f:
                        # 去除行首行尾的空格
                        s1 = s1.strip()
                        # 读取l2文件中的对应行
                        s2 = l2_in_f.readline().strip()
                        # 清理l1和l2文件中的内容
                        s1 = clean_s(s1, l1)
                        s2 = clean_s(s2, l2)
                        # 计算l1和l2文件中的内容长度
                        s1_len = len_s(s1, l1)
                        s2_len = len_s(s2, l2)
                        # 如果最小长度大于0，则去除长度小于最小长度的句子
                        if min_len > 0: # remove short sentence
                            if s1_len < min_len or s2_len < min_len:
                                continue
                        # 如果最大长度大于0，则去除长度大于最大长度的句子
                        if max_len > 0: # remove long sentence
                            if s1_len > max_len or s2_len > max_len:
                                continue
                        # 如果比例大于0，则去除长度比例大于比例的句子
                        if ratio > 0: # remove by ratio of length
                            if s1_len/s2_len > ratio or s2_len/s1_len > ratio:
                                continue
                        # 将清理后的l1和l2文件写入对应文件
                        print(s1, file=l1_out_f)
                        print(s2, file=l2_out_f)

In [43]:
clean_corpus(data_prefix, src_lang, tgt_lang)
clean_corpus(test_prefix, src_lang, tgt_lang, ratio=-1, min_len=-1, max_len=-1)

F:\Ai\Data\ml2021spring-hw5\DATA\rawdata\ted2020/train_dev.raw.clean.en & zh exists. skipping clean.
F:\Ai\Data\ml2021spring-hw5\DATA\rawdata\ted2020/test.raw.clean.en & zh exists. skipping clean.


In [44]:

!head {data_prefix+'.clean.'+src_lang} -n 5
!head {data_prefix+'.clean.'+tgt_lang} -n 5

'head' �����ڲ����ⲿ���Ҳ���ǿ����еĳ���
���������ļ���
'head' �����ڲ����ⲿ���Ҳ���ǿ����еĳ���
���������ļ���


In [45]:
valid_ratio = 0.01
train_ratio = 1 - valid_ratio

In [46]:
# 判断train.clean.src_lang和train.clean.tgt_lang以及valid.clean.src_lang和valid.clean.tgt_lang是否存在
if (prefix / f'train.clean.{src_lang}').exists() and (prefix / f'train.clean.{tgt_lang}').exists() and \
   (prefix / f'valid.clean.{src_lang}').exists() and (prefix / f'valid.clean.{tgt_lang}').exists():
    # 如果存在，则打印提示信息，并跳过split
    print(f'train/valid splits exists. skipping split.')
else:
    # 如果不存在，则计算data_prefix.clean.src_lang文件的行数
    line_num = sum(1 for line in open(f'{data_prefix}.clean.{src_lang}'))
    # 生成一个与行数相同的标签列表
    labels = list(range(line_num))
    # 随机打乱标签列表
    random.shuffle(labels)
    # 遍历src_lang和tgt_lang
    for lang in [src_lang, tgt_lang]:
        # 打开train.clean.lang文件
        train_f = open(os.path.join(data_dir, dataset_name, f'train.clean.{lang}'), 'w')
        # 打开valid.clean.lang文件
        valid_f = open(os.path.join(data_dir, dataset_name, f'valid.clean.{lang}'), 'w')
        # 初始化计数器
        count = 0
        # 遍历data_prefix.clean.lang文件
        for line in open(f'{data_prefix}.clean.{lang}', 'r'):
            # 如果标签列表中的第count个元素除以行数小于train_ratio，则将行写入train.clean.lang文件
            if labels[count]/line_num < train_ratio:
                train_f.write(line)
            # 否则将行写入valid.clean.lang文件
            else:
                valid_f.write(line)
            # 计数器加1
            count += 1
        # 关闭train.clean.lang文件
        train_f.close()
        # 关闭valid.clean.lang文件
        valid_f.close()

train/valid splits exists. skipping split.


In [47]:
## Subword Unit

import sentencepiece as spm

vocab_size = 8000

if (prefix/f'spm{vocab_size}.model').exists():
    print(f'{prefix}/spm{vocab_size}.model exists. skipping spm_train.')
else:
    spm.SentencePieceTrainer.train(
        input=','.join([f'{prefix}/train.clean.{src_lang}',
                        f'{prefix}/valid.clean.{src_lang}',
                        f'{prefix}/train.clean.{tgt_lang}',
                        f'{prefix}/valid.clean.{tgt_lang}']),
        model_prefix = prefix/f'spm{vocab_size}',
        vocab_size = vocab_size,
        character_coverage = 1,
        model_type = 'unigram',
        input_sentence_size=1e6,
        shuffle_input_sentence=True,
        normalization_rule_name = 'nmt_nfkc_cf'
    )

F:\Ai\Data\ml2021spring-hw5\DATA\rawdata\ted2020/spm8000.model exists. skipping spm_train.


In [48]:
spm_model = spm.SentencePieceProcessor(model_file=str(prefix/f'spm{vocab_size}.model'))
in_tag = {
    'train':'train.clean',
    'valid':'valid.clean',
    'test':'test.raw.clean',
}

for split in ['train', 'valid', 'test']:
    for lang in [src_lang, tgt_lang]:
        out_path = prefix/f'{split}.{lang}'
        if out_path.exists():
            print(f"{out_path} exists. skipping spm_encode.")
        else:
            with open(prefix/f'{split}.{lang}', 'w') as out_f:
                with open(prefix/f'{in_tag[split]}.{lang}', 'r') as in_f:
                    for line in in_f:
                        line = line.strip()
                        tok = spm_model.encode(line, out_type=str)
                        print(' '.join(tok), file=out_f)

F:\Ai\Data\ml2021spring-hw5\DATA\rawdata\ted2020\train.en exists. skipping spm_encode.
F:\Ai\Data\ml2021spring-hw5\DATA\rawdata\ted2020\train.zh exists. skipping spm_encode.
F:\Ai\Data\ml2021spring-hw5\DATA\rawdata\ted2020\valid.en exists. skipping spm_encode.
F:\Ai\Data\ml2021spring-hw5\DATA\rawdata\ted2020\valid.zh exists. skipping spm_encode.
F:\Ai\Data\ml2021spring-hw5\DATA\rawdata\ted2020\test.en exists. skipping spm_encode.
F:\Ai\Data\ml2021spring-hw5\DATA\rawdata\ted2020\test.zh exists. skipping spm_encode.


In [49]:
!head {data_dir+'/'+dataset_name+'/train.'+src_lang} -n 5
!head {data_dir+'/'+dataset_name+'/train.'+tgt_lang} -n 5|

'head' �����ڲ����ⲿ���Ҳ���ǿ����еĳ���
���������ļ���
�����﷨����ȷ��


In [50]:
# 替换 fairseq 功能：构建共享词表
binpath = Path('./DATA/data-bin', dataset_name)
binpath.mkdir(parents=True, exist_ok=True)

In [51]:
from collections import Counter
# 从分词后的训练数据构建词表
vocab = Counter()
for lang in [src_lang, tgt_lang]:
    train_file = prefix / f'train.{lang}'
    if train_file.exists():
        with open(train_file, 'r', encoding='utf-8') as f:
            for line in f:
                vocab.update(line.strip().split())

In [52]:
special_tokens = ['<unk>', '<pad>', '<bos>', '<eos>']
vocab = special_tokens + [word for word, count in vocab.most_common(vocab_size - len(special_tokens))]

# 保存词表
vocab_file = binpath / 'vocab.txt'
with open(vocab_file, 'w', encoding='utf-8') as f:
    for word in vocab:
        f.write(word + '\n')

# 生成索引数据（替代 fairseq 的二进制格式）
word_to_id = {word: idx for idx, word in enumerate(vocab)}
for split in ['train', 'valid', 'test']:
    for lang in [src_lang, tgt_lang]:
        input_file = prefix / f'{split}.{lang}'
        output_file = binpath / f'{split}.ids.{lang}'
        if input_file.exists():
            with open(input_file, 'r', encoding='utf-8') as f_in, \
                    open(output_file, 'w', encoding='utf-8') as f_out:
                for line in f_in:
                    tokens = line.strip().split()
                    ids = [word_to_id.get(token, word_to_id['<unk>']) for token in tokens]
                    f_out.write(' '.join(map(str, ids)) + '\n')

print(f"预处理完成，分词数据保存在 {prefix}，词表和索引数据保存在 {binpath}")

预处理完成，分词数据保存在 F:\Ai\Data\ml2021spring-hw5\DATA\rawdata\ted2020，词表和索引数据保存在 DATA\data-bin\ted2020


In [53]:
import torch

# 检查 CUDA 可用性
if torch.cuda.is_available():
    device = torch.device('cuda:0')
    print("使用 GPU:")
    print(f"  - GPU 数量: {torch.cuda.device_count()}")
    print(f"  - 当前 GPU: {torch.cuda.get_device_name(0)}")
    print(f"  - CUDA 版本: {torch.version.cuda}")
else:
    device = torch.device('cpu')
    print("使用 CPU")

# 示例：后续使用 device
print(f"计算设备: {device}")

使用 GPU:
  - GPU 数量: 1
  - 当前 GPU: NVIDIA GeForce RTX 4070 Ti SUPER
  - CUDA 版本: 12.4
计算设备: cuda:0


In [54]:
config = Namespace(
    datadir = "./DATA/data-bin/ted2020",
    savedir = "./checkpoints/rnn",
    source_lang = "en",
    target_lang = "zh",
    
    # cpu threads when fetching & processing data.
    num_workers=2,  
    # batch size in terms of tokens. gradient accumulation increases the effective batchsize.
    max_tokens=8192,
    accum_steps=2,
    
    # the lr s calculated from Noam lr scheduler. you can tune the maximum lr by this factor.
    lr_factor=2.,
    lr_warmup=4000,
    
    # clipping gradient norm helps alleviate gradient exploding
    clip_norm=1.0,
    
    # maximum epochs for training
    max_epoch=30,
    start_epoch=1,
    
    # beam size for beam search
    beam=5, 
    # generate sequences of maximum length ax + b, where x is the source length
    max_len_a=1.2, 
    max_len_b=10,
    # when decoding, post process sentence by removing sentencepiece symbols.
    post_process = "sentencepiece",
    
    # checkpoints
    keep_last_epochs=5,
    resume=None, # if resume from checkpoint name (under config.savedir)
    
    # logging
    use_wandb=False,
)

In [55]:
logging.basicConfig(
    format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    level="INFO", # "DEBUG" "WARNING" "ERROR"
    stream=sys.stdout,
)
proj = "hw5.seq2seq"
logger = logging.getLogger(proj)
if config.use_wandb:
    import wandb
    wandb.init(project=proj, name=Path(config.savedir).stem, config=config)

In [56]:
from pathlib import Path
import torch
from datasets import load_dataset
from transformers import PreTrainedTokenizerFast
from torch.utils.data import DataLoader
from argparse import Namespace

# 配置
config = Namespace(
    datadir="./DATA/data-bin/ted2020",
    source_lang="en",
    target_lang="zh",
    num_workers=2,
    max_tokens=8192,
)

# 确认 sentencepiece 模型
spm_model_path = prefix / f'spm{vocab_size}.model'
if not spm_model_path.exists():
    raise FileNotFoundError(f"找不到 {spm_model_path}，请先训练 sentencepiece 模型")

# 加载 tokenizer
tokenizer = PreTrainedTokenizerFast(
    tokenizer_file=str(spm_model_path),
    bos_token="<bos>",
    eos_token="<eos>",
    unk_token="<unk>",
    pad_token="<pad>",
)

# 加载数据集
data_files = {
    "train": {
        config.source_lang: str(Path(config.datadir) / f'train.{config.source_lang}'),
        config.target_lang: str(Path(config.datadir) / f'train.{config.target_lang}'),
    },
    "valid": {
        config.source_lang: str(Path(config.datadir) / f'valid.{config.source_lang}'),
        config.target_lang: str(Path(config.datadir) / f'valid.{config.target_lang}'),
    },
    "test": {
        config.source_lang: str(Path(config.datadir) / f'test.{config.source_lang}'),
        config.target_lang: str(Path(config.datadir) / f'test.{config.target_lang}'),
    },
}

dataset = load_dataset("text", data_files=data_files)

# 数据处理
def process_data(examples):
    src_texts = examples[config.source_lang]
    tgt_texts = examples[config.target_lang]
    encodings = tokenizer(
        src_texts,
        text_target=tgt_texts,
        max_length=config.max_tokens // 2,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    )
    for key in ["input_ids", "labels"]:
        seq_len = encodings[key].shape[1]
        padded_len = (seq_len + 7) // 8 * 8
        if seq_len < padded_len:
            encodings[key] = torch.nn.functional.pad(
                encodings[key], (0, padded_len - seq_len), value=tokenizer.pad_token_id
            )
    encodings["attention_mask"] = (encodings["input_ids"] != tokenizer.pad_token_id).long()
    return encodings

# 处理数据集
processed_dataset = dataset.map(
    process_data,
    batched=True,
    remove_columns=[config.source_lang, config.target_lang],
)

# 创建 DataLoader
def collate_fn(batch):
    return {
        "input_ids": torch.stack([x["input_ids"] for x in batch]),
        "attention_mask": torch.stack([x["attention_mask"] for x in batch]),
        "labels": torch.stack([x["labels"] for x in batch]),
    }

dataloaders = {
    split: DataLoader(
        processed_dataset[split],
        batch_size=16,
        shuffle=(split == "train"),
        num_workers=config.num_workers,
        collate_fn=collate_fn,
    )
    for split in ["train", "valid", "test"] if split in processed_dataset
}

print(f"数据集加载完成：{dataloaders.keys()}")

Exception: stream did not contain valid UTF-8

In [None]:
spm_model_path

WindowsPath('DATA/rawdata/ted2020/spm8000.model')