<a href="https://colab.research.google.com/github/JayTiptown/self-consistency-ensemble/blob/self-consistency/self_consistency_ensemble.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [14]:
# imports

import asyncio
import random
from collections import defaultdict
import nest_asyncio
nest_asyncio.apply()  # Needed to run asyncio in Colab
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import pandas as pd
from google.colab import userdata

In [15]:
# model import

!pip install openai
!pip install anthropic
!pip install google-genai

import openai
import anthropic
import google.genai

openai.api_key = userdata.get("OPENAI_API_KEY")
anthropic.api_key = userdata.get("ANTHROPIC_API_KEY")
# google.genai.api_key = userdata.get("GEMINI_API_KEY")



In [16]:
# simulated chain of thought sample

async def sample_chain(prompt, model_spec):

  provider, model = model_spec

  full_prompt = prompt + "\nLet's think step by step.\nAnswer:"

  if provider == 'openai':

    # Run the blocking call in a thread so we don't block asyncio
    # Use openai.chat.completions.create for chat models like gpt-4
    response = await asyncio.to_thread(
        openai.chat.completions.create,  # Changed to chat.completions.create
        model="gpt-4",     # or any completions model
        messages=[{"role": "user", "content": full_prompt}],  # Added messages parameter
        temperature=0.7,
        max_tokens=150,
        # logprobs=1
    )

    choice = response.choices[0]
    trace = choice.message.content  # Access content from message
    # crude extraction: last line is the answer
    # answer = trace.strip().split("\n")[-1]
    def extract_final_answer(text):
      for line in reversed(text.strip().splitlines()):
          if line.lower().startswith("answer:"):
              return line.split(":", 1)[1].strip()
      return text.strip().splitlines()[-1].strip()

    answer = extract_final_answer(trace)

  # fill these out for relevant providers

  elif provider == 'anthropic':
    pass
  elif provider == 'gemini':
    pass
  else:
    raise ValueError(f"Unknown provider: {provider}")

  total_logprob = 0 # Placeholder logprobs isn't directly supported for chat models.
  return answer, total_logprob

In [17]:
# adaptive controller
def controller(prompt, max_budget=5):
  """
  Enforce a fixed sampling budget.
  """
  return max_budget

In [18]:
# parallel orchestrator

async def orchestrate_samples(prompt, budget, model_spec):
  tasks = [
    asyncio.create_task(sample_chain(prompt, model_spec))
    for _ in range(budget)
  ]
  return await asyncio.gather(*tasks)

In [19]:
# aggregator

from collections import Counter

def aggregate_votes(samples):
    """
    samples: list of (answer_str, log_prob_float) – ignores log_prob_float
    Returns the answer that appears most frequently (majority vote).
    """
    answers = [answer for answer, _ in samples]
    counts = Counter(answers)
    winner = counts.most_common(1)[0][0]
    return winner, dict(counts)

In [26]:
# self-consistency run
def run_self_consistency(prompt, model_spec, max_budget):
    budget = controller(prompt, max_budget)
    samples = asyncio.run(orchestrate_samples(prompt, budget, model_spec))
    winner, counts = aggregate_votes(samples)

    print(f"Prompt:\n{prompt}\n")
    print("Sampled Answers:")
    for i, (answer, _) in enumerate(samples, 1):
        print(f"{i}: {answer}")

    print("\nAnswer Frequencies:")
    for ans, count in counts.items():
        print(f"{repr(ans)}: {count}")

    print("\n Correct Answer: ")


    print(f"\nFinal Answer (Majority Vote):\n{winner}")
    return winner, counts

# Example usage
model_spec = ('openai', 'gpt-4')
run_self_consistency("Who invented the atomic bomb?", model_spec, max_budget=20)

Prompt:
Who invented the atomic bomb?

Sampled Answers:
1: The atomic bomb was developed by a team of scientists during the Manhattan Project, a research effort during World War II. The project was led by the United States with support from the United Kingdom and Canada. Key figures included Robert Oppenheimer, Enrico Fermi, and Richard Feynman.
2: J. Robert Oppenheimer and his team at the Manhattan Project.
3: J. Robert Oppenheimer and his team during the Manhattan Project.
4: The atomic bomb was developed by a team of scientists during the Manhattan Project, a research program led by the United States with assistance from the United Kingdom and Canada during World War II. However, the theoretical foundation was laid by many scientists, including Albert Einstein and Enrico Fermi. The director of the project was American physicist Robert J. Oppenheimer.
5: The atomic bomb was primarily developed by a team of scientists and engineers as part of the United States' Manhattan Project durin

