In [None]:
!pip install transformers datasets evaluate
!pip uninstall accelerate
!pip install accelerate
!pip install sacrebleu

# !pip uninstall sentencepiece
# !pip install sentencepiece
!pip install --upgrade urllib3

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
from datasets import load_dataset, load_from_disk
import pandas as pd
from transformers import AutoTokenizer, EncoderDecoderModel
import evaluate
import numpy as np
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments,DataCollatorForSeq2Seq

data_path = "/content/drive/MyDrive/Dataset/"
train_dataset = load_dataset("wmt18", "zh-en", split='train[:5%]', cache_dir=data_path)
val_dataset = load_dataset("wmt18", "zh-en", split='validation[:50%]', cache_dir=data_path)

Downloading data files:   0%|          | 0/10 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/113M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.39G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/98.2M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/167M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/107M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/100M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/99.3M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/150M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/38.7M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/10 [00:00<?, ?it/s]

Extracting data files: 0it [00:00, ?it/s]

Generating train split:   0%|          | 0/25160346 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/2001 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/3981 [00:00<?, ? examples/s]

In [None]:
'''
extract 'en' column and 'zh' column from dataset
'''

en_texts = []
zh_texts =[]

for item in train_dataset['translation']:
    en_texts.append(item['en'])
    zh_texts.append(item['zh'])

val_en_texts = []
val_zh_texts = []

for item in val_dataset["translation"]:
    val_en_texts.append(item['en'])
    val_zh_texts.append(item['zh'])

In [None]:
ds_train = pd.DataFrame({'en':en_texts, 'zh':zh_texts})
ds_val = pd.DataFrame({'en':val_en_texts, 'zh':val_zh_texts})

print(ds_train['zh'].head())
print(ds_train['en'].head())

0                                        1929年还是1989年?
1    巴黎-随着经济危机不断加深和蔓延，整个世界一直在寻找历史上的类似事件希望有助于我们了解目前正...
2    一开始，很多人把这次危机比作1982年或1973年所发生的情况，这样得类比是令人宽心的，因为...
3    如今人们的心情却是沉重多了，许多人开始把这次危机与1929年和1931年相比，即使一些国家政...
4                    目前的趋势是，要么是过度的克制（欧洲），要么是努力的扩展（美国）。
Name: zh, dtype: object
0                                        1929 or 1989?
1    PARIS – As the economic crisis deepens and wid...
2    At the start of the crisis, many people likene...
3    Today, the mood is much grimmer, with referenc...
4    The tendency is either excessive restraint (Eu...
Name: en, dtype: object


In [None]:
model_checkpoint = "bert-base-multilingual-cased"
#model_checkpoint = "Helsinki-NLP/opus-mt-zh-en"

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, model_max_length=128)
#tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, model_max_length=150)
model = EncoderDecoderModel.from_encoder_decoder_pretrained(model_checkpoint, model_checkpoint)

