##### Copyright 2025 Google LLC.
Licensed under the Apache 2.0 License.

In [None]:
# @title Licensed under the Apache 2.0 License (the "License"); { display-mode: "form" }
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

## PH-LLM Professional Exam Evaluation

In [None]:
# @title Import
import collections
from collections import Counter
import re
import time
from typing import Any, Callable, List, Optional, Set, Tuple
import numpy as np
import pandas as pd
from saxml.client.python import sax
import copy
import ast

In [None]:
# @title Utilities and MCQ processing.
def _get_sax_model(
    sax_address,
    use_proxy=False,
    num_conn: int = 1,
) -> sax.Model.LM:
  """Get a SAX model."""
  opts = sax.Options()
  if use_proxy:
    opts.proxy_addr = 'sax_proxy_address' # @param {type: "string"}
  opts.num_conn = num_conn
  return sax.Model(sax_address, opts).LM()


def find_majority_vote_answer(
    dict_list: List[Tuple[str, Set[str]]]
) -> Optional[Tuple[str, Set[str]]]:
  """Finds the tuple with the majority vote from a list of tuples."""
  # Extract identifiers and count them
  identifiers = []
  for entry in dict_list:
    entry_answer = entry['model_answer']
    if entry_answer not in ['(A)', '(B)', '(C)', '(D)', '(E)']:
      continue
    else:
      identifiers.append(entry_answer)
  counts = Counter(identifiers)
  most_common = counts.most_common(2)
  # Check for a clear majority
  if len(most_common) == 1 or (
      len(most_common) > 1 and most_common[0][1] > most_common[1][1]
  ):
    majority_identifier = most_common[0][0]
    # Find and return the first tuple with the majority identifier
    for t in dict_list:
      if t['model_answer'] == majority_identifier:
        return (t['model_answer'], t['model_generations'])
  # If no majority or a tie, return None
  return None


def _postprocess_generation_answer(generations: set[str]) -> str:
  """Process the generated answers."""
  answer_re = re.compile(r'<answer>(\([ABCDE]\))</answer>', re.IGNORECASE)
  answers = []
  for gen in generations:
    gen = gen.strip()
    matcher = answer_re.search(gen)
    if matcher:
      answers.append(matcher.group(1).upper())
  # If no generation yielded a valid formatted Answer, flag as skipped.
  if not answers:
    return {_MODEL_GEN: generations, _SKIPPED: 1}
  # This extracts a list of most common answers within the xml tags
  # <answer></answer>, takes the first entry (in case of ties), and then
  # extracts the answer text (the second entry in the pair is the number of
  # times it appeared).
  model_answer = collections.Counter(answers).most_common(1)[0][0]
  return model_answer

def add_instruction_to_prompt(
    samples: list[dict[str, Any]], domain: str
) -> list[dict[str, Any]]:
  """Returns samples with `inputs` modified to add instruction to prompt."""
  if domain == 'Sleep':
    input_key = mcq_constants.SLEEP_MCQ_INPUTS_FEATURE_NAME
    choices_key = mcq_constants.SLEEP_MCQ_EVAL_LABELS_FEATURE_NAME
  elif domain == 'Fitness':
    input_key = mcq_constants.FITNESS_MCQ_INPUTS_FEATURE_NAME
    choices_key = mcq_constants.FITNESS_MCQ_EVAL_LABELS_FEATURE_NAME
  else:
    raise ValueError(f'Invalid domain: {domain}')

  retval = []
  for orig_sample in samples:
    sample = copy.deepcopy(orig_sample)
    sample[input_key] = mcq_prompt_lib.create_prompt_to_generate_mcqs(
        sample[input_key], sample[choices_key], domain
    )
    retval.append(sample)
  return retval


