In [1]:
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from torch import nn
import transformers
from datasets import load_dataset, Dataset, DatasetDict, concatenate_datasets
from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorForSeq2Seq
from transformers import Trainer, TrainingArguments, Seq2SeqTrainer, Seq2SeqTrainingArguments
from transformers import AdamW, get_scheduler
from transformers import pipeline
from transformers.trainer_pt_utils import get_parameter_names
from tqdm.notebook import tqdm_notebook
from scipy.stats import spearmanr, kendalltau
from sklearn.metrics import mean_squared_error

from peft import (
    LoraConfig,
    get_peft_model,
    get_peft_model_state_dict,
    prepare_model_for_kbit_training,
)

from functools import partial
import re
import gc

[2023-07-24 08:34:01,063] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [2]:
LORA_R = 16
LORA_ALPHA = 32
LORA_DROPOUT= 0.05
LORA_TARGET_MODULES = [
    'q_proj',
    'k_proj',
    'v_proj',
    'up_proj',
    'down_proj'
]

MAX_LENGTH = 256

BATCH_SIZE = 64
MICRO_BATCH_SIZE = 32
GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE
LEARNING_RATE = 3e-4
OUTPUT_DIR = "llama2_7b_finetune"

In [3]:
wmt_base = load_dataset('RicardoRei/wmt-da-human-evaluation', split='train')

Found cached dataset csv (/home/vyskov/.cache/huggingface/datasets/RicardoRei___csv/RicardoRei--wmt-da-human-evaluation-a4a96cd6106c3667/0.0.0/eea64c71ca8b46dd3f537ed218fc9bf495d5707789152eb2764f5c78fa66d59d)


In [4]:
translations = ['en-de', 'en-ru', 'zh-en']
id2langs = {
    'ru': 'russian',
    'en': 'english',
    'de': 'german',
    'zh': 'chinese'
}

In [5]:
train_take_part = 0.25
eval_take_part = 0.1

train_dataset = []
eval_dataset = []

for translation in translations:
    train_part = wmt_base.filter(lambda example: (example['year'] != 2022) & (example['lp'] == translation))
    train_part = train_part.shuffle(seed=42)
    train_dataset.append(Dataset.from_dict(train_part[:int(len(train_part)*train_take_part)]))

    eval_part = wmt_base.filter(lambda example: (example['year'] == 2022) & (example['lp'] == translation))
    eval_part = eval_part.shuffle(seed=42)
    eval_dataset.append(Dataset.from_dict(eval_part[:int(len(eval_part)*eval_take_part)]))
train_dataset = concatenate_datasets(train_dataset)
eval_dataset = concatenate_datasets(eval_dataset)

Loading cached processed dataset at /home/vyskov/.cache/huggingface/datasets/RicardoRei___csv/RicardoRei--wmt-da-human-evaluation-a4a96cd6106c3667/0.0.0/eea64c71ca8b46dd3f537ed218fc9bf495d5707789152eb2764f5c78fa66d59d/cache-1cfd06b39276352a.arrow
Loading cached shuffled indices for dataset at /home/vyskov/.cache/huggingface/datasets/RicardoRei___csv/RicardoRei--wmt-da-human-evaluation-a4a96cd6106c3667/0.0.0/eea64c71ca8b46dd3f537ed218fc9bf495d5707789152eb2764f5c78fa66d59d/cache-c80783202fc9a509.arrow
Loading cached processed dataset at /home/vyskov/.cache/huggingface/datasets/RicardoRei___csv/RicardoRei--wmt-da-human-evaluation-a4a96cd6106c3667/0.0.0/eea64c71ca8b46dd3f537ed218fc9bf495d5707789152eb2764f5c78fa66d59d/cache-9e804864585ee5cf.arrow
Loading cached processed dataset at /home/vyskov/.cache/huggingface/datasets/RicardoRei___csv/RicardoRei--wmt-da-human-evaluation-a4a96cd6106c3667/0.0.0/eea64c71ca8b46dd3f537ed218fc9bf495d5707789152eb2764f5c78fa66d59d/cache-a0338df62a22c4f9.arrow
L