Downloading (…)okenizer_config.json:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/625 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/996k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.96M [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/714M [00:00<?, ?B/s]

Some weights of BertLMHeadModel were not initialized from the model checkpoint at bert-base-multilingual-cased and are newly initialized: ['bert.encoder.layer.11.crossattention.self.key.weight', 'bert.encoder.layer.7.crossattention.self.query.weight', 'bert.encoder.layer.5.crossattention.self.query.bias', 'bert.encoder.layer.3.crossattention.self.value.bias', 'bert.encoder.layer.0.crossattention.self.value.bias', 'bert.encoder.layer.9.crossattention.output.dense.weight', 'bert.encoder.layer.2.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.1.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.3.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.6.crossattention.self.value.bias', 'bert.encoder.layer.5.crossattention.output.dense.bias', 'bert.encoder.layer.10.crossattention.self.value.weight', 'bert.encoder.layer.6.crossattention.self.query.bias', 'bert.encoder.layer.4.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.6.crossattention.output.dense.bia

In [None]:
encoder_max_length = 512
decoder_max_length = 128
# encoder_max_length = 256
# decoder_max_length = 64

def process_data_to_model_inputs(batch):
    inputs = tokenizer(batch["zh"], padding="max_length", truncation=True, max_length=encoder_max_length, return_tensors="pt")
    outputs = tokenizer(batch["en"], padding="max_length", truncation=True, max_length=decoder_max_length)

    batch["input_ids"] = inputs.input_ids
    batch["attention_mask"] = inputs.attention_mask
    batch["decoder_input_ids"] = outputs.input_ids
    batch["decoder_attention_mask"] = outputs.attention_mask
    batch["labels"] = outputs.input_ids.copy()
    batch["decoder_attention_mask"] = [[1 if token != tokenizer.pad_token_id else 0 for token in labels] for labels in batch["labels"]]

    #batch["labels"] = [[-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"]]

    return batch


In [None]:
ds_train.to_csv(data_path + "wmt18_train.csv", index=False)
ds_val.to_csv(data_path + "wmt18_val.csv", index=False)

ds_train = load_dataset('csv', data_files=data_path + "wmt18_train.csv")
ds_val = load_dataset('csv', data_files=data_path + "wmt18_val.csv")

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [None]:
df_train = pd.read_csv(data_path + "wmt18_train.csv")
df_val = pd.read_csv(data_path + "wmt18_val.csv")

subset_size = 1200
ds_train = df_train.head(subset_size)

val_subset_size = 500
ds_val = df_val.head(val_subset_size)

In [None]:
ds_train.to_csv(data_path + "wmt18_train.csv", index=False)
ds_val.to_csv(data_path + "wmt18_val.csv", index=False)

ds_train = load_dataset('csv', data_files=data_path + "wmt18_train.csv",split='train')
ds_val = load_dataset('csv', data_files=data_path + "wmt18_val.csv",split='train')

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [None]:
'''
to check it if there's NaN in the dataset
'''

df_train = ds_train.to_pandas()
df_val = ds_val.to_pandas()

missing_values_train = df_train.isnull().sum()
missing_values_val = df_val.isnull().sum()

print(missing_values_train, missing_values_val)

en    0
zh    0
dtype: int64 en    0
zh    0
dtype: int64


In [None]:
tk_train = ds_train.map(
    process_data_to_model_inputs,
    batched=True,
    remove_columns=["en", "zh"]
)

tk_val = ds_val.map(
    process_data_to_model_inputs,
    batched=True,
    remove_columns=["en", "zh"]
)

Map:   0%|          | 0/1200 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

In [None]:
# save the processed train dataset
tk_train.save_to_disk(data_path + "processed_wmt18_train")

# 保存处理后的验证数据集
tk_val.save_to_disk(data_path + "processed_wmt18_val")

Saving the dataset (0/1 shards):   0%|          | 0/1200 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/500 [00:00<?, ? examples/s]

In [None]:
# 加载处理后的训练数据集
tk_train = load_from_disk(data_path + "processed_wmt18_train")

# 加载处理后的验证数据集
tk_val = load_from_disk(data_path + "processed_wmt18_val")

In [None]:
tk_train

Dataset({
    features: ['input_ids', 'attention_mask', 'decoder_input_ids', 'decoder_attention_mask', 'labels'],
    num_rows: 1200
})

In [None]:
# tk_val['train']

# has_none = any(item is None for item in tk_val)
# has_empty_string = any(item =="" for item in tk_val)

# if has_none:
#     print("containing nan")
# else:
#     print("no nan")

# if has_empty_string:
#     print("nan string yes")
# else:
#     print("no nan string")

no nan
no nan string


In [None]:
from datasets import load_metric
metric = evaluate.load("sacrebleu")

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]

    return preds, labels

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    result = {"bleu": result["score"]}

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result


Downloading builder script:   0%|          | 0.00/8.15k [00:00<?, ?B/s]

In [None]:
# vocab = tokenizer.get_vocab()

#  if "[START]" not in vocab:
#     new_id = len(vocab)
#     vocab["[START]"] = new_id
#     tokenizer.add_token("[START]")

# start_token_id = vocab.get("[START]", tokenizer.unk_token_id)

# model.config.decoder_start_token_id = start_token_id

## Training

In [None]:
model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.eos_token_id = tokenizer.sep_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.config.vocab_size = model.config.encoder.vocab_size
model.config.max_length = 256
model.config.min_length = 12
# model.config.max_length = 256
# model.config.min_length = 24
model.config.no_repeat_ngram_size = 3
model.config.early_stopping = True
model.config.length_penalty = 2.0
model.config.num_beams = 4

In [None]:
batch_size = 4
args = Seq2SeqTrainingArguments(
    #eval_stpes=100,
    predict_with_generate=True,
    evaluation_strategy = "epoch",
    #evaluation_strategy = "steps",
    per_device_train_batch_size = batch_size,
    per_device_eval_batch_size = batch_size,
    num_train_epochs = 3,
    report_to = None,
    output_dir = "./",
    logging_steps=2,
    #save_steps=10,
    save_steps=999999,
    eval_steps=4,
)

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, return_tensors="pt")

