# DeepSeek-R1-Distill-Qwen-14B Execution Prediction Notebook

This notebook evaluates **deepseek-ai/DeepSeek-R1-Distill-Qwen-14B**, a reasoning model. Following the paper, we provide the **zero-shot execution prediction and choice prompts**, and collect five generations per problem to compute OC/OR/MC/MR and execution-choice metrics.


## Step 1: Install Dependencies

In [None]:
!pip install -q transformers accelerate datasets tqdm

## Step 2: Load Model

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

MODEL_ID = 'deepseek-ai/DeepSeek-R1-Distill-Qwen-14B'
print('Loading model:', MODEL_ID)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map='auto',
)
device = next(model.parameters()).device
print('✓ Model ready on', device)


## Step 3: Load Benchmark Dataset

In [None]:
from datasets import load_dataset

DATASET_REPO_ID = "asgaur/leetcode-contests-431-467-mutations-strat110"
SPLIT = "train"

print(f"Loading dataset from {DATASET_REPO_ID}...")
dataset = load_dataset(DATASET_REPO_ID)
print(f"Total samples: {len(dataset[SPLIT])}")


## Step 4: Helper Functions

In [None]:
import json
import re
import ast
from typing import Dict

ZERO_SHOT_EXECUTION_PROMPT = """You are given a Python program and an assertion containing an input to a function. Replace the ?? in the assertion with a literal (no unsimplified expressions, no function calls) representing the function's return value for the given input. Execute the program exactly as written, even if it is incorrect or incomplete. Provide the full assertion in [ANSWER] and [/ANSWER] tags."""

ZERO_SHOT_CHOICE_PROMPT = """Execution Choice Prompt (Zero-Shot)
You are given two Python programs below and an assertion containing an input to a function. First, choose either program, whichever one you are more confident in reasoning about. Then, replace the ?? in the assertion with a literal (no unsimplified expressions, no function calls) representing the function's return value for the given input on your chosen program. Execute the program exactly as written, even if it is incorrect or incomplete. For your final answer, output the letter of your chosen program (A or B) and the full assertion in the following json format:
{
"chosen_program": chosen_program_letter,
"assertion": full_assertion
}
"""

def build_execution_prediction_prompt(sample: Dict, use_mutated: bool = False) -> str:
    function_name = sample['function_name']
    code_key = 'mutated_code' if use_mutated else 'code'
    program = sample.get(code_key)
    test_input = sample['input']
    if test_input and test_input.startswith(f"{function_name}(") and test_input.endswith(")"):
        input_args = test_input[len(function_name) + 1:-1]
    else:
        input_args = test_input
    return f"{ZERO_SHOT_EXECUTION_PROMPT}

[PYTHON]
{program}
assert {function_name}({input_args}) == ??
[/PYTHON]"


def build_execution_choice_prompt(sample: Dict, original_first: bool = True):
    function_name = sample['function_name']
    original_code = sample['code']
    mutated_code = sample.get('mutated_code') or original_code
    test_input = sample['input']
    if test_input and test_input.startswith(f"{function_name}(") and test_input.endswith(")"):
        input_args = test_input[len(function_name) + 1:-1]
    else:
        input_args = test_input
    if original_first:
        program_a, program_b = original_code, mutated_code
        mapping = {'A': 'original', 'B': 'mutated'}
    else:
        program_a, program_b = mutated_code, original_code
        mapping = {'A': 'mutated', 'B': 'original'}
    question = f"[PROGRAM_A]
{program_a}
[/PROGRAM_A]
[PROGRAM_B]
{program_b}
[/PROGRAM_B]
[ASSERTION]
assert {function_name}({input_args}) == ??
[/ASSERTION]"
    prompt = f"{ZERO_SHOT_CHOICE_PROMPT}
{question}"
    return prompt, mapping