In [6]:
def preprocess(text):
    return ' '.join(text.lower().strip().split())

def get_gemba_da_prompt(langs, source_seg, reference_seg, target_seg, score=None):
    src_lang_id, tgt_lang_id = langs.split('-')
    source_lang, target_lang = id2langs[src_lang_id], id2langs[tgt_lang_id]
    source_seg, reference_seg, target_seg = map(preprocess, [source_seg, reference_seg, target_seg])
    return ''.join([f'Score the following translation from {source_lang} to {target_lang} respect to the human reference on a continuous scale from 0 to 100, ',
            'where a score of zero means "no meaning preserved" and score of one hundred means "perfect meaning and grammar".\n',
            f'{source_lang} source: "{source_seg}"\n',
            f'{target_lang} human reference: "{reference_seg}"\n',
            f'{target_lang} translation: "{target_seg}"\n',
            'Score:' if score is None else f'Score: {score}'])

In [7]:
wmt_datasets = DatasetDict({'train': train_dataset, 'eval': eval_dataset})
wmt_datasets

DatasetDict({
    train: Dataset({
        features: ['lp', 'src', 'mt', 'ref', 'score', 'raw', 'annotators', 'domain', 'year'],
        num_rows: 90281
    })
    eval: Dataset({
        features: ['lp', 'src', 'mt', 'ref', 'score', 'raw', 'annotators', 'domain', 'year'],
        num_rows: 2082
    })
})

In [8]:
llama2_tokenizer = AutoTokenizer.from_pretrained('./llama-2-7b-hf/')
llama2_tokenizer.pad_token = 0

In [9]:
def tokenize_translation(data, tokenizer, prompt_func, fields, max_length=MAX_LENGTH, inference=False):
    prompt = prompt_func(*[data[field] for field in fields])
    result = tokenizer(prompt, truncation=True, max_length=max_length, padding=False)
    if result['input_ids'][-1] != tokenizer.eos_token_id \
        and len(result['input_ids']) < max_length:
        result['input_ids'].append(tokenizer.eos_token_id)
        result['attention_mask'].append(1)
    if not inference:
        result['labels'] = result['input_ids'].copy()
    return result

In [10]:
llama2_wmt_finetune_datasets = wmt_datasets.map(partial(tokenize_translation,
                                                        tokenizer=llama2_tokenizer,
                                                        prompt_func=get_gemba_da_prompt,
                                                        fields=['lp', 'src', 'ref', 'mt', 'raw']),
                                                num_proc=4)

Map (num_proc=4):   0%|          | 0/90281 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/2082 [00:00<?, ? examples/s]

In [11]:
llama2_wmt_finetune_datasets = llama2_wmt_finetune_datasets.remove_columns(set(llama2_wmt_finetune_datasets['train'].column_names)-{'input_ids', 'attention_mask', 'labels'})
llama2_wmt_finetune_datasets.set_format('torch')
llama2_wmt_finetune_datasets['train'].column_names

['input_ids', 'attention_mask', 'labels']

In [12]:
data_collator = transformers.DataCollatorForSeq2Seq(
    llama2_tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
)

In [13]:
train_dataloader = DataLoader(
    llama2_wmt_finetune_datasets['train'], shuffle=True, batch_size=8, collate_fn=data_collator
)
eval_dataloader = DataLoader(
    llama2_wmt_finetune_datasets['eval'], batch_size=8, collate_fn=data_collator
)

In [14]:
for batch in train_dataloader:
    break
{k: v.shape for k, v in batch.items()}

You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


{'input_ids': torch.Size([8, 256]),
 'attention_mask': torch.Size([8, 256]),
 'labels': torch.Size([8, 256])}