In [None]:
trainer = Seq2SeqTrainer(
    model=model,
    args=args,
    compute_metrics=compute_metrics,
    train_dataset=tk_train,
    eval_dataset=tk_val,
    data_collator=data_collator,
    tokenizer=tokenizer,
    # train_dataset_key="train",
    # eval_dataset_key="train",
    #for decoder input
    #train_decoder_input_ids=tk_train['decoder_input_ids'],
    #eval_decoder_input_ids=tk_val['decoder_input_ids']
)
trainer.train()



Epoch,Training Loss,Validation Loss,Bleu,Gen Len
1,0.0001,6.761288,0.0049,256.0
2,0.0,6.971948,0.0048,254.728
3,0.0,7.046183,0.0046,254.092


TrainOutput(global_step=900, training_loss=0.0003880182137236766, metrics={'train_runtime': 4590.0679, 'train_samples_per_second': 0.784, 'train_steps_per_second': 0.196, 'total_flos': 2209419307008000.0, 'train_loss': 0.0003880182137236766, 'epoch': 3.0})

In [None]:
trainer.save_model("/content/drive/MyDrive/Dataset/")

model.from_pretrained("/content/drive/MyDrive/Dataset/")

In [None]:
import matplotlib.pyplot as plt
# with max_length of 128
train_loss = [0.011600,0.009200	,0.005900]
validation_loss =[8.611526,9.142328,9.494763]
bleu_scores = [0.010400,0.014800,0.031800]
gen_lengths =[126.500000,128.000000,128.000000]

fig, axs = plt.subplots(2,1, figsize=(5,6))#create fig

#draw the loss
axs[0].plot(train_loss, label='Train Loss', marker='o')
axs[0].plot(validation_loss, label='Vakidation Loss', marker='o')
axs[0].legend()
axs[0].set_title('Training and Validation Loss')
axs[0].set_xlabel('Epoch')
axs[0].set_ylabel('Loss')

#draw the bleu socres and gen len
axs[1].plot(bleu_scores, label='BLEU Score', marker='o')
axs[1].plot(gen_lengths, label='Generated Sequence Length', marker='o')
axs[1].legend()
axs[1].set_title('BLEU Score and Generated Sequence Length')
axs[1].set_xlabel('Epoch')
axs[1].set_ylabel('Value')

plt.tight_layout()
plt.show()

## Evaluation

In [None]:
test_dataset = load_dataset("wmt18", "zh-en", split='test[:5%]', cache_dir=data_path)

test_en_texts =[]
test_zh_texts = []

for item in test_dataset['translation']:
    test_en_texts.append(item['en'])
    test_zh_texts.append(item['zh'])

ds_test = pd.DataFrame({'en':test_en_texts, 'zh':test_zh_texts})

In [None]:
ds_test

