In [106]:
import datasets

ds = datasets.load_dataset("zouharvi/bio-mqm-dataset", split='test')
ds = ds.filter(lambda example: example['lang_src'] == 'en' and example['lang_tgt'] == 'zh').select_columns(['src', 'tgt']).rename_columns({'src': 'english', 'tgt': 'chinese'})
ds_valid = datasets.load_dataset("zouharvi/bio-mqm-dataset", split='validation')
ds_valid = ds_valid.filter(lambda example: example['lang_src'] == 'en' and example['lang_tgt'] == 'zh').select_columns(['src', 'tgt']).rename_columns({'src': 'english', 'tgt': 'chinese'})

In [107]:
from torch.utils.data import Dataset, random_split
import json
 
max_dataset_size = 6000
train_set_size = 4000
valid_set_size = 2000
 
class TRANS(Dataset):
    def __init__(self, data_file:str):
        self.data = self.load_data(data_file)
    def __init__(self, data: datasets.Dataset):
        self.data = {}
        for idx in range(len(data)):
            if idx >= max_dataset_size:
                break
            sample = data[idx]
            self.data[idx] = {
                'english': sample['english'],
                'chinese': sample['chinese']
            }
    
    def load_data(self, data_file):
        Data = {}
        with open(data_file, 'rt', encoding='utf-8') as f:
            for idx, line in enumerate(f):
                if idx >= max_dataset_size:
                    break
                sample = json.loads(line.strip())
                Data[idx] = sample
        return Data
    
    def __len__(self):
        return len(self.data)
 
    def __getitem__(self, idx):
        return self.data[idx]

data = TRANS(ds)
test_data = TRANS(ds_valid)
train_data, valid_data = random_split(data, [train_set_size, valid_set_size])

# data = TRANS('data/translation2019zh_train.json')
# train_data, valid_data = random_split(data, [train_set_size, valid_set_size])
# test_data = TRANS('data/translation2019zh_valid.json')
print("训练集大小:", len(train_data))
print("验证集大小:", len(valid_data))
print("测试集大小:", len(test_data))

训练集大小: 4000
验证集大小: 2000
测试集大小: 1819


In [108]:
import torch
from transformers import AutoTokenizer
 
#这是hugging face里的模型，需要科学上网
model_checkpoint = "Helsinki-NLP/opus-mt-en-zh"
 
#这是modelscope里的，国内可以直接访问
# model_checkpoint ="moxying/opus-mt-zh-en"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
 
max_input_length = 128
max_target_length = 128
 
#每次给模型输入4批数据
inputs = [train_data[s_idx]["english"] for s_idx in range(4)]
targets = [train_data[s_idx]["chinese"] for s_idx in range(4)]
 
model_inputs = tokenizer(
    inputs, 
    padding=True, 
    max_length=max_input_length, 
    truncation=True,
    return_tensors="pt"
)
  #默认情况下分词器会采用源语言的设定来编码文本，要编码目标语言则需要通过上下文管理器
#  as_target_tokenizer(),否则中文分词器可能无法识别大部分的英文单词
with tokenizer.as_target_tokenizer(): 
    labels = tokenizer(
        targets, 
        padding=True, 
        max_length=max_target_length, 
        truncation=True,
        return_tensors="pt"
    )["input_ids"]
 
end_token_index = torch.where(labels == tokenizer.eos_token_id)[1]
for idx, end_idx in enumerate(end_token_index):
    labels[idx][end_idx+1:] = -50



In [109]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')
 
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
model = model.to(device)
 
def collote_fn(batch_samples):
    batch_inputs, batch_targets = [], []
    for sample in batch_samples:
        batch_inputs.append(sample['english'])
        batch_targets.append(sample['chinese'])
    batch_data = tokenizer(
        batch_inputs, 
        padding=True, 
        max_length=max_input_length,
        truncation=True, 
        return_tensors="pt"
    )
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            batch_targets, 
            padding=True, 
            max_length=max_target_length,
            truncation=True, 
            return_tensors="pt"
        )["input_ids"]
        batch_data['decoder_input_ids'] = model.prepare_decoder_input_ids_from_labels(labels)
        end_token_index = torch.where(labels == tokenizer.eos_token_id)[1]
        for idx, end_idx in enumerate(end_token_index):
            labels[idx][end_idx+1:] = -100
        batch_data['labels'] = labels
    return batch_data
 
