# Preference Tuning for Summarization using Synthetic Data

**⏱️ Time to complete**:

Preference tuning is a powerful tool that can optimize LLMs towards complex preferences that can not easily captured through supervised fine-tuning. However, manually annotating preferences between model outputs using human raters can be extremely time-consuming and expensive. In this example, we demonstrate how preference data can be synthetically generated by leveraging larger LLMs to score a model's outputs.


We will focus on the task of summarization for the [CNN/DailyMail](https://huggingface.co/datasets/abisee/cnn_dailymail) dataset

# Table of Contents
1. [Data Preprocessing](#step-1-data-preprocessing): In this section we cover how we can prepare preference data for the summarization task using an LLM-as-a-judge.
2. [DPO Finetuning](#step-2-fine-tuning): This section will cover how you can fine-tune an open source model on the preference data on the Anyscale platform.
3. [Evaluation](#step-3-evaluation): The section will lay down a blue-print for evaluation and compare performance to that of closed source models like OpenAI's GPT-4.
4. [Iterative-DPO](#step-4-iterative): An optional step to further boost performance with iterative preference-tuning. 

First, let's make the necessary imports

In [1]:
import os
import yaml
import datasets
import openai

import ray.data

# Step 1: Synthetic Data Generation

In [3]:
hf_ds = datasets.load_dataset("abisee/cnn_dailymail", '3.0.0', split="train").shuffle(seed=21)
# extract a subset of 20000 articles
hf_ds_subset =  hf_ds.select(range(20000))

ray_ds = ray.data.from_huggingface(hf_ds_subset)
raw_example = ray_ds.take(1)[0]

Downloading readme:   0%|          | 0.00/15.6k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/257M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/257M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/259M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/34.7M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/30.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/287113 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/13368 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/11490 [00:00<?, ? examples/s]

Downloading readme:   0%|          | 0.00/15.6k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/257M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/257M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/259M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/34.7M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/30.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/287113 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/13368 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/11490 [00:00<?, ? examples/s]

2024-08-09 11:11:41,594	INFO worker.py:1596 -- Connecting to existing Ray cluster at address: 10.0.4.151:6379...
2024-08-09 11:11:41,601	INFO worker.py:1772 -- Connected to Ray cluster. View the dashboard at [1m[32mhttps://session-m4a38rehf7miww178mefsrumy2.i.anyscaleuserdata.com [39m[22m
2024-08-09 11:11:41,603	INFO packaging.py:358 -- Pushing file package 'gcs://_ray_pkg_571a453227fe1f71a0db8d4c7877fab901d9fc29.zip' (0.10MiB) to Ray cluster...
2024-08-09 11:11:41,604	INFO packaging.py:371 -- Successfully pushed file package 'gcs://_ray_pkg_571a453227fe1f71a0db8d4c7877fab901d9fc29.zip'.
2024-08-09 11:11:49,781	INFO dataset.py:2416 -- Tip: Use `take_batch()` instead of `take() / show()` to return records in pandas or numpy batch format.
2024-08-09 11:11:49,784	INFO streaming_executor.py:108 -- Starting execution of Dataset. Full logs are in /tmp/ray/session_2024-08-09_10-34-55_275931_2296/logs/ray-data
2024-08-09 11:11:49,785	INFO streaming_executor.py:109 -- Execution plan of Data

- limit=1 1: 0 bundle [00:00, ? bundle/s]

Running 0: 0 bundle [00:00, ? bundle/s]

[36m(autoscaler +1h53m32s)[0m Tip: use `ray status` to view detailed cluster status. To disable these messages, set RAY_SCHEDULER_EVENTS=0.


## Generating Article Questions

For each article, we need to generate 5 multiple choice questions using Llama-70B-Instruct.

In [6]:
import pprint 
pprint.pprint(raw_example, width=100)

{'article': 'Scam: Lisa Harrison, 34, promised customers low currency rates on US dollars and '
            'special deals . A wedding planner who stole £80,000 from couples in a bid to satisfy '
            "an 'out-of-control' online gambling addiction has been jailed. Lisa Harrison, 34, "
            'began taking money from her clients in summer 2013 by enticing them with low currency '
            'rates on US dollars and flight upgrades. She took money from 19 couples who had '
            'entrusted their savings to her after being promised the wedding of their dreams. It '
            'is understood that the company she worked for, iPlan New York, specialised in '
            'weddings in New York City. Her website iplannewyork.com, which has been taken down, '
            "said: 'iPlan New York was set up to create and style the perfect tailor made wedding "
            "for couples travelling to New York to get married! 'We are passionate about what we "
            'do and p

Our data pre-processing is going to look as follows: 

![preprocessing](./assets/preprocessing.png?1)


# TODO: Instructions for pre-processing

\<We have the relevant preprocessing code in `utils/generate_questions.py` and `utils/generate_summaries.py`. You can run data generation as an Anyscale job with configs/generate_questions_job.yaml and configs/generate_summaries_job.yaml.\>

\<After preprocessing, here's an example for the Q&A generated by Llama 70B and here's an example for the summaries generated by Mistral 7B Instruct \>


\<We sample chosen and rejected messages from the summaries based on the Q&A Accuracy score. We use a threshold of 3/5 for classifying examples as 'chosen' and 'rejected'. Here's an example training dataset sample for the DPO model\>

# Step 2: Fine-tuning

Now that we have the pre-processed dataset, we are ready to fine-tune `Mistral-7B-Instruct-v0.1` using DPO. On Anyscale, we've created an easy-to-use interface to do preference-tuning using `DPO`. We leverage Ray to overlap reference model log-probability calculation with model training to improve GPU utilization. Most implementations compute log probabilities synchronously with model training,

![hf model](assets/hf_dpo.png)

While our implementation using Ray is asynchronous:  


![assistant model](assets/anyscale_dpo.png)

Further, our use of Ray Data also implies that the compute configuration for the reference model can be completely decoupled with the policy model. For example, reference model calculation can run on a different node with zero code changes needed. 


To get started with DPO training, we provide the config for DPO in [configs/mistral_dpo_summarization.yaml](configs/mistral_dpo_summarization.yaml) . 


TODO: The provided config uses 6 and 2 A10s and doesn't utilize GPUs properly. We should improve logprob processor

In [10]:
!cat configs/mistral_dpo_summarization.yaml

model_id: mistralai/Mistral-7B-Instruct-v0.1
# Example summarization dataset with 10k examples for training with an average of 2.2k tokens per sample
train_path: s3://air-example-data/preference-tuning-summarization/train.jsonl
valid_path: s3://air-example-data/preference-tuning-summarization/valid.jsonl
task: "preference_tuning"
context_length: 4096
# For DPO, it is recommended to set a high `num_data_blocks_per_device` to not bottleneck the logp processor.
# We recommend not going beyond 20 so as to not spawn too many Ray actors. 
num_data_blocks_per_device: 16
num_devices: 6 # <--- runs training on 6 GPUs
train_batch_size_per_device: 2
eval_batch_size_per_device: 2
learning_rate: 5e-6
num_epochs: 3
no_gradient_checkpoint: False
output_dir: /mnt/local_storage/
deepspeed:
  config_path: deepspeed_configs/zero_3.json
worker_resources:
  accelerator_type:A10G: 1
flash_attention_2: True
padding: "longest"
preference_tuning_config:
  beta: 0.01
  logprob_processor_scaling_config:
    cust

In [None]:
!llmforge anyscale finetune end-to-end-examples/fine-tune-preference/configs/mistral_dpo_summarization.yaml

# Step 3: Evaluation

Let's evaluate our trained model. Here we'll use two baselines: (1) the base model before finetuning (reference model in DPO) and (2) GPT-4o.

## Evaluation strategy

Our evaluation strategy involves the same Q&A scoring system as used while generating the preference data. \<TODO: Add more here\>



# Step 4: Iterative-DPO (optional)

TODO