Unnamed: 0,en,zh
0,"Last week, the broadcast of period drama “Beau...",上周，古装剧《美人私房菜》临时停播，意外引发了关于国产剧收视率造假的热烈讨论。
1,Civil rights group issues travel warning for M...,民权团体针对密苏里州发出旅行警告
2,The National Association for the Advancement o...,由于密苏里州的歧视性政策和种族主义袭击，美国有色人种促进协会 (NAACP) 向准备前往密苏...
3,"""The NAACP Travel Advisory for the state of Mi...",“2017 年 8 月 28 日生效的 NAACP 密苏里州旅行咨询中呼吁，因近期密苏里州发...
4,A recent Missouri law making it harder for peo...,NAACP 指出，最近通过的一项密苏里州法律使得人们更难赢得歧视诉讼，该州执法也一定程度上针...
...,...,...
194,"After the policy was enhanced, officials state...",随后政策加码，官方表示，“大学生可以低于市场价20%的价格买房”；
195,"In July, Chengdu issue regulations stating tha...",7月，成都出台规定，外地大学本科及以上学历的青年人才，凭毕业证即可申请落户成都。
196,"In addition, Kunshan, Lingang in Shanghai, as ...",除此之外，还有昆山、上海临港等城市或区域也出台了 “人才引入”相关政策，对人才购房落户降低门槛。
197,"Zhang Hongwei believed that this was, in fact,...",张宏伟认为，这实际上是各大城市对于人口尤其是人才的争夺，留住人口尤其是人才，城市才具有竞争力...


In [None]:
model = EncoderDecoderModel.from_pretrained("/content/drive/MyDrive/Dataset/")

In [None]:
test_size = 300
df_test = ds_test.head(test_size)

In [None]:
df_test

Unnamed: 0,en,zh
0,"Last week, the broadcast of period drama “Beau...",上周，古装剧《美人私房菜》临时停播，意外引发了关于国产剧收视率造假的热烈讨论。
1,Civil rights group issues travel warning for M...,民权团体针对密苏里州发出旅行警告
2,The National Association for the Advancement o...,由于密苏里州的歧视性政策和种族主义袭击，美国有色人种促进协会 (NAACP) 向准备前往密苏...
3,"""The NAACP Travel Advisory for the state of Mi...",“2017 年 8 月 28 日生效的 NAACP 密苏里州旅行咨询中呼吁，因近期密苏里州发...
4,A recent Missouri law making it harder for peo...,NAACP 指出，最近通过的一项密苏里州法律使得人们更难赢得歧视诉讼，该州执法也一定程度上针...
...,...,...
194,"After the policy was enhanced, officials state...",随后政策加码，官方表示，“大学生可以低于市场价20%的价格买房”；
195,"In July, Chengdu issue regulations stating tha...",7月，成都出台规定，外地大学本科及以上学历的青年人才，凭毕业证即可申请落户成都。
196,"In addition, Kunshan, Lingang in Shanghai, as ...",除此之外，还有昆山、上海临港等城市或区域也出台了 “人才引入”相关政策，对人才购房落户降低门槛。
197,"Zhang Hongwei believed that this was, in fact,...",张宏伟认为，这实际上是各大城市对于人口尤其是人才的争夺，留住人口尤其是人才，城市才具有竞争力...


In [None]:
df_test.to_csv("wmt_test_sub.csv")
ds_test = load_dataset('csv', data_files="wmt_test_sub.csv", split='train')

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [None]:
ds_test = ds_test.remove_columns("Unnamed: 0")

In [None]:
def translation(batch):
    # cut off at BERT max length 512
    inputs = tokenizer(batch["zh"],  padding="max_length", truncation="longest_first", max_length=encoder_max_length, return_tensors="pt")
    input_ids = inputs.input_ids.to("cuda")
    attention_mask = inputs.attention_mask.to("cuda")

    model.eval()
    outputs = model.generate(input_ids, attention_mask=attention_mask, max_length=decoder_max_length, decoder_start_token_id=tokenizer.cls_token_id)#max_new_tokens=128
    output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)

    batch["pred_en"] = output_str

    return batch

In [None]:
batch_size = 8
outcomes = ds_test.map(translation, batched=True, batch_size=batch_size, remove_columns=["en"])

Map:   0%|          | 0/199 [00:00<?, ? examples/s]

In [None]:
outcomes

Dataset({
    features: ['zh', 'pred_en'],
    num_rows: 199
})

In [None]:
out_pred_en_texts =[]
out_zh_texts = []

for item in outcomes:
    out_zh_texts.append(item['zh'])
    out_pred_en_texts.append(item['pred_en'])


outcomes = pd.DataFrame({'zh':out_zh_texts,'pred_en':out_pred_en_texts})

In [None]:
outcomes

Unnamed: 0,zh,pred_en
0,上周，古装剧《美人私房菜》临时停播，意外引发了关于国产剧收视率造假的热烈讨论。,"...,,, - - - / / /.. ) ) ) ） ） ） them them the..."
1,民权团体针对密苏里州发出旅行警告,"...,,, ( ( ( （ （ （ ( ( although although altho..."
2,由于密苏里州的歧视性政策和种族主义袭击，美国有色人种促进协会 (NAACP) 向准备前往密苏...,"...,,, ( ( ( （ （ （ ( ( although although altho..."
3,“2017 年 8 月 28 日生效的 NAACP 密苏里州旅行咨询中呼吁，因近期密苏里州发...,"...,,, ( ( ( （ （ （ ( ( although although altho..."
4,NAACP 指出，最近通过的一项密苏里州法律使得人们更难赢得歧视诉讼，该州执法也一定程度上针...,"...,,, ( ( ( （ （ （ ( ( although although altho..."
...,...,...
194,随后政策加码，官方表示，“大学生可以低于市场价20%的价格买房”；,"...,,, - - - / / /.. ) ) ) ） ） ） them them the..."
195,7月，成都出台规定，外地大学本科及以上学历的青年人才，凭毕业证即可申请落户成都。,"...,,, ( ( ( （ （ （ ( ( although although altho..."
196,除此之外，还有昆山、上海临港等城市或区域也出台了 “人才引入”相关政策，对人才购房落户降低门槛。,"...,,, ( ( ( （ （ （ ( ( although although altho..."
197,张宏伟认为，这实际上是各大城市对于人口尤其是人才的争夺，留住人口尤其是人才，城市才具有竞争力...,"...,,, ( ( ( （ （ （ ( ( although although altho..."


