# Text Summarization using RLHF

## 1. Supervised Fine-Tuning (SFT)


### 1.1. Download dataset

We download dataset from transformers library to contruct formatting like: "Text: document # Summary: summary"

In [9]:
# install lib
!pip install -q datasets evaluate==0.4.1 rouge_score==0.1.2 peft==0.10.0

In [None]:
# load the dataset
from datasets import load_dataset

sft_ds_name = 'CarperAI/openai_summarize_tldr'
sft_ds = load_dataset(sft_ds_name)
sft_train = sft_ds['train']
sft_valid = sft_ds['valid']
sft_test = sft_ds['test']

# contruct
def formatting_func(example):
    text = f"### Text: {example['promt']}\n ### Summary: {example['label']}"
    return text

# demo formatting
for example in sft_train:
    print(formatting_func(example))
    break


### 1.2. Model

prior training model using OPT. So that speed up training model, we can use `quantization` technique and `LORA`.

In [None]:
import torch
from trl import ModelConfig, get_quantization_config, get_kbit_device_map
from peft import LoraConfig, PertConfig, PertModel, get_peft_model, prepare_model_for_kbit_training

model_config = ModelConfig(
    model_name_or_path = 'facebook/opt-350m'
)

torch_dtype = (
    model_config.torch_dtype
    if model_config.torch_dtype in ["auto", None]
    else getattr(torch, model_config.torch_dtype)
)

quantization_config = get_quantization_config(model_config)
model_kwargs = dict(
    revision = model_config.model_revision,
    trust_remote_code = model_config.trust_remote_code,
    attn_implementation=model_config.attn_implementation,
    torch_dtype=torch_dtype,
    use_cache=False,
    device_map=get_kbit_device_map () if quantization_config is not None else None,
    quantization_config=quantization_config,
)

tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

# lora
peft_config = LoraConfig(
    r = 16,
    lora_alpha = 32,
    lora_dropout = 0.05,
    bias = "none",
    task_type = "CAUSAL_LM"
)

### 1.3 Metric

We use `ROUGE` metric to evaluate model.

In [None]:
import evaluate

rouge = evaluate.load("rouge")

def compute_metrics(eval_preds):
    if isinstance(eval_preds, tuple):
        eval_preds = eval_preds[0]
    labels_ids = eval_preds.label_ids
    pred_ids = eval_preds.predictions
    pred_str = tokenizer.batch_decode(pred_ids , skip_special_tokens=True)
    label_str = tokenizer.batch_decode(labels_ids , skip_special_tokens=True)
    result = rouge.compute(predictions=pred_str , references=label_str)
    return result

### 1.4. Trainer
We shall contruct parameters for model

In [None]:
from trl import SFTTrainer
from transformers import TrainingArguments

num_epochs = 10
trainig_args = TrainingArguments(
    output_dir = './save_model',
    evaluation_strategy = "epoch",
    save_strategy = 'epoch',
    per_device_train_batch_size = 4,
    per_device_eval_batch_size = 4,
    adam_beta1 =0.9,
    adam_beta2 =0.95,
    num_train_epochs=num_epochs,
    load_best_model_at_end=True,
)

max_input_length = 512
trainer = SFTTrainer(
    model=model_config.model_name_or_path,
    model_init_kwargs=model_kwargs,
    args=training_args,
    train_dataset=sft_train,
    eval_dataset=sft_valid,
    max_seq_length=max_input_length,
    tokenizer=tokenizer,
    peft_config=peft_config,
    compute_metrics=compute_metrics,
    packing=True,
    formatting_func=formatting_func
)

trainer.train()