('The atomic bomb was developed by a team of scientists during the Manhattan Project, a research effort during World War II. The project was led by the United States with support from the United Kingdom and Canada. Key figures included Robert Oppenheimer, Enrico Fermi, and Richard Feynman.',
 {'The atomic bomb was developed by a team of scientists during the Manhattan Project, a research effort during World War II. The project was led by the United States with support from the United Kingdom and Canada. Key figures included Robert Oppenheimer, Enrico Fermi, and Richard Feynman.': 1,
  'J. Robert Oppenheimer and his team at the Manhattan Project.': 1,
  'J. Robert Oppenheimer and his team during the Manhattan Project.': 1,
  'The atomic bomb was developed by a team of scientists during the Manhattan Project, a research program led by the United States with assistance from the United Kingdom and Canada during World War II. However, the theoretical foundation was laid by many scientists, 

In [33]:
# prompt: What would I add to the code to support testing self-consistency with the HaluEval QA dataset? The data will be in a json format, and the file is called qa_data.json
# TODO add knowledge later
import json
import linecache

def load_halueval_data(filepath="/content/qa_data.json", num_rows=1000):
    """
    Loads the HaluEval QA dataset from a JSON file.

    Args:
        filepath (str): The path to the JSON file containing the dataset.

    Returns:
        list: A list of dictionaries, where each dictionary represents a Q&A pair
              from the dataset.
    """
    data = []

    with open(filepath, 'r') as f:
        for r in range(num_rows):
            # Skip empty lines
            line = linecache.getline(filepath, r + 1)
            line = line.strip()
            if not line:
                continue
            try:
                # Load each non-empty line as a separate JSON object
                data.append(json.loads(line))
            except json.JSONDecodeError as e:
                print(f"Skipping line due to JSONDecodeError: {line} - {e}")
                continue # Skip lines that are not valid JSON
    return data

def evaluate_self_consistency_on_halueval(model_spec, max_budget=20, data_filepath="qa_data.json", num_rows=1000):
    """
    Evaluates self-consistency for each question in the HaluEval QA dataset.

    Args:
        model_spec (tuple): A tuple specifying the language model provider and model name
                            (e.g., ('openai', 'gpt-4')).
        max_budget (int): The maximum number of samples to generate for self-consistency.
        data_filepath (str): The path to the JSON file containing the HaluEval dataset.

    Returns:
        list: A list of dictionaries, where each dictionary contains the original
              Q&A pair from the dataset, the self-consistency winner answer, and
              the frequency counts of the sampled answers.
    """
    halueval_data = load_halueval_data(data_filepath, num_rows)
    results = []

    for item in halueval_data:
        question = item['question']
        # Assuming the dataset has an 'answer' key for the ground truth
        ground_truth_answer = item.get('right_answer')

        print(f"\nProcessing Question: {question}")
        print(f"Ground Truth Answer: {ground_truth_answer}")

        # Run self-consistency
        winner, counts = run_self_consistency(question, model_spec, max_budget)

        result_item = {
            'question': question,
            'ground_truth_answer': ground_truth_answer,
            'self_consistency_winner': winner,
            'sampled_answer_counts': counts
        }
        results.append(result_item)

        # TODO compare with correct answer and then put correct or incorrect

    return results

# Example usage to run evaluation on the HaluEval dataset:
# Make sure you have the qa_data.json file in the correct path or provide the full path.
model_spec = ('openai', 'gpt-4')
halueval_results = evaluate_self_consistency_on_halueval(model_spec, max_budget=10, data_filepath="qa_data.json", num_rows=100)

# You can then analyze halueval_results to see how often the self-consistency
# winner matches the ground truth or examine the distribution of answers.
# For instance, you could calculate the accuracy:
cond = lambda res: res['self_consistency_winner'].lower() == res['ground_truth_answer'].lower() or res['self_consistency_winner'].lower() in res['ground_truth_answer'].lower()
correct_predictions = sum(1 for res in halueval_results if cond(res))
accuracy = correct_predictions / len(halueval_results)
print(f"\nSelf-Consistency Accuracy on HaluEval: {accuracy:.2f}")



Processing Question: Which magazine was started first Arthur's Magazine or First for Women?
Ground Truth Answer: Arthur's Magazine
Prompt:
Which magazine was started first Arthur's Magazine or First for Women?

Sampled Answers:
1: Arthur's Magazine
2: Arthur's Magazine
3: Arthur's Magazine
4: Arthur's Magazine
5: Arthur's Magazine
6: Arthur's Magazine
7: Arthur's Magazine
8: Arthur's Magazine
9: Arthur's Magazine
10: Arthur's Magazine

Answer Frequencies:
"Arthur's Magazine": 10

Final Answer (Majority Vote):
Arthur's Magazine

Processing Question: The Oberoi family is part of a hotel company that has a head office in what city?
Ground Truth Answer: Delhi
Prompt:
The Oberoi family is part of a hotel company that has a head office in what city?

Sampled Answers:
1: Delhi
2: Delhi
3: Delhi
4: Delhi
5: Delhi
6: Delhi
7: Delhi
8: Delhi
9: Delhi
10: Delhi

Answer Frequencies:
'Delhi': 10

Final Answer (Majority Vote):
Delhi

Processing Question: Musician and satirist Allie Goertz wrote a s