In [1]:
from utils import NoisyTextDataset
from transformers import BartConfig, BartForConditionalGeneration, PreTrainedTokenizerFast
from sklearn.model_selection import train_test_split
import os 
import pandas as pd 
from tqdm import tqdm
from datasets import Dataset



# 加载分词器
tokenizer = PreTrainedTokenizerFast.from_pretrained("../user_data/bart_tokenizer")
print("tokenizer is done!")


# 创建 BART 配置，替换为 base 版本的参数配置
config = BartConfig(
    vocab_size=tokenizer.vocab_size,
    max_position_embeddings=128, 
    encoder_layers=6,
    decoder_layers=6,
    encoder_attention_heads=8,
    decoder_attention_heads=8,
    d_model=512,
    pad_token_id=tokenizer.pad_token_id,  # 设置 <pad> token 的 ID
    bos_token_id=tokenizer.bos_token_id,  # 设置 <s> token 的 ID
    eos_token_id=tokenizer.eos_token_id,   # 设置 </s> token 的 ID
    forced_eos_token_id=tokenizer.eos_token_id
)


# 初始化 BART 模型，使用自定义配置，重新初始化权重
model = BartForConditionalGeneration(config)
print("model is done!")

resampled_monolingual_data = pd.read_csv("../user_data/step0/resampled_monolingual_data.csv")
# 获取所有文本数据（来自单语和双语数据集的合并数据）
texts = resampled_monolingual_data['text'].tolist()
print(f"texts: {len(texts)}")
dataset = NoisyTextDataset(texts, tokenizer)
print("dataset is done!")

  from .autonotebook import tqdm as notebook_tqdm
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'BartTokenizer'. 
The class this function is called from is 'PreTrainedTokenizerFast'.


tokenizer is done!
model is done!
texts: 4419969
dataset is done!


In [None]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
import os


output_dir = "../user_data/step0"
# 设置训练参数
training_args = Seq2SeqTrainingArguments(
    output_dir=os.path.join(output_dir, "results"),         # 训练结果保存路径
    logging_dir=os.path.join(output_dir, "logs"),           # 日志保存路径
    logging_strategy="steps",                               # 记录一次日志
    logging_steps=1000,
    save_strategy="epoch",                                  # 每个 epoch 保存一次
    learning_rate=5e-5,                                     # 学习率
    per_device_train_batch_size=128,                         # 每个设备的训练批次大小
    weight_decay=0.01,                                      # 权重衰减
    num_train_epochs=3,                                     # 训练 epoch 数
    bf16=True,                                              # 使用 bf16 精度
    save_total_limit=3,                                     # 保存的 checkpoint 数量上限
)

print("训练参数已设置完成！")

# 使用 Seq2SeqTrainer 进行训练
trainer = Seq2SeqTrainer(
    model=model,                                     # BART 模型
    args=training_args,                              # 训练参数
    train_dataset=dataset,                           # 训练数据集
    tokenizer=tokenizer,                             # 分词器
)

# 开始训练
trainer.train()

print("模型训练完成并已保存！")

2024-10-08 17:50:10.393925: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-10-08 17:50:10.414502: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-10-08 17:50:10.420617: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-10-08 17:50:10.436972: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


训练参数已设置完成！


Step,Training Loss
1000,3.2422
2000,0.9568
3000,0.7426
4000,0.643
5000,0.5821
6000,0.5389
7000,0.5076
8000,0.4677
9000,0.4137
10000,0.359


IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

