# Contrastive Learning from AI Revisions (CLAIR)
This notebook accompanies the "Anchored Preference Optimization and Contrastive Revisions: Addressing Underspecification in Alignment" paper.

In this notebook, we will create preference pairs for alignment through contrastive revisions. We use an LLM behind API for the revision process, but we've cached the results so you can run the notebook without API key.

![alt text](images/github-clair-notebook.png "Contrastive Learning from AI Revisions")

In [1]:
import requests
from joblib import Memory
import datasets
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Set your API Key. If unset, will use caching.
API_KEY = 'your-openai-api-key'

# Change model and kwargs. Only tested on this exact model.
model_name = 'gpt-4-0125-preview'
model_url = 'https://api.openai.com/v1/chat/completions'
model_kwargs = {
    'max_tokens': 4096,
    'temperature': .7
}

Helper functions that can be ignored for now:

In [3]:
# Setup the cache directory
memory = Memory("./cache", verbose=0)

# Get joblib cache working in notebooks
# source: https://stackoverflow.com/questions/75202475/joblib-persistence-across-sessions-machines
def cache(mem, module, **mem_kwargs):
    def cache_(f):
        f.__module__ = module
        f.__qualname__ = f.__name__
        return mem.cache(f, **mem_kwargs)
    return cache_

# extract existing preferences from ultrafeedback
def get_preferences_from_ultrafeedback(dataset):
    instruction = []
    chosen = []
    rejected = []

    for _, row in dataset.iterrows():
        response = [x["response"] for x in row["completions"]]
        score = [x["overall_score"] for x in row["completions"]]

        if len(score):
            chosen_index = score.index(max(score))
            rejected_index = score.index(min(score))

            instruction.append(row["instruction"])
            chosen.append(response[chosen_index])
            rejected.append(response[rejected_index])

    return pd.DataFrame.from_dict({
        "text": instruction,
        "rejected": rejected,
        "chosen": chosen  
    })

# Visualize a preference triple
def visualize_triple(triple: dict):
    print('---TEXT (first 400 characters):\n')
    print(triple['text'][:400])
    print('---REJECTED (first 400 characters):\n')
    print(triple['rejected'][:400])
    print('---CHOSEN (first 400 characters):\n')
    print(triple['chosen'][:400])
    if 'rational' in triple:
        print('---REVISION RATIONAL (first 400 characters):\n')
        print(triple['rational'][:400])


@cache(memory, "CLAIR_preferences")
def query_chat_model(user_prompt, system_prompt='', url='https://api.openai.com/v1/chat/completions', model_name='gpt-4-0125-preview'):
    print(f"Querying {model_name} API at {url}...")
    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {API_KEY}"
    }
    data = {
        "model": model_name,
        "messages": [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ],
        **model_kwargs
    }
    response = requests.post(url, headers=headers, json=data)
    
    if response.status_code == 200:
        result = response.json()
        return result
    else:
        raise Exception(f"API request failed with status code {response.status_code}: {response.text}")

# Parse revisions
def get_revision_from_response(response):
    raw_completion = response['choices'][0]['message']['content']

    if "**Corrected Student Solution:**" in raw_completion:
        splits = raw_completion.split("**Corrected Student Solution:**")
    elif "{corrected_student_solution}:" in raw_completion:
        splits = raw_completion.split("{corrected_student_solution}:")
    elif "{corrected_student_solution}" in raw_completion:
        splits = raw_completion.split("{corrected_student_solution}")
    elif "**Worsened Student Solution:**" in raw_completion:
        splits = raw_completion.split("**Worsened Student Solution:**")
    elif "{worsened_student_solution}:" in raw_completion:
        splits = raw_completion.split("{worsened_student_solution}:")
    elif "{worsened_student_solution}" in raw_completion:
        splits = raw_completion.split("{worsened_student_solution}")
    
    if len(splits) >= 2: 
        edit = splits[1]
        edit = edit.strip('\n\n').strip()

        rational = splits[0]
        if '{teacher_reasoning}' in rational:
            rational = rational.split('{teacher_reasoning}')[1].strip(':').strip()
        rational = rational.strip('\n\n').strip()
    else:
        Exception('Failed to parse response')
    return edit, rational



## Load data
We will load data from an existing dataset. Alternatively, you can use a Language Model to generate your own data.

In this notebook, we will use prompts and existing responses from the UltraFeedback dataset. UltraFeedback already contains preference pairs, against which we can compare our CLAIR preferences.

In [4]:
ultraf = datasets.load_dataset('openbmb/UltraFeedback')['train'].to_pandas()
ultraf = ultraf[:20] # take first 20 examples
ultraf = get_preferences_from_ultrafeedback(ultraf)

In [5]:
# look at an example
visualize_triple(ultraf.loc[2])

---TEXT (first 400 characters):

Identify the interrelated economic, political, and social factors that contributed to the stock market crash of 1929, including but not limited to the impact of World War I on the global economy, the role of government policies such as the Smoot-Hawley Tariff Act, the effects of speculative investment practices and margin trading, and the socioeconomic disparities of the time period. Additionally,
---REJECTED (first 400 characters):

Sure, I'd be happy to help you learn about the causes and effects of the 1929 stock market crash and how it compares to other financial crises.

The stock market crash of 1929 was a significant event that occurred on October 29, known as Black Tuesday. This event marked the beginning of the Great Depression, a period of economic downturn and unemployment that lasted for nearly a decade. There were
---CHOSEN (first 400 characters):