def parse_execution_choice_response(response: str) -> Dict[str, str]:
    json_match = re.search(r'\{\s*"chosen_program".*?\}', response, re.DOTALL)
    if not json_match:
        raise ValueError('Could not find JSON payload in response.')
    json_text = json_match.group(0)
    try:
        return json.loads(json_text)
    except json.JSONDecodeError:
        chosen_match = re.search(r'"chosen_program"\s*:\s*"?([A-Za-z])"?', json_text)
        assertion_match = re.search(r'"assertion"\s*:\s*("(?:[^"\]|\.)*")', json_text)
        if not chosen_match or not assertion_match:
            raise ValueError('Failed to parse execution choice JSON response.')
        chosen_program = chosen_match.group(1)
        assertion_literal = assertion_match.group(1)
        try:
            assertion = ast.literal_eval(assertion_literal)
        except Exception:
            assertion = assertion_literal.strip('"')
        return {'chosen_program': chosen_program, 'assertion': assertion}


def extract_output_from_assertion(assertion: str) -> str:
    if not assertion:
        return ''
    text = assertion.strip()
    text = re.sub(r'^\[ASSERTION\]\s*', '', text, flags=re.IGNORECASE)
    text = re.sub(r'\s*\[/ASSERTION\]$', '', text, flags=re.IGNORECASE)
    match = re.search(r'assert\s+[\w\.]+\([^)]*\)\s*==\s*(.+)', text)
    if match:
        return match.group(1).strip()
    return text


def extract_answer_from_response(response: str) -> str:
    pattern = r'\[ANSWER\](.*?)\[/ANSWER\]'
    matches = re.findall(pattern, response, re.DOTALL | re.IGNORECASE)
    if matches:
        assertion = matches[0].strip()
        extracted = extract_output_from_assertion(assertion)
        return extracted
    pattern = r"assert\s+\w+\([^)]*\)\s*==\s*(.+?)(?:\n|$)"
    matches = re.findall(pattern, response, re.MULTILINE)
    if matches:
        return matches[0].strip()
    return response.strip()

def check_predicted_output(predicted_output: str, expected_output: str):
    predicted = (predicted_output or '').strip()
    expected = (expected_output or '').strip()
    if predicted == expected:
        return True, None
    try:
        predicted_val = ast.literal_eval(predicted)
        expected_val = ast.literal_eval(expected)
        if predicted_val == expected_val:
            return True, None
    except (ValueError, SyntaxError):
        pass
    return False, f"Predicted: {predicted}, Expected: {expected}"


def is_boolean_output(value: str) -> bool:
    if value is None:
        return False
    try:
        parsed = ast.literal_eval(value.strip())
        return isinstance(parsed, bool)
    except (ValueError, SyntaxError, AttributeError):
        lowered = value.strip().lower()
        return lowered in {'true', 'false'}


## Step 5: Test Execution Prediction on One Sample