In [None]:
def read_mcq_dataset(
    dataset_path: str, domain: str, difficulty_level: Optional[list[str]] = None
) -> pd.DataFrame:
  """Reads the MCQ dataset."""
  with open(dataset_path, 'r') as f:
    synthetic_mcq_dataset = pd.read_csv(f)

  synthetic_mcq_dataset['choices'] = synthetic_mcq_dataset['choices'].apply(
      ast.literal_eval
  )

  if difficulty_level:
    synthetic_mcq_dataset = synthetic_mcq_dataset[
        (synthetic_mcq_dataset['domain'] == domain)
        & (synthetic_mcq_dataset['difficulty'].isin(difficulty_level))
    ]
  else:
    synthetic_mcq_dataset = synthetic_mcq_dataset[
        synthetic_mcq_dataset['domain'] == domain
    ]
  return synthetic_mcq_dataset


In [None]:
# @title Model Evaluation


def evaluate_model(
    sax_address: str,
    dataset_path: str,
    domain: str,
    eval_func: Callable[
        [
            sax.LanguageModel,
            str,
            dict[str, str],
            Optional[str],
            Optional[float],
            Optional[int],
            Optional[int],
        ],
        dict[str, int],
    ],
    num_examples: int = -1,
    num_replicas: int = 1,
    prompt_type: Optional[str] = None,
    use_proxy: bool = False,
    temperature: float = 0.0,
    max_decoding_steps: int = 2048,
    sc_round: int = 1,
    mcq_difficulty_level: Optional[list[str]] = None,
) -> list[dict[str, Any]]:
  """Returns counts of 'correct', 'incorrect', 'skipped' questions.

  Args:
    sax_address: Path to the SAX model address.
    dataset_path: Path to the dataset of MCQ example questions.
    domain: The domain of the MCQ dataset (e.g. 'sleep' or 'fitness').
    eval_func: Function used to evaluate the model specified at `sax_address`.
    num_examples: Number of examples to evaluate. If <0, evaluates all examples.
    num_replicas: Number of model replicas available. To parallelize, we need to
      both specify the number of connections to open to the server and then run
      parallel evaluations on the model.
    prompt_type: The type of prompt to use (e.g., CoT or Step-Back).
    use_proxy: Whether to use the proxy to connect to the model.
    temperature: The temperature to use for the llm model.
    max_decoding_steps: The maximum number of decoding steps to run.
    sc_round: The round of self-consistency.
    mcq_difficulty_level: The difficulty level of the MCQ dataset.

  Returns:
    A list of the examples featurized as dictionaries along with the model
    results.
  """
  start_time = time.time()
  model = _get_sax_model(
      sax_address, num_conn=num_replicas, use_proxy=use_proxy
  )
  if domain == 'Sleep':
    feature_dicts = read_mcq_dataset(
        dataset_path,
        domain=domain,
        difficulty_level=mcq_difficulty_level,
    )
  elif domain == 'Fitness':
    feature_dicts = read_mcq_dataset(
        dataset_path, domain=domain, difficulty_level='None'
    )
  else:
    raise ValueError(f'Unknown domain: {domain}')
  if num_examples < 0:
    num_examples = len(feature_dicts)
  examples_to_evaluate = feature_dicts[:num_examples]
  def _run_one_example(feats: dict[str, Any]) -> dict[str, Any]:
    inputs = feats.copy()
    res = eval_func(
        model=model,
        domain=domain,
        features=inputs,
        prompt_type=prompt_type,
        temperature=temperature,
        max_decoding_steps=max_decoding_steps,
        sc_round=sc_round,
    )
    assert set(res.keys()).isdisjoint(set(inputs.keys()))
    res.update(inputs)
    return res
  retval = []
  for _, ex in examples_to_evaluate.iterrows():
    retval.append({'feats': ex.to_dict()})
  print(
      f'Evaluated {dataset_path} with {sax_address} in'
      f' {time.time() - start_time} seconds using {num_replicas} workers.',
      flush=True,
  )
  return retval


################################################################################
# Methods for evaluating MCQs.
################################################################################

