In [1]:
from transformers import AutoTokenizer, GPTNeoXForCausalLM, TrainingArguments
from datasets import load_dataset, load_metric
from peft import LoraConfig
from trl import SFTTrainer
import numpy as np
import torch
import logging
import wandb

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
wandb.init(project="efficient-llm-sft", entity="irisiris")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33mirisiris[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
logging.basicConfig(filename='./log/finetune-2.8b.log', level=logging.INFO, 
                    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

In [4]:
# load the dataset
dataset = load_dataset("tatsu-lab/alpaca", split="train")

In [5]:
train_ratio = 0.8
split_datasets = dataset.train_test_split(train_size=train_ratio, seed=1006)

train_dataset = split_datasets['train']
val_dataset = split_datasets['test']

In [6]:
# lora config
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

In [7]:
# base model and tokenizer
model = GPTNeoXForCausalLM.from_pretrained(
  "EleutherAI/pythia-2.8b-deduped",
  revision="step143000",
  cache_dir="./pythia-2.8b-deduped/step143000",
)

In [8]:
tokenizer = AutoTokenizer.from_pretrained(
  "EleutherAI/pythia-2.8b-deduped",
  revision="step143000",
  cache_dir="./pythia-2.8b-deduped/step143000",
)

tokenizer.pad_token = tokenizer.eos_token

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [9]:
# prepare training arguments
trainer_args = TrainingArguments(
    num_train_epochs=15,
    learning_rate=1e-4,
    evaluation_strategy="epoch",
    lr_scheduler_type="cosine",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=1,
    seed=1006,
    output_dir="./output/2.8b",
    save_strategy="epoch",
    eval_accumulation_steps=4,
    save_total_limit=2,
    report_to="wandb",
)

In [10]:
# compute metrics
rouge = load_metric("rouge", trust_remote_code=True)

  rouge = load_metric("rouge", trust_remote_code=True)


In [11]:
def compute_metrics(eval_pred):
    pred_ids, label_ids = eval_pred
    if isinstance(pred_ids, tuple):
        pred_ids = pred_ids[0]

    decoded_predictions = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)  # a list of decoded string
    logger.info(f"decoded_predictions: {decoded_predictions[0]}")

    label_ids = np.where(label_ids != -100, label_ids, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
    logger.info(f"decoded_labels: {decoded_labels[0]}")

    rouge_output = rouge.compute(
        predictions=decoded_predictions, 
        references=decoded_labels,
    )

    return {key: value.mid.fmeasure * 100 for key, value in rouge_output.items()}

In [12]:
def preprocess_logits_for_metrics(logits, labels):
    pred_ids = torch.argmax(logits[0], dim=-1)  # (batch_size, seq_length)
    return pred_ids, labels

In [13]:
# prepare trainer
trainer = SFTTrainer(
    model=model, 
    tokenizer=tokenizer,
    args=trainer_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
    preprocess_logits_for_metrics=preprocess_logits_for_metrics,
    max_seq_length=512,
    dataset_text_field="text",
    peft_config=lora_config,
    packing=True
)

In [14]:
trainer.train()

Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum
1,1.1649,1.156159,73.57339,48.340833,65.629199,71.420109
2,1.1373,1.142821,73.723584,48.693642,65.929097,71.595419


config.json: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 571/571 [00:00<00:00, 2.69MB/s]


KeyboardInterrupt: 

In [None]:
# trainer.model.save_pretrained("./output/160m/final_checkpoint/")