In [None]:
import torch
from datasets import load_dataset, load_metric
from transformers.optimization import get_scheduler
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, MarianModel, AdamW

### 加载编码器

In [None]:
tokenizer = AutoTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-ro',use_fast=True)

print(tokenizer)

# 编码试算
tokenizer.batch_encode_plus([['Hello,this is one sentence','This is another sentence']])

### 数据集处理

In [None]:
dataset = load_dataset(path='wmt16', name='ro-en')

# 采样, 数据量太大了跑不动
dataset['train'] = dataset['train'].shuffle(1).select(range(20000))
dataset['validation'] = dataset['validation'].shuffle(1).select(range(200))
dataset['test'] = dataset['test'].shuffle(1).select(range(200))

In [None]:
# 数据预处理
def preprocess_function(data, tokenizer):
    # 去除数据中的en和ro
    en = [ex['en'] for ex in data['translation']]
    ro = [ex['ro'] for ex in data['translation']]

    # 源语言直接编码就行了
    data = tokenizer.batch_encode_plus(en, max_length=128, truncation=True)

    # 目标语言在特殊模块中编码
    with tokenizer.as_target_tokenizer():
        data['labels'] = tokenizer.batch_encode_plus(ro, max_length=128, truncation=True)['input_ids']

    return data

dataset = dataset.map(function=preprocess_function,
                      batched=True,
                      batch_size=1000,
                      num_proc=4,
                      remove_columns=['translation'],
                      fn_kwargs={'tokenizer': tokenizer})

print(dataset['train'][0])
dataset

### 数据整理函数

In [None]:
def collate_fn(data):
    # 求最长的label
    max_length = max([len(i['labels']) for i in data])

    # 把所有的label都补pad到最长(便于做成batch)
    for i in data:
        pads = [-100] * (max_length - len(i['labels']))
        i['labels'] = i['labels'] + pads

    # 把多个数据整合成一个tensor
    data = tokenizer.pad(
        encoded_inputs=data, 
        padding=True, 
        max_length=None,
        pad_to_multiple_of=None,
        return_tensors='pt',
    )

    # 定义decoder_input_ids
    data['decoder_input_ids'] = torch.full_like(data['labels'], tokenizer.get_vocab()['<pad>'], dtype=torch.long)

    data['decoder_input_ids'][:, 1:] = data['labels'][:, :-1]
    data['decoder_input_ids'][data['decoder_input_ids'] == -100] = tokenizer.get_vocab()['<pad>']

    return data

### 数据集加载器

In [None]:
loader = torch.utils.data.DataLoader(
    dataset=dataset['train'],
    batch_size=8,
    collate_fn=collate_fn,
    shuffle=True,
    drop_last=True
)

for i, data in enumerate(loader):
    break

for k, v in data.items():
    print(k, v.shape, v[:2])

len(loader)

### 定义模型

In [None]:
class Model(torch.nn.Module): 
    def __init__(self):
        super().__init__()
        self.pretrained = MarianModel.from_pretrained('Helsinki-NLP/opus-mt-en-ro')
        self.register_buffer('final_logits_bias', torch.zeros(1, tokenizer.vocab_size)) # 记录在模型的state_dict中, 且不会对其求梯度, 与nn.Parameter相反

        self.fc = torch.nn.Linear(512, tokenizer.vocab_size, bias=False)

        # 加载预训练模型的参数
        parameters = AutoModelForSeq2SeqLM.from_pretrained('Helsinki-NLP/opus-mt-en-ro')
        self.fc.load_state_dict(parameters.lm_head.state_dict())

        self.criterion = torch.nn.CrossEntropyLoss()

    def forward(self, input_ids, attention_mask, labels, decoder_input_ids):
        logits = self.pretrained(input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids)

        logits = logits.last_hidden_state

        logits = self.fc(logits) + self.final_logits_bias

        loss = self.criterion(logits.flatten(end_dim=1), labels.flatten())

        return {'loss': loss, 'logits': logits}

### 模型试算

In [None]:
model = Model()

# 统计参数量
print(sum(i.numel() for i in model.parameters()) / 10000)

out = model(**data)

out['loss'], out['logits'].shape

### 加载评价指标

In [None]:
metric = load_metric(path = 'sacrebleu')

# 试算
metric.compute(predictions=['hello there', 'general kenobi'], references=[['hello there'], ['general kenobi']])

### 训练

In [None]:
def train():
    global model
    device = 'cpu' if torch.cuda.is_available() else 'cpu'
    model = model.to(device)

    optimizer = AdamW(model.parameters(), lr = 2e-5)
    scheduler = get_scheduler(name='linear',
                              num_warmup_steps=0,
                              num_training_steps=len(loader),
                              optimizer=optimizer)
    
    model.train()
    for i, data in enumerate(loader):
        for k in data.keys():
            data[k] = data[k].to(device)

        out = model(**data)
        loss = out['loss']

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # 解决梯度爆炸

        optimizer.step()
        scheduler.step()

        optimizer.zero_grad()
        model.zero_grad()

        if(i % 50 == 0):
            labels = data['labels'][:, 1:]
            out = out['logits'].argmax(dim=2)[:, :-1]

            accuracy = (out == labels).sum().item() / labels.numel()

            lr = optimizer.state_dict()['param_groups'][0]['lr']

            print(i, loss.item(), lr, accuracy)

    model = model.to('cpu')

train()
torch.save(model, './models/en_gen.model')

### 测试函数

In [None]:
def test(metric):
    model.eval()

    # 数据加载器
    loader_test = torch.utils.data.DataLoader(
        dataset=dataset['test'], 
        batch_size=8,
        collate_fn=collate_fn,
        shuffle=True,
        drop_last=True,
    )

    predictions = []
    references = []
    for i, data in enumerate(loader_test):
        # 计算
        with torch.no_grad():
            out = model(**data)
        
        pred = tokenizer.batch_decode(out['logits'].argmax(dim=2))
        label = tokenizer.batch_decode(data['decoder_input_ids'])

        predictions.extend(pred)
        references.extend(label)

        if(i % 2 == 0):
            print(i)
            input_ids = tokenizer.decode(data['input_ids'][0])

            print('input_ids = ', input_ids)
            print('pred = ', pred[0])
            print('label = ', label[0])

        if(i == 10):
            break

    references = [[j] for j in references]
    metrics_out = metric.compute(predictions=predictions, references=references)
    print(metrics_out)

test(metric)