# Potential outcomes from evaluating the model on the question.
_CORRECT = 'correct'
_INCORRECT = 'incorrect'
_SKIPPED = 'skipped'
_NO_MAJORITY_VOTE = {'NO MAJORITY VOTE, USED lm.Score INSTEAD'}

# The answer the model provided (if not _SKIPPED).
_MODEL_ANSWER = 'model_answer'

# Relevant only for lm.Score -- the raw logprobs of each choice.
_MODEL_SCORES = 'model_scores'

# Relevant only for lm.Generate -- the generated text.
_MODEL_GEN = 'model_generations'


def eval_score(
    *,
    model: sax.LanguageModel,
    domain: str,
    features: dict[str, Any],
    prompt_type: Optional[str] = None,
    temperature: float = 0.0,
    max_decoding_steps: int = 5,
    sc_round=None,
) -> dict[str, Any]:
  """Returns correct/incorrect for the question when evaluated with lm.Score."""
  del sc_round  # unused.
  del prompt_type  # unused.
  del temperature  # unused.
  del max_decoding_steps  # unused.
  full_question = add_instruction_to_prompt(
      [features], domain=domain
  )[0]['question']

  # Run lm.Score for the question.
  if isinstance(model, sax.LanguageModel):
    sax_options = sax.ModelOptions()
    sax_options.SetTimeout(150)
    scores = []
    for ao in features['choices']:
      scores.extend(model.Score(full_question, [ao], sax_options))
  else:
    raise ValueError(f'Unsupported model type: {type(model)}')
  model_answer = list(features['choices'].keys())[np.argmax(scores)]
  return {
      _MODEL_ANSWER: model_answer,
      _MODEL_SCORES: scores,
      _CORRECT if model_answer == features['answer'] else _INCORRECT: 1,
  }


def _create_mcq_generate_prompt(
    mcq_question: str,
    mcq_options: dict[str, str],
    prompt_type: str,
    domain: str,
) -> str:
  """Converts a sleep MCQ question to a generate prompt."""
  if prompt_type == 'step_back' and domain == 'Sleep':
    return mcq_prompt_lib.SLEEP_TAKE_STEP_BACK_MCQ.format(
        mcq_options=', '.join(sorted(mcq_options)),
        mcq_question=mcq_question.strip(),
        domain=domain,
    )
  elif prompt_type == 'cot' and domain == 'Sleep':
    return mcq_prompt_lib.SLEEP_COT_MCQ.format(
        mcq_options=', '.join(sorted(mcq_options)),
        mcq_question=mcq_question.strip(),
        domain=domain,
    )
  elif prompt_type == 'cot' and domain == 'Fitness':
    return mcq_prompt_lib.FITNESS_COT_MCQ.format(
        mcq_options=', '.join(sorted(mcq_options)),
        mcq_question=mcq_question.strip(),
        domain=domain,
    )
  elif prompt_type == 'step_back' and domain == 'Fitness':
    return mcq_prompt_lib.FITNESS_TAKE_STEP_BACK_MCQ.format(
        mcq_options=', '.join(sorted(mcq_options)),
        mcq_question=mcq_question.strip(),
        domain=domain,
    )
  else:
    raise ValueError(
        f'Unsupported combination of prompt type and domain: {prompt_type} and'
        f' {domain=}.'
    )


def eval_generate(
    *,
    model: sax.LanguageModel,
    domain: str,
    features: dict[str, str],
    prompt_type: Optional[str] = None,
    temperature: float = 0.0,
    max_decoding_steps: int = 2048,
    sc_round=None,
) -> dict[str, Any]:
  """Returns correct/incorrect for the question when evaluated with lm.Generate."""
  del sc_round  # unused.
  full_question = _create_mcq_generate_prompt(
      features['question'],
      features['choices'],
      prompt_type,
      domain,
  )
  if isinstance(model, sax.LanguageModel):
    sax_options = sax.ModelOptions()
    # N.B.: Could consider modifying this. Right now this does multiple
    # generations and aggregates. But it could be changed to only do the most
    # likely (temperature=0) or other values.
    sax_options.SetExtraInput('temperature', temperature)
    sax_options.SetExtraInput(
        'per_example_max_decode_steps', max_decoding_steps
    )
    sax_options.SetTimeout(150)
    # Uniquify all the different generations created.
    generations = {gen for gen, _ in model.Generate(full_question, sax_options)}
  else:
    raise ValueError(f'Unsupported model type: {type(model)}')
  model_answer = _postprocess_generation_answer(generations)
  return {
      _MODEL_ANSWER: model_answer,
      _CORRECT if model_answer == features['answer'] else _INCORRECT: 1,
      _MODEL_GEN: generations,
  }