In [None]:
test_sample = dataset[SPLIT][0]
prompt = build_execution_prediction_prompt(test_sample, use_mutated=False)
print('Prompt:
', prompt)

torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
messages = [{"role": "user", "content": prompt}]
encoded = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    return_dict=True,
    return_tensors='pt',
)
encoded = {k: v.to(device) for k, v in encoded.items()}
outputs = model.generate(
    **encoded,
    do_sample=True,
    temperature=TEMPERATURE,
    top_p=TOP_P,
    max_new_tokens=MAX_NEW_TOKENS,
)
generated = outputs[0, encoded['input_ids'].shape[-1]:]
response = tokenizer.decode(generated, skip_special_tokens=True)
print('
Model response:
', response)
prediction = extract_answer_from_response(response)
print('
Parsed prediction:', prediction)
print('Expected output :', test_sample['output'])


## Step 6: Execution Prediction Benchmark

In [None]:
import random
import time
from typing import Dict, List
import torch
import pandas as pd
from tqdm.auto import tqdm

NUM_PROBLEMS = None
START_INDEX = 0
NUM_GENERATIONS = 5
MAX_NEW_TOKENS = 800
TEMPERATURE = 0.2
TOP_P = 0.95
SEED = 42
SKIP_BOOLEAN_FOR_REVERSION = True

random.seed(SEED)

def _compute_pass(counts: Dict[str, int]):
    return counts['success'] / counts['total'] if counts['total'] else None

def _decode_predictions(prompt: str, seed: int):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    messages = [{"role": "user", "content": prompt}]
    encoded = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_dict=True,
        return_tensors='pt',
    )
    encoded = {k: v.to(device) for k, v in encoded.items()}
    start_time = time.time()
    outputs = model.generate(
        **encoded,
        do_sample=True,
        temperature=TEMPERATURE,
        top_p=TOP_P,
        max_new_tokens=MAX_NEW_TOKENS,
    )
    latency = time.time() - start_time
    generated = outputs[0, encoded['input_ids'].shape[-1]:]
    response = tokenizer.decode(generated, skip_special_tokens=True)
    return response, latency


split = dataset[SPLIT]
stop = len(split) if NUM_PROBLEMS is None else min(len(split), START_INDEX + NUM_PROBLEMS)
problem_indices = list(range(START_INDEX, stop))
if not problem_indices:
    raise ValueError('No problems selected. Adjust START_INDEX/NUM_PROBLEMS.')

metrics_counts = {
    'OC': {'success': 0, 'total': 0},
    'OR': {'success': 0, 'total': 0},
    'MC': {'success': 0, 'total': 0},
    'MR': {'success': 0, 'total': 0},
}
reversion_skip_count = 0
prediction_records: List[Dict] = []
all_latencies: List[float] = []

for idx in tqdm(problem_indices, desc='Execution prediction'):
    sample = split[idx]
    mutated_code = sample.get('mutated_code')
    if not mutated_code:
        prediction_records.append({'problem_index': int(idx), 'problem_id': sample.get('id'), 'skipped': True, 'skip_reason': 'missing mutated_code'})
        continue

    original_output = sample['output']
    mutated_output = sample.get('mutated_output') or original_output
    include_reversion = True
    if SKIP_BOOLEAN_FOR_REVERSION and (is_boolean_output(original_output) or is_boolean_output(mutated_output)):
        include_reversion = False
        reversion_skip_count += 1

    original_prompt = build_execution_prediction_prompt(sample, use_mutated=False)
    mutated_prompt = build_execution_prediction_prompt(sample, use_mutated=True)

    oc_successes = or_successes = mc_successes = mr_successes = 0
    original_predictions = []
    mutated_predictions = []

    seed_base = SEED + idx * 1000
    for gen_idx in range(NUM_GENERATIONS):
        response_text, latency = _decode_predictions(original_prompt, seed_base + gen_idx)
        all_latencies.append(latency)
        original_prediction = extract_answer_from_response(response_text)
        oc_correct, _ = check_predicted_output(original_prediction, original_output)
        or_correct = None
        if include_reversion:
            or_correct, _ = check_predicted_output(original_prediction, mutated_output)
        original_predictions.append({'generation': gen_idx, 'prediction': original_prediction, 'response': response_text, 'latency_s': latency, 'oc_correct': bool(oc_correct), 'or_correct': bool(or_correct) if isinstance(or_correct, bool) else None})
        if oc_correct:
            oc_successes += 1
        if include_reversion and or_correct:
            or_successes += 1

        response_text_mut, latency_mut = _decode_predictions(mutated_prompt, seed_base + 500 + gen_idx)
        all_latencies.append(latency_mut)
        mutated_prediction = extract_answer_from_response(response_text_mut)
        mc_correct, _ = check_predicted_output(mutated_prediction, mutated_output)
        mr_correct = None
        if include_reversion:
            mr_correct, _ = check_predicted_output(mutated_prediction, original_output)
        mutated_predictions.append({'generation': gen_idx, 'prediction': mutated_prediction, 'response': response_text_mut, 'latency_s': latency_mut, 'mc_correct': bool(mc_correct), 'mr_correct': bool(mr_correct) if isinstance(mr_correct, bool) else None})
        if mc_correct:
            mc_successes += 1
        if include_reversion and mr_correct:
            mr_successes += 1

    metrics_counts['OC']['success'] += oc_successes
    metrics_counts['OC']['total'] += NUM_GENERATIONS
    metrics_counts['MC']['success'] += mc_successes
    metrics_counts['MC']['total'] += NUM_GENERATIONS
    if include_reversion:
        metrics_counts['OR']['success'] += or_successes
        metrics_counts['OR']['total'] += NUM_GENERATIONS
        metrics_counts['MR']['success'] += mr_successes
        metrics_counts['MR']['total'] += NUM_GENERATIONS

    prediction_records.append({'problem_index': int(idx), 'problem_id': sample.get('id'), 'difficulty': sample.get('difficulty'), 'include_reversion': include_reversion, 'original_output': original_output, 'mutated_output': mutated_output, 'oc_successes': oc_successes, 'or_successes': or_successes if include_reversion else None, 'mc_successes': mc_successes, 'mr_successes': mr_successes if include_reversion else None, 'original_predictions': original_predictions, 'mutated_predictions': mutated_predictions})

metrics_summary = {metric: _compute_pass(counts) for metric, counts in metrics_counts.items()}
avg_latency = (sum(all_latencies) / len(all_latencies)) if all_latencies else None
benchmark_summary = {
    'dataset': DATASET_REPO_ID,
    'problems_evaluated': len(problem_indices),
    'generations_per_problem': NUM_GENERATIONS,
    'oc_pass_at_1': metrics_summary['OC'],
    'or_pass_at_1': metrics_summary['OR'],
    'mc_pass_at_1': metrics_summary['MC'],
    'mr_pass_at_1': metrics_summary['MR'],
    'avg_latency_s': avg_latency,
    'reversion_skipped_problems': reversion_skip_count if SKIP_BOOLEAN_FOR_REVERSION else 0,
}

benchmark_table = pd.DataFrame([
    {'Metric': 'Dataset', 'Value': benchmark_summary['dataset']},
    {'Metric': 'Problems Evaluated', 'Value': benchmark_summary['problems_evaluated']},
    {'Metric': 'Generations per Problem', 'Value': benchmark_summary['generations_per_problem']},
    {'Metric': 'OC pass@1', 'Value': metrics_summary['OC']},
    {'Metric': 'OR pass@1', 'Value': metrics_summary['OR']},
    {'Metric': 'MC pass@1', 'Value': metrics_summary['MC']},
    {'Metric': 'MR pass@1', 'Value': metrics_summary['MR']},
    {'Metric': 'Avg latency (s)', 'Value': benchmark_summary['avg_latency_s']},
    {'Metric': 'Reversion skipped', 'Value': benchmark_summary['reversion_skipped_problems']},
])
formatters = {'Value': (lambda val: f"{val:.3f}" if isinstance(val, float) and val is not None else val)}

print('✓ Execution prediction benchmark complete!')


## Step 7: Visualize Execution Prediction Metrics

In [None]:
import matplotlib.pyplot as plt
import numpy as np

if 'metrics_summary' not in globals():
    raise RuntimeError('Run the benchmark cell first.')

metrics = {k: metrics_summary.get(k) for k in ['OC', 'OR', 'MC', 'MR']}
labels = list(metrics.keys())
values = [metrics[k] * 100 if metrics[k] is not None else None for k in labels]

plt.figure(figsize=(8, 4))
bars = plt.bar(labels, [v if v is not None else 0 for v in values], color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728'])
for bar, val in zip(bars, values):
    if val is None:
        bar.set_alpha(0.3)
        plt.text(bar.get_x() + bar.get_width() / 2, bar.get_height(), 'N/A', ha='center', va='bottom')
    else:
        plt.text(bar.get_x() + bar.get_width() / 2, bar.get_height(), f"{val:.1f}%", ha='center', va='bottom')
plt.ylim(0, 100)
plt.ylabel('pass@1 (%)')
plt.title('Execution Prediction Metrics (OC/OR/MC/MR)')
plt.show()
plt.close()

display(benchmark_table.style.format(formatters))


## Step 8: Save Execution Prediction Metrics

In [None]:
import json
import math
from datetime import datetime

if 'benchmark_summary' not in globals():
    raise RuntimeError('Run the benchmark cell first.')

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
output_filename = f"execution_prediction_metrics_{timestamp}.json"

def _clean(value):
    if value is None:
        return None
    if isinstance(value, float) and math.isnan(value):
        return None
    return value

payload = {
    'benchmark_summary': {k: _clean(v) for k, v in benchmark_summary.items()},
    'metrics_summary': {k: _clean(v) for k, v in metrics_summary.items()},
    'metrics_counts': metrics_counts,
}

with open(output_filename, 'w') as f:
    json.dump(payload, f, indent=2)

print(f'✓ Saved metrics to {output_filename}')


## Step 9: Execution Choice Benchmark (Preference / Correctness / Reversion)

In [None]:
import pandas as pd
from tqdm.auto import tqdm

NUM_PROBLEMS_CHOICE = 20
START_INDEX_CHOICE = 0
NUM_RUNS_PER_PROBLEM = 2
MAX_NEW_TOKENS_CHOICE = 800
TEMPERATURE_CHOICE = TEMPERATURE
TOP_P_CHOICE = TOP_P

def _choice_pass(success: int, total: int):
    return success / total if total else None

choice_split = dataset[SPLIT]
stop_choice = len(choice_split) if NUM_PROBLEMS_CHOICE is None else min(len(choice_split), START_INDEX_CHOICE + NUM_PROBLEMS_CHOICE)
choice_indices = list(range(START_INDEX_CHOICE, stop_choice))
if not choice_indices:
    raise ValueError('No problems selected for execution choice. Adjust START_INDEX_CHOICE/NUM_PROBLEMS_CHOICE.')

orderings = [True, False]
selected_orderings = orderings[:NUM_RUNS_PER_PROBLEM]

execution_choice_counts = {
    'preference': {'original': 0, 'mutated': 0, 'total': 0},
    'OC': {'correct': 0, 'total': 0, 'reversion_correct': 0, 'reversion_total': 0},
    'MC': {'correct': 0, 'total': 0, 'reversion_correct': 0, 'reversion_total': 0},
    'invalid_runs': 0,
}
execution_choice_results = []
choice_latencies = []

for idx in tqdm(choice_indices, desc='Execution choice'):
    sample = choice_split[idx]
    original_output = sample['output']
    mutated_output = sample.get('mutated_output') or original_output

    include_reversion = True
    if SKIP_BOOLEAN_FOR_REVERSION and (is_boolean_output(original_output) or is_boolean_output(mutated_output)):
        include_reversion = False

    base_seed = SEED + idx * 1000

    for run_offset, original_first in enumerate(selected_orderings):
        prompt, mapping = build_execution_choice_prompt(sample, original_first=original_first)
        response_text, latency = _decode_predictions(prompt, base_seed + run_offset)
        choice_latencies.append(latency)

        run_record = {
            'problem_index': int(idx),
            'problem_id': sample.get('id'),
            'function_name': sample.get('function_name'),
            'original_first': original_first,
            'response': response_text,
            'latency_s': latency,
            'include_reversion': include_reversion,
            'chosen_program_letter': None,
            'chosen_program_type': None,
            'assertion': None,
            'prediction': None,
            'correct_for_chosen_program': None,
            'reversion_for_other_program': None,
            'correctness_error': None,
            'reversion_error': None,
        }

        try:
            parsed = parse_execution_choice_response(response_text)
        except Exception as exc:
            execution_choice_counts['invalid_runs'] += 1
            run_record['correctness_error'] = str(exc)
            execution_choice_results.append(run_record)
            continue

        chosen_letter = parsed.get('chosen_program')
        assertion_text = parsed.get('assertion', '')
        if not chosen_letter or chosen_letter not in mapping:
            execution_choice_counts['invalid_runs'] += 1
            run_record['correctness_error'] = 'Missing/invalid chosen_program in response.'
            execution_choice_results.append(run_record)
            continue

        chosen_type = mapping[chosen_letter]
        predicted_output = extract_output_from_assertion(assertion_text)
        chosen_output = original_output if chosen_type == 'original' else mutated_output
        other_output = mutated_output if chosen_type == 'original' else original_output

        is_correct, _ = check_predicted_output(predicted_output, chosen_output)
        reversion_flag = None
        if include_reversion:
            reversion_flag, _ = check_predicted_output(predicted_output, other_output)

        run_record.update({
            'chosen_program_letter': chosen_letter,
            'chosen_program_type': chosen_type,
            'assertion': assertion_text,
            'prediction': predicted_output,
            'correct_for_chosen_program': bool(is_correct),
            'reversion_for_other_program': bool(reversion_flag) if isinstance(reversion_flag, bool) else None,
        })
        execution_choice_results.append(run_record)

        execution_choice_counts['preference']['total'] += 1
        execution_choice_counts['preference'][chosen_type] += 1

        metric_key = 'OC' if chosen_type == 'original' else 'MC'
        execution_choice_counts[metric_key]['total'] += 1
        if is_correct:
            execution_choice_counts[metric_key]['correct'] += 1
        if include_reversion:
            rev_key = 'OR' if chosen_type == 'original' else 'MR'
        else:
            rev_key = None
        if include_reversion:
            target = execution_choice_counts['OC'] if chosen_type == 'original' else execution_choice_counts['MC']
            target['reversion_total'] += 1
            if reversion_flag:
                target['reversion_correct'] += 1

execution_choice_summary = {
    'preference_original': execution_choice_counts['preference']['original'] / execution_choice_counts['preference']['total'] if execution_choice_counts['preference']['total'] else None,
    'preference_mutated': execution_choice_counts['preference']['mutated'] / execution_choice_counts['preference']['total'] if execution_choice_counts['preference']['total'] else None,
    'oc_correct': _choice_pass(execution_choice_counts['OC']['correct'], execution_choice_counts['OC']['total']),
    'or_reversion': _choice_pass(execution_choice_counts['OC']['reversion_correct'], execution_choice_counts['OC']['reversion_total']),
    'mc_correct': _choice_pass(execution_choice_counts['MC']['correct'], execution_choice_counts['MC']['total']),
    'mr_reversion': _choice_pass(execution_choice_counts['MC']['reversion_correct'], execution_choice_counts['MC']['reversion_total']),
    'invalid_runs': execution_choice_counts['invalid_runs'],
}

choice_metrics_table = pd.DataFrame([
    {'Metric': 'Preference (Original)', 'Value': execution_choice_summary['preference_original']},
    {'Metric': 'Preference (Mutated)', 'Value': execution_choice_summary['preference_mutated']},
    {'Metric': 'OC Correct', 'Value': execution_choice_summary['oc_correct']},
    {'Metric': 'OR Reversion', 'Value': execution_choice_summary['or_reversion']},
    {'Metric': 'MC Correct', 'Value': execution_choice_summary['mc_correct']},
    {'Metric': 'MR Reversion', 'Value': execution_choice_summary['mr_reversion']},
])

print('✓ Execution choice benchmark complete!')


## Step 10: Visualize Execution Choice Metrics

In [None]:
import matplotlib.pyplot as plt

if 'execution_choice_summary' not in globals():
    raise RuntimeError('Run the execution choice benchmark first.')

choice_metrics = {
    'Pref. Original': execution_choice_summary.get('preference_original'),
    'Pref. Mutated': execution_choice_summary.get('preference_mutated'),
    'OC Correct': execution_choice_summary.get('oc_correct'),
    'OR Reversion': execution_choice_summary.get('or_reversion'),
    'MC Correct': execution_choice_summary.get('mc_correct'),
    'MR Reversion': execution_choice_summary.get('mr_reversion'),
}
labels = list(choice_metrics.keys())
values = [choice_metrics[k] * 100 if choice_metrics[k] is not None else None for k in labels]

plt.figure(figsize=(10, 4))
bars = plt.bar(labels, [v if v is not None else 0 for v in values], color=['#9467bd', '#8c564b', '#1f77b4', '#ff7f0e', '#2ca02c', '#d62728'])
for bar, val in zip(bars, values):
    if val is None:
        bar.set_alpha(0.3)
        plt.text(bar.get_x() + bar.get_width() / 2, bar.get_height(), 'N/A', ha='center', va='bottom')
    else:
        plt.text(bar.get_x() + bar.get_width() / 2, bar.get_height(), f"{val:.1f}%", ha='center', va='bottom')
plt.ylim(0, 100)
plt.ylabel('Rate (%)')
plt.title('Execution Choice Metrics')
plt.show()
plt.close()

display(choice_metrics_table.style.format({'Value': lambda v: f"{v:.3f}" if isinstance(v, float) and v is not None else v}))


## Step 11: Save Execution Choice Metrics

In [None]:
import json
import math
from datetime import datetime

if 'execution_choice_summary' not in globals():
    raise RuntimeError('Run the execution choice benchmark first.')

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
output_filename = f"execution_choice_metrics_{timestamp}.json"

def _clean(value):
    if value is None:
        return None
    if isinstance(value, float) and math.isnan(value):
        return None
    return value

payload = {
    'execution_choice_summary': {k: _clean(v) for k, v in execution_choice_summary.items()},
    'execution_choice_counts': execution_choice_counts,
    'execution_choice_results': execution_choice_results,
}

with open(output_filename, 'w') as f:
    json.dump(payload, f, indent=2)

print(f'✓ Saved execution choice metrics to {output_filename}')
