In [1]:
from transformers import AutoTokenizer, GPTNeoXForCausalLM, TrainingArguments
from datasets import load_dataset, load_metric
from peft import LoraConfig
from trl import SFTTrainer
import numpy as np
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# load the dataset
dataset = load_dataset("tatsu-lab/alpaca", split="train")

In [3]:
train_ratio = 0.9
split_datasets = dataset.train_test_split(train_size=train_ratio, seed=1006)

train_dataset = split_datasets['train']
val_dataset = split_datasets['test']

In [4]:
# lora config
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

In [5]:
# base model and tokenizer
model = GPTNeoXForCausalLM.from_pretrained(
  "EleutherAI/pythia-70m-deduped",
  revision="step3000",
  cache_dir="./pythia-70m-deduped/step3000",
)

In [6]:
tokenizer = AutoTokenizer.from_pretrained(
  "EleutherAI/pythia-70m-deduped",
  revision="step3000",
  cache_dir="./pythia-70m-deduped/step3000",
)

tokenizer.pad_token = tokenizer.eos_token

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [7]:
# prepare training arguments
trainer_args = TrainingArguments(
    num_train_epochs=2,
    evaluation_strategy="epoch",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=1,
    seed=1006,
    output_dir="./output",
    save_strategy="epoch",
    eval_accumulation_steps=4,
)

In [8]:
# compute metrics
rouge = load_metric("rouge", trust_remote_code=True)

  rouge = load_metric("rouge", trust_remote_code=True)


In [9]:
def compute_metrics(eval_pred):
    label_ids = eval_pred.label_ids
    pred_ids = eval_pred.predictions[0]

    decoded_predictions = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_ids[label_ids == -100] = tokenizer.pad_token_id
    decoded_labels = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    rouge_output = rouge.compute(
        predictions=decoded_predictions, 
        references=decoded_labels,
    )

    return {key: value.mid.fmeasure * 100 for key, value in rouge_output.items()}

In [10]:
def preprocess_logits_for_metrics(logits, labels):
    pred_ids = torch.argmax(logits[0], dim=-1)
    return pred_ids, labels

In [11]:
# prepare trainer
trainer = SFTTrainer(
    model=model, 
    tokenizer=tokenizer,
    args=trainer_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
    preprocess_logits_for_metrics=preprocess_logits_for_metrics,
    max_seq_length=2048,
    dataset_text_field="text",
    peft_config=lora_config,
    packing=True
)

Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [12]:
trainer.train()

Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,2.4044,2.250429,68.711784,38.730587,53.091947,66.168742
2,2.2298,2.215025,69.059928,39.129061,53.602359,66.536307


TrainOutput(global_step=1276, training_loss=2.2914053698692203, metrics={'train_runtime': 875.5506, 'train_samples_per_second': 5.829, 'train_steps_per_second': 1.457, 'total_flos': 2814002979667968.0, 'train_loss': 2.2914053698692203, 'epoch': 2.0})

In [13]:
trainer.model.save_pretrained("./output/final_checkpoint/")