In [None]:
import sys
sys.path.append('../')  # To ensure utils can be imported if needed

from utils.data_loader import load_precomputed_results

import matplotlib.pyplot as plt

import json
import tiktoken # for token counting
import numpy as np
from collections import defaultdict

In [2]:
%load_ext autoreload
%autoreload 2

## Fine-tuning

In [None]:
data_path = "../pcd_data/finetune/correction_training_set.jsonl"
# data_path = "../pcd_data/finetune/correction_training_set_100.jsonl"

# Load the dataset
with open(data_path, 'r', encoding='utf-8') as f:
    dataset = [json.loads(line) for line in f]

# Initial dataset stats
print("Num examples:", len(dataset))
print("First example:")
for message in dataset[0]["messages"]:
    print(message)

In [None]:
# Format error checks
format_errors = defaultdict(int)

for ex in dataset:
    if not isinstance(ex, dict):
        format_errors["data_type"] += 1
        continue
        
    messages = ex.get("messages", None)
    if not messages:
        format_errors["missing_messages_list"] += 1
        continue
        
    for message in messages:
        if "role" not in message or "content" not in message:
            format_errors["message_missing_key"] += 1
        
        if any(k not in ("role", "content", "name", "function_call", "weight") for k in message):
            format_errors["message_unrecognized_key"] += 1
        
        if message.get("role", None) not in ("system", "user", "assistant", "function"):
            format_errors["unrecognized_role"] += 1
            
        content = message.get("content", None)
        function_call = message.get("function_call", None)
        
        if (not content and not function_call) or not isinstance(content, str):
            format_errors["missing_content"] += 1
    
    if not any(message.get("role", None) == "assistant" for message in messages):
        format_errors["example_missing_assistant_message"] += 1

if format_errors:
    print("Found errors:")
    for k, v in format_errors.items():
        print(f"{k}: {v}")
else:
    print("No errors found")

In [None]:
encoding = tiktoken.get_encoding("cl100k_base")

# not exact!
# simplified from https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
def num_tokens_from_messages(messages, tokens_per_message=3, tokens_per_name=1):
    num_tokens = 0
    for message in messages:
        num_tokens += tokens_per_message
        for key, value in message.items():
            num_tokens += len(encoding.encode(value))
            if key == "name":
                num_tokens += tokens_per_name
    num_tokens += 3
    return num_tokens

def num_assistant_tokens_from_messages(messages):
    num_tokens = 0
    for message in messages:
        if message["role"] == "assistant":
            num_tokens += len(encoding.encode(message["content"]))
    return num_tokens

def print_distribution(values, name):
    print(f"\n#### Distribution of {name}:")
    print(f"min / max: {min(values)}, {max(values)}")
    print(f"mean / median: {np.mean(values)}, {np.median(values)}")
    print(f"p5 / p95: {np.quantile(values, 0.1)}, {np.quantile(values, 0.9)}")

In [None]:
# Warnings and tokens counts
n_missing_system = 0
n_missing_user = 0
n_messages = []
convo_lens = []
assistant_message_lens = []

for ex in dataset:
    messages = ex["messages"]
    if not any(message["role"] == "system" for message in messages):
        n_missing_system += 1
    if not any(message["role"] == "user" for message in messages):
        n_missing_user += 1
    n_messages.append(len(messages))
    convo_lens.append(num_tokens_from_messages(messages))
    assistant_message_lens.append(num_assistant_tokens_from_messages(messages))
    
print("Num examples missing system message:", n_missing_system)
print("Num examples missing user message:", n_missing_user)
print_distribution(n_messages, "num_messages_per_example")
print_distribution(convo_lens, "num_total_tokens_per_example")
print_distribution(assistant_message_lens, "num_assistant_tokens_per_example")
n_too_long = sum(l > 16385 for l in convo_lens)
print(f"\n{n_too_long} examples may be over the 16,385 token limit, they will be truncated during fine-tuning")

In [None]:
# Pricing and default n_epochs estimate
MAX_TOKENS_PER_EXAMPLE = 16385

TARGET_EPOCHS = 3
MIN_TARGET_EXAMPLES = 100
MAX_TARGET_EXAMPLES = 25000
MIN_DEFAULT_EPOCHS = 1
MAX_DEFAULT_EPOCHS = 25

