<a href="https://colab.research.google.com/github/122004/122004/blob/main/Copy_of_habermas_machine_newspeak_event_250118.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Exploring Deliberative Democracy with the Habermas Machine

This colab contains the code for running a version of the Habermas Machine based on a series of prompted Gemini models, for use at the Newspeak House event on Jan 18, 2025.

The purpose of this notebook is to demonstrate the workings of the Habermas Machine, using publicly accessible models rather than the custom fine-tuned model. The Habermas Machine was introduced in this paper:

Tessler, M. H., Bakker, M. A., Jarrett D., Sheahan, H., Chadwick, M. J., Koster, R., Evans, G., Campbell-Gillingham, J., Collins, T., Parkes, D. C., Botvinick, M., and Summerfield, C. "AI can help humans find common ground in democratic deliberation." *Science*. (2024). [[url]](https://www.science.org/stoken/author-tokens/ST-2196/full)

Colab by Michiel Bakker (miba@google.com) and MH Tessler (mhtessler@google.com)


# Setup and importing packages and libraries

In [None]:
# @title Imports

import abc
import enum
import os
import re
import sys
import textwrap
import time
from collections.abc import Collection, Mapping, Sequence
from typing import NamedTuple

from IPython.display import Javascript, Markdown

from typing_extensions import override

import google.generativeai as genai
import numpy as np
from tenacity import retry, stop_after_attempt, wait_fixed

from IPython.display import HTML
shell = get_ipython()

def adjust_font_size():
  display(HTML('''<style>
    body {
      font-size: 14px;
    }
  '''))

if adjust_font_size not in shell.events.callbacks['pre_execute']:
  shell.events.register('pre_execute', adjust_font_size)

# Adjust font size dynamically
def change_input_font_size(size=20):
    display(Javascript(f"""
        var elements = document.getElementsByClassName('input');
        for (var i = 0; i < elements.length; i++) {{
            elements[i].style.fontSize = '{size}px';
        }}
    """))

change_input_font_size(40)  # Adjust the size as needed

In [None]:
# @title LLM Client utils for colab.

def truncate(
    string: str,
    *,
    max_length: int = sys.maxsize,
    delimiters: Collection[str] = (),
) -> str:
  """Truncates a string to a maximum length up to a delimiter.

  Args:
    string: String to truncate
    max_length: Maximum length of the string.
    delimiters: Delimiters that must not be present in the truncated string.

  Returns:
    The longest prefix of string that does not exceed max_length and does not
    contain any delimiter.
  """
  truncated = string[:max_length]
  for delimiter in delimiters:
    truncated = truncated.split(delimiter, 1)[0] + delimiter
  return truncated

In [None]:
# @title Base class for LLM clients.

DEFAULT_TEMPERATURE = 0.8
DEFAULT_TERMINATORS = ()
DEFAULT_TIMEOUT_SECONDS = 60
# We truncate the response if we detect the terminator string before the max
# tokens so we set a high default value for max tokens.
DEFAULT_MAX_TOKENS = 4096

class LLMClient(abc.ABC):
 """Language model client base class."""

 @abc.abstractmethod
 def sample_text(
     self,
     prompt: str,
     *,
     max_tokens: int = DEFAULT_MAX_TOKENS,
     terminators: Collection[str] = DEFAULT_TERMINATORS,
     temperature: float = DEFAULT_TEMPERATURE,
     timeout: float = DEFAULT_TIMEOUT_SECONDS,
     seed: int | None = None,
 ) -> str:
   """Samples text from the model.

   Args:
     prompt: The input text that the model conditions on.
     max_tokens: The maximum number of tokens in the response.
     terminators: The response will be terminated before any of these
       characters.
     temperature: Model temperature.
     timeout: Timeout for the request.
     seed: Optional seed for the sampling. If None a random seed will be used.

   Returns:
     The sampled response (i.e. does not include the prompt).

   Raises:
     TimeoutError: if the operation times out.
   """
   raise NotImplementedError('sample_text method is not implemented.')

In [None]:
# @title Language Model that uses Google AI Studio API.

DEFAULT_SAFETY_SETTINGS = (
   {
       'category': 'HARM_CATEGORY_HARASSMENT',
       'threshold': 'BLOCK_ONLY_HIGH',
   },
   {
       'category': 'HARM_CATEGORY_HATE_SPEECH',
       'threshold': 'BLOCK_ONLY_HIGH',
   },
   {
       'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
       'threshold': 'BLOCK_ONLY_HIGH',
   },
   {
       'category': 'HARM_CATEGORY_DANGEROUS_CONTENT',
       'threshold': 'BLOCK_ONLY_HIGH',
   },
)

class AIStudioClient(LLMClient):
 """Language Model that uses AI Studio API."""

 def __init__(
     self,
     model_name: str,
     *,
     safety_settings: Sequence[Mapping[str, str]] = DEFAULT_SAFETY_SETTINGS,
     sleep_periodically: bool = False,
 ) -> None:
   """Initializes the instance.

   Args:
     model_name: which language model to use. For more details, see
       https://aistudio.google.com/.
     safety_settings: Gemini safety settings. For more details, see
       https://ai.google.dev/gemini-api/docs/safety.
     sleep_periodically: sleep between API calls to avoid rate limit.
   """
   self._api_key = os.environ['GOOGLE_API_KEY']
   self._model_name = model_name
   self._safety_settings = safety_settings
   self._sleep_periodically = sleep_periodically

   genai.configure(api_key=self._api_key)
   self._model = genai.GenerativeModel(
       model_name=self._model_name,
       safety_settings=safety_settings,
   )

   self._calls_between_sleeping = 10
   self._n_calls = 0

 @override
 @retry(stop=stop_after_attempt(3), wait=wait_fixed(10))
 def sample_text(
     self,
     prompt: str,
     *,
     max_tokens: int = DEFAULT_MAX_TOKENS,
     terminators: Collection[str] = DEFAULT_TERMINATORS,
     temperature: float = DEFAULT_TEMPERATURE,
     timeout: float = DEFAULT_TIMEOUT_SECONDS,
     seed: int | None = None,
 ) -> str:
   del timeout
   del seed  # AI Studio does not support seeds.

   self._n_calls += 1
   if self._sleep_periodically and (
       self._n_calls % self._calls_between_sleeping == 0):
     print('Sleeping for 10 seconds...')
     time.sleep(10)

   sample = self._model.generate_content(
       prompt,
       generation_config=genai.GenerationConfig(
           temperature=temperature,
           max_output_tokens=max_tokens,
           stop_sequences=terminators,
       ),
       safety_settings=self._safety_settings,
       stream=False,
   )
   try:
     # AI Studio returns a list of parts, but we only use the first one.
     response = sample.candidates[0].content.parts[0].text
   except ValueError as e:
     print('An error occurred: ', e)
     print(f'prompt: {prompt}')
     print(f'sample: {sample}')
     response = ''
   return truncate(response, delimiters=terminators)

In [None]:
# @title Base class for reward models.

RankingResult = NamedTuple(
   'RankingResult',
   [
       ('ranking', np.ndarray | None),
       ('explanation', str | None),
   ],
)


class BaseRankingModel(abc.ABC):
 """Base class for reward models that rank multiple statements."""

 @abc.abstractmethod
 def predict_ranking(
     self,
     llm_client: LLMClient,
     question: str,
     opinion: str,
     statements: Sequence[str],
     previous_winner: str | None = None,
     critique: str | None = None,
 ) -> RankingResult:
   """Samples text from the model.

   Args:
     llm_client: LLM client to use.
     question: Question that the citizen is responding to.
     opinion: Text-based opinion of the citizen.
     statements: Statements that are ranked.
     previous_winner: The statement that won the previous round.
     critique: Critique of the previous winner.

   Returns:
     A RankingResult tuple, consisting of:
       - Array of rankings with dimensions: [num_statements,]. In this array
         lower is better and the best candidate is thus given rank 0. For
         example, an array [0, 1, 0] corresponds to the first citizen
         preferring candidates 0 and 2 over candidate 1, while preferring
         candidates 0 and 2 equally. If the model has an error, None is
         returned.
       - Explanation for the ranking (for example the raw output including
         the chain-of-thought) or the error. None if there is no explanation.
   """
   raise NotImplementedError('predict_ranking method is not implemented.')

In [None]:
# @title [HAS PROMPT] A ranking model that uses chain-of-thought reasoning.

class COTRankingModel(BaseRankingModel):
 """A ranking model that uses chain-of-thought reasoning to rank statements."""

 @override
 def predict_ranking(
     self,
     llm_client: LLMClient,
     question: str,
     opinion: str,
     statements: Sequence[str],
     previous_winner: str | None = None,
     critique: str | None = None,
     num_retries_on_error: int = 5,
 ) -> RankingResult:
   """Ranks statements based on their length (see base class)."""
   if previous_winner is None and critique is not None:
     raise ValueError(
         'If there is a previous_winner, there should be a critique.'
     )
   if len(statements) < 2:
     raise ValueError('There should be at least two statements to rank.')

   for i in range(num_retries_on_error):
    prompt = _generate_rm_prompt(
        question, opinion, statements, previous_winner, critique
    )
    response = llm_client.sample_text(prompt, terminators=['</answer>'])

    ranking, explanation = _process_rm_model_response(response)

    if len(ranking) != len(statements):
      error_msg = 'INCORRECT_RANKING_LENGTH'
      if explanation:
        error_msg += f', Explanation: {explanation}'
      return RankingResult(None, error_msg)
    if ranking is not None and 'INCORRECT' not in explanation:
      break
   return RankingResult(ranking, explanation)


def _generate_rm_opinion_critique_prompt(
   question: str,
   opinion: str,
   statements: Sequence[str],
   previous_winner: str,
   critique: str,
) -> str:
 """Generates a prompt for the LLM using opinion and critique."""
 prompt = f"""As an AI assistant, your job is to rank these statements in the order that the participant would most likely agree with them, based on their opinion and critique to a summary statement from a previous discussion round. Use Arrow notation for the ranking, where ">" means "preferred to". Ties are NOT allowed and items should be in descending order of preference so you can ONLY use ">" and the letters of the statements in the ranking. Examples of valid rankings: B > A, D > A > C > B. B > C > A > E > D.

Please think through this task step-by-step:

1. Analyze the participant's opinion and critique, noting key points and sentiments.
2. Analyze the critique to the summary statement from the previous discussion round.
3. Compare each statement to the participant's opinion and critique, considering how well it aligns with or supports their view.
4. Consider any nuances or implications in the statements that might appeal to or repel the participant based on their expressed opinion.
5. Rank the statements accordingly using only ">" and the letters of the statements.

Provide your answer in the following format:
<answer>
[Your step-by-step reasoning and explanation for the ranking]
<sep>
[Final ranking using arrow notation]
</answer>

For example for five statements A, B, C, D and E a valid response could be:
<answer>
1. The participant's opinion emphasizes the importance of environmental protection and the need for immediate action to address climate change. The critique of the previous winner highlights that it failed to offer specific solutions.

2. The critique emphasizes the need for concrete solutions to address climate change, indicating that the participant values action-oriented approaches.

3. - Statement A directly addresses the urgency of climate action and proposes concrete steps, aligning with both the participant's opinion and critique.
 - Statements B and D acknowledge the seriousness of climate change but offer less concrete solutions. B focuses on global cooperation, while D emphasizes economic considerations.
 - Statement C downplays the urgency of climate change, contradicting the participant's stance.
 - Statement E completely opposes the participant's view by denying the existence of climate change.

4.  The participant's emphasis on immediate action suggests a preference for proactive solutions and a dislike for approaches that downplay the issue or offer only abstract ideas.

5. Based on this analysis, the ranking is: A > D > B > C > E

<sep>
A > D > B > C > E
</answer>

It is important to follow the template EXACTLY. So ALWAYS start with <answer>, then the explanation, then <sep> then only the final ranking and then </answer>.

Below you will find the question, the participant's opinion, the statement from the previous round, and a critique of that statement. You will also find a list of statements to rank.

Question: {question}

Participant's Opinion: {opinion}

Statement from previous round: {previous_winner}

Critique: {critique}

Statements to rank:
"""
 for i, statement in enumerate(statements):
   letter = chr(ord('A') + i)  # A, B, C, D, etc.
   try:
     statement = (
         statement.strip().strip('').strip('""').strip('\n').strip()
     )
   except Exception as exc:
     raise ValueError(f'Issue with statement: {statement}') from exc
   prompt += f'{letter}. {statement}\n'

 return prompt.strip()


def _generate_rm_opinion_only_prompt(
   question: str,
   opinion: str,
   statements: Sequence[str],
) -> str:
 """Generates a prompt for the LLM using only the opinion."""
 prompt = f"""
Task: As an AI assistant, your job is to rank these statements in the order that the participant would most likely agree with them, based on their opinion. Use Arrow notation for the ranking, where ">" means "preferred to". Ties are NOT allowed and items should be in descending order of preference so you can ONLY use ">" and the letters of the statements in the final ranking. Examples of valid final rankings: B > A, D > A > C > D. B > C > A > E > D.

Please think through this task step-by-step:

1. Analyze the participant's opinion, noting key points and sentiments.
2. Compare each statement to the participant's opinion, considering how well it aligns with or supports their view.
3. Consider any nuances or implications in the statements that might appeal to or repel the participant based on their expressed opinion.
4. Rank the statements accordingly using only ">" and the letters of the statements.

Provide your answer in the following format:
<answer>
[Your step-by-step reasoning and explanation for the ranking]
<sep>
[Final ranking using arrow notation]
</answer>

For example for five statements A, B, C, D and E a valid response could be:
<answer>
1. The participant's opinion emphasizes the importance of environmental protection and the need for immediate action to address climate change.

2. - Statement A directly addresses the urgency of climate action and proposes concrete steps, aligning with the participant's opinion.
  - Statements B and D acknowledge the seriousness of climate change but offer less concrete solutions. B focuses on global cooperation, while D emphasizes economic considerations.
  - Statement C downplays the urgency of climate change, contradicting the participant's stance.
  - Statement E completely opposes the participant's view by denying the existence of climate change.

3.  The participant's emphasis on immediate action suggests a preference for proactive solutions and a dislike for approaches that downplay the issue or offer only abstract ideas.

4. Based on this analysis, the ranking is: A > D > B > C > E

<sep>
A > D > B > C > E
</answer>

It is important to follow the template EXACTLY. So ALWAYS start with <answer>, then the explanation, then <sep> then only the final ranking and then </answer>.

Below you will find the question and the participant's opinion. You will also find a list of statements to rank.

Question: {question}

Participant's Opinion: {opinion}

Statements to rank:
"""
 for i, statement in enumerate(statements):
   letter = chr(ord('A') + i)  # A, B, C, D, etc.
   try:
     statement = (
         statement.strip().strip('').strip('""').strip('\n').strip()
     )
   except Exception as exc:
     raise ValueError(f'Issue with statement: {statement}') from exc
   prompt += f'{letter}. {statement}\n'

 return prompt.strip()


def _generate_rm_prompt(
   question: str,
   opinion: str,
   statements: Sequence[str],
   previous_winner: str | None = None,
   critique: str | None = None,
) -> str:
 """Generates a prompt for the LLM."""
 if previous_winner is None:
   return _generate_rm_opinion_only_prompt(
       question, opinion, statements
   )
 else:
   return _generate_rm_opinion_critique_prompt(
       question, opinion, statements, previous_winner, critique
   )


def _check_response_format(response: str) -> bool:
 """Checks if the response is in a correct format with <answer> and <sep>.

 Args:
   response: The model's raw response

 Returns:
   bool: True if the format is correct, False otherwise
 """
 pattern = r'<answer>\s*.*?\s*<sep>\s*.*?\s*</answer>'
 return bool(re.search(pattern, response, re.DOTALL))


def _check_arrow_format(arrow_ranking):
 """Checks if the arrow ranking format is correct.

 Args:
   arrow_ranking: The arrow ranking string (eg A > B > C)

 Returns:
 bool: True if the format is correct, False otherwise
 """
 if len(arrow_ranking) < 3:
   return False

 # Remove whitespace and replace multiple spaces with single spaces.
 arrow_ranking = re.sub(r'\s+', ' ', arrow_ranking.strip())

 # Remove spaces around '>' and '=' symbols.
 arrow_ranking = re.sub(r'\s*(>|=)\s*', r'\1', arrow_ranking)

 # Check if the ranking contains only allowed characters.
 if not re.match(r'^[A-Z>=]+$', arrow_ranking):
   return False

 # Check for consecutive '>' symbols, '=' at the start/end,
 # or '=' immediately before '>'.
 if (
     '>>' in arrow_ranking
     or arrow_ranking.startswith('=')
     or arrow_ranking.endswith('=')
     or '=>' in arrow_ranking
 ):
   return False

 # Split by '>' and check each group
 groups = arrow_ranking.split('>')

 if len(groups) < 1:
   return False

 seen_letters = set()
 for group in groups:
   # Check if the group is empty.
   if not group:
     return False
   # Check if each group contains only unique letters separated by '='.
   letters = group.split('=')
   if len(letters) != len(set(letters)):
     return False
   # Check if any letter in this group has been seen before.
   if any(letter in seen_letters for letter in letters):
     return False
   seen_letters.update(letters)

 return True

def _extract_arrow_ranking(text: str) -> str | None:
 """Extracts the arrow ranking from a given string.

 Args:
   text: The input string containing the arrow ranking.

 Returns:
   The extracted arrow ranking or None if not found.
 """
 # Regular expression to match a full arrow ranking pattern
 match = re.search(r'\b([A-Z](?:\s*(?:>|=)\s*[A-Z])*)\b', text)

 if match:
   return match.group(1).replace(' ', '')  # Removes any extra spaces
 else:
   return None

def _process_rm_model_response(response: str) -> tuple[np.ndarray | None, str]:
 """Processes the model's response, extract the explanation and arrow ranking.

 Args:
   response: The raw model response.

 Returns:
   A tuple of:
   - np.ndarray: The arrow ranking if it is correct, None otherwise.
   - str: The explanation if it is correct, "INCORRECT_TEMPLATE" if the
   response format is incorrect, or "INCORRECT_ARROW_RANKING" if the arrow
   ranking is incorrect.
 """
 if _check_response_format(response):
   match = re.search(
       r'<answer>\s*(.*?)\s*<sep>\s*(.*?)\s*</answer>', response, re.DOTALL
   )
   if match is None:
     return None, f'INCORRECT_TEMPLATE: {response}'
   else:
     explanation = match.group(1).strip()
     arrow_ranking = _extract_arrow_ranking(
         match.group(2).strip())
 else:
   # Backup as it sometimes returns "final ranking:" in a different format.
   match = re.search(r'(?i)final ranking:\s*(.*)', response)
   if match is None:
     return None, f'INCORRECT_TEMPLATE: {response}'
   else:
     explanation = response
     arrow_ranking = _extract_arrow_ranking(match.group(1))

 if arrow_ranking is None or not _check_arrow_format(
     arrow_ranking
 ):
   # Check if the ranking is in the explanation.
   arrow_ranking = _extract_arrow_ranking(
       explanation.strip()
   )
   if arrow_ranking is None or not _check_arrow_format(
       arrow_ranking
   ):
     return None, f'INCORRECT_ARROW_RANKING: {response}'

 # Convert arrow ranking to numpy array.
 elements = re.findall(r'[A-Z]', arrow_ranking)
 unique_elements = sorted(set(elements))
 ranking_dict = {element: 0 for element in unique_elements}

 groups = arrow_ranking.split('>')
 for rank, group in enumerate(groups):
   tied_elements = group.strip().split('=')
   for element in tied_elements:
     ranking_dict[element.strip()] = rank

 result = np.array([ranking_dict[element] for element in unique_elements])

 return result, response

In [None]:
# @title Base social ranking method for aggregating ranked preferences.

RANKING_MOCK = -1

SocialRankingResult = NamedTuple(
   'SocialRankingResult',
   [
       ('social_ranking', np.ndarray),
       ('untied_social_ranking', np.ndarray),
   ],
)

class Base(abc.ABC):
 """Social ranking base class."""

 def __init__(self, tie_breaking_method: 'TieBreakingMethod'):
   """Initialize the Base class.

   Args:
       tie_breaking_method: Method that is used to break ties.
   """
   self._tie_breaking_method = tie_breaking_method

 @abc.abstractmethod
 def aggregate(
     self,
     rankings: np.ndarray,
     seed: int | None = None,  # TODO(miba): Input rng instead of seed.
 ) -> SocialRankingResult:
   """Aggregates a set of rankings into a single social ranking.

   Args:
     rankings: Array of batched rankings with dimensions: [num_citizens,
       num_candidates]. In this array lower is better and the best candidate
       is thus given rank 0. For example, an array [[1, 0], [0, 0]] corresponds
       to the first citizen preferring candidate 1 over candidate 0 while the
       second citizen has no preference for either candidate over the other.
     seed: Optional seed for tie breaking.

   Return:
     A tuple of:
     - An array with the aggregated social rank for each candidate.
       If B>C>A, the array will be [2, 0, 1]. The array has dimensions:
       [num_candidates,]. In this array, ties are allowed and there can be
       multiple potential winners.
     - Untied aggregated social rank. The ranks are untied using the
       `tie_breaking_method`. If there were no ties in the social ranking,
       this just returns the same ranking.
   """
   raise NotImplementedError('Base class is not implemented.')

In [None]:
# @title Social ranking method that implements the Schulze method.

class Schulze(Base):
 """Schulze social ranking method (Schulze, M. 2011).

 We follow the steps from https://electowiki.org/wiki/Schulze_method.
 """

 @override
 def aggregate(
     self,
     rankings: np.ndarray,
     seed: int | None = None,
 ) -> SocialRankingResult:
   """Aggregates rankings into a single social ranking and unties the ranking.

   Args:
     rankings: Array of batched rankings with dimensions: [num_citizens,
       num_candidates]. In this array lower is better and the best candidate
       is thus given rank 0. For example, an array [[1, 0], [0, 0]] corresponds
       to the first citizen preferring candidate 1 over candidate 0 while the
       second citizen has no preference for either candidate over the other.
     seed: Random seed that is used to break ties.

   Returns:
     A tuple of:
     - An array with the aggregated social rank for each candidate.
       If B>C>A, the array will be [2, 0, 1]. The array has dimensions:
       [num_candidates,]. In this array, ties are allowed and there can be
       multiple potential winners.
     - Untied aggregated social rank. The ranks are untied using the
       `tie_breaking_method`. If there were no ties in the social ranking,
       this just returns the same ranking.
   """
   rng = np.random.default_rng(seed)

   # TODO(miba): Add test for mock rankings.
   rankings = filter_out_mocks(rankings)
   if rankings.size == 0:
     social_ranking_with_ties = np.full(
         (rankings.shape[1]), RANKING_MOCK, dtype=np.int32)
     social_ranking_without_ties = rng.permutation(
         np.arange(rankings.shape[1])).astype(np.int32)
     return SocialRankingResult(
         social_ranking_with_ties, social_ranking_without_ties
     )
   social_ranking = self.aggregate_with_ties(rankings)

   # Return if there are no ties or the method is TIES_ALLOWED.
   if (
       is_untied_ranking(social_ranking)
       or self._tie_breaking_method == TieBreakingMethod.TIES_ALLOWED
   ):
     return SocialRankingResult(social_ranking, social_ranking)
   # If not, we need to break the ties.
   else:
     tied_social_ranking = social_ranking.copy()

     # Schulze tie-breaking ranking of the candidates (TBRC).
     if self._tie_breaking_method == TieBreakingMethod.TBRC:
       # Copy and shuffle rankings so we can randomly select ballots.
       random_ballots = rankings.copy()
       rng.shuffle(random_ballots)

       # Each iteration, we try one random ballot and keep track of already
       # untied positions.
       for random_ballot in random_ballots:  # Loop over shuffled ballots.
         # Untie social ranking with random ballot.
         social_ranking = untie_ranking_with_ballot(
             social_ranking, random_ballot
         )
         # Exit the while loop if there are no more ties.
         if is_untied_ranking(social_ranking):
           return SocialRankingResult(
               tied_social_ranking,
               social_ranking,
           )

     # If there are still ties or the method is random, break ties randomly.
     if self._tie_breaking_method in [
         TieBreakingMethod.RANDOM,
         TieBreakingMethod.TBRC,
     ]:
       # New random ballot that can be added to untie rankings.
       random_ballot = np.arange(social_ranking.size)
       rng.shuffle(random_ballot)
       social_ranking = untie_ranking_with_ballot(
           social_ranking, random_ballot
       )
       return SocialRankingResult(
           tied_social_ranking,
           social_ranking,
       )
     raise ValueError(
         f'tie_breaking_method {self._tie_breaking_method.name} is not'
         ' supported.'
     )

 def aggregate_with_ties(
     self,
     rankings: np.ndarray,
 ) -> np.ndarray:
   """Aggregates rankings into a single social ranking with potential ties."""

   check_rankings(rankings)

   pairwise_defeats = self._compute_pairwise_defeats(rankings)
   strongest_path_strengths = self._compute_strongest_path_strengths(
       pairwise_defeats)
   social_ranking = self._rank_candidates(
       strongest_path_strengths)
   return social_ranking

 def _compute_pairwise_defeats(self, rankings: np.ndarray) -> np.ndarray:
   """Computes the number of votes who prefer one over the other candidate.

   Args:
     rankings: Array of batched rankings with dimensions: [num_citizens,
       num_candidates].

   Returns:
     An array with the number of voters who prefer candidate x to candidate y.
       Dimensions [num_candidates, num_candidates].
   """
   num_citizens, num_candidates = rankings.shape
   pairwise_defeats = np.zeros(
       (num_candidates, num_candidates), dtype=np.int32)
   for citizen_id in range(num_citizens):
     for idx in range(num_candidates):
       for idy in range(num_candidates):
         # Lower is better as the higest rank is 0.
         if rankings[citizen_id, idx] < rankings[citizen_id, idy]:
           pairwise_defeats[idx, idy] += 1
   return pairwise_defeats

 def _compute_strongest_path_strengths(
     self, pairwise_defeats: np.ndarray) -> np.ndarray:
   """Computes the strength of the strongest path between candidates.

   Args:
     pairwise_defeats: An array with the number of voters who prefer candidate
         x to candidate y. Dimensions [num_candidates, num_candidates].
   Returns:
     An array with the strength of the strongest path between candidate x and
     candidate y. Dimensions [num_candidates, num_candidates].
   """
   if len(set(pairwise_defeats.shape)) != 1:
     raise ValueError('pairwise_defeats should be a square array.')
   if np.any(np.diag(pairwise_defeats) != 0):
     raise ValueError('pairwise_defeats should have an all zero diagonal.')

   num_candidates = pairwise_defeats.shape[0]
   path_strengths = np.zeros((num_candidates, num_candidates), dtype=np.int32)
   for idx in range(num_candidates):
     for idy in range(num_candidates):
       if idx != idy:
         if pairwise_defeats[idx, idy] > pairwise_defeats[idy, idx]:
           path_strengths[idx, idy] = pairwise_defeats[idx, idy]

   for idx in range(num_candidates):
     for idy in range(num_candidates):
       if idx != idy:
         for idz in range(num_candidates):
           if idx != idz and idy != idz:
             path_strengths[idy, idz] = max(
                 path_strengths[idy, idz],
                 min(path_strengths[idy, idx], path_strengths[idx, idz]),
             )

   return path_strengths

 def _rank_candidates(self, path_strengths: np.ndarray) -> np.ndarray:
   """Rank the candidates by winning path strength.

   Args:
     path_strengths: An array with the strength of the strongest path between
       candidate x and candidate y. Dimensions [num_candidates,
       num_candidates].
   Returns:
     An array with the aggregated social rank for each candidate. Dimensions
       [num_candidates,]. Note that this social rank can contain ties.
   """

   if len(set(path_strengths.shape)) != 1:
     raise ValueError('The path_strengths array should be square.')
   if np.any(np.diag(path_strengths) != 0):
     raise ValueError('path_strengths should have an all zero diagonal.')

   # Compute the margin array and determine pairwise weak preferences.
   pairwise_dominance = (path_strengths - path_strengths.T) >= 0

   # Potential winners are those are preferred (weakly) most often.
   weakly_preferred_count = pairwise_dominance.sum(axis=1)

   # We can compute the rankings from the weakly preferred count as the binary
   # relationships (A >= B) from Schulze are transitive (see page 200 from
   # https://arxiv.org/pdf/1804.02973.pdf).
   _, rankings = np.unique(-1 * weakly_preferred_count, return_inverse=True)
   return rankings

In [None]:
# @title Base class for statement models.

StatementResult = NamedTuple(
   'StatementResult',
   [
       ('statement', str),
       ('explanation', str),
   ],
)

class BaseStatementModel(abc.ABC):
 """Base class for reward models that rank multiple statements."""

 @abc.abstractmethod
 def generate_statement(
     self,
     llm_client: LLMClient,
     question: str,
     opinions: Sequence[str],
     previous_winner: str | None = None,
     critiques: Sequence[str] | None = None,
     seed: int | None = None,
 ) -> StatementResult:
   """Samples text from the model.

   Args:
     llm_client: The LLM client used to generate the statement.
     question: Question that the citizens are responding to.
     opinions: Text-based opinions of the citizens.
     previous_winner: The statement that won the previous round.
     critiques: Critiques of the previous winner.
     seed: Optional seed for the sampling. If None a random seed will be used.

   Returns:
     A tuple containing:
       - The predicted statement.
       - The explanation (e.g. chain-of-thought)
   """
   raise NotImplementedError('generate_statement method is not implemented.')

In [None]:
# @title Utils for social choice methods.

class TieBreakingMethod(enum.Enum):
 """Method for breaking ties."""
 TIES_ALLOWED = 'ties_allowed'
 RANDOM = 'random'
 TBRC = 'tbrc'  # Schulze tie-breaking ranking of the candidates (TBRC).

def filter_out_mocks(rankings: np.ndarray) -> np.ndarray:
 """Filters out mock rankings and checks whether mocks are correctly used."""
 if not np.issubdtype(rankings.dtype, np.integer):
   raise ValueError(
       f'The array should be an integer array but is {rankings.dtype}.'
   )
 is_mock = rankings == RANKING_MOCK
 any_mock = is_mock.any(axis=1)
 all_mock = is_mock.all(axis=1)
 if not (any_mock == all_mock).all():
   raise ValueError(
       'If a citizen uses a mock rank for one candidate'
       'it should use a mock rank for all candidates.'
   )
 return rankings[np.logical_not(any_mock)]


def normalize_ranking(ranking: np.ndarray) -> np.ndarray:
 """Normalizes ranking so e.g. [0, 2, 5, 5] -> [0, 1, 2, 2]."""
 if ranking.ndim != 1:
   raise ValueError('The input array should be a single ranking so `ndim=1`')
 _, normalized_ranking = np.unique(ranking, return_inverse=True)
 return normalized_ranking


def is_untied_ranking(ranking: np.ndarray) -> bool:
 """Checks if the ranking is untied."""
 if ranking.ndim != 1:
   raise ValueError('The input array should be a single ranking so `ndim=1`')
 return np.unique(ranking).size == ranking.size


def check_rankings(rankings: np.ndarray, allow_ties: bool = True) -> None:
 """Checks if a ranking array is a valid ranking array.

 Args:
   rankings: Array of batched rankings with dimensions: [num_citizens,
       num_candidates]. In this array lower is better and the best candidate
       is thus given rank 0. For example, an array [[1, 0], [0, 0]] corresponds
       to the first citizen preferring candidate 1 over candidate 0 while the
       second citizen has no preference for either candidate over the other.
       We assume that the mock ranks have been filtered out.
   allow_ties: If True, rating ties are allowed.
 """
 if not np.issubdtype(rankings.dtype, np.integer):
   raise ValueError(
       f'The array should be an integer array but is {rankings.dtype}.')

 sorted_rankings = np.sort(rankings, axis=1)
 if np.any(sorted_rankings[:, 0] != 0):
   raise ValueError('All rankings should have a 0 as highest ranking')

 diff_sorted_rankings = np.diff(sorted_rankings, axis=1)

 if allow_ties:
   if not np.all(
       np.logical_or(diff_sorted_rankings == 1, diff_sorted_rankings == 0)):
     raise ValueError(
         'Incorrect ratings, the step size between ratings should be 0 or 1.')
 else:
   if not np.all(diff_sorted_rankings == 1):
     raise ValueError('Incorrect ratings, the step size between ratings should'
                      ' be 1 as ties are not allowed.')


def untie_ranking_with_ballot(
   ranking: np.ndarray, ballot: np.ndarray) -> np.ndarray:
 """Unties ranking with extra ballot and renormalizes rankings."""
 if ranking.ndim != 1:
   raise ValueError('The input array should be a single ranking so `ndim=1`')
 if ranking.shape != ballot.shape:
   raise ValueError('The ranking and ballot should have the same shape.')
 # We multiply the rankings with the number of candidates to ensure that we do
 # not change the order of the already sorted candidates. We then add a ballot
 # to untie the social ranking.
 ranking = normalize_ranking(
     ranking) * len(ranking) + normalize_ranking(ballot)

 # Renormalize the social ranking to make sure ranks are consecutive.
 return normalize_ranking(ranking)

In [None]:
# @title [HAS PROMPT] A model that uses the chain-of-thought method to generate statements.

class COTModel(BaseStatementModel):
 """Statement model that uses chain-of-thought reasoning."""

 def generate_statement(
     self,
     llm_client: LLMClient,
     question: str,
     opinions: Sequence[str],
     previous_winner: str | None = None,
     critiques: Sequence[str] | None = None,
     seed: int | None = None,
     override_prompt: str | None = None,
     retries_if_empty: int = 5,
 ) -> StatementResult:
   """Generates a statement (see base model)."""
   for i in range(retries_if_empty):
    prompt = _generate_gm_prompt(question, opinions, previous_winner, critiques)
    response = llm_client.sample_text(
        prompt, terminators=['</answer>'], seed=seed)

    statement, explanation = _process_gm_model_response(response)
    if len(statement) > 5 and explanation != 'INCORRECT_TEMPLATE':
      break
   return StatementResult(statement, explanation)

def _generate_gm_opinion_critique_prompt(
   question: str,
   opinions: Sequence[str],
   previous_winner: str,
   critiques: Sequence[str],
) -> str:
 """Generates a prompt using opinions, previous winner, and critiques."""

 prompt = f"""You are assisting a citizens' jury in forming a group jury opinion on an important question. The jury members have provided their individual opinions, a first draft of a jury statement was created, and critiques of that draft were gathered. Your role is to generate a revised group jury statement that incorporates the feedback and aims to better represent the collective view of the jury. Do not make large edits unless several citizens ask for large edits. Stick to the original proposal and respond surgically to the feedback.

Please think through this task step-by-step:

1. Carefully analyze the individual opinions, noting key themes, points of agreement, and areas of disagreement.
2. Review the previous draft group jury statement and identify its strengths and weaknesses vis a vis the critiques.
3. Analyze the critiques of the previous draft, paying attention to specific suggestions and concerns raised by the jury members. Analyze this both with respect to the opinion expressed in the statement as well as the specific proposal(s) in the draft statement.
4. Based on the opinions, the previous draft, and the critiques, revise the group statement, creating a revised group jury statement that addresses the concerns raised and better reflects the collective view of the jury and an actionable, agreeable plan. Ensure the statement is clear, concise, addresses the core issue posed in the question, and appeals to the group to the maximum extent possible.  Do not mention specific opinion and critique numbers when making your revisions.

Provide your answer in the following format:
<answer>
[Your step-by-step reasoning and explanation for the revised statement]
<sep>
[Revised consensus statement]
</answer>

Example:
<answer>
1. The individual opinions show a general agreement that childcare is important and that parents need support, with several opinions highlighting the financial burden of childcare and the need to enable parents to work (Opinions 1, 3, 4, 5). There is consensus that childcare should support development and learning (Opinions 1, 2, 5) and not just be a childminding service. A point of contention is the age at which universal free childcare should start, with some preferring a later start to allow for bonding (Opinion 1) and others wanting childcare from birth, or the option for parental leave (Opinion 4). There's support for universal paid parental leave for both parents (Opinions 1, 5) and the idea of offering the choice between childcare and paid leave at certain ages (Opinion 1, 5).
2. The previous draft statement made some good points, but could not reconcile the differences in the start age of childcare, and was potentially unclear on some issues, such as whether paid childcare from birth was to be offered as an option (not just the default).
3. The critiques highlight the need to make a clear distinction between universal free childcare from birth and parental leave. While all agree childcare must enhance development (Critique 2, 5), some wanted to allow parents to access paid childcare from birth if needed (Critique 1), though not as the default. Concerns were raised about the costs to small businesses of paid parental leave (Critique 2), the need for childcare to make going back to work worthwhile, and the need for sufficient pay during parental leave (Critique 4). Critiques also raised the need for universality to be irrespective of circumstance (Critique 3), and a positive reiteration of some aspects of the draft was also given (Critique 5).
4. Based on the opinions, the previous draft, and the critiques, the revised statement offers both universal paid parental leave and universal free childcare, but not necessarily from birth, allowing for early bonding between parents and children while also acknowledging the need for working parents to be supported. The revised statement provides a structure for a universal system that addresses multiple family situations without imposing a single approach, recognizing parents can choose paid care from birth if needed but that universal *free* childcare should start later. The statement is clear and concise, addresses the issue of the form and cost of childcare, and is intended to appeal to as many jury members as possible by offering both supports. We also acknowledge the importance of these systems being available to everyone, irrespective of circumstance. To implement this we would propose that the government offer 6 months of parental leave at 80% of the average weekly wage and to implement universal free childcare from 9 months, operating from 8am to 6pm, Monday to Friday. All parents should have the option to access paid childcare from birth, with the government subsiding this childcare at a rate of 20% until the child is 9 months old. This phased approach allows parents flexibility while also moving towards universal childcare.
<sep>
We agree that it is important to support parents and children through both parental leave and free childcare. We support the government providing universal paid parental leave, and that this should be available to both parents. We also support the provision of universal free childcare, but not necessarily from birth, acknowledging the benefits of bonding in the early months. However, we support parents being able to access paid childcare from birth if they need it. To enable working parents to return to work we propose that free childcare should be provided in a way that supports children's development and learning, not just as a childminding service. These supports should be available to all, irrespective of their family situation. To implement this we would propose that the government offer 6 months of parental leave at 80% of the average weekly wage and to implement universal free childcare from 9 months, operating from 8am to 6pm, Monday to Friday. All parents should have the option to access paid childcare from birth, with the government subsidising this childcare at a rate of 20% until the child is 9 months old.
</answer>

Below you will find the question, the individual opinions, the previous draft jury statement, and the critiques provided by the jury members.

Question: {question}

Individual Opinions:
"""
 for i, opinion in enumerate(opinions):
   prompt += f'Opinion Person {i+1}: {opinion}\n'

 prompt += f"""
Previous Draft Jury Statement: {previous_winner}

Critiques of the Previous Draft:
"""

 for i, critique in enumerate(critiques):
   prompt += f'Critique Person {i+1}: {critique}\n'

 return prompt.strip()

def _generate_gm_opinion_only_prompt(
   question: str,
   opinions: Sequence[str],
) -> str:
 """Generates a prompt for the LLM using only the opinions."""
 prompt = f"""
You are assisting a citizens' jury in forming an initial consensus opinion on an important question. The initial consensus opinion should begin with a clear statement of the group's position on the question (e.g., Yes or No, Support or Oppose). It should also include a thorough and detailed justification or argument, based on the jury's opinions that were written and using specific language from the jury's opinions when possible. In essence, the statement should capture the main points of agreement and represent the collective view of the jury but using the specific language from the jury.
In addition, consider drafting a concrete proposal for a plan of action. The proposal would be crafted from proposals suggested inside the opinions of the jury members or variants thereof. The proposal should appeal to the group to the maximum extent possible. The proposal should outline a plan for implementation.  While research or studies or task forces may be part of the process, the core of the proposal should be a specific action or set of actions aimed at addressing the issue directly, and not just studying it. If there are numbers involved (like times, ages, or monetary quantities), propose specific numbers. But again, only make a proposal if the group seems ready for it. One sign they are ready for it is if there is substantial agreement already on the issue, or if someone in the jury directly proposes something.

Please think through this task step-by-step:

1. Carefully analyze the individual opinions, noting key themes, points of agreement, areas of disagreement, and whether or not arguments make specific proposals or are just statements of values or arguments.
2. Based on the analysis, synthesize a clear, specific, and actionable jury statement that represents the shared perspective of the jury members.  Address the core issue posed in the question, and ensure the statement appeals to the group members to the maximum extent possible.  Do not refer specific opinion numbers or citizen numbers in the statement, but you may in the analysis.
3. Consider adding an additional concrete suggestion based on the opinions submitted. This can be a proposal from one of the jury members or a variant thereof. In the absence of specific proposals, come up a new proposal that would appeal to the most group members as possible. Consider the trade-off of being specific but alienting more people, and err on the side of being more specific and actionable.

Provide your answer in the following format:
<answer>
[Your step-by-step reasoning and explanation for the statement]
<sep>
[Draft consensus statement]
</answer>

Example:
<answer>
1. Most opinions acknowledge the benefits of childcare for parents and children, with support for enabling parents to work and children to develop socially and educationally (Opinions 1, 3, 4, 5). There is a general consensus that support for parents is needed, whether through childcare or paid leave. A key point of divergence is the appropriate age for universal free childcare to begin, with some advocating for a later start (Opinions 1, 2) to allow for bonding in the early months, and some wanting a start from birth, or offering a choice of childcare or stay-at-home income (Opinions 4). Opinions also highlight the need to target childcare towards those who cannot afford to pay (Opinion 2). Some mention the gendered nature of childcare (Opinions 3), with the need to reduce barriers to care for both men and women. There's a strong emphasis on the high cost of childcare impacting families financially (Opinions 3, 4).
2. The jury statement prioritizes support for parents through parental leave and free childcare, acknowledging the importance of early childhood development. It suggests a phased approach, with paid parental leave from birth and free childcare from a later age, allowing for both bonding and economic support, which reflects the different perspectives presented. The need for childcare to support both development and learning, and not just function as a childminding service, is also highlighted.
3. Building on the desire for support for parents from birth and concrete suggestions by the opinions, we propose a two-pronged approach to start within 3 months. First, the government should establish a universal paid parental leave scheme, offering 6 months of leave at 80% of the parent's average weekly wage, available to both parents and transferable between them. The funds for this scheme should be diverted from current paid childcare schemes. Second, the government should introduce universal free childcare for children from the age of 9 months, operating from 8am to 6pm Monday to Friday, to be implemented nationally at the same time as the parental leave program. To facilitate uptake the city should promote both programs through national campaigns on TV, social media, and in public information centres, especially at family doctor offices, and offer a help line for any queries. This addresses the need for support at birth and offers concrete action on childcare.
<sep>
In general, free childcare is a good thing, but it is important to consider how it is provided and for which age groups. We feel that it is important to offer support to parents in the form of parental leave, and that this should be available to both parents. In addition, we feel that free childcare should be provided from a young age, and that it should be provided in a way that supports children's development and learning, and not just as a childminding service. However, we do not feel that free childcare should be provided from birth, as we feel that it is important for babies to have a consistent primary caregiver in their early months. For this reason, we would support the government providing universal paid parental leave from birth, and providing universal free childcare from, say, 6 months old. We would also offer parents the opportunity to either use free childcare between 6 months and 1 year, or to have paid parental leave for the same period.
</answer>

Below you will find the question and the individual opinions of the jury members.

Question: {question}

Individual Opinions:
"""

 for i, opinion in enumerate(opinions):
   prompt += f'Opinion Person {i+1}: {opinion}\n'

 return prompt.strip()

def _generate_gm_prompt(
   question: str,
   opinions: Sequence[str],
   previous_winner: str | None = None,
   critiques: Sequence[str] | None = None,
) -> str:
 """Generates a prompt for the LLM."""
 if previous_winner is None:
   return _generate_gm_opinion_only_prompt(question, opinions)
 else:
   return _generate_gm_opinion_critique_prompt(
       question, opinions, previous_winner, critiques
   )

def _process_gm_model_response(response: str) -> tuple[str, str]:
 """Processes the model's response, extracting the statement and explanation.

 Args:
     response: The raw model response.

 Returns:
     A tuple of (statement, explanation).  If the response format is
     incorrect, returns ("", "INCORRECT_TEMPLATE").
 """
 match = re.search(
     r'<answer>\s*(.*?)\s*<sep>\s*(.*?)\s*</answer>', response, re.DOTALL
 )
 if match:
   explanation = match.group(1).strip()
   statement = match.group(2).strip()
   return statement, explanation
 else:
   return '', 'INCORRECT_TEMPLATE'

In [None]:
# @title Top-level utils for Habermas Machine.

def numerical_ranking_to_ordinal_text(ranking_array):
 """Converts a numerical ranking array to ordinal text representation.

 Elements with the same rank are ordered ascendingly by their original index.
 For example:
     * [1, 1, 0, 0] becomes "3 = 4 > 0 = 1"
     * [1, 2, 2, 0] becomes "4 > 1 > 2 = 3"
     * [0, 0, 2, 1] becomes "1 = 2 > 4 > 3"

 Args:
     ranking_array: The NumPy array of rankings.

 Returns:
     A string representing the arrow ranking or None if input is invalid.
 """

 if not isinstance(ranking_array, np.ndarray) or not np.issubdtype(
     ranking_array.dtype, np.integer
 ):
   raise ValueError(
       f"The array should be an integer array but is {ranking_array.dtype}."
   )

 n = len(ranking_array)
 ranked_elements = sorted(  # Use original indices
     zip(ranking_array, range(n)))

 result = []
 current_rank = -1  # Initialize to an invalid rank.
 current_group = []

 for rank, original_index in ranked_elements:
   if rank != current_rank:  # Start a new group
     if current_group:  # Append the previous group if it exists
       result.append(
           # Adjust the rank to start from 1.
           " = ".join(str(i + 1) for i in sorted(current_group)))
     current_rank = rank
     current_group = [original_index]
   else:
     current_group.append(original_index)

 result.append(" = ".join(str(i + 1) for i in sorted(current_group)))

 return " > ".join(result)

In [None]:
# @title Habermas Machine.

class HabermasMachine:
 """Mediates caucus deliberation among participants.

 The Habermas Machine facilitates AI-mediated deliberation among a group of
 participants on a given question. It acts as a "mediator," iteratively
 refining a group statement that aims to capture the common ground of the
 participants' opinions.

 The process involves:

 1. Gathering initial opinions from participants.
 2. Generating candidate group statements using a Large Language Model (LLM).
 3. Evaluating these statements using a personalized reward model, predicting
    the order of preference of each participant for each statement.
 4. Aggregating individual preferences using a social choice method to select
    a winning statement.
 5. Gathering critiques of the winning statement from participants.
 6. Generating revised statements based on the critiques and previous opinions.
 7. Optionally, repeating steps 3-6 for multiple rounds, refining the statement
    iteratively. In the paper, we use one opinion and one critique round.

 This class manages the entire deliberation process, including interaction with
 the LLM, the reward model, and the social choice mechanism.  It maintains the
 history of opinions, critiques, candidate statements, and winning statements
 for each round.
 """

 def __init__(
     self,
     question: str,
     statement_client: LLMClient,
     reward_client: LLMClient,
     statement_model: BaseStatementModel,
     reward_model: BaseRankingModel,
     social_choice_method: Base,
     num_candidates: int = 16,
     num_citizens: int = 5,
     seed: int | None = None,
     verbose: bool = False,
 ):
   """Initializes the Habermas Machine."""
   self._question = question  # Question to be answered.
   self._round = 0  # Current round (round 0 is the opinion round).
   self._critiques = []  # Critiques from current and previous rounds.
   self._statement_client = statement_client
   self._reward_client = reward_client
   self._statement_model = statement_model
   self._reward_model = reward_model
   self._social_choice_method = social_choice_method
   self._num_candidates = num_candidates  # Number of candidates to generate.
   self._rng = np.random.default_rng(seed)  # Random number generator.
   self._num_citizens = num_citizens
   self._previous_winners = []  # Winning statements from previous rounds.
   self._ranking_explanations = []  # Explanations for the rankings.
   self._previous_tied_rankings = []  # Rankings from previous rounds.
   self._previous_untied_rankings = []  # Untied rankings from previous rounds.
   self._statement_explanations = []  # Explanations for the statements.
   self._previous_candidates = []  # Candidates from previous rounds.
   self._verbose = verbose  # Whether to print round information.
   self._opinions = []  # Initial opinions.

 def _get_new_seed(self):
   """Generates a new random seed."""
   return self._rng.integers(np.iinfo(np.int32).max)


 def _generate_statements(
     self,
 ) -> tuple[list[str], list[str]]:  # statements, explanations.
   """Generates candidate statements."""
   statements = []
   explanations = []
   for _ in range(self._num_candidates):
     # Shuffle the opinions and critiques to avoid ordering bias.
     indices = self._rng.permutation(self._num_citizens)
     shuffled_opinions = [self._opinions[j] for j in indices]
     shuffled_critiques = (
         [self._critiques[-1][i] for i in indices] if self._critiques else None
     )
     statement, explanation = self._statement_model.generate_statement(
         llm_client=self._statement_client,
         question=self._question,
         opinions=shuffled_opinions,
         previous_winner=(
             self._previous_winners[-1] if self._previous_winners else None),
         critiques=shuffled_critiques,
         seed=self._get_new_seed(),
     )
     statements.append(statement)
     explanations.append(explanation)
   return statements, explanations


 def _get_rankings(
     self, statements: list[str]) -> tuple[np.ndarray, list[None | str]]:
   """Gets rankings over all candidates for each citizen."""
   all_rankings = []
   explanations = []
   for i in range(self._num_citizens):
     # Shuffle the statements to avoid ordering bias.
     indices = self._rng.permutation(self._num_candidates)
     shuffled_statements = [statements[j] for j in indices]

     ranking, explanation = self._reward_model.predict_ranking(
         llm_client=self._reward_client,
         question=self._question,
         opinion=self._opinions[i],
         statements=shuffled_statements,
         previous_winner=(
             self._previous_winners[-1] if self._round > 0 else None
         ),
         critique=self._critiques[-1][i] if self._round > 0 else None,
     )

     if ranking is None:
       raise ValueError(
           f"Ranking is None for citizen {i+1}. Explanation: {explanation}")

     unshuffled_ranking = np.full_like(ranking, fill_value=RANKING_MOCK)
     unshuffled_ranking[indices] = ranking
     all_rankings.append(unshuffled_ranking)
     explanations.append(explanation)
   return np.array(all_rankings), explanations

 def overwrite_previous_winner(self, winner: str):
   """Overwrites the last winner."""
   if self._round == 0:
     raise ValueError("There is no previous winner before the opinion round.")
   else:
     if self._verbose:
       print("\nOverwriting last winner.")
       print(f"Previous winner: {self._previous_winners[-1]}")
       print(f"New winner: {winner}")
     self._previous_winners[-1] = winner

 def mediate(
     self, opinions_or_critiques: Sequence[str]) -> tuple[str, list[str]]:
   """Runs a single medatiation step and returns the winning statement."""
   if len(opinions_or_critiques) != self._num_citizens:
     raise ValueError(
         f"Expected {self._num_citizens} opinions or critiques, got"
         f" {len(opinions_or_critiques)}."
     )

   if self._round == 0:
     self._opinions = list(opinions_or_critiques)
   else:
     self._critiques.append(list(opinions_or_critiques))

   if self._verbose:
     if self._round == 0:
       print("\n\nOpinion round.")
     else:
       print(f"\n\nCritique round {self._round}.")
     print(f"\nQuestion: {self._question}")
     print("\nOpinions:")
     for i, opinion in enumerate(self._opinions):
       print(f"\tCitizen {i + 1}: {opinion}")
     if self._round > 0:
       print(f"\nPrevious winner: {self._previous_winners[-1]}")
       print("\nCritiques:")
       for i, critique in enumerate(self._critiques[-1]):
         print(f"\tCitizen {i + 1}: {critique}")

   statements, statement_explanations = self._generate_statements()

   if self._verbose:
     print("\nStatements generated:")
     for i, statement in enumerate(statements):
       print(f"\tStatement {i+1}: {statement}")

   all_rankings, ranking_explanations = self._get_rankings(statements)
   if self._verbose:
     print("\nRankings:")
     for i, ranking in enumerate(all_rankings):
       print(
           f"\tCitizen {i + 1}:"
           f" {numerical_ranking_to_ordinal_text(ranking)}"
       )

   tied_social_ranking, untied_social_ranking = (
       self._social_choice_method.aggregate(
           all_rankings, seed=self._get_new_seed()
       )
   )

   if self._verbose:
     print("\nPotentially tied social ranking:")
     print(numerical_ranking_to_ordinal_text(tied_social_ranking))
     print("\nUntied social ranking:")
     print(numerical_ranking_to_ordinal_text(untied_social_ranking))

   # Record the statements with tied and untied rankings.
   statements_with_tied_rankings = []
   statements_with_untied_rankings = []
   for idx, statement in enumerate(statements):
     statements_with_untied_rankings.append((
         statement,
         untied_social_ranking[idx],
     ))
     statements_with_tied_rankings.append((
         statement,
         tied_social_ranking[idx],
     ))
   self._previous_tied_rankings.append(statements_with_tied_rankings)
   self._previous_untied_rankings.append(statements_with_untied_rankings)

   # Get the sorted indices based on the social_ranking.
   sorted_indices = np.argsort(untied_social_ranking)

   # Reorder the statements based on the sorted indices.
   sorted_statements = [statements[i] for i in sorted_indices]
   sorted_statement_explanations = [
       statement_explanations[i] for i in sorted_indices
   ]
   winner = sorted_statements[0]
   self._ranking_explanations.append(ranking_explanations)
   self._statement_explanations.append(sorted_statement_explanations)
   self._previous_winners.append(winner)
   self._previous_candidates.append(sorted_statements)

   if self._verbose:
     print(f"\nWinning statement: {winner}")

   self._round += 1
   return winner, sorted_statements

In [None]:
# @title Functions for names, opinions, and critiques

def get_number_of_people():
    """Gets a specific number of names and numbers them."""
    return int(input(f"How many people are in the group? "))

def get_and_number_names(num_names):
    """Gets a specific number of names and numbers them."""
    names = []
    for i in range(num_names):
        name = input(f"Enter name {i + 1}: ")
        names.append(name)

    return names


def print_numbered_names(name_list):
  """Prints the list of names numbered"""
  if name_list:
    print("\nThe names you entered are:")
    for i, name in enumerate(name_list):
      print(f"Citizen {i+1}: {name}")


def get_opinions(question, names):
  """Get opinions for the list of names"""
  opinions = [] # Use a dictionary to store name:opinion pairs
  print(f'The question is: {question}')
  for name in names:
    opinion = input(f"What is {name}'s opinion? ")
    opinions.append(opinion) # Store the opinion using name as key
  return opinions


def print_numbered_opinions(opinions, names):
  """Prints the opinions in a numbered list"""
  if opinions:
    print("\nThe opinions you entered are:")
    for i, (name, opinion) in enumerate(zip(names, opinions)):
      print(f"\nCitizen {i + 1} ({name}): {opinion}")


def get_critiques(question, statement, names):
  """Get critiques for the list of names"""
  print(f'The question is: {question}')
  print(f'The initial winning statement is: {statement}')
  critiques = [] # Use a dictionary to store name:critique pairs
  for name in names:
    critique = input(f"What is {name}'s critique? ")
    critiques.append(critique)
  return critiques


def print_numbered_critiques(critiques, names):
  """Prints the critiques in a numbered list"""
  if critiques:
    print("\nThe critiques you entered are:")
    for i, (name, critique) in enumerate(zip(names, critiques)):
      print(f"\nCitizen {i + 1} ({name}): {critique}")

In [None]:
# @title Nice statement printing

def nice_statement_printing(statement, width, title="Winning statement:"):
    """Nice statement printing, Markdown first, then wrapping."""

    # Create the markdown string first
    markdown_string = f"**{title}**\n\n```\n{statement}\n```"

    # Render the Markdown to HTML
    rendered_markdown = Markdown(markdown_string)

    # Extract the raw HTML text
    raw_html = rendered_markdown.data  # Access raw HTML of the Markdown object

    # Wrap the rendered HTML, ensuring that we don't affect html elements by wrapping
    wrapped_html = textwrap.fill(raw_html, width=width)

    # Re-display it by wrapping the raw HTML from the original markdown
    display(Markdown(wrapped_html)) # Re display with the wrapped html

# Configuration

In [None]:
# @title Set API key

os.environ['GOOGLE_API_KEY'] = 'AIzaSyDK5Cws4LjSQHs3ZnXZmnfZ3B11fTtA9xg'

# @title Set model, candidates, and methods

NUM_CANDIDATES = 5
MODEL = 'gemini-1.5-pro'

statement_client = AIStudioClient(model_name=MODEL)
reward_client = AIStudioClient(model_name=MODEL)

statement_model = COTModel()
reward_model = COTRankingModel()
social_choice_method = Schulze(
    tie_breaking_method=TieBreakingMethod.TBRC
)

In [None]:
# @title Set model, candidates, and methods

NUM_CANDIDATES = 5
MODEL = 'gemini-1.5-pro'

statement_client = AIStudioClient(model_name=MODEL)
reward_client = AIStudioClient(model_name=MODEL)

statement_model = COTModel()
reward_model = COTRankingModel()
social_choice_method = Schulze(
    tie_breaking_method=TieBreakingMethod.TBRC
)

# Question 1

In [None]:
# @title Select Question and set up group (if this doesn't immediately query you for the group size, stop and retry)
QUESTION_1 = 'Should the UK government prioritize policies that actively encourage AI development in key industries, even if it means potentially delaying significant investment in addressing climate change?' # @param [ 'Should the UK government prioritize policies that actively encourage AI development in key industries, even if it means potentially delaying significant investment in addressing climate change?', 'Should the UK introduce a Universal Basic Income (UBI) for all citizens, regardless of their employment status?']

num_citizens = get_number_of_people()
names = get_and_number_names(num_citizens)
print_numbered_names(names)

In [None]:
# @title Initialize the Habermas Machine
hm_1 = HabermasMachine(
       question=QUESTION_1,
       statement_client=statement_client,
       reward_client=reward_client,
       statement_model=statement_model,
       reward_model=reward_model,
       social_choice_method=social_choice_method,
       num_candidates=NUM_CANDIDATES,
       num_citizens=num_citizens,
       verbose=True,
)

In [None]:
# @title Get opinions
opinions = get_opinions(QUESTION_1, names)
print_numbered_opinions(opinions, names)

In [None]:
# @title Generate initial statement (~2 min).

winner, _ = hm_1.mediate(opinions)

print('\n\n\n\n\n')
nice_statement_printing(winner, 100, "Initial winner:")

In [None]:
# @title Get critiques

critiques = get_critiques(QUESTION_1, winner, names)
print_numbered_critiques(critiques, names)

In [None]:
# @title Generate revised statement (~2 min).

winner, _  = hm_1.mediate(critiques)

print('\n\n\n\n\n')
nice_statement_printing(winner, 100, "Final winner:")

# Question 2

In [None]:
# @title Select Question and set up group (if this doesn't immediately query you for the group size, stop and retry)
QUESTION_2 = 'Should the UK introduce a Universal Basic Income (UBI) for all citizens, regardless of their employment status?' # @param [ 'Should the UK government prioritize policies that actively encourage AI development in key industries, even if it means potentially delaying significant investment in addressing climate change?', 'Should the UK introduce a Universal Basic Income (UBI) for all citizens, regardless of their employment status?']

num_citizens = get_number_of_people()
names = get_and_number_names(num_citizens)
print_numbered_names(names)

In [None]:
# @title Initialize the Habermas Machine
hm_2 = HabermasMachine(
       question=QUESTION_2,
       statement_client=statement_client,
       reward_client=reward_client,
       statement_model=statement_model,
       reward_model=reward_model,
       social_choice_method=social_choice_method,
       num_candidates=NUM_CANDIDATES,
       num_citizens=num_citizens,
       verbose=True,
)

In [None]:
# @title Get opinions
opinions_2 = get_opinions(QUESTION_2, names)
print_numbered_opinions(opinions_2, names)

In [None]:
# @title Generate initial statement.

winner_2, _ = hm_2.mediate(opinions_2)

print('\n\n\n\n\n')
nice_statement_printing(winner_2, 100, "Initial winner:")

In [None]:
# @title Get critiques
critiques_2 = get_critiques(QUESTION_2, winner_2, names)
print_numbered_critiques(critiques_2, names)

In [None]:
# @title Generate revised statement.

winner_2, _  = hm_2.mediate(critiques_2)

print('\n\n\n\n\n')
nice_statement_printing(winner_2, 100, "Final winner:")

# License and disclaimer

Copyright 2024 DeepMind Technologies Limited

All software is licensed under the Apache License, Version 2.0 (Apache 2.0);
you may not use this file except in compliance with the Apache 2.0 license.
You may obtain a copy of the Apache 2.0 license at:
https://www.apache.org/licenses/LICENSE-2.0

All other materials are licensed under the Creative Commons Attribution 4.0
International License (CC-BY). You may obtain a copy of the CC-BY license at:
https://creativecommons.org/licenses/by/4.0/legalcode

Unless required by applicable law or agreed to in writing, all software and
materials distributed here under the Apache 2.0 or CC-BY licenses are
distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
either express or implied. See the licenses for the specific language governing
permissions and limitations under those licenses.

This is not an official Google product.

If you use or build on this code, please cite: Tessler, M. H., Bakker, M. A., Jarrett D., Sheahan, H., Chadwick, M. J., Koster, R., Evans, G., Campbell-Gillingham, J., Collins, T., Parkes, D. C., Botvinick, M., and Summerfield, C. "AI can help humans find common ground in democratic deliberation." *Science*. (2024).