The stock market crash of 1929 was a result of a complex interplay of economic, political, an

## Revise and Improve
We will contrastively revise and improve answers using an LLM over API. We'll start from the `rejected` answers in UltraFeedback.

In [6]:
revision_template = """You are a teacher and your task is to minimally improve a student's answer. I will give you a {{task}} and a {{student_solution}}. Your job is to revise the {{student_solution}} such that it is clearer, more correct, and more engaging. Copy all non-corrected parts of the student's answer. Do not allude to the {{corrected_student_solution}} being a revision or a correction in your final solution.\n\n{{task}}: {input}\n\n{{student_solution}}: {output}\n\n-----------------\n\nLet's first think step by step with a {{teacher_reasoning}} to decide how to improve the {{student_solution}}, then give the {{corrected_student_solution}}. Mention the {{teacher_reasoning}} and {{corrected_student_solution}} identifiers to structure your answer.\n\n"""

# let's look at the revision prompt
print(revision_template)

You are a teacher and your task is to minimally improve a student's answer. I will give you a {{task}} and a {{student_solution}}. Your job is to revise the {{student_solution}} such that it is clearer, more correct, and more engaging. Copy all non-corrected parts of the student's answer. Do not allude to the {{corrected_student_solution}} being a revision or a correction in your final solution.

{{task}}: {input}

{{student_solution}}: {output}

-----------------

Let's first think step by step with a {{teacher_reasoning}} to decide how to improve the {{student_solution}}, then give the {{corrected_student_solution}}. Mention the {{teacher_reasoning}} and {{corrected_student_solution}} identifiers to structure your answer.




In [7]:
# create the prompts using data from ultrafeedback
prompts = []
prompts_text = []
prompts_rejected = []

for _, triple in ultraf.iterrows():
    prompts.append(revision_template.format(input=triple['text'], output=triple['rejected']))
    prompts_text.append(triple['text'])
    prompts_rejected.append(triple['rejected'])

# visualize prompt
print(prompts[2])
        

You are a teacher and your task is to minimally improve a student's answer. I will give you a {task} and a {student_solution}. Your job is to revise the {student_solution} such that it is clearer, more correct, and more engaging. Copy all non-corrected parts of the student's answer. Do not allude to the {corrected_student_solution} being a revision or a correction in your final solution.

{task}: Identify the interrelated economic, political, and social factors that contributed to the stock market crash of 1929, including but not limited to the impact of World War I on the global economy, the role of government policies such as the Smoot-Hawley Tariff Act, the effects of speculative investment practices and margin trading, and the socioeconomic disparities of the time period. Additionally, provide a comparative analysis of the causes and effects of the 1929 stock market crash with other notable financial crises in history, such as the Great Depression of the 1930s and the 2008 global f

In [8]:
# query the API and parse the responses
# For the first 20 examples, this code will rely on cached revisions
preferences = []
for prompt, prompt_text, prompt_rejected in zip(prompts, prompts_text, prompts_rejected):
    try:
        response = query_chat_model(prompt, model_name=model_name, url=model_url)
        revision, revision_rational = get_revision_from_response(response)

        preferences.append({
            'text': prompt_text,
            'rejected': prompt_rejected,
            'chosen': revision,
            'rational': revision_rational,
        })
    except Exception as e:
        # don't block on exception
        print(e)

# Turn this into a dataset
ultraf_revisions = pd.DataFrame.from_records(preferences)

In [9]:
# Visualize an edit
visualize_triple(ultraf_revisions.loc[2])

---TEXT (first 400 characters):

Identify the interrelated economic, political, and social factors that contributed to the stock market crash of 1929, including but not limited to the impact of World War I on the global economy, the role of government policies such as the Smoot-Hawley Tariff Act, the effects of speculative investment practices and margin trading, and the socioeconomic disparities of the time period. Additionally,
---REJECTED (first 400 characters):

Sure, I'd be happy to help you learn about the causes and effects of the 1929 stock market crash and how it compares to other financial crises.

The stock market crash of 1929 was a significant event that occurred on October 29, known as Black Tuesday. This event marked the beginning of the Great Depression, a period of economic downturn and unemployment that lasted for nearly a decade. There were
---CHOSEN (first 400 characters):

The stock market crash of 1929, culminating on October 29, known as "Black Tuesday," was a pi

In [10]:
ultraf_revisions.head()

Unnamed: 0,text,rejected,chosen,rational
0,Can you write a C++ program that prompts the u...,"Sure, here is the program using the C++11 algo...",```cpp\n#include <iostream>\n#include <string>...,The student's solution seems to misunderstand ...
1,Suppose you are a content creator and want to ...,To use GPT for generating compelling titles an...,To use GPT for automatically generating compel...,The student's solution provides a comprehensiv...
2,"Identify the interrelated economic, political,...","Sure, I'd be happy to help you learn about the...","The stock market crash of 1929, culminating on...",The student's answer provides a decent overvie...
3,How can I convert the decimal number 31 to bin...,"Sure, I'd be happy to help with that! \n\nIn J...","Sure, I'd be happy to help with that!\n\nIn Ja...",The student's solution contains a couple of in...
4,Can you modify the C++ code provided below to ...,"Sure, I can help you with that. Here's how you...",To generate the first 20 Fibonacci numbers in ...,The student's solution provides a basic outlin...


## You've now created CLAIR preferences!
<!-- We've already created a dataset of 32K revisions. Let's load this and compare CLAIR revisions with conventional RLAIF preferences. -->