In [15]:
model = AutoModelForCausalLM.from_pretrained('./llama-2-7b-hf/', device_map='auto',
                                             torch_dtype=torch.float16, load_in_8bit=True)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [16]:
model = prepare_model_for_kbit_training(model)
config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    target_modules=LORA_TARGET_MODULES,
    lora_dropout=LORA_DROPOUT,
    bias="none",
    task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
model.print_trainable_parameters()

trainable params: 28,049,408 || all params: 6,766,465,024 || trainable%: 0.4145356238525057


In [17]:
def get_score(text):
    score = re.search('Score: \d+(\.\d+)?|$', text).group(0).split(' ', 1)[-1]
    if score.replace('.', '').isnumeric:
        return float(score)
    return -1.0

def compute_metrics(eval_data, tokenizer):
    preds, labels = eval_data

    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    pred_scores = []
    label_scores = []
    skips = 0
    for pred, label in zip(decoded_preds, decoded_labels):
        label = get_score(label)
        if label < 0:
            skips += 1
            continue
        pred = get_score(pred)
        pred_scores.append(pred)
        label_scores.append(label)

    r = spearmanr(pred_scores, label_scores)
    tau = kendalltau(pred_scores, label_scores)
    return {'R': r.statistic, 'tau': tau.statistic, 'skip_for_eval': skips}

In [18]:
training_args = Seq2SeqTrainingArguments(
    output_dir="./llama2_7b_finetune",
    num_train_epochs=1,
    per_device_train_batch_size=MICRO_BATCH_SIZE,
    per_device_eval_batch_size=MICRO_BATCH_SIZE,
    logging_steps=20,
    warmup_steps=100,
    save_strategy='steps',
    evaluation_strategy='steps',
    weight_decay=1e-6,
    eval_steps=400,
    save_steps=400,
    save_total_limit=3,
    dataloader_num_workers=4,
    gradient_checkpointing=True,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    load_best_model_at_end=True,
    report_to=None,
    optim='adamw_torch',
    predict_with_generate=True,
    generation_max_length=int(MAX_LENGTH*3/2)
)

trainer = transformers.Seq2SeqTrainer(
    model=model,
    tokenizer=llama2_tokenizer,
    args=training_args,
    train_dataset=llama2_wmt_finetune_datasets['train'],
    eval_dataset=llama2_wmt_finetune_datasets['eval'],
    data_collator=data_collator,
    compute_metrics=partial(compute_metrics, tokenizer=llama2_tokenizer)
)

model.config.use_cache = False
old_state_dict = model.state_dict
model.state_dict = (
    lambda self, *_, **__: get_peft_model_state_dict(
        self, old_state_dict()
    )
).__get__(model, type(model))
 
model = torch.compile(model)

In [None]:
trainer.train()

