In [None]:
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from torch import nn
import transformers
from datasets import load_dataset, Dataset, DatasetDict, concatenate_datasets
from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorForSeq2Seq
from transformers import Trainer, TrainingArguments, Seq2SeqTrainer, Seq2SeqTrainingArguments
from transformers import AdamW, get_scheduler
from transformers.trainer_pt_utils import get_parameter_names
from tqdm.notebook import tqdm_notebook
from scipy.stats import spearmanr, kendalltau
from sklearn.metrics import mean_squared_error

from peft import (
    LoraConfig,
    get_peft_model,
    get_peft_model_state_dict,
    prepare_model_for_kbit_training,
)

from functools import partial
import gc

In [None]:
LORA_R = 32
LORA_ALPHA = 8
LORA_DROPOUT= 0.05
LORA_TARGET_MODULES = [
    "q_proj",
    "v_proj",
]

MAX_LENGTH = 128

BATCH_SIZE = 64
MICRO_BATCH_SIZE = 16
GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE
LEARNING_RATE = 3e-4
OUTPUT_DIR = "llama2_7b_finetune"

In [None]:
wmt_base = load_dataset('RicardoRei/wmt-da-human-evaluation', split='train')

In [None]:
translations = ['en-de', 'en-ru', 'zh-en']
id2langs = {
    'ru': 'russian',
    'en': 'english',
    'de': 'german',
    'zh': 'chinese'
}

In [None]:
test_dataset = wmt_base.filter(lambda example: (example['year'] == 2022) & (example['lp'] in translations))

In [None]:
take_part = 0.5
train_dataset = []
eval_dataset = []
for translation in translations:
    train_part = wmt_base.filter(lambda example: (example['year'] != 2022) & (example['lp'] == translation))
    train_part = train_part.shuffle(seed=42)
    train_dataset.append(Dataset.from_dict(train_part[:int(len(train_part)*take_part)]))

    eval_part = wmt_base.filter(lambda example: (example['year'] == 2022) & (example['lp'] == translation))
    eval_part = eval_part.shuffle(seed=42)
    eval_dataset.append(Dataset.from_dict(eval_part[:int(len(eval_part)*take_part)]))
train_dataset = concatenate_datasets(train_dataset)
eval_dataset = concatenate_datasets(eval_dataset)

In [None]:
def get_gemba_da_prompt(langs, source_seg, reference_seg, target_seg, score=None):
    source_lang, target_lang = langs.split('-')
    return ''.join([f'Score the following translation from {source_lang} to {target_lang} respect to the human reference on a continuous scale from 0 to 100, ',
            'where a score of zero means "no meaning preserved" and score of one hundred means "perfect meaning and grammar".\n',
            f'{source_lang} source: "{source_seg}"\n',
            f'{target_lang} human reference: "{reference_seg}"\n',
            f'{target_lang} translation: "{target_seg}"\n',
            'Score:' if score is None else f'Score: {score}'])

In [None]:
wmt_datasets = DatasetDict({'train': train_dataset, 'eval': eval_dataset})
wmt_datasets

In [None]:
llama2_tokenizer = AutoTokenizer.from_pretrained('./llama-2-7b-hf/')
llama2_tokenizer.pad_token = 0

In [None]:
def tokenize_translation(data, tokenizer, prompt_func, fields, max_length=MAX_LENGTH):
    prompt = prompt_func(*[data[field] for field in fields])
    result = tokenizer(prompt, truncation=True, max_length=max_length, padding=False)
    if result['input_ids'][-1] != tokenizer.eos_token_id \
        and len(result['input_ids']) < max_length:
        result['input_ids'].append(tokenizer.eos_token_id)
        result['attention_mask'].append(1)
    result['labels'] = result['input_ids'].copy()
    return result

In [None]:
llama2_wmt_finetune_datasets = wmt_datasets.map(partial(tokenize_translation,
                                                        tokenizer=llama2_tokenizer,
                                                        prompt_func=get_gemba_da_prompt,
                                                        fields=['lp', 'src', 'ref', 'mt', 'raw']),
                                                num_proc=4)

In [None]:
llama2_wmt_finetune_datasets = llama2_wmt_finetune_datasets.remove_columns(set(llama2_wmt_finetune_datasets['train'].column_names)-{'input_ids', 'attention_mask', 'labels'})
llama2_wmt_finetune_datasets.set_format('torch')
llama2_wmt_finetune_datasets['train'].column_names

In [None]:
data_collator = transformers.DataCollatorForSeq2Seq(
    llama2_tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
)

In [None]:
train_dataloader = DataLoader(
    llama2_wmt_finetune_datasets['train'], shuffle=True, batch_size=8, collate_fn=data_collator
)
eval_dataloader = DataLoader(
    llama2_wmt_finetune_datasets['eval'], batch_size=8, collate_fn=data_collator
)

In [None]:
for batch in train_dataloader:
    break
{k: v.shape for k, v in batch.items()}

In [None]:
model = AutoModelForCausalLM.from_pretrained('./llama-2-7b-hf/', device_map='auto',
                                             torch_dtype=torch.float16, load_in_8bit=True)

In [None]:
model = prepare_model_for_kbit_training(model)
config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    target_modules=LORA_TARGET_MODULES,
    lora_dropout=LORA_DROPOUT,
    bias="none",
    task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
model.print_trainable_parameters()

In [None]:
def compute_metrics(eval_data, tokenizer):
    preds, labels = eval_data
    print
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    pred_scores = []
    label_scores = []
    for pred, label in zip(decoded_preds, decoded_labels):
        label = label.rsplit('Score:', 1)[-1].strip()
        if not label.replace('.', '').isnumeric():
            continue
        pred = pred.rsplit('Score:', 1)[-1].strip()
        pred = float(pred) if pred.replace('.', '').isnumeric() else -1.0
        pred_scores.append(float(pred))
        label_scores.append(float(label))

    r = spearmanr(pred_scores, label_scores)
    tau = kendalltau(pred_scores, label_scores)
    return {'R': r.statistic, 'tau': tau.statistic}

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir="./llama2_7b_finetune",
    num_train_epochs=1,
    per_device_train_batch_size=MICRO_BATCH_SIZE,
    per_device_eval_batch_size=MICRO_BATCH_SIZE//2,
    logging_steps=5,
    warmup_steps=100,
    save_strategy='steps',
    evaluation_strategy='steps',
    weight_decay=1e-6,
    eval_steps=5,
    save_steps=5,
    save_total_limit=3,
    dataloader_num_workers=4,
    gradient_checkpointing=True,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    load_best_model_at_end=True,
    report_to=None,
    optim='adamw_torch',
    predict_with_generate=True,
    generation_max_length=MAX_LENGTH+1
)

trainer = transformers.Seq2SeqTrainer(
    model=model,
    tokenizer=llama2_tokenizer,
    args=training_args,
    train_dataset=llama2_wmt_finetune_datasets['train'],
    eval_dataset=llama2_wmt_finetune_datasets['eval'],
    data_collator=data_collator,
    compute_metrics=partial(compute_metrics, tokenizer=llama2_tokenizer)
)

model.config.use_cache = False
old_state_dict = model.state_dict
model.state_dict = (
    lambda self, *_, **__: get_peft_model_state_dict(
        self, old_state_dict()
    )
).__get__(model, type(model))
 
model = torch.compile(model)

In [None]:
trainer.train()