def eval_generate_sc(
    *,
    model: sax.LanguageModel,
    domain: str,
    features: dict[str, str],
    prompt_type: Optional[str] = None,
    temperature: float = 0.0,
    max_decoding_steps: int = 2048,
    sc_round: int = 5,
) -> dict[str, Any]:
  """Returns correct/incorrect for the question when evaluated with lm.Generate."""
  sc_generations = [
      eval_generate(
          model=model,
          domain=domain,
          features=features,
          prompt_type=prompt_type,
          temperature=temperature,
          max_decoding_steps=max_decoding_steps,
          sc_round=None,
      )
      for _ in range(sc_round)
  ]
  most_popular_answer = find_majority_vote_answer(sc_generations)
  if not most_popular_answer:
    retval = eval_score(
        model=model,
        domain=domain,
        features=features,
        prompt_type=None,
        temperature=temperature,
        max_decoding_steps=max_decoding_steps,
        sc_round=sc_round,
    )
    del retval[_MODEL_SCORES]
    retval[_MODEL_GEN] = _NO_MAJORITY_VOTE
    return retval
  else:
    model_answer, generations = most_popular_answer
  return {
      _MODEL_ANSWER: model_answer,
      _CORRECT if model_answer == features['answer'] else _INCORRECT: 1,
      _MODEL_GEN: generations,
  }

## Evaluation

In [None]:
def _accuracy(results: list[dict[str, Any]]) -> tuple[int, int, float]:
  """Returns (correct, incorrect, accuracy) tuple."""
  correct = sum(q.get(_CORRECT, 0) for q in results)
  incorrect = sum(q.get(_INCORRECT, 0) for q in results)
  acc = np.nan if correct + incorrect == 0 else correct / (correct + incorrect)
  return correct, incorrect, acc


def analyze_results(results: list[dict[str, Any]]) -> None:
  """Prints out analysis of results, both stratified and combined."""
  stratifications = {'All': results}
  num_questions_with_difficulty = sum(int('difficulty' in q) for q in results)
  if num_questions_with_difficulty not in [0, len(results)]:
    raise ValueError(
        'Expected either all or none of the questions to be annotated with '
        f'difficulty, found {num_questions_with_difficulty}/{len(results)}.'
    )
  if num_questions_with_difficulty:
    for difficulty in {q['difficulty'] for q in results}:
      stratifications[difficulty] = [
          q for q in results if q['difficulty'] == difficulty
      ]

  for diff, strat in sorted(stratifications.items()):
    correct, incorrect, acc = _accuracy(strat)
    print(
        f'Accuracy for {diff} questions: {correct}/{correct + incorrect} ='
        f' {acc:.2f}'
    )


def save_results(results: list[dict[str, Any]], filename: str) -> None:
  """Saves results to a CSV file."""
  df_results = pd.DataFrame(results)
  with open(filename, 'w') as f:
    df_results.to_csv(f, index=True)