train_dataloader = DataLoader(train_data, batch_size=32, shuffle=True, collate_fn=collote_fn)
valid_dataloader = DataLoader(valid_data, batch_size=32, shuffle=False, collate_fn=collote_fn)

Using cuda device


In [110]:
from tqdm.auto import tqdm
 
def train_loop(dataloader, model, optimizer, lr_scheduler, epoch, total_loss):
    progress_bar = tqdm(range(len(dataloader)))
    progress_bar.set_description(f'loss: {0:>7f}')
    finish_batch_num = (epoch-1) * len(dataloader)
    
    model.train()
    for batch, batch_data in enumerate(dataloader, start=1):
        batch_data = batch_data.to(device)
        outputs = model(**batch_data)
        loss = outputs.loss
 
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
 
        total_loss += loss.item()
        progress_bar.set_description(f'loss: {total_loss/(finish_batch_num + batch):>7f}')
        progress_bar.update(1)
    return total_loss

In [111]:
from sacrebleu.metrics import BLEU
import numpy as np

bleu = BLEU()
 
def test_loop(dataloader, model):
    preds, labels = [], []
    
    model.eval()
    for batch_data in tqdm(dataloader):
        batch_data = batch_data.to(device)
        with torch.no_grad():
            generated_tokens = model.generate(
                batch_data["input_ids"],
                attention_mask=batch_data["attention_mask"],
                max_length=max_target_length,
            ).cpu().numpy()
        label_tokens = batch_data["labels"].cpu().numpy()
        
        decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
        label_tokens = np.where(label_tokens != -100, label_tokens, tokenizer.pad_token_id)
        decoded_labels = tokenizer.batch_decode(label_tokens, skip_special_tokens=True)
 
        preds += [pred.strip() for pred in decoded_preds]
        labels += [[label.strip()] for label in decoded_labels]
    bleu_score = bleu.corpus_score(preds, labels).score
    print(f"BLEU: {bleu_score:>0.2f}\n")
    return bleu_score

# test_loop(valid_dataloader, model)

In [None]:
from transformers import get_scheduler
from torch.optim import AdamW

learning_rate = 2e-5
epoch_num = 3
 
optimizer = AdamW(model.parameters(), lr=learning_rate)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=epoch_num*len(train_dataloader),
)

total_loss = 0.
best_bleu = 0.
for t in range(epoch_num):
    print(f"Epoch {t+1}/{epoch_num}\n-------------------------------")
    total_loss = train_loop(train_dataloader, model, optimizer, lr_scheduler, t+1, total_loss)
    valid_bleu = test_loop(valid_dataloader, model)
    if valid_bleu > best_bleu:
        best_bleu = valid_bleu
        print('saving new weights...\n')
        torch.save(model.state_dict(), f'ord_model/epoch_{t+1}_valid_bleu_{valid_bleu:0.2f}_model_weights.bin')
print("Done!")

Epoch 1/3
-------------------------------


  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

BLEU: 11.99

saving new weights...

Epoch 2/3
-------------------------------


  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

BLEU: 11.99

Epoch 3/3
-------------------------------


  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

BLEU: 11.99

Done!


In [113]:
import re

# model.load_state_dict(torch.load('epoch_1_valid_bleu_82.97_model_weights.bin'))

def clean_translation(text):
    text = re.sub(r"\s+", " ", text)  # 去除多余的空格
    text = re.sub(r"\s([?.!\"'])", r"\1", text)  # 去除句末的多余空格
    return text

def translate(text, tokenizer, model):
    inputs = tokenizer.encode(text, return_tensors="pt", truncation=True)
    inputs = inputs.to(device)
    outputs = model.generate(inputs, max_length=5000, num_beams=8, early_stopping=True)
    translated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return clean_translation(translated_text)

# print(translate("However, these drugs are administered via intravitreal injections that are associated with sight-threatening complications.", tokenizer, model))
print(test_data[1])
print(translate(test_data[1]['english'], tokenizer, model))

{'english': 'The most feared of these complications is endophthalmitis, a severe infection of the eye with extremely poor visual outcomes.', 'chinese': '这些并发症中最令人担心的是眼内炎，这是一种严重的眼睛感染，视力极差。'}
最担心的这些并发症是内眼炎,眼部严重感染,视觉效果极差。