n_epochs = TARGET_EPOCHS
n_train_examples = len(dataset)
if n_train_examples * TARGET_EPOCHS < MIN_TARGET_EXAMPLES:
    n_epochs = min(MAX_DEFAULT_EPOCHS, MIN_TARGET_EXAMPLES // n_train_examples)
elif n_train_examples * TARGET_EPOCHS > MAX_TARGET_EXAMPLES:
    n_epochs = max(MIN_DEFAULT_EPOCHS, MAX_TARGET_EXAMPLES // n_train_examples)

n_billing_tokens_in_dataset = sum(min(MAX_TOKENS_PER_EXAMPLE, length) for length in convo_lens)
print(f"Dataset has ~{n_billing_tokens_in_dataset} tokens that will be charged for during training")
print(f"By default, you'll train for {n_epochs} epochs on this dataset")
print(f"By default, you'll be charged for ~{n_epochs * n_billing_tokens_in_dataset} tokens")
print(f"By default, this will cost ~${n_epochs * n_billing_tokens_in_dataset / 1000000 * 3:.2f}")

In [None]:
with open("../api_key/config.json", 'r') as f:
    openai_api_key = json.load(f)["openai_api_key"]
    
from openai import OpenAI
client = OpenAI(api_key=openai_api_key)

In [None]:
import os
os.system(f'curl https://api.openai.com/v1/files -H "Authorization: Bearer {openai_api_key}" -F purpose="fine-tune" -F file="@../pcd_data/finetune/correction_training_set.jsonl"')

In [None]:
client.fine_tuning.jobs.create(
    training_file=..., # The ID of the uploaded file; see the previous cell
    model="gpt-4o-mini-2024-07-18"
)

In [None]:
# List 10 fine-tuning jobs
client.fine_tuning.jobs.list(limit=10)

In [None]:
completion = client.chat.completions.create(
    model=..., # The model name is in the `fine_tuned_model` field in the previous cell.
    messages=[
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": "Hello!"}
    ]
)

print(completion.choices[0].message)

## Evaluation

First run `python run_experiments_finetune.py`.

Then you will have `exp_results/mix/test_400_perturbed_premise_gpt-4o-mini-sft.jsonl` and execute the following cells:

In [None]:
# Load precomputed results
df_results = load_precomputed_results("../exp_results/eval/results.pkl")

In [4]:
from utils.data_loader import load_dataset_results, load_api_key
from utils.evaluator.cot_evaluator import CoTEvaluator
from utils.analysis.metrics import perturbation_ratio_given_correct
from utils.analysis.bootstrapping import bootstrap, bootstrap_with_ratios
from utils.analysis.data_processing import assign_confidence_bins

In [None]:
from tqdm import tqdm
import pandas as pd
import json

datasets = ["mix"]
models = [
    "gpt-4o-mini-sft",
]
mode = "premise"
sample_size = 400
columns = ["base_original", "base_misinformed", "inst_original", "inst_misinformed", "ft_original", "ft_misinformed", "inst_ft_original", "inst_ft_misinformed"]

with open("../api_key/config_yichen.json", "r") as f:
    config = json.load(f)
    api_key = config["openai_api_key"]
    
cot_evaluator = CoTEvaluator(api_key=api_key)

In [6]:
############################################
# Step 1: Evaluate CoT and Store Results
############################################
overall_results = {}
print("Evaluating CoT outputs...")
for model_name in tqdm(models, desc="Models"):
    overall_results[model_name] = {}
    for dataset_name in tqdm(datasets, desc=f"{model_name} Datasets", leave=False):
        try:
            df_sample = pd.read_json(f"../exp_results/mix/test_{sample_size}_perturbed_{mode}_{model_name}.jsonl", lines=True)
        except:
            continue

        df_sample["correct_answer"] = df_sample["correct_answer"].astype(str)

        df_eval_results = pd.DataFrame(columns=columns)
        for idx, row in tqdm(df_sample.iterrows(), total=df_sample.shape[0], desc=f"Evaluating {model_name}-{dataset_name}", leave=False):
            evaluation_results = {}
            for column in columns:
                if "original" in column or column == "prfx_q":
                    evaluation_results[column] = cot_evaluator.evaluate_cot_list(
                        row[column], row["correct_answer"]
                    )
                else:
                    evaluation_results[column] = {}
                    for role in ["human"]:
                        if role in row[column]:
                            evaluation_results[column][role] = cot_evaluator.evaluate_cot_list(
                                row[column][role],
                                row["correct_answer"]
                            )
                        else:
                            evaluation_results[column][role] = {}
            df_eval_results = pd.concat([df_eval_results, pd.DataFrame([evaluation_results], index=[idx])])

        df_eval_results.reset_index(drop=True, inplace=True)
        overall_results[model_name][dataset_name] = df_eval_results

Evaluating CoT outputs...


Models:   0%|          | 0/1 [00:00<?, ?it/s]
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Models: 100%|██████████| 1/1 [00:01<00:00,  1.63s/it]


In [None]:
import numpy as np

print("Computing bootstrapping results for radar/table...")
for model_name in tqdm(models, desc="Radar Bootstrapping - Models"):
    if model_name not in overall_results:
        continue
    df_results[model_name] = {}
    for source in tqdm(["human"], desc=f"{model_name} Sources", leave=False):
        mean_df = pd.DataFrame(index=datasets, columns=columns)
        margin_df = pd.DataFrame(index=datasets, columns=columns)
        ratio_lists = {col: {} for col in columns}

        for column in tqdm(columns, desc=f"{model_name}-{source} Columns", leave=False):
            for dataset_name in tqdm(datasets, desc=f"{model_name}-{source}-{column} Datasets", leave=False):
                if dataset_name not in overall_results[model_name]:
                    mean_df.loc[dataset_name, column] = np.nan
                    margin_df.loc[dataset_name, column] = (np.nan, np.nan)
                    ratio_lists[column][dataset_name] = []
                    continue

                df_data = overall_results[model_name][dataset_name].reset_index(drop=True)
                if "original" in column:
                    original_column = "base_original"
                else:
                    original_column = "prfx_q"
                ratios = bootstrap_with_ratios(
                    df_data,
                    lambda x: perturbation_ratio_given_correct(
                        x,
                        cij=original_column,
                        pij=column,
                        evaluation_type="overall_correct",
                        perturbation_role=source if column != original_column else None
                    ),
                    n=1000,
                )
                mean_val = np.mean(ratios)
                lb_val = np.percentile(ratios, 2.5)
                ub_val = np.percentile(ratios, 97.5)

                mean_df.loc[dataset_name, column] = mean_val
                margin_df.loc[dataset_name, column] = (lb_val, ub_val)
                ratio_lists[column][dataset_name] = ratios

        df_results[model_name][source] = {
            "mean": mean_df,
            "margin": margin_df,
            "ratio_lists": ratio_lists
        }

Computing bootstrapping results for radar/table...


Radar Bootstrapping - Models:   0%|          | 0/1 [00:00<?, ?it/s]
[A

[A[A

[A[A

[A[A
[A

[A[A

[A[A

[A[A
[A

[A[A

[A[A

[A[A
[A

[A[A

[A[A

[A[A
[A

[A[A

[A[A

[A[A
[A

[A[A

[A[A

[A[A
[A

[A[A

[A[A

[A[A
[A

[A[A

[A[A

[A[A
[A
Radar Bootstrapping - Models: 100%|██████████| 1/1 [04:29<00:00, 269.69s/it]


In [9]:
import pickle
with open("../exp_results/eval/results.pkl", "wb") as f:
    pickle.dump(df_results, f)

In [10]:
import pickle
import pandas as pd
from tqdm import tqdm
from utils.evaluator.cot_evaluator import CoTEvaluator
import argparse
from utils.data_loader import load_dataset_results, load_api_key

In [11]:
def initialize_cot_evaluator(openai_api_key):
    cot_evaluator = CoTEvaluator(api_key=openai_api_key, point_out_model_name='gpt-4o')
    return cot_evaluator

In [22]:
def aggregate_sample(args, openai_api_key):
    # Instead of starting with a DataFrame, we'll start with an empty list
    rows_to_add = []
    cot_evaluator = initialize_cot_evaluator(openai_api_key)

    for model_name in tqdm(args.model_names, desc="Loading samples"):
        # Load the dataset results (make sure load_dataset_results handles caching or is fast)
        df_sample = pd.read_json(f"../exp_results/mix/test_{args.sample_size}_perturbed_{args.mode}_{args.model_names[0]}.jsonl", lines=True)
        df_sample["correct_answer"] = df_sample["correct_answer"].astype(str)
        # Using itertuples or iterrows can be considered; itertuples is slightly faster
        for idx, row in tqdm(df_sample.iterrows(), total=len(df_sample), desc=f"Processing {model_name}"):
            for source in ["human"]:
                for i in range(5):
                    try:
                        original_output = row["base_original"][f"c_{i}"]
                        output = row["ft_misinformed"]["human" if source == "human" else "self"][f"c_{i}"]
                        try:
                            overall_correct_original = cot_evaluator.answer_verifier(original_output[-1], row["correct_answer"])
                        except:
                            continue
                        # Append to rows_to_add if the original answer is correct
                        if overall_correct_original:
                            try:
                                overall_correct = cot_evaluator.answer_verifier(output[-1], row["correct_answer"])
                                rows_to_add.append({
                                    "question": row["question"],
                                    "correct_answer": row["correct_answer"],
                                    "premise": row["premise"][source],
                                    "perturbed_premise": row["perturbed_premise"][source],
                                    "overall_correct": overall_correct,
                                    "output": output,
                                    "model": model_name,
                                    "dataset": args.dataset_name,
                                })
                            except:
                                continue
                    except KeyError as e:
                        raise

    # Convert the accumulated rows into a DataFrame once
    df_aggregated = pd.DataFrame(rows_to_add)
    
    def filter_out_truncated_answer(row):
        return 'answer' in row['output'][-1]

    # Filter and sample
    df_aggregated = df_aggregated[df_aggregated["output"].apply(lambda x: x != [])]
    df_aggregated = df_aggregated[df_aggregated.apply(filter_out_truncated_answer, axis=1)]

    # Sample 400 from each column
    df_sample = df_aggregated.sample(n=args.sample_size, random_state=42)
    return df_sample

In [13]:
def perform_evaluation(df_sample, cot_evaluator, args):
    df_evaluated = df_sample.copy()
    df_evaluated["detection"] = df_evaluated.apply(lambda x: None, axis=1)
    df_evaluated["correction"] = df_evaluated.apply(lambda x: None, axis=1)
    df_evaluated["perturbation"] = df_evaluated.apply(lambda x: None, axis=1)

    # Use a tqdm progress bar
    with tqdm(total=len(df_sample), desc="Evaluating samples") as pbar:
        for idx, row in tqdm(df_sample.iterrows(), total=len(df_sample)):
            detection_label, detection_explanations = cot_evaluator.overall_point_out_error_verifier(row["question"], row["output"], row["perturbed_premise"])
            if detection_label:
                detection_positions = cot_evaluator.point_out_position_verifier(row["question"], row["output"], row["perturbed_premise"])
                correction_label, correction_explanations = cot_evaluator.overall_success_correction_verifier(row["question"], row["output"], row["premise"], row["perturbed_premise"])
            else:
                detection_positions = []
                correction_label, correction_explanations = False, detection_explanations
            perturbation_label, perturbation_explanations = cot_evaluator.overall_perturbation_verifier(row["question"], row["output"], row["premise"], row["perturbed_premise"])
                
            df_evaluated.at[idx, "detection"] = {"label": detection_label, "explanations": detection_explanations, "steps": detection_positions}
            df_evaluated.at[idx, "correction"] = {"label": correction_label, "explanations": correction_explanations, "steps": detection_positions}
            df_evaluated.at[idx, "perturbation"] = {"label": perturbation_label, "explanations": perturbation_explanations}
            
            # Get usage from cot_evaluator and update tqdm
            try:
                usage = cot_evaluator.get_usage()["gpt-4o"]
            except:
                usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
            pbar.set_postfix({
                "Prompt tokens": usage['prompt_tokens'], 
                "Completion tokens": usage['completion_tokens'], 
                "Total tokens": usage['total_tokens']
            })
            pbar.update(1)
            
    return df_evaluated

In [25]:
class Args():
    def __init__(self):
        self.sample_size = 400
        self.api_config_file_path = "../api_key/config_yiyang.json"
        self.dataset_name = "mix"
        self.model_names = ["gpt-4o-mini-sft"]
        self.mode = "premise"
        self.columns = ["base_original", "base_misinformed", "inst_original", "inst_misinformed", "ft_original", "ft_misinformed", "inst_ft_original", "inst_ft_misinformed"]
        
args = Args()
openai_api_key = load_api_key(args.api_config_file_path)
cot_evaluator = initialize_cot_evaluator(openai_api_key)
df_sample = aggregate_sample(args, openai_api_key)
df_evaluated = perform_evaluation(df_sample, cot_evaluator, args)

Processing gpt-4o-mini-sft: 100%|██████████| 400/400 [00:00<00:00, 1682.27it/s]
Loading samples: 100%|██████████| 1/1 [00:00<00:00,  2.28it/s]
100%|██████████| 400/400 [48:59<00:00,  7.35s/it]59<00:00,  6.94s/it, Prompt tokens=3164072, Completion tokens=75951, Total tokens=3240023]
Evaluating samples: 100%|██████████| 400/400 [48:59<00:00,  7.35s/it, Prompt tokens=3164072, Completion tokens=75951, Total tokens=3240023]


In [26]:
df_evaluated

Unnamed: 0,question,correct_answer,premise,perturbed_premise,overall_correct,output,model,dataset,detection,correction,perturbation
1706,"If $n$ is 1 less than a multiple of 50, what i...",2.0,$n = 50k - 1$; $n^2 + 2n + 3 = (50k - 1)^2 + 2...,$k = 40 + n - 1$; $k^2 \div 2k - 3 = (40n - 1)...,True,[The first formula from the user contains a mi...,gpt-4o-mini-sft,mix,"{'label': True, 'explanations': 'The model's s...","{'label': True, 'explanations': 'The model cor...","{'label': False, 'explanations': 'Step 1 detec..."
1199,a bullet train 150 m long is running with a sp...,15 sec,"$T = \frac{D}{S_{\text{relative}}}$, where $S_...","$T = \frac{S_{\text{relative}}}{D}$, where $S_...",True,[The first formula from the user contains a mi...,gpt-4o-mini-sft,mix,"{'label': True, 'explanations': 'The model cor...","{'label': True, 'explanations': 'The model cor...","{'label': False, 'explanations': 'Step 1 corre..."
1873,The sum of the $x$-coordinates of the vertices...,10.0,$\text{Sum of x-coordinates of vertices} = X$;...,$\text{Sum of x-coordinates of midpoints} = X$...,True,[The first formula from the user contains a mi...,gpt-4o-mini-sft,mix,"{'label': True, 'explanations': 'The model's s...","{'label': False, 'explanations': 'The model's ...","{'label': True, 'explanations': 'Step 7 is cor..."
276,"Ruby is 6 times older than Sam. In 9 years, Ru...",6,$\text{Ruby's age} = 6 \times \text{Sam's age}...,$\text{Sam's age} = 6 \times \text{Ruby's age}...,True,[The first formula from the user contains a mi...,gpt-4o-mini-sft,mix,"{'label': True, 'explanations': 'The model cor...","{'label': True, 'explanations': 'The model cor...","{'label': False, 'explanations': 'Step 1 detec..."
417,Antoine's french onion soup recipe calls for 2...,2,$\text{Onion cost} = 2 \times \text{Onion weig...,$\text{Onion cost} = 3 \div \text{Onion weight...,True,[The first formula from the user contains a mi...,gpt-4o-mini-sft,mix,"{'label': True, 'explanations': 'The model's s...","{'label': True, 'explanations': 'The model cor...","{'label': False, 'explanations': 'Step 1 detec..."
...,...,...,...,...,...,...,...,...,...,...,...
1341,the circumferences of the fore and hind - whee...,24 metres,$\text{Distance} = \text{lcm}\left(\text{Circu...,$\text{Distance} = \text{lcm}\left(\text{Circu...,True,[The given formula from the user contains a mi...,gpt-4o-mini-sft,mix,"{'label': True, 'explanations': 'The model cor...","{'label': True, 'explanations': 'The model cor...","{'label': False, 'explanations': 'Step 1 corre..."
676,The expression $3x^2 + 14x + 8$ can be written...,-2.0,$3x^2 + 14x + 8 = (3x + A)(x + B)$; $A = \frac...,$3x^2 \div 15x \div 8 = (3 + A)(B \div x)$; $A...,True,[The first formula from the user contains a mi...,gpt-4o-mini-sft,mix,"{'label': True, 'explanations': 'The model's s...","{'label': False, 'explanations': 'The model's ...","{'label': False, 'explanations': 'Step 1 detec..."
1821,Find four-thirds of $\frac{9}{2}$.,6.0,$\frac{4}{3} \times \frac{9}{2}$,$4 + 30 \times \frac{9}{1}$,True,[The given formula from the user contains a mi...,gpt-4o-mini-sft,mix,"{'label': True, 'explanations': 'The first ste...","{'label': True, 'explanations': 'The model cor...","{'label': False, 'explanations': 'Step 1 corre..."
614,The equation of the circle that passes through...,-120.0,$(x-h)^2+(y-k)^2=r^2$; $(x-2)^2+(y-3)^2=(-1-2)...,$(y-k)^2 \times (x-h)^2 = r^2$; $(y-2)^2 \time...,True,[The first formula from the user contains a mi...,gpt-4o-mini-sft,mix,"{'label': True, 'explanations': 'The model's s...","{'label': True, 'explanations': 'The model cor...","{'label': False, 'explanations': 'Step 1 detec..."


## Plot Sankey Graph

In [27]:
from utils.plotters.sankey_plot import plot_sankey_data, plot_position_distribution

import matplotlib.pyplot as plt

In [29]:
df_evaluated.to_json("../exp_results/mix/test_400_premise_evaluated_gpt-4o-mini-sft.jsonl", orient="records", lines=True)

In [28]:
# Ensure 'overall_correct_original', 'overall_correct' and 'point_out_error' are booleans
df_evaluated['overall_correct'] = df_evaluated['overall_correct'].astype(bool)
df_evaluated['detection'] = df_evaluated['detection'].apply(lambda x: x['label'])
df_evaluated['correction'] = df_evaluated['correction'].apply(lambda x: x['label'])

df_evaluated['Detect'] = df_evaluated['detection'].map({True: 'Corr\\n', False: 'N-Corr    \\n'})

# Three-state mapping for 'Correct' and 'CorrectPrompt'
def map_correct(row):
    if not row['detection']:
        return 'N-Corr\\n'
    elif row['detection'] and not row['correction']:
        return 'NF-Corr\\n'
    elif row['detection'] and row['correction']:
        return 'F-Corr\\n'
    
df_evaluated['Correct'] = df_evaluated.apply(map_correct, axis=1)
df_evaluated['Overall'] = df_evaluated['overall_correct'].map({True: '✅ Answer\\n', False: '❎ Answer\\n'})

# Group by 'Original', 'Overall' and 'ErrorPoint' to get counts
grouped = df_evaluated.groupby(['Detect', 'Correct', 'Overall']).size().reset_index(name='counts')
html_template = plot_sankey_data(['Detect', 'Correct', 'Overall'], grouped)

ratio = grouped[(grouped['Correct'] == 'F-Corr\\n') & (grouped['Overall'] == '✅ Answer\\n')].counts.sum() / grouped[grouped['Correct'] == 'F-Corr\\n'].counts.sum()
print(f"{ratio*100:.2f}% of the successful corrections have a correct overall answer.")
ratio = 1 - grouped[(grouped['Correct'] == 'F-Corr\\n') & (grouped['Overall'] == '❎ Answer\\n')].counts.sum() / grouped[grouped['Overall'] == '❎ Answer\\n'].counts.sum()
print(f"{ratio*100:.2f}% of the overall inaccuracies have a failed or No Identification.")

# Save the html file. I have no idea to save it as a pdf.
with open(f"../figures/sankey_finetune_gpt4o.html", "w") as html_file:
    html_file.write(html_template)

95.02% of the successful corrections have a correct overall answer.
42.31% of the overall inaccuracies have a failed or No Identification.