In [None]:
from nltk.translate.bleu_score import corpus_bleu
from datasets import load_metric

#load bleu metric
bleu_metric = load_metric("bleu")

references = outcomes['zh']
generated_translations = outcomes['pred_en']

references = [[ref] for ref in references]
generated_translations = [[gen] for gen in generated_translations]
print(references)
print(generated_translations)
#bleu_scores = bleu_metric.compute(predictions=generated_translations, references=references)["bleu"]
bleu_scores = corpus_bleu(references, generated_translations)

#average_bleu = sum(bleu_scores) / len(bleu_scores)
print("Corpus BLEU score:", bleu_scores)

[['上周，古装剧《美人私房菜》临时停播，意外引发了关于国产剧收视率造假的热烈讨论。'], ['民权团体针对密苏里州发出旅行警告'], ['由于密苏里州的歧视性政策和种族主义袭击，美国有色人种促进协会 (NAACP) 向准备前往密苏里州出游的有色人群发出旅行警告。'], ['“2017 年 8 月 28 日生效的 NAACP 密苏里州旅行咨询中呼吁，因近期密苏里州发生了一系列可疑的种族性事件，所有非裔美籍旅行者、游客以及密苏里州人在密苏里州旅行时应特别注意并采取极其谨慎的态度，特此告知，”该团体的声明宣称。'], ['NAACP 指出，最近通过的一项密苏里州法律使得人们更难赢得歧视诉讼，该州执法也一定程度上针对少数群体，这些现象促使该组织发布了旅行警告。'], ['侵犯公民权利的行为正发生在人们身上。'], ['他们因肤色被停车盘问，被殴打或被杀害，”密苏里州 NAACP 主席罗德·查培尔告诉堪萨斯城星报 (The Kansas City Star)。'], ['“我们收到了许多投诉，数量前所未有。”'], ['这是该组织在美国针对某个州发布的第一个此类警告。'], ['该组织援引了密苏里大学对黑人学生的种族诽谤以及 28 岁田纳西州黑人男性托利·桑德斯的死亡事件。'], ['今年早些时候，桑德斯在可疑情况下死亡。他在密苏里州旅行时燃油耗尽，被密苏里州警方在无指控犯罪的情况下拘留。'], ['咨询中还指出，密苏里州总检察长办公室最近的一份报告显示，“与白人相比，该州的黑人司机被停车盘查的可能性要高出 75%”。'], ['查培尔说：“该份咨询是为了提高人们的意识，警告他们的家人、朋友和同事在密苏里州可能发生的情况。”'], ['“人们需要做好准备，无论是携带保释金前往密苏里州，还是让亲属知道自己在州内旅行”。'], ['根据联邦调查局仇恨犯罪报告计划的最新数据，密苏里州在 2015 年记录了 100 起仇恨罪行；根据罪行量，该州在全国排名第 16 位。'], ['旅行警告也是对密苏里州新法律的回应，该法律将使起诉住房或就业歧视企业变得更加困难。'], ['此前，美国各州通过了移民执法法律，要求当地执法部门拘留移民违规人员，美国公民自由联盟 (ACLU) 表示此举会增加种族诉讼数量，并发布了针对德克萨斯州和亚利桑那州的旅行咨询。'], ['旅行警告通常是由美国政府针对

In [None]:
# source_lang = "zh"
# target_lang = "en"

# prefix = "translate Chinese to English:"

# def preprocess_function(examples):
#     inputs = [prefix + example[source_lang] for example in examples['zh']]
#     targets = [example[target_lang] for example in examples['en']]

#     model_inputs = tokenizer(inputs, text_target=targets, max_length=128, truncation=True)
#     return model_inputs

In [None]:
# import pandas as pd

# train_en_texts = dt_rain['en']
# train_zh_texts = ds_train['zh']

# val_en_texts = val_dataset['translation']['en']
# val_zh_texts = val_dataset['translation']['zh']
