<a href="https://colab.research.google.com/github/BorowskiKacper/llm_multiagent_debate/blob/main/mmlu_Multi_Agent_Review_Board.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# README

# Multi agent review board notebook using "mmlu" data (Source code has "geography", "gsm", "math", and "mmlu" data)
# Source code: https://github.com/composable-models/llm_multiagent_debate

# CONFIGURE parameters like which models to use and agent configs using fourth code block/cell.

In [None]:
# Install depencies and authenticate with hugging face to access gated models
!pip install transformers accelerate -q

# Authenticate with Hugging Face
from huggingface_hub import login
login()  # You'll need a HF token

In [None]:
# CONFIGURE rounds, questions, pipelines, and agent configs in this cell
from transformers import pipeline
import torch

# =====ROUNDS & QUESTIONS=====
rounds = 2
questions = 100

# =====GENERATION PIPELINES=====
# NOTE: You can create pipelines from multiple models, not just one.
# But beware that you have a limited amount of RAM and VRAM to work withRuntime -> View Resources
model_id="meta-llama/Llama-3.2-1B-Instruct"
pipe = pipeline(
    "text-generation",
    model=model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

# =====AGENTS=====
# Currently each agent uses the same model/pipeline, but different temperatures
agent_configs = [{"pipe": pipe, "temp": 0.1}, {"pipe": pipe, "temp": 0.7}, {"pipe": pipe, "temp": 1.5}]

In [None]:
from glob import glob
import pandas as pd
import json
import time
import random


In [None]:
# Load MMLU data and store data frames
!curl "https://people.eecs.berkeley.edu/~hendrycks/data.tar" | tar -xvf -

tasks = glob("./data/test/*.csv")
dfs = [pd.read_csv(task) for task in tasks]

len(dfs)


In [None]:
def construct_message(agents, question, idx):
    if len(agents) == 0:
        return {"role": "user", "content": "Can you double check that your answer is correct. Put your final answer in the form (X) at the end of your response."}

    prefix_string = "These are the solutions to the problem from other agents: "

    for agent in agents:
        agent_response = agent[idx]["content"]
        response = "\n\n One agent solution: ```{}```".format(agent_response)

        prefix_string = prefix_string + response

    prefix_string = prefix_string + """\n\n Using the reasoning from other agents as additional advice, can you give an updated answer? Examine your solution and that other agents step by step. Put your answer in the form (X) at the end of your response.""".format(question)
    return {"role": "user", "content": prefix_string}


def construct_assistant_message(completion):
    # content = completion["choices"][0]["message"]["content"]
    content = completion[0]["generated_text"][-1]["content"]
    return {"role": "assistant", "content": content}


def generate_answer(agent_config, agent_context):
    # try:

    #     completion = openai.ChatCompletion.create(
    #               model="gpt-3.5-turbo-0301",
    #               messages=answer_context,
    #               n=1)
    # except:
    #     print("retrying due to an error......")
    #     time.sleep(20)
    #     return generate_answer(answer_context)
    completion = agent_config["pipe"](
      agent_context,
      max_new_tokens=1024,
      do_sample=True,
      temperature=agent_config["temp"],
      top_p=0.9,
    )


    return completion


def parse_question_answer(df, ix):
    question = df.iloc[ix, 0]
    a = df.iloc[ix, 1]
    b = df.iloc[ix, 2]
    c = df.iloc[ix, 3]
    d = df.iloc[ix, 4]

    question = "Can you answer the following question as accurately as possible? {}: A) {}, B) {}, C) {}, D) {} Explain your answer, putting the answer in the form (X) at the end of your response.".format(question, a, b, c, d)

    answer = df.iloc[ix, 5]

    return question, answer



In [None]:
random.seed(0)
response_dict = {}

for i in range(questions):
    df = random.choice(dfs)
    ix = len(df)
    idx = random.randint(0, ix-1)

    question, answer = parse_question_answer(df, idx)

    # agent_contexts = [[{"role": "user", "content": question}] for agent in range(agents)]
    agent_contexts = [[{"role": "user", "content": question}] for agent in range(len(agent_configs))]
    print(f"Question {i+1}/{questions}: {question}")
    print(f"Answer: {answer}")

    for round in range(rounds):
        for i, agent_context in enumerate(agent_contexts):
            print(f"Round: {round + 1}/{rounds} | Agent: {i+1}/{len(agent_configs)}")
            if round != 0:
                agent_contexts_other = agent_contexts[:i] + agent_contexts[i+1:]
                message = construct_message(agent_contexts_other, question, 2 * round - 1)
                agent_context.append(message)

            completion = generate_answer(agent_configs[i], agent_context)

            assistant_message = construct_assistant_message(completion)
            agent_context.append(assistant_message)
            print(assistant_message)

    response_dict[question] = (agent_contexts, answer)

json.dump(response_dict, open("mmlu_{}_{}.json".format(len(agent_configs), rounds), "w"))


Eval

In [35]:
import sklearn
import json
from collections import Counter
import re

In [39]:
def parse_answer(input_str):
    pattern = r'\((\w)\)'
    matches = re.findall(pattern, input_str)

    solution = None

    for match_str in matches[::-1]:
        solution = match_str.upper()
        if solution:
            break

    return solution


def parsed_answers(pred_solutions):
    if type(pred_solutions) == list:
        pred_answers = []

        for pred_solution in pred_solutions:
            pred_answer = parse_answer(pred_solution)
            pred_answers.append(pred_answer)

    else:
        pred_answers = [parse_answer(pred_solutions)]

    return pred_answers


def most_frequent(letters):
  if not letters:
    raise ValueError("Expected non-empty list.")

  return Counter(letters).most_common(1)[0][0]

In [62]:
class AgentStats:
  def __init__(self):
    self.counts = {}
    self.correct_count = 0

  def add_stat(self, predicted_answer, gt):
    self.counts[predicted_answer] = self.counts.get(predicted_answer, 0) + 1
    self.correct_count += predicted_answer == gt

  def __str__(self):
    return f"Counts: {self.counts}\nCorrect: {self.correct_count}\n========================="


In [66]:
# Load Inference results
response_dict = json.load(open("mmlu_3_2.json", "r"))
questions = list(response_dict.keys())

# Store individual answers, consensus answers, and ground truth for each question
results = []
for question in questions:
    responses, gt = response_dict[question]

    pred_solutions = []
    for response in responses:
        pred_solution = response[-1]['content']
        pred_solutions.append(pred_solution)


    pred_answers = parsed_answers(pred_solutions) # Individual answers
    consensus = most_frequent(pred_answers) # Consensus
    res = [pred_answers, consensus, gt]
    results.append(res)

# Parse stats for each agent answer and consensus answers
agents = [AgentStats() for _ in range(len(results[0]))]
group = AgentStats()
for answers, consensus, gt in results:
  # Individual eval
  for i, answer in enumerate(answers):
    agents[i].add_stat(answer, gt)

  # Group eval
  group.add_stat(consensus, gt)

print("---------------Individual Agents---------------")
for agent in agents:
  print(agent)

print("---------------Group results---------------")
print(group)

print(f"Total Questions: {len(questions)}")


---------------Individual Agents---------------
Counts: {'A': 8, 'C': 5, 'B': 4, 'X': 51, 'D': 9, None: 15, 'E': 2, 'I': 1, 'Y': 1, '1': 1, 'T': 1, 'G': 1}
Correct: 9
Counts: {'X': 42, None: 17, 'B': 7, 'Y': 1, 'D': 12, 'C': 12, 'A': 5, '1': 1, 'W': 1, 'E': 1}
Correct: 15
Counts: {'A': 12, 'D': 11, None: 21, 'X': 29, '4': 1, 'B': 12, 'C': 5, 'I': 1, '2': 2, 'W': 2, '8': 1, 'F': 1, 'Î’': 1}
Correct: 14
---------------Group results---------------
Counts: {'A': 9, 'C': 6, 'B': 7, 'X': 50, 'D': 7, None: 15, 'E': 1, 'I': 1, 'Y': 1, '1': 1, 'W': 1}
Correct: 13
Total Questions: 99