def perform_full_evaluation(
    *,
    sax_address: str,
    dataset_path: str,
    domain: str,
    outroot: str | None = None,
    num_examples: int = -1,
    num_replicas: int = 1,
    prompt_type: Optional[str] = None,
    use_proxy: bool = False,
    use_eval_generate: bool = False,
    temperature: float = 0,
    max_decoding_steps: int = 2048,
    mcq_difficulty_level: Optional[list[str]] = None,
    sc_round: Optional[int] = None,
) -> None:
  """Performs full evaluation."""
  if outroot:
    outroot += f'.{sax_address.split("/")[-1]}'

  if not use_eval_generate:

    score_test_results = evaluate_model(
        sax_address=sax_address,
        dataset_path=dataset_path,
        domain=domain,
        eval_func=eval_score,
        num_examples=num_examples,
        num_replicas=num_replicas,
        use_proxy=use_proxy,
        mcq_difficulty_level=mcq_difficulty_level,
        temperature=temperature,
    )
    print('## Results for lm.Score evaluation. ##')
    print('\n# Test data:')
    analyze_results(score_test_results)
    if outroot:
      score_test_save_path = outroot.format(split='test') + '.score.csv'
      save_results(score_test_results, score_test_save_path)
      return score_test_results
  elif use_eval_generate and sc_round:
    generate_test_results = evaluate_model(
        sax_address=sax_address,
        dataset_path=dataset_path,
        domain=domain,
        eval_func=eval_generate_sc,
        num_examples=num_examples,
        num_replicas=num_replicas,
        prompt_type=prompt_type,
        use_proxy=use_proxy,
        temperature=temperature,
        max_decoding_steps=max_decoding_steps,
        sc_round=sc_round,
        mcq_difficulty_level=mcq_difficulty_level,
    )
    print(f'## Results for self-consistency {prompt_type} evaluation. ##')
    print('\n# Test data:')
    analyze_results(generate_test_results)
    if outroot:
      generate_test_save_path = (
          outroot.format(split='test') + f'.{prompt_type}.sc.csv'
      )
      save_results(generate_test_results, generate_test_save_path)
    return generate_test_results
  else:
    generate_test_results = evaluate_model(
        sax_address=sax_address,
        dataset_path=dataset_path,
        domain=domain,
        eval_func=eval_generate,
        num_examples=num_examples,
        num_replicas=num_replicas,
        prompt_type=prompt_type,
        use_proxy=use_proxy,
        temperature=temperature,
        max_decoding_steps=max_decoding_steps,
        mcq_difficulty_level=mcq_difficulty_level,
    )
    print(f'## Results for lm.Generate {prompt_type} evaluation. ##')
    print('\n# Test data:')
    analyze_results(generate_test_results)
    if outroot:
      generate_test_save_path = (
          outroot.format(split='test') + f'.{prompt_type}.csv'
      )
      save_results(generate_test_results, generate_test_save_path)
    return generate_test_results

In [None]:
SAX_ADDRESSES = [
    'YOUR_SAX_ADDRESS'
]
g_num_model_replicas = 5  # @param {type:"integer"}
g_domain = 'Sleep'  # @param ['Sleep', 'Fitness']
temperature = 0.7  # @param {type:"number"}
max_decoding_steps = 2048  # @param {type:"integer"}
max_decoding_steps_score = 5
sc_round = 3  # @param {type:"integer"}
dataset_path = './synthetic_mcq_data.csv' # @param
outroot = '/tmp/' # @param

## CoT + Self Consistency - LM.Generate / Score

In [None]:
for sax_address in SAX_ADDRESSES:
  _ = perform_full_evaluation(
      sax_address=sax_address,
      dataset_path=dataset_path,
      domain=g_domain, # or 'Fitness'
      outroot=outroot,
      # num_examples=3, # Only used for debugging.
      num_replicas=g_num_model_replicas,
      prompt_type='cot',
      use_proxy=True,
      use_eval_generate=True,
      temperature=temperature,
      max_decoding_steps=max_decoding_steps,
      mcq_difficulty_level=[
          mcq_constants.SLEEP_MCQ_DIFF_LEVEL_EASY,
          mcq_constants.SLEEP_MCQ_DIFF_LEVEL_MODERATE,
          mcq_constants.SLEEP_MCQ_DIFF_LEVEL_HARD,
      ],
      sc_round=sc_round,
  )