[34m[1mwandb[0m: Currently logged in as: [33mw-vskv-w[0m ([33mairi23-efficient-llm-metrics[0m). Use [1m`wandb login --relogin`[0m to force relogin




Step,Training Loss,Validation Loss


In [None]:
model.save_pretrained('./llama2_7b_wmt_finetune_lora_qkv_updown_proj_final')

In [49]:
test_dataset = wmt_base.filter(lambda example: (example['year'] == 2022) & (example['lp'] in translations))

Loading cached processed dataset at /home/vyskov/.cache/huggingface/datasets/RicardoRei___csv/RicardoRei--wmt-da-human-evaluation-a4a96cd6106c3667/0.0.0/eea64c71ca8b46dd3f537ed218fc9bf495d5707789152eb2764f5c78fa66d59d/cache-750e7e0241e2fef2.arrow


In [42]:

test_dataset = test_dataset.map(partial(tokenize_translation,
                                        tokenizer=llama2_tokenizer,
                                        prompt_func=get_gemba_da_prompt,
                                        max_length=MAX_LENGTH*2,
                                        fields=['lp', 'src', 'ref', 'mt'],
                                        inference=True),
                                num_proc=4)
labels = test_dataset['raw']

test_dataset = test_dataset.remove_columns(set(test_dataset.column_names)-{'input_ids', 'attention_mask'})
test_dataset.set_format('torch')
test_dataloader = DataLoader(
    test_dataset, shuffle=False, batch_size=16, collate_fn=data_collator
)

Loading cached processed dataset at /home/vyskov/.cache/huggingface/datasets/RicardoRei___csv/RicardoRei--wmt-da-human-evaluation-a4a96cd6106c3667/0.0.0/eea64c71ca8b46dd3f537ed218fc9bf495d5707789152eb2764f5c78fa66d59d/cache-750e7e0241e2fef2.arrow


Map (num_proc=4):   0%|          | 0/20820 [00:00<?, ? examples/s]

In [81]:
i = 14535
pipeline('text-generation', model=model, tokenizer=llama2_tokenizer, device_map='auto', torch_dtype=torch.float16)(get_gemba_da_prompt(
    test_dataset[i]['lp'], test_dataset[i]['src'], test_dataset[i]['ref'], test_dataset[i]['mt']), max_length=400), test_dataset[i]['raw']

The model 'OptimizedModule' is not supported for text-generation. Supported models are ['BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'CodeGenForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'ElectraForCausalLM', 'ErnieForCausalLM', 'FalconForCausalLM', 'GitForCausalLM', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCausalLM', 'GPTJForCausalLM', 'LlamaForCausalLM', 'MarianForCausalLM', 'MBartForCausalLM', 'MegaForCausalLM', 'MegatronBertForCausalLM', 'MusicgenForCausalLM', 'MvpForCausalLM', 'OpenLlamaForCausalLM', 'OpenAIGPTLMHeadModel', 'OPTForCausalLM', 'PegasusForCausalLM', 'PLBartForCausalLM', 'ProphetNetForCausalLM', 'QDQBertLMHeadModel', 'ReformerModelWithLMHead', 'RemBertForCausalLM', 

([{'generated_text': 'Score the following translation from zh to en respect to the human reference on a continuous scale from 0 to 100, where a score of zero means "no meaning preserved" and score of one hundred means "perfect meaning and grammar".\nzh source: "按照先试点再推行，由点到面逐步推进的原则，探索建立校企联合招生、联合培养、一体化育人的长效机制，切实提升学生岗位技能，提高学生对口就业率和就业质量。"\nen human reference: "According to the principle of experimental units first and implementation later, and gradual advancement from point to area, the long-term mechanism of joint recruitment by schools and enterprises, joint training and integrated training shall be explored and established, so as to fundamentally improve the job skills of students and improve the rate of employment fitting students’ major, as well as the quality of employment quality of students."\nen translation: "In accordance with the principle of piloting first and then implementing it, and gradually advancing from point to point, we will explore the establishment of a long-term me

In [90]:
model.eval()

with torch.no_grad():
    for batch in test_dataloader:
        output = model.generate(**batch, max_new_tokens=10)
        print(output)
        print(llama2_tokenizer.batch_decode(output, skip_special_tokens=True))
        break

tensor([[29900, 29900, 29900,  ..., 18120, 29915, 29879],
        [29900, 29900, 29900,  ..., 18120, 29915, 29879],
        [29900, 29900, 29900,  ...,   474, 29915, 29885],
        ...,
        [29900, 29900, 29900,  ...,   474, 29915, 29885],
        [29900, 29900, 29900,  ...,  7013,  2039,    13],
        [    1,  2522,   487,  ...,  1213,    13, 20097]])
['000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000 Score the following translation from zh to en respect to the human reference on a continuous scale from 0 to 100, where a score of zero means "no meaning preserved" and score of one hundred means "perfect meaning and grammar".\nzh source: "是否有途径处罚他"\nen human reference: "Is there a way to punish him?"\nen translation: "Is there a way to punish him"\nScore: nobody\'s business but the lord\'s', '0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000