In [None]:
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from torch import nn
import transformers
from datasets import load_dataset, Dataset, DatasetDict, concatenate_datasets
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, DataCollatorForSeq2Seq
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
from transformers import AdamW, get_scheduler
from transformers import pipeline
from transformers.trainer_pt_utils import get_parameter_names
from tqdm.notebook import tqdm_notebook
from scipy.stats import spearmanr, kendalltau
from sklearn.metrics import mean_squared_error

from peft import (
    PeftConfig,
    PeftModel,
    LoraConfig,
    get_peft_model,
    get_peft_model_state_dict,
    prepare_model_for_kbit_training
)

from functools import partial
from collections import defaultdict
import re
import gc

In [None]:
MODEL_NAME = 'bigscience/mt0-xxl-mt'
HF_TOKEN = ...

In [None]:
LORA_R = 8
LORA_ALPHA = 32
LORA_DROPOUT= 0.05
LORA_TARGET_MODULES = [
    'q',
    'v'
]

MAX_LENGTH = 256

BATCH_SIZE = 32
MICRO_BATCH_SIZE = 16
GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE
LEARNING_RATE = 3e-4
OUTPUT_DIR = 'mt0_finetune'

In [None]:
wmt_train = load_dataset('RicardoRei/wmt-da-human-evaluation', split='train')
wmt_test = load_dataset("RicardoRei/wmt-mqm-human-evaluation", split="train")

In [None]:
translations = ['en-de', 'en-ru', 'zh-en']
id2langs = {
    'ru': 'russian',
    'en': 'english',
    'de': 'german',
    'zh': 'chinese'
}

In [None]:
wmt_test = wmt_test.filter(lambda example: (example['year'] == 2022) & (example['lp'] in translations))
wmt_test = wmt_test.rename_column('score', 'raw')

In [None]:
seed = 1
train_take_part = 0.25
eval_take_part = 0.01

train_dataset = []
eval_dataset = []

for translation in translations:
    train_part = wmt_train.filter(lambda example: (example['year'] != 2022) & (example['lp'] == translation))
    train_part = train_part.shuffle(seed=seed)
    train_dataset.append(Dataset.from_dict(train_part[:int(len(train_part)*train_take_part)]))
    eval_dataset.append(Dataset.from_dict(train_part[int(len(train_part)*train_take_part):int(len(train_part)*(train_take_part+eval_take_part))]))
train_dataset = concatenate_datasets(train_dataset)
eval_dataset = concatenate_datasets(eval_dataset)

In [None]:
def preprocess(text):
    if text is None:
        return ''
    return ' '.join(text.lower().strip().split())

def get_gemba_da_prompt(langs, source_seg, reference_seg, target_seg):
    src_lang_id, tgt_lang_id = langs.split('-')
    source_lang, target_lang = id2langs[src_lang_id], id2langs[tgt_lang_id]
    source_seg, reference_seg, target_seg = map(preprocess, [source_seg, reference_seg, target_seg])
    return f'Score the following translation from {source_lang} to {target_lang} respect to the human reference on a continuous scale from 0 to 100, ' \
            'where a score of zero means "no meaning preserved" and score of one hundred means "perfect meaning and grammar".\n' \
            f'{source_lang} source: "{source_seg}"\n' \
            f'{target_lang} human reference: "{reference_seg}"\n' \
            f'{target_lang} translation: "{target_seg}"\n' \
            'Score:'

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=HF_TOKEN)
tokenizer.pad_token = 0

In [None]:
def tokenize_(data, tokenizer, prompt_func, input_fields, target_field=None, max_length=MAX_LENGTH):
    prompt = prompt_func(*[data[field] for field in input_fields])
    result = tokenizer(prompt, truncation=True, max_length=max_length, padding=False)
    if max_length is not None and result['input_ids'][-1] != tokenizer.eos_token_id \
        and len(result['input_ids']) < max_length:
        result['input_ids'].append(tokenizer.eos_token_id)
        result['attention_mask'].append(1)
    if target_field is not None:
        target_rokens = tokenizer(str(data[target_field]), truncation=True, max_length=max_length, padding=False)
        result['labels'] = target_rokens['input_ids']
    return result

In [None]:
train_dataset = train_dataset.map(partial(tokenize_,
                                          tokenizer=tokenizer,
                                          prompt_func=get_gemba_da_prompt,
                                          input_fields=['lp', 'src', 'ref', 'mt'],
                                          target_field='raw'), num_proc=4)
eval_dataset = eval_dataset.map(partial(tokenize_,
                                        tokenizer=tokenizer,
                                        prompt_func=get_gemba_da_prompt,
                                        input_fields=['lp', 'src', 'ref', 'mt'],
                                        target_field='raw'), num_proc=4)

In [None]:
finetune_datasets = DatasetDict({'train': train_dataset, 'eval': eval_dataset})
finetune_datasets

In [None]:
for key in finetune_datasets:
    finetune_datasets[key] = finetune_datasets[key].remove_columns(
        set(finetune_datasets[key].column_names)-{'input_ids', 'attention_mask', 'labels'}
    )
finetune_datasets.set_format('torch')
finetune_datasets

In [None]:
data_collator = transformers.DataCollatorForSeq2Seq(
    tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
)

In [None]:
train_dataloader = DataLoader(
    finetune_datasets['train'], shuffle=True, batch_size=8, collate_fn=data_collator
)
eval_dataloader = DataLoader(
    finetune_datasets['eval'], batch_size=8, collate_fn=data_collator
)

In [None]:
for batch in train_dataloader:
    break
{k: v.shape for k, v in batch.items()}

In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, device_map='auto', torch_dtype=torch.float16,
                                              load_in_8bit=True)

In [None]:
model = prepare_model_for_kbit_training(model)
config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    target_modules=LORA_TARGET_MODULES,
    lora_dropout=LORA_DROPOUT,
    bias="none",
    task_type="SEQ_2_SEQ",
)
model = get_peft_model(model, config)
model.print_trainable_parameters()

In [None]:
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=1,
    per_device_train_batch_size=MICRO_BATCH_SIZE,
    per_device_eval_batch_size=MICRO_BATCH_SIZE,
    logging_steps=20,
    warmup_steps=100,
    save_strategy='steps',
    evaluation_strategy='steps',
    weight_decay=1e-6,
    eval_steps=200,
    save_steps=200,
    save_total_limit=3,
    dataloader_num_workers=4,
    gradient_checkpointing=True,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    load_best_model_at_end=True,
    report_to=None,
    optim='adamw_torch'
)

trainer = transformers.Trainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=finetune_datasets['train'],
    eval_dataset=finetune_datasets['eval'],
    data_collator=data_collator
)

model.config.use_cache = False

old_state_dict = model.state_dict
model.state_dict = (
    lambda self, *_, **__: get_peft_model_state_dict(
        self, old_state_dict()
    )
).__get__(model, type(model))

model = torch.compile(model)

In [None]:
trainer.train()

In [None]:
model.save_pretrained('./mt0-xxl-mt_wmt_finetune_lora_qvo_final')

In [None]:
test_dataset = wmt_test.map(partial(tokenize_,
                                    tokenizer=tokenizer,
                                    prompt_func=get_gemba_da_prompt,
                                    max_length=None,
                                    input_fields=['lp', 'src', 'ref', 'mt'],
                                    target_field='raw'),
                            num_proc=4)

In [None]:
peft_model_id = './mt0-xxl-mt_wmt_finetune_lora_qvo_final'
config = PeftConfig.from_pretrained(peft_model_id)
inference_model = AutoModelForSeq2SeqLM.from_pretrained(
    config.base_model_name_or_path, device_map='auto', torch_dtype=torch.float16,
    use_auth_token=HF_TOKEN, load_in_8bit=True
)
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path, use_auth_token=HF_TOKEN)
model = PeftModel.from_pretrained(inference_model, peft_model_id)

In [None]:
def get_score(text):
    score = re.search('-?\d+(\.\d+)?', text)
    if not score:
        return -1.0
    score = score.group(0).split(' ', 1)[-1]
    if score.replace('.', '').isnumeric:
        return float(score)
    return -1.0

In [None]:
model.eval()

preds = defaultdict(list)
labels = defaultdict(list)
for translation in translations:
    print(translation)
    lang_dataset = test_dataset.filter(lambda x: x['lp'] == translation)
    for k, v in zip(lang_dataset['lp'], lang_dataset['raw']):
        labels[k].append(v)
    lang_dataset = lang_dataset.remove_columns(set(lang_dataset.column_names)-{'input_ids', 'attention_mask', 'labels'})
    lang_dataset.set_format('torch')
    lang_dataloader = DataLoader(
        lang_dataset, shuffle=False, batch_size=32, collate_fn=data_collator
    )
    pb = tqdm_notebook(lang_dataloader)
    try:
        with torch.no_grad():
            for batch in pb:
                labels_ = batch.pop('labels')
                labels_ = np.where(labels_ != -100, labels_, tokenizer.pad_token_id)

                try:
                    decoded_preds = tokenizer.batch_decode(model.generate(input_ids=batch['input_ids'].to('cuda'), do_sample=True, top_k=5, max_new_tokens=10), skip_special_tokens=True)
                except RuntimeError:
                    decoded_preds = ['-1.0'] * len(batch)
                decoded_labels = tokenizer.batch_decode(labels_, skip_special_tokens=True)

                for pred, label in zip(decoded_preds, decoded_labels):
                    label = get_score(label)
                    if label < 0:
                        continue
                    pred = get_score(pred)
                    preds[translation].append(pred)
                    labels[translation].append(label)

                tau = kendalltau(preds[translation], labels[translation][:len(preds[translation])])
                pb.set_description(f"'tau': {tau.statistic}")
                torch.cuda.empty_cache()
    except KeyboardInterrupt:
        continue

---