In [2]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [3]:
import os
import sys
import torch
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import transformers
import torch.nn.functional as F
import random
from tqdm import tqdm
import subprocess as sp



In [4]:
#########
wd = "drive/MyDrive/gsm_test" # CHANGE TO YOUR WORKING DIRECTOR
#########

In [5]:
##########
file = open("drive/My Drive/gsm_test/HF_token.txt", "r") # REQUIRE HUGGINGFACE TOKEN TO ACCESS CERTAIN MODELS, SETUP TOKEN, SAVE TO TEXT
##########

HF_token = file.readline()
file.close()
os.environ["HF_TOKEN"] = HF_token

#!pip install -U "huggingface_hub[cli]"
!huggingface-cli login --token $HF_TOKEN --add-to-git-credential
!git config --global credential.helper store

Token is valid (permission: read).
The token `MyFirstToken` has been saved to /root/.cache/huggingface/stored_tokens
[1m[31mCannot authenticate through git-credential as no helper is defined on your machine.
You might have to re-authenticate when pushing to the Hugging Face Hub.
Run the following command in your terminal in case you want to set the 'store' credential helper as default.

git config --global credential.helper store

Read https://git-scm.com/book/en/v2/Git-Tools-Credential-Storage for more details.[0m
Token has not been saved to git credential helper.
Your token has been saved to /root/.cache/huggingface/token
Login successful.
Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


In [6]:
######### Specify a folder at Google drive to save downloaded pretrained models
###########
model_path = "drive/My Drive/Colab Notebooks/huggingface_models"
###########

def create_folder(dir):
    if not os.path.isdir(dir):
        os.mkdir(dir)

def get_gpu_memory():
    command = "nvidia-smi --query-gpu=memory.free --format=csv"
    memory_free_info = sp.check_output(command.split()).decode('ascii').split('\n')[:-1][1:]
    memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)]
    return memory_free_values

def fix_random_seed(seed, reproduce=False):
    # torch.backends.cudnn.enabled = True
    # torch.backends.cudnn.benchmark = True

    if reproduce:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        ## NOTE: uncomment for CUDA >= 10.2
        # os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
        ## NOTE: uncomment for pytorch >= 1.8
        # torch.use_deterministic_algorithms(True)

    random.seed(seed)
    np.random.seed(seed)
    rng = torch.manual_seed(seed)

    return rng

In [7]:
model_dict = {
    1: "mistral-7b",  # Checked
    2: "llama2-7b",
    3: "llama3-8b",
    4: "llama3-8B-Instruct",  # Checked
    5: "gemma2-9b-it",
    6: "qwen-1.5B",   # Checked
    7: "qwen-3B",  # Checked
    8: "qwen-7B",  # Checked
    ## The following are math reasoning models
    9: "llama3.1-8b",
    10: "qwen2.5-7b",  # Checked, all correct
    11: "mathstral-7b",  # Checked, incorrect on questions with most operands
    12: "deepseek-7b",  # Checked, incorrect on last two questions with most operands
}


In [8]:
from typing import Tuple, List, Dict
from typing import Union
from math import isclose
import math
from fractions import Fraction

def fmt_money(x: float) -> str:
    return f"${x:.2f}"

def safe_round_money(x: float) -> float:
    return round(x + 1e-12, 2)

###############################################################################
# Each generator now allows up to 6 extra clauses (k in 0..6)
# All "steps" logic removed; functions return only (question_text, answer)
###############################################################################

def gen_sheep(k: int, base_value: int = 20) -> Tuple[str, int]:
    """
    Chain order (left -> right):
      Toulouse = 2 * Charleston
      Charleston = 4 * Seattle
      Seattle = 2 * Alice
      Alice = 5 * Bob
      Bob = 3 * David
      (then optionally) David = 2 * Zoe, Zoe = 4 * John, John = 5 * Rachel

    k: how many of the extra_names to append after Seattle (0..6).
    base_value is the value assigned to the last person in the chain (extra_names[k-1] if k>0 else Seattle).
    """
    extra_names = ["Alice", "Bob", "David", "Zoe", "John", "Rachel"]  # keep this exactly
    extra_chain_mults = [2, 5, 3, 2, 4, 5]  # corresponds to relations between Seattle->Alice, Alice->Bob, ...

    assert 0 <= k <= len(extra_names), f"k must be 0..{len(extra_names)}"

    core_names = ["Toulouse", "Charleston", "Seattle"]
    core_mults = [2, 4]  # Toulouse = 2 * Charleston; Charleston = 4 * Seattle

    names = core_names + extra_names[:k]
    multipliers = core_mults + extra_chain_mults[:k]

    clauses = ""
    for i in range(len(names) - 1):
        A = names[i]
        B = names[i + 1]
        M = multipliers[i]
        clauses += f"{A} has {M} times as many sheep as {B}. "

    last_name = names[-1]
    values: Dict[str, int] = {}
    values[last_name] = int(base_value)

    for i in range(len(names) - 2, -1, -1):
        left = names[i]
        right = names[i + 1]
        M = multipliers[i]
        values[left] = int(values[right] * M)

    present_names = ["Toulouse", "Charleston", "Seattle"] + extra_names[:k]
    total = sum(values[name] for name in present_names)

    def nice_join(lst: List[str]) -> str:
        if len(lst) == 1:
            return lst[0]
        return ", ".join(lst[:-1]) + " and " + lst[-1]

    q = clauses + "How many sheep do " + nice_join(present_names) + f" have together if {last_name} has {base_value} sheep?"
    return q, total

def gen_dvds(k: int) -> Tuple[str, int]:
    groups_customers = [4, 2, 8, 3, 7, 5]  # 6 groups
    groups_DVDs =      [3, 8, 9, 2, 1, 9]
    assert 0 <= k <= len(groups_customers)
    clauses = ""
    total = 0
    total += 3 * 1
    total += 2 * 2
    for i in range(k):
        n = groups_customers[i]
        m = groups_DVDs[i]
        item = "DVDs" if m > 1 else "DVD"
        clauses += f"His next {n} customers buy {m} {item} each. "
        total += n * m
    q = (
        "Billy sells DVDs. He has 8 customers on Tuesday. "
        "His first 3 customers buy one DVD each. "
        "His next 2 customers buy 2 DVDs each. "
        + clauses +
        "His last 3 customers don't buy any DVDs. "
        "How many DVDs did Billy sell on Tuesday?"
    )
    return q, total

def gen_clothes(k: int) -> Tuple[str, float]:
    items = ["suspenders", "socks", "sunglasses", "gloves", "earrings", "slippers"]  # 6 items
    num_pairs = [2, 3, 1, 3, 4, 5]
    prices =    [20.50, 5.00, 30.00, 21.50, 94.50, 20.50]
    assert 0 <= k <= len(items)

    total = 0.0
    total += 3 * 16.5
    total += 3 * 22.5
    total += 3 * 42.0

    clauses = ""
    clauses2 = ""
    for i in range(k):
        n = num_pairs[i]
        item = items[i]
        p = prices[i]
        clauses += f"{n} pairs of {item}, "
        clauses2 += f"One pair of {item} costs ${p}. "
        subtotal = n * p
        total += subtotal

    q = (
        "Mishka bought 3 pairs of shorts, " + clauses + "3 pairs of pants, and 3 pairs of shoes. "
        "One pair of shorts costs $16.50. " + clauses2 +
        "One pair of pants costs $22.50 and one pair of shoes costs $42. "
        "How many dollars did Mishka spend on all the clothing items?"
    )
    total = safe_round_money(total)
    return q, total

def gen_rainwater(k: int, per_inch: float = 15.0, price_per_gallon: float = 1.2) -> Tuple[str, float]:
    extra_days_inches = [2, 1, 5, 3, 4, 6]   # 6 extras
    extra_day_names = ["Wednesday", "Thursday", "Friday", "Saturday", "Sunday", "Next Monday"]
    assert 0 <= k <= len(extra_days_inches)
    total_gallons = 0.0
    total_gallons += 4 * per_inch
    total_gallons += 3 * per_inch
    q_days = "On Monday it rained 4 inches and on Tuesday it rained 3 inches. "
    clauses = ""
    for i in range(k):
        day = extra_day_names[i]
        inches = extra_days_inches[i]
        clauses += f"On {day} it rained {inches} inches. "
        total_gallons += inches * per_inch
    total_money = safe_round_money(total_gallons * price_per_gallon)
    q = (
        f"For each inch of rain he collects {per_inch} gallons. "
        + q_days + clauses +
        f"He can sell water for ${price_per_gallon} per gallon. How much money did he make from selling all the water?"
    )
    return q, total_money

def gen_carlotta(k: int, perf_minutes: int = 6) -> Tuple[str, int]:
    extras = [
        ("warm-up", 1), ("costume changes", 2), ("interviews", 1),
        ("meet-and-greet", 3), ("vocal coaching", 2), ("travel time", 7)
    ]  # up to 6 extras
    assert 0 <= k <= len(extras)

    sing = perf_minutes

    practice_per_min = 3
    tantrum_per_min = 5

    practice = practice_per_min * sing
    tantrums = tantrum_per_min * sing
    total = sing + practice + tantrums

    extras_clause = ""
    extra_names = []
    if k > 0:
        parts = []
        for i in range(k):
            name, per_min = extras[i]
            parts.append(f"{per_min} minutes {name}")
            extra_names.append(name)
            amt = per_min * sing
            total += amt
        extras_clause = "She also spends " + ", ".join(parts[:-1]) + (", and " + parts[-1] if len(parts) > 1 else parts[0]) + " for each minute singing. "

    opening = (
        "For every 1 minute that Carlotta sings on stage during the final opera performance, "
        "she spends an additional 3 minutes practicing and 5 minutes throwing temper tantrums. "
    )
    q_body = opening + extras_clause + f"If her final stage performance is {perf_minutes} minutes long, "

    base_activities = ["practicing", "throwing tantrums", "singing"]
    all_activities = base_activities + extra_names

    def nice_join(lst: List[str]) -> str:
        if len(lst) == 0:
            return ""
        if len(lst) == 1:
            return lst[0]
        if len(lst) == 2:
            return f"{lst[0]} and {lst[1]}"
        return ", ".join(lst[:-1]) + " and " + lst[-1]

    q = q_body + "what is the total combined amount of time, in minutes, that she spends " + nice_join(all_activities) + " in the final stage performance?"
    return q, total

from typing import Tuple, List, Dict

def gen_pills(k: int) -> Tuple[str, int]:
    """
    k: number of extra multiplier-style clauses to add (0..6).
    Returns: (question_text, total_pills_per_week)
    """
    assert 0 <= k <= 6

    # Base daily pills (keys used for computation)
    per_day: Dict[str, int] = {}
    per_day["insulin"] = 2
    per_day["blood_pressure"] = 3
    per_day["anticonvulsants"] = 2 * per_day["blood_pressure"]  # twice as many as blood pressure pills

    # Extra definitions: (display_name, reference_key, multiplier)
    extras_defs = [
        ("vitamin C", "insulin", 1),            # as many vitamin C as insulin per day
        ("cholesterol", "blood_pressure", 2),    # twice as many cholesterol as blood pressure per day
        ("allergy", "anticonvulsants", 1),      # as many allergy as anticonvulsants per day
        ("thyroid", "cholesterol", 1),          # as many thyroid as cholesterol per day
        ("probiotic", "insulin", 1),            # as many probiotic as insulin per day
        ("iron", "probiotic", 2)                # twice as many iron as probiotic per day
    ]

    # helper to convert multiplier to English phrase similar to original style
    def mult_phrase(m: int) -> str:
        if m == 1:
            return "as many"
        if m == 2:
            return "twice as many"
        if m == 3:
            return "three times as many"
        return f"{m} times as many"

    extra_clauses: List[str] = []

    # process up to k extras, each references some previously-mentioned key
    for i in range(k):
        display_name, ref_key, mult = extras_defs[i]
        key = display_name.lower().replace(" ", "_")

        # check the referenced pill is already present (insulin, blood_pressure, anticonvulsants or a previously processed extra)
        if ref_key not in per_day:
            # If ref_key missing, fallback to blood_pressure (shouldn't happen with current defs)
            ref_key = "blood_pressure"

        # compute value by multiplier of the referenced per_day count
        per_day[key] = int(per_day[ref_key] * mult)

        # textual clause: match original phrasing: "She also takes twice as many X as Y pills each day."
        phrase = mult_phrase(mult)
        ref_display = ref_key.replace("_", " ")
        clause = f"She also takes {phrase} {display_name} as {ref_display} pills each day."
        extra_clauses.append(clause)

    # Build question text: base sentence + extra clauses (if any) + final question line
    base = (
        "Holly needs to take 2 insulin pills per day, 3 blood pressure pills per day, "
        "and twice as many anticonvulsants as blood pressure pills each day. "
    )
    extras_text = " ".join(extra_clauses) + " " if extra_clauses else ""
    q = base + extras_text + "How many pills does Holly take in a week?"

    # compute weekly total
    total_per_day = sum(per_day.values())
    total_week = total_per_day * 7

    return q, total_week


def gen_snackpacks(k: int) -> Tuple[str, int]:
    """
    k: number of extra clauses to include (0..6).
    Returns: (question_text, items_per_baggie (int))
    Ensures every snack count is a multiple of 16.
    """
    assert 0 <= k <= 6

    pretzels = 64
    goldfish = 4 * pretzels
    suckers = 32

    extras_defs = [
        ("direct", "granola bars", 20),
        ("mult",   "cookies",      "granola bars", Fraction(2, 1)),
        ("direct", "fruit snacks", 16),
        ("mult",   "carrots",      "fruit snacks", Fraction(3, 1)),
        ("direct", "crackers",     24),
        ("mult",   "juice packs",  "crackers",     Fraction(2, 3)),
    ]

    extras_used = extras_defs[:k]

    def keyify(name: str) -> str:
        return name.lower().replace(" ", "_")

    ref_map = {}
    for item in extras_used:
        if item[0] == "mult":
            _, display_name, ref_display, mult = item
            ref_key = keyify(ref_display)
            ref_map.setdefault(ref_key, []).append(mult)

    def _gcd(a: int, b: int) -> int:
        return math.gcd(a, b)
    def _lcm(a: int, b: int) -> int:
        if a == 0 or b == 0:
            return 0
        return a // _gcd(a, b) * b

    counts = {
        "pretzels": pretzels,
        "goldfish": goldfish,
        "suckers": suckers,
    }

    list_parts = [f"{pretzels} pretzels", f"four times as many goldfish", f"{suckers} suckers"]
    extras_text_parts: List[str] = []
    direct_keys: List[str] = []

    def requirement_factor_for_ref(ref_key: str) -> int:
        factors = [1]
        for mult in ref_map.get(ref_key, []):
            if isinstance(mult, Fraction):
                p = mult.numerator
                q = mult.denominator
                factors.append(q // math.gcd(p, q))
            else:
                factors.append(1)
        res = 1
        for f in factors:
            res = _lcm(res, f)
        return res

    def mult_phrase(m):
        if isinstance(m, int):
            if m == 1:
                return "as many"
            if m == 2:
                return "twice"
            words = {3: "three", 4: "four", 5: "five", 6: "six", 7: "seven", 8: "eight", 9: "nine", 10: "ten"}
            if m in words:
                return f"{words[m]} times"
            return f"{m} times"
        if isinstance(m, Fraction):
            if m == Fraction(1, 1):
                return "as many"
            if m == Fraction(1, 2):
                return "half"
            if m == Fraction(2, 1):
                return "twice"
            if m == Fraction(3, 1):
                return "three times"
            if m == Fraction(2, 3):
                return "two-thirds"
            return f"{m.numerator}/{m.denominator} times"
        return str(m)

    for item in extras_used:
        if item[0] == "direct":
            _, display_name, cnt = item
            key = keyify(display_name)
            orig = int(cnt)
            factor = requirement_factor_for_ref(key)
            required_multiple = 16 * factor
            multiplier = (orig + required_multiple - 1) // required_multiple
            adjusted = multiplier * required_multiple
            counts[key] = adjusted
            extras_text_parts.append(f"{adjusted} {display_name}")
            direct_keys.append(key)
    for item in extras_used:
        if item[0] == "mult":
            _, display_name, ref_display, mult = item
            ref_key = keyify(ref_display)
            if ref_key not in counts:
                ref_key = "pretzels"
            if isinstance(mult, Fraction):
                val = counts[ref_key] * mult
                val_int = int(val)
            else:
                val_int = int(counts[ref_key] * int(mult))
            key = keyify(display_name)
            counts[key] = val_int
            phrase = mult_phrase(mult)
            if phrase in ("twice", "half") or phrase.endswith("times"):
                extras_text_parts.append(f"{phrase} as many {display_name} as {ref_display}")
            else:
                extras_text_parts.append(f"{phrase} {display_name} as {ref_display}")

    all_list_parts = list_parts + extras_text_parts
    all_list = ", ".join(all_list_parts)

    total_items = sum(counts.values())

    if total_items % 16 != 0:
        delta = (16 - (total_items % 16)) % 16
        counts["extra_stickers"] = counts.get("extra_stickers", 0) + delta
        total_items = sum(counts.values())
        all_list = all_list + (", {} extra_stickers".format(delta) if delta else "")

    per_baggie = total_items // 16

    q = (
        "MIlle is making snack packs for her kindergarten class. "
        f"She's going to divide {all_list} into baggies for the 16 kids in the class. "
        "How many items does each baggie have?"
    )

    return q, per_baggie

def gen_shells(k: int) -> Tuple[str, int]:
    extras_names = ["Maya", "Liam", "Noah", "Ava", "Sophia", "Ethan"]
    extras_counts = [3, 5, 2, 4, 6, 7]
    assert 0 <= k <= len(extras_names)

    initial = 2
    ed_found = 7 + 2 + 4
    jacob_found = ed_found + 2

    comparative_phrases = [
        ("Maya",  "12 fewer than Jacob",    "Jacob"),
        ("Liam",  "2 more than Maya",       "Maya"),
        ("Noah",  "3 fewer than Liam",      "Liam"),
        ("Ava",   "2 more than Noah",       "Noah"),
        ("Sophia","2 more than Ava",        "Ava"),
        ("Ethan", "1 more than Sophia",     "Sophia"),
    ]

    counts = {
        "Ed": ed_found,
        "Jacob": jacob_found,
    }

    already_names = ["Ed", "Jacob"] + extras_names[:k]
    already_phrase = ", ".join(already_names[:-1]) + (" and " + already_names[-1] if len(already_names) > 1 else already_names[0])

    for i in range(k):
        name, phrase, ref = comparative_phrases[i]
        cnt = extras_counts[i]
        counts[name] = cnt

    total = initial + counts["Ed"] + counts["Jacob"]
    for i in range(k):
        name = extras_names[i]
        total += counts[name]

    q = (
        f"{already_phrase} already had {initial} shells in their shell collection. "
        "Then they went to the beach to collect even more. "
        "Ed found 7 limpet shells, 2 oyster shells, and 4 conch shells. "
        "Jacob found 2 more shells than Ed did. "
        + (" ".join(f"{extras_names[i]} found {comparative_phrases[i][1]}." for i in range(k)) + " " if k>0 else "") +
        "How many shells do they have altogether?"
    )

    return q, total

def gen_lunch(k: int) -> Tuple[str, float]:
    extra_items = [("apple", 0.30), ("milk", 0.50), ("salad", 0.75), ("yogurt", 0.60), ("cheese stick", 0.40), ("banana", 0.25)]
    assert 0 <= k <= len(extra_items)
    n3 = 5 * 30
    n4 = 4 * 28
    n5 = 4 * 27
    total_students = n3 + n4 + n5
    prices = [("hamburger", 2.10), ("carrots", 0.50), ("cookie", 0.20)]
    total_cost = total_students * sum(p for _, p in prices)
    clauses = ""
    for i in range(k):
        name, p = extra_items[i]
        clauses += f"Each student also gets {name}, which costs ${p}. "
        total_cost += total_students * p
    total_cost = safe_round_money(total_cost)
    q = (
        "Marla is planning the lunch menu for an elementary school. "
        "There are 5 third grade classes with 30 students each, 4 fourth grade classes with 28 students each, and 4 fifth grade classes with 27 students each. "
        "Each student gets a hamburger, which costs $2.10, some carrots, which cost $0.50, and a cookie, which cost $0.20. "
        + clauses +
        "How much does one lunch for all the students cost?"
    )
    return q, total_cost

def gen_pinata(k: int) -> Tuple[str, int]:
    assert 0 <= k <= 6

    red = 40
    yellow = 3 * red - 20
    blue = yellow // 2

    counts: Dict[str, int] = {"red": red, "yellow": yellow, "blue": blue}

    def mult_to_phrase(m):
        if isinstance(m, Fraction):
            if m == Fraction(1, 1):
                return "as many"
            if m == Fraction(1, 2):
                return "half"
            if m == Fraction(1, 5):
                return "one-fifth"
            if m == Fraction(3, 4):
                return "three-quarters"
            if m == Fraction(2, 3):
                return "two-thirds"
            return f"{m.numerator}/{m.denominator}"
        else:
            m_int = int(m)
            if m_int == 1:
                return "as many"
            if m_int == 2:
                return "two times"
            if m_int == 3:
                return "three times"
            if m_int == 4:
                return "four times"
            if m_int == 5:
                return "five times"
            return f"{m_int} times"

    extras_defs = [
        ("green",  Fraction(1, 2),  0, "red"),
        ("purple", 5,              10, "yellow"),
        ("orange", Fraction(1, 5), 0, "yellow"),
        ("brown",  4,  0, "blue"),
        ("pink",   Fraction(3, 4),  0, "brown"),
        ("teal",   3,               2, "brown"),
    ]

    extras_defs = extras_defs[:k]

    clauses_parts: List[str] = []
    for name, mult, offset, ref_color in extras_defs:
        ref = ref_color.lower()
        if ref not in counts:
            ref = "red"
        ref_val = counts[ref]
        numeric_mult_val = (ref_val * mult) if isinstance(mult, Fraction) else (ref_val * int(mult))
        val = int(numeric_mult_val) - int(offset)
        counts[name] = val
        mult_phrase = mult_to_phrase(mult)
        if offset == 0:
            clause = f"{mult_phrase} as many {name} candies as {ref_color} candies"
        elif offset > 0:
            clause = f"{offset} less than {mult_phrase} as many {name} candies as {ref_color} candies"
        else:
            clause = f"{abs(offset)} more than {mult_phrase} as many {name} candies as {ref_color} candies"
        clauses_parts.append("There were also " + clause + ".")

    core = (
        "With one mighty blow, Maria cracked open the pinata, and candies spilled all over the floor. "
        "There were 40 red candies, 20 less than three times as many yellow candies as red candies, "
        "and half as many blue candies as yellow candies. "
    )
    q = core + (" ".join(clauses_parts) + " " if clauses_parts else "") + \
        "If Carlos ate all of the yellow candies, how many candies remained?"

    total_before = sum(counts.values())
    remaining = total_before - counts["yellow"]
    return q, remaining

def fmt_money(x: float) -> str:
    return f"${x:.2f}"


def safe_round_money(x: float) -> float:
    return round(x + 1e-12, 2)

###############################################################################
# Extended: original 10 generators kept; 10 new generators added (q11..q20)
# Each new generator accepts k extras (0..6) where applicable and returns
# (question_text, answer) to match the existing API.
###############################################################################

# --- existing functions (gen_sheep .. gen_pinata) are assumed unchanged ---
# For brevity in this file we keep the previously provided 10 generators intact
# and then add the 10 new ones. In a real repo you'd import or keep the old
# definitions; here we redefine only the new generators and update gen_question.

# --- NEW: Paint problem (q11) ---

def gen_paint(k: int) -> Tuple[str, int]:
    """Henrietta repainting. Base: living room walls 600 sqft, three bedrooms 400 sqft each.
    One gallon covers 600 sqft. k extra room types (0..6) appended with their counts.
    Returns (question_text, gallons_needed) where gallons is an integer (ceiling).
    """
    assert 0 <= k <= 6
    base_rooms = [("living room walls", 600, 1), ("bedroom walls", 400, 3)]
    extra_defs = [
        ("kitchen walls", 500, 6),
        ("dining room walls", 240, 5),
        ("bathroom walls", 150, 8),
        ("garage walls", 480, 5),
        ("hallway walls", 900, 8),
        ("sunroom walls", 300, 6),
    ][:k]

    clauses = ""
    total_sqft = 0
    for name, sqft, count in base_rooms + extra_defs:
        total_sqft += sqft * count
        if count == 1:
            clauses += f"The {name} take up {sqft} square feet. "
        else:
            clauses += f"The {name} (each) take up {sqft} square feet and there are {count} of them. "

    coverage = 600
    gallons_needed = math.ceil(total_sqft / coverage)
    q = clauses + f"If one gallon of paint covers {coverage} square feet, how many gallons of paint does Henrietta need to paint her house?"
    return q, gallons_needed


# --- NEW: Night out / club bill (q12) ---

def gen_club(k: int) -> Tuple[str, float]:
    """James at a club. Base: club charges $20 to enter. James buys 2 rounds for his 5 friends.
    He also buys 6 drinks for himself at $6 each and orders fried chicken $14. He leaves 30% tip
    on everything he orders (food & drinks). k extra ordered items (quantity and unit price)
    are appended. Returns (question_text, total_spent) as dollars rounded to cents.
    """
    assert 0 <= k <= 6
    entry_fee = 20.0
    rounds = {"rounds_count": 2, "friends": 5, "drink_price": 6.0}
    self_drinks = {"count": 6, "price": 6.0}
    food = [("fried chicken", 1, 14.0)]

    extra_orders = [
        ("nachos", 2, 8.0),
        ("wings (bucket)", 1, 12.0),
        ("bottle of wine", 7, 25.0),
        ("shots", 9, 3.5),
        ("dessert platter", 4, 9.0),
        ("extra beers", 6, 5.0),
    ][:k]

    # Compute subtotal for items James ordered (excluding club entry which is a charge but not "ordered")
    # James pays entry for himself once.
    subtotal_orders = 0.0
    # 2 rounds for his 5 friends -> each round = friends * drink_price
    rounds_total = rounds["rounds_count"] * rounds["friends"] * rounds["drink_price"]
    subtotal_orders += rounds_total
    # James's own drinks
    subtotal_orders += self_drinks["count"] * self_drinks["price"]
    # food
    for name, qty, price in food:
        subtotal_orders += qty * price
    # extras
    extra_clauses = ""
    for name, qty, price in extra_orders:
        subtotal_orders += qty * price
        extra_clauses += f"He also orders {qty} {name} at ${price} each. "

    tip_rate = 0.30
    tip = subtotal_orders * tip_rate
    total = entry_fee + subtotal_orders + tip
    total = safe_round_money(total)

    q = (
        f"The club charges ${entry_fee:.0f} to enter. "
        f"James buys {rounds['rounds_count']} rounds for his {rounds['friends']} friends. "
        f"He also buys {self_drinks['count']} drinks for himself. Drinks cost ${rounds['drink_price']:.0f} each. He decides to eat some food so he orders fried chicken which costs $14. "
        + extra_clauses +
        f"He leaves a {int(tip_rate*100)}% tip on everything he orders. How much did he spend for the night?"
    )
    return q, total


# --- NEW: Wine ages (q13) ---

def gen_wine(k: int, known_name: str = "Carlo Rosi", known_age: int = 40) -> Tuple[str, int]:
    """Base relations (phrasing kept unambiguous):
       - Franzia is three times as old as Carlo Rosi
       - Carlo Rosi is four times as old as Twin Valley
       Known: Carlo Rosi age given (default 40)
       Add up to k extra wines given as (name, reference_name, multiplier).
       Returns (question_text, total_age).

    Notes:
      - To ensure all ages are integer the known_age must be divisible by 4 (so Twin Valley is integer).
      - The final question sentence now appears after all descriptive clauses.
    """
    assert 0 <= k <= 6

    # ensure integer ages for the chain relation Carlo Rosi = 4 * Twin Valley
    if known_age % 4 != 0:
        raise ValueError("known_age must be divisible by 4 so Twin Valley's age is an integer")

    ages: Dict[str, int] = {}
    # assign known age
    ages[known_name] = int(known_age)

    # unambiguous base relations
    ages["Twin Valley"] = ages[known_name] // 4
    ages["Franzia"] = ages[known_name] * 3

    # extra wines (in an order that ensures references are available)
    extra_defs = [
        ("Silver Ridge", "Franzia", 1),   # same age as Franzia
        ("Old Vine", "Twin Valley", 2),   # twice Twin Valley
        ("Crestwood", "Carlo Rosi", 1),   # same as Carlo
        ("Barrel Hill", "Old Vine", 3),   # 3 times Old Vine
        ("Sunset", "Crestwood", 2),
        ("Golden Acre", "Twin Valley", 4),
    ][:k]

    clauses_parts: List[str] = []
    # build ages and human-friendly clauses
    for name, ref, mult in extra_defs:
        if ref not in ages:
            # this should not happen with the chosen ordering; raise so bug is visible
            raise ValueError(f"Reference '{ref}' for '{name}' is not defined yet")
        ages[name] = int(ages[ref] * mult)
        if mult == 1:
            clauses_parts.append(f"The {name} is the same age as the {ref}.")
        else:
            clauses_parts.append(f"The {name} is {mult} times as old as the {ref}.")

    # compose final question text (move the 'calculate' sentence to the end)
    base = (
        "The Franzia wine is three times as old as the Carlo Rosi, "
        "and the Carlo Rosi is four times as old as the Twin Valley. "
    )
    extras_text = " ".join(clauses_parts)
    q = (
        base + (extras_text + " " if extras_text else "") +
        f"If the {known_name} is {known_age} years old, what is the total age of all the listed wines?"
    )

    total = sum(ages.values())
    return q, total


# --- NEW: Carpets among houses (q14) ---

def gen_carpets(k: int) -> Tuple[str, int]:
    """Base: house1=12, house2=20, house3=10, house4 = 2 * house3
    Add more houses (direct or multiplier). Returns (question_text, total_carpets)
    """
    assert 0 <= k <= 6
    counts: Dict[str, int] = {
        "house 1": 12,
        "house 2": 20,
        "house 3": 10,
    }
    counts["house 4"] = 2 * counts["house 3"]

    extra_defs = [
        ("house 5", "direct", 8),        # 8 carpets
        ("house 6", "mult", ("house 5", 4)),
        ("house 7", "mult", ("house 2", 7)),
        ("house 8", "direct", 15),
        ("house 9", "mult", ("house 4", 3)),
        ("house 10", "direct", 6),
    ][:k]

    clauses = ""
    for name, typ, val in extra_defs:
        if typ == "direct":
            counts[name] = int(val)
            clauses += f"{name.title()} has {val} carpets. "
        else:
            ref, mult = val
            ref_val = counts.get(ref, counts["house 1"])  # fallback
            counts[name] = int(ref_val * mult)
            clauses += f"{name.title()} has {mult} times as many carpets as {ref}. "

    total = sum(counts.values())
    q = (
        "There are 12 carpets in house 1, 20 carpets in house 2, and 10 carpets in house 3. "
        "House 4 has twice as many carpets as house 3. " + clauses +
        "How many carpets do all the houses have in total?"
    )
    return q, total


# --- NEW: Josie's safari trip (q15) ---

def gen_safari(k: int) -> Tuple[str, int]:
    """
    k: number of extra comparative clauses to add (0..6).
    Returns: (question_text, total_animals)
    """
    assert 0 <= k <= 6

    # Base counts from the original problem
    per_animal: Dict[str, int] = {}
    per_animal["antelopes"] = 80
    per_animal["rabbits"] = per_animal["antelopes"] + 34                     # 34 more rabbits than antelopes
    per_animal["hyenas"] = (per_animal["antelopes"] + per_animal["rabbits"]) - 42  # 42 fewer than antelopes+rabbits
    per_animal["wild_dogs"] = per_animal["hyenas"] + 50                      # 50 more wild dogs than hyenas
    per_animal["leopards"] = per_animal["rabbits"] // 2                      # leopards half the rabbits (integer)

    # extras_defs: tuples:
    # (display_name, ref_keys (str or list[str]), multiplier (int or Fraction), offset (int))
    # value computed as int(sum(per_animal[ref] for ref in ref_keys) * multiplier) + offset
    extras_defs = [
        ("zebras",    ["leopards"],                      2,                0),   # twice as many zebras as leopards
        ("giraffes",  ["zebras"],                        1,               -3),   # 3 fewer giraffes than zebras
        ("buffaloes", ["antelopes", "rabbits"],          Fraction(1, 2),   0),   # half as many buffaloes as (antelopes + rabbits)
        ("monkeys",   ["rabbits", "leopards"],           Fraction(2, 3),   0),   # two-thirds as many monkeys as (rabbits + leopards)
        ("elephants", ["wild_dogs", "hyenas"],           1,                5),   # 5 more elephants than (wild_dogs + hyenas)
        ("hippos",    ["antelopes"],                     Fraction(1, 5),   2),   # 2 more than one-fifth as many hippos as antelopes
    ]

    # helper to make a compact multiplier token
    def mult_token(m: Union[int, Fraction]) -> str:
        if isinstance(m, Fraction):
            if m == Fraction(1, 1):
                return "as many"
            if m == Fraction(1, 2):
                return "half"
            if m == Fraction(2, 3):
                return "two-thirds"
            if m == Fraction(1, 5):
                return "one-fifth"
            return f"{m.numerator}/{m.denominator}"
        else:
            if m == 1:
                return "as many"
            if m == 2:
                return "twice"
            if m == 3:
                return "three times"
            return f"{int(m)} times"

    clauses: List[str] = []

    # process up to k extras
    for i in range(k):
        display, ref_keys, multiplier, offset = extras_defs[i]
        key = display.lower().replace(" ", "_")

        # normalize ref_keys to list
        if isinstance(ref_keys, str):
            ref_list = [ref_keys]
        else:
            ref_list = list(ref_keys)

        # ensure referenced animals exist; if not, fallback to antelopes
        ref_norm = []
        for r in ref_list:
            rk = r.lower().replace(" ", "_")
            if rk not in per_animal:
                rk = "antelopes"
            ref_norm.append(rk)

        # compute reference sum
        ref_sum = sum(per_animal[rk] for rk in ref_norm)

        # compute numeric value
        if isinstance(multiplier, Fraction):
            numeric = int(ref_sum * multiplier)
        else:
            numeric = int(ref_sum * int(multiplier))
        val = numeric + int(offset)
        per_animal[key] = val

        # build two forms of reference phrase:
        # - simple: for templates like "twice as many X as Y" -> "leopards" or "antelopes and rabbits combined"
        # - full: for templates that need an article or explicit 'total number' -> "the total number of antelopes and rabbits combined"
        if len(ref_norm) == 1:
            ref_name = ref_norm[0].replace("_", " ")
            ref_phrase_simple = ref_name
            ref_phrase_full = ref_name
        else:
            nice_names = " and ".join(r.replace("_", " ") for r in ref_norm)
            ref_phrase_simple = f"{nice_names} combined"
            ref_phrase_full = f"the total number of {nice_names} combined"

        # multiplier token and templates
        mp_token = mult_token(multiplier)

        if offset == 0:
            # no offset: use the standard comparative templates
            if mp_token == "as many":
                clause = f"There were also as many {display} as {ref_phrase_full}."
            else:
                clause = f"There were also {mp_token} as many {display} as {ref_phrase_simple}."
        elif offset > 0:
            # more
            if mp_token == "as many":
                clause = f"There were also {offset} more {display} than {ref_phrase_full}."
            else:
                clause = f"There were also {offset} more {display} than {mp_token} as many {ref_phrase_simple}."
        else:
            # fewer
            if mp_token == "as many":
                clause = f"There were also {abs(offset)} fewer {display} than {ref_phrase_full}."
            else:
                clause = f"There were also {abs(offset)} fewer {display} than {mp_token} as many {ref_phrase_simple}."

        # normalize whitespace and append
        clause = " ".join(clause.split())
        clauses.append(clause)

    # build the question text
    core = (
        "Josie and her family went on a safari and had an exciting time watching different animals grazing and playing together. "
        "She decided to count each animal she saw and calculated their total number. "
        "She counted 80 antelopes, 34 more rabbits than antelopes, 42 fewer hyenas than the total number of antelopes and rabbits combined, "
        "50 more wild dogs than hyenas, and the number of leopards was half the number of rabbits. "
    )
    extras_text = " ".join(clauses) + " " if clauses else ""
    q = core + extras_text + "What is the total number of animals that Josie counted?"

    total = sum(per_animal.values())
    return q, total


# --- NEW: Fish counts (q16) ---

def gen_fish(k: int) -> Tuple[str, int]:
    """Micah has 7 fish. Kenneth has three times as many as Micah and Matthias has 15 less than Kenneth.
    Add up to k extra boys with relations ("times" or "less"). Return (question_text, total_fish).
    """
    assert 0 <= k <= 6
    counts: Dict[str, int] = {}
    counts["Micah"] = 7
    counts["Kenneth"] = 3 * counts["Micah"]
    counts["Matthias"] = counts["Kenneth"] - 15

    extras = [
        ("Noah", ("times", 7, "Micah")),      # Noah has 7 times Micah
        ("Liam", ("less", 5, "Kenneth")),     # Liam has 5 less than Kenneth
        ("Owen", ("times", 3, "Matthias")),   # Owen has 3 times Matthias
        ("Ethan", ("less", 10, "Owen")),      # Ethan has 10 less than Owen
        ("Ava", ("times", 4, "Micah")),       # Ava same as Micah
        ("Sophia", ("less", 2, "Ava")),       # Sophia 2 less than Ava
    ][:k]

    clauses = ""
    for name, rule in extras:
        typ = rule[0]
        if typ == "times":
            mult = int(rule[1])
            ref = rule[2]
            val = counts.get(ref, counts["Micah"]) * mult
            counts[name] = val
            clauses += f"{name} has {mult} times as many fish as {ref} in his aquarium. "
        else:
            less = int(rule[1])
            ref = rule[2]
            val = counts.get(ref, counts["Micah"]) - less
            counts[name] = val
            clauses += f"{name} has {less} fewer fish than {ref} in his aquarium. "

    total = sum(counts.values())
    base_clause = "Micah has 7 fish in his aquarium. Kenneth has three times as many fish in his aquarium and Matthias has 15 fewer fish than Kenneth in his aquarium. "
    q = base_clause + clauses + "How many fish do they have in total?"
    return q, total


# --- NEW: Basketball scoring segments (q17) ---

def gen_basketball(k: int) -> Tuple[str, int]:
    """
    Build a basketball scoring question using up to k segments (0..6).
    The "last seconds" segment is always the final clause added (if k >= 1),
    and its duration is computed so the total of all included segments = 40 seconds.

    Returns (question_text, fireflies_points_more_than_hornets).
    """
    assert 0 <= k <= 6

    # starting score (with 40 seconds left)
    initial_hornets = 86
    initial_fireflies = 74

    # segments: (seconds, list of (team, count, points_per_score))
    # Note: the seconds here are the canonical durations if all segments are present.
    # We'll only use the durations for the "previous" segments; the last segment's
    # seconds will be computed dynamically as 40 - sum(previous durations).
    segments_data = [
        (5,  [("Fireflies", 7, 3), ("Hornets", 2, 2)]),   # first 5 seconds
        (3,  [("Hornets", 2, 3),   ("Fireflies", 1, 2)]), # next 3 seconds
        (8,  [("Fireflies", 3, 2), ("Hornets", 1, 2)]),   # next 8 seconds
        (10, [("Hornets", 1, 3),   ("Fireflies", 2, 3)]), # next 10 seconds
        (8,  [("Fireflies", 4, 2), ("Hornets", 0, 2)]),   # next 8 seconds
        (6,  [("Fireflies", 2, 3), ("Hornets", 1, 2)]),   # canonical last 6 seconds (we will override seconds when needed)
    ]

    hornets = initial_hornets
    fireflies = initial_fireflies
    text_clauses: List[str] = []

    if k == 0:
        # no extra segments; question will only contain the initial score line
        pass
    else:
        # include the first (k-1) segments in order (these use their canonical durations)
        prev_duration_sum = 0
        for i in range(max(0, k - 1)):
            secs, actions = segments_data[i]
            prev_duration_sum += secs
            pieces = []
            for team, cnt, pts in actions:
                kind = "three-point baskets" if pts == 3 else "two-point baskets"
                pieces.append(f"{team} scored {cnt} {kind}")
            clause = f"In the next {secs} seconds, " + " while ".join(pieces) + "."
            text_clauses.append(clause)
            # apply scores
            for team, cnt, pts in actions:
                if team == "Fireflies":
                    fireflies += cnt * pts
                else:
                    hornets += cnt * pts

        # now compute the dynamic last-seconds duration so total = 40
        last_secs = 40 - prev_duration_sum
        if last_secs < 0:
            # safety: shouldn't happen with the provided canonical durations,
            # but guard against negative by setting to 0 (and not changing scoring wording).
            last_secs = 0

        # use the scoring actions of the canonical last segment but report the dynamic seconds
        _, last_actions = segments_data[-1]
        pieces = []
        for team, cnt, pts in last_actions:
            kind = "three-point baskets" if pts == 3 else "two-point baskets"
            pieces.append(f"{team} scored {cnt} {kind}")
        last_clause = f"In the last {last_secs} seconds, " + " while ".join(pieces) + "."
        text_clauses.append(last_clause)
        # apply last segment scores (counts don't depend on seconds)
        for team, cnt, pts in last_actions:
            if team == "Fireflies":
                fireflies += cnt * pts
            else:
                hornets += cnt * pts

    segments_text = " ".join(text_clauses) + " " if text_clauses else ""
    q = (
        f"With 40 seconds left the Dunkinville Hornets were leading the Fredingo Fireflies "
        f"by a score of {initial_hornets} to {initial_fireflies}. "
        + segments_text +
        "How many more points did the Fireflies score in the game than did the Hornets?"
    )

    diff = fireflies - hornets
    return q, diff


# --- NEW: Jewelry (q18) ---

def gen_jewelry(k: int) -> Tuple[str, int]:
    """
    Extend the jewelry-store question by adding up to 6 additional display items.
    k: number of extra display clauses to include (0..6)
    Returns: (question_text, total_cost_to_fill_displays_in_dollars)

    Each clause states BOTH the capacity and the actual amount currently held.
    Relationship-based computations are used internally if needed, but no comparison
    relationships are printed in the problem text.
    """
    assert 0 <= k <= 6

    # Base displays: key -> (capacity, current, unit_singular, price_dollars)
    displays: Dict[str, Tuple[int, int, str, int]] = {
        "necklace": (12, 5, "necklace", 4),
        "ring":     (30, 18, "ring", 10),
        "bracelet": (15, 8, "bracelet", 5),
    }

    # Extra displays definitions in order (max 6). Each entry is a dict with:
    #  - key: internal key
    #  - readable: how to name it in the sentence
    #  - cap: capacity
    #  - kind: "direct", "mult", "diff", or "sum"
    #  - params: parameters meaning depends on kind
    #  - unit: singular unit word for price sentence
    #  - price: supplier price per unit (dollars)
    extras_defs = [
        # direct current
        {"key": "watch_display", "readable": "watch display", "cap": 20, "kind": "direct",
         "params": {"current": 7}, "unit": "watch", "price": 25},
        # twice as many as watch_display
        {"key": "earring_display", "readable": "earring display", "cap": 50, "kind": "mult",
         "params": {"ref": "watch_display", "mult": 2}, "unit": "earring", "price": 8},
        # one fewer than necklace's current
        {"key": "tiara_display", "readable": "tiara display", "cap": 5, "kind": "diff",
         "params": {"ref": "necklace", "offset": -1}, "unit": "tiara", "price": 60},
        # sum of necklace and watch_display
        {"key": "cufflink_display", "readable": "cufflink display", "cap": 25, "kind": "sum",
         "params": {"ref1": "necklace", "ref2": "watch_display"}, "unit": "cufflink", "price": 15},
        # twice as many as bracelet
        {"key": "brooch_display", "readable": "brooch display", "cap": 18, "kind": "mult",
         "params": {"ref": "bracelet", "mult": 2}, "unit": "brooch", "price": 12},
        # three fewer than bracelet
        {"key": "anklet_display", "readable": "anklet display", "cap": 12, "kind": "diff",
         "params": {"ref": "bracelet", "offset": -3}, "unit": "anklet", "price": 10},
    ]

    # process up to k extras and compute their current counts (clamped between 0 and capacity)
    for i in range(min(k, len(extras_defs))):
        ed = extras_defs[i]
        key = ed["key"]
        readable = ed["readable"]
        cap = int(ed["cap"])
        kind = ed["kind"]
        unit = ed["unit"]
        price = int(ed["price"])

        if kind == "direct":
            cur = int(ed["params"]["current"])

        elif kind == "mult":
            ref = ed["params"]["ref"]
            mult = int(ed["params"]["mult"])
            if ref not in displays:
                # fallback to bracelet if reference missing (shouldn't happen)
                ref = "bracelet"
            ref_cur = displays[ref][1]
            cur = int(ref_cur * mult)

        elif kind == "diff":
            ref = ed["params"]["ref"]
            offset = int(ed["params"]["offset"])
            if ref not in displays:
                ref = "bracelet"
            ref_cur = displays[ref][1]
            cur = int(ref_cur + offset)

        elif kind == "sum":
            ref1 = ed["params"]["ref1"]
            ref2 = ed["params"]["ref2"]
            if ref1 not in displays:
                ref1 = "necklace"
            if ref2 not in displays:
                ref2 = "necklace"
            cur = int(displays[ref1][1] + displays[ref2][1])

        else:
            raise ValueError("unknown kind")

        # clamp current between 0 and capacity
        if cur < 0:
            cur = 0
        if cur > cap:
            cur = cap

        # save into displays
        displays[key] = (cap, cur, unit, price)

    # Build the textual problem: base display sentences + extra display sentences (no comparisons)
    base_text = (
        "A jewelry store is restocking its shelves. "
        "The necklace stand, which can hold 12 necklaces, currently holds 5 necklaces. "
        "The ring display, which can hold 30 rings, currently holds 18 rings. "
        "The bracelet display, which can hold 15 bracelets, currently holds 8 bracelets. "
    )

    extras_clauses = []
    # iterate extras in order that were added
    for i in range(min(k, len(extras_defs))):
        ed = extras_defs[i]
        key = ed["key"]
        readable = ed["readable"]
        cap, cur, unit, _ = displays[key]
        # form plural by adding 's' in a simple controlled way; unit strings above are singular
        unit_plural = unit + "s"
        clause = f"The {readable}, which can hold {cap} {unit_plural}, currently holds {cur} {unit_plural}."
        extras_clauses.append(clause)

    extras_text = " ".join(extras_clauses) + " " if extras_clauses else ""

    # Build a single supplier sentence at the end listing prices for every display (base + extras), in introduction order.
    # We'll list: necklace, ring, bracelet, then any extras in the order defined.
    order_keys = ["necklace", "ring", "bracelet"] + [ed["key"] for ed in extras_defs[:min(k, len(extras_defs))]]
    price_parts = []
    seen_units = set()
    for key in order_keys:
        cap, cur, unit, price = displays[key]
        # avoid duplicate unit entries if same unit used multiple times (unlikely here but safe)
        if unit in seen_units:
            continue
        seen_units.add(unit)
        price_parts.append(f"${price} per {unit}")

    # format price sentence with commas and final "and"
    if not price_parts:
        price_sentence = ""
    elif len(price_parts) == 1:
        price_sentence = "The supplier charges " + price_parts[0] + ". "
    else:
        price_sentence = "The supplier charges " + ", ".join(price_parts[:-1]) + ", and " + price_parts[-1] + ". "

    q = base_text + extras_text + price_sentence + "How much, in dollars, will the store need to pay to fill the displays?"

    # Compute total dollars needed to fill every display to capacity
    total_cost = 0
    for key, (cap, cur, _, price) in displays.items():
        need = max(0, cap - cur)
        total_cost += need * price

    return q, total_cost


# --- NEW: Marbles collection (q19) ---

def gen_marbles(k: int) -> Tuple[str, int]:
    """
    Generate the extended marbles problem with up to k extra people (0..6).
    Changes applied:
      - place the final "If Jenny collected ..., what's the total?" at the end
      - map multiplier 1 -> 3 (printed "three times"), multiplier 2 -> 5 (printed "five times")
      - print readable multiplier phrases (no "1 times")
    Returns: (question_text, total_blue)
    """
    assert 0 <= k <= 6

    # base data for Jenny
    j_red = 30
    j_blue = 25

    # initial counts
    counts_red: Dict[str, int] = {"Jenny": j_red}
    counts_blue: Dict[str, int] = {"Jenny": j_blue}

    # base Mary & Anie relationships (unchanged logic)
    counts_red["Mary"] = 2 * counts_red["Jenny"]
    counts_red["Anie"] = counts_red["Mary"] + 20

    counts_blue["Anie"] = 2 * counts_blue["Jenny"]
    counts_blue["Mary"] = counts_blue["Anie"] // 2  # Mary has half Anie's blue

    # extras definitions (in order). Each tuple: (name, red_rule, blue_rule)
    # red_rule and blue_rule are tuples: ("times"/"less", numeric, reference_name)
    extras_all = [
        ("Tom",  ("times", 1, "Mary"),   ("less", 5, "Jenny")),
        ("Sara", ("less", 10, "Anie"),   ("times", 2, "Mary")),
        ("Leo",  ("times", 1, "Jenny"),  ("times", 1, "Jenny")),
        ("Maya", ("times", 2, "Tom"),    ("less", 3, "Sara")),
        ("Ivy",  ("less", 2, "Leo"),     ("times", 1, "Mary")),
        ("Zane", ("times", 3, "Jenny"),  ("times", 2, "Anie")),
    ]
    extras = extras_all[:k]

    # helper: map multiplier values per user's rule (1 -> 3, 2 -> 5)
    def mapped_multiplier(raw: int) -> int:
        if raw == 1:
            return 3
        if raw == 2:
            return 5
        return raw

    # helper: readable phrase for multiplier (use words for small numbers)
    num_to_word = {3: "three", 5: "five", 2: "two", 1: "one"}
    def multiplier_phrase(n: int) -> str:
        w = num_to_word.get(n)
        if w:
            return f"{w} times"
        return f"{n} times"

    clauses_parts = []

    # compute extras
    for person, red_rule, blue_rule in extras:
        # RED
        if red_rule[0] == "times":
            raw_mult = int(red_rule[1])
            ref = red_rule[2]
            mult = mapped_multiplier(raw_mult)
            ref_val = counts_red.get(ref, counts_red["Jenny"])
            counts_red[person] = ref_val * mult
            phrase = multiplier_phrase(mult)
            clauses_parts.append(f"{person} collected {phrase} as many red marbles as {ref}.")
        elif red_rule[0] == "less":
            less = int(red_rule[1])
            ref = red_rule[2]
            ref_val = counts_red.get(ref, counts_red["Jenny"])
            counts_red[person] = ref_val - less
            clauses_parts.append(f"{person} collected {less} fewer red marbles than {ref}.")
        else:
            raise ValueError("unknown red rule")

        # BLUE
        if blue_rule[0] == "times":
            raw_mult = int(blue_rule[1])
            ref = blue_rule[2]
            mult = mapped_multiplier(raw_mult)
            ref_val = counts_blue.get(ref, counts_blue["Jenny"])
            counts_blue[person] = ref_val * mult
            phrase = multiplier_phrase(mult)
            clauses_parts.append(f"{person} collected {phrase} as many blue marbles as {ref}.")
        elif blue_rule[0] == "less":
            less = int(blue_rule[1])
            ref = blue_rule[2]
            ref_val = counts_blue.get(ref, counts_blue["Jenny"])
            counts_blue[person] = ref_val - less
            clauses_parts.append(f"{person} collected {less} fewer blue marbles than {ref}.")
        else:
            raise ValueError("unknown blue rule")

    # final total of blue marbles
    total_blue = sum(counts_blue.values())

    # Build the question text: story, extras, THEN the final "If Jenny collected ..." question at the end
    story = (
        "From Monday to Friday, Mary collected twice as many red marbles as Jenny "
        "and half the number of blue marbles collected by Anie, "
        "who collected 20 more red marbles than Mary and twice the number of blue marbles Jenny collected. "
    )
    extras_text = " ".join(clauses_parts) + " " if clauses_parts else ""
    question_line = f"If Jenny collected {j_red} red marbles and {j_blue} blue marbles, what's the total number of blue marbles collected by the friends together?"

    q = story + extras_text + question_line

    return q, total_blue


# --- NEW: Cycling distances (q20) ---

Number = int | float

def gen_cycling(k: int) -> Tuple[str, Number]:
    """
    Natalia's cycling. Base:
      - Monday = 40
      - Tuesday = 50
      - Wednesday = 50% fewer than Tuesday (i.e., 50% of Tuesday)
      - Thursday = Monday + Wednesday
    k extra days appended with a variety of relations (0..6).
    Returns: (question_text, total_km) where total_km is an int when the exact total
             is an integer, otherwise a float (exact value).
    This version uses Fraction to avoid truncation/rounding errors.
    """
    assert 0 <= k <= 6

    # helper to format a Fraction for the problem text: show integer as "65", fraction as "32.5"
    def fmt(fr: Fraction) -> str:
        if fr.denominator == 1:
            return str(fr.numerator)
        # convert to float for a concise decimal representation
        # using float() is safe here because these are small rationals like halves, etc.
        val = float(fr)
        # Remove trailing .0 if it happens via conversion (shouldn't if denominator != 1)
        return str(val)

    # store values as Fractions for exact arithmetic
    days: Dict[str, Fraction] = {}
    days["Monday"] = Fraction(40)
    days["Tuesday"] = Fraction(50)
    # "50% fewer than Tuesday" interpreted as 50% of Tuesday (i.e., Tuesday * 0.5)
    days["Wednesday"] = days["Tuesday"] * Fraction(1, 2)
    days["Thursday"] = days["Monday"] + days["Wednesday"]

    # extras definitions (same semantics as your original code)
    extras: List[Tuple[str, Tuple]] = [
        ("Friday", ("direct", 30)),
        ("Saturday", ("percent_less", 50, "Thursday")),  # 50% fewer than Thursday
        ("Sunday", ("sum", ("Monday", "Saturday"))),     # sum of Monday and Saturday
        ("Next Monday", ("direct", 60)),
        ("Next Tuesday", ("percent_more", 20, "Next Monday")),
        ("Next Wednesday", ("sum", ("Friday", "Tuesday"))),
    ][:k]

    clauses_parts: List[str] = []
    for day, rule in extras:
        typ = rule[0]
        if typ == "direct":
            val = Fraction(int(rule[1]))
            days[day] = val
            clauses_parts.append(f"On {day} she rode {fmt(val)} kilometers.")
        elif typ == "percent_less":
            pct = int(rule[1])
            ref = rule[2]
            ref_val = days.get(ref, Fraction(0))
            # compute exact fraction: ref * (100 - pct) / 100
            val = ref_val * Fraction(100 - pct, 100)
            days[day] = val
            clauses_parts.append(f"On {day} she rode {pct}% fewer kilometers than {ref}.")
        elif typ == "percent_more":
            pct = int(rule[1])
            ref = rule[2]
            ref_val = days.get(ref, Fraction(0))
            val = ref_val * Fraction(100 + pct, 100)
            days[day] = val
            clauses_parts.append(f"On {day} she rode {pct}% more kilometers than {ref}.")
        elif typ == "sum":
            a, b = rule[1]
            val = days.get(a, Fraction(0)) + days.get(b, Fraction(0))
            days[day] = val
            clauses_parts.append(f"On {day} she rode as many kilometers as the sum of the kilometers from {a} and {b}.")
        else:
            raise ValueError(f"Unknown rule type: {typ}")

    # build the textual problem; use fmt to turn Fractions into readable numbers
    q = (
        "Natalia is riding a bicycle for the cycling competition. "
        f"On Monday she rode {fmt(days['Monday'])} kilometers and on Tuesday {fmt(days['Tuesday'])} kilometers. "
        f"On Wednesday she rode {fmt(days['Wednesday'])} kilometers and on Thursday she rode {fmt(days['Thursday'])} kilometers. "
        + " ".join(clauses_parts) +
        " How many kilometers did Natalia ride in total?"
    )

    # compute exact total (Fraction), then return as int if whole, else float
    total_exact: Fraction = sum(days.values(), start=Fraction(0))
    if total_exact.denominator == 1:
        total_out: Number = int(total_exact.numerator)
    else:
        total_out = float(total_exact)  # keep exact decimal representation (e.g., 527.5 -> 527.5)

    return q, total_out


def gen_omelets(k: int) -> Tuple[str, int]:
    """
    Extend the omelets time question by adding up to k extra prep/cleanup clauses (0..6).
    Returns: (question_text, total_minutes)
    """
    assert 0 <= k <= 6

    peppers = 4
    onions = 2
    omelets = 5

    time_chop_pepper = 3
    time_chop_onion = 4
    time_grate_cheese_per_omelet = 1
    time_cook_per_omelet = 5

    base_minutes = (
        peppers * time_chop_pepper
        + onions * time_chop_onion
        + omelets * time_grate_cheese_per_omelet
        + omelets * time_cook_per_omelet
    )

    extras: List[Tuple[str, int]] = [
        ("He also needs to chop 3 tomatoes, and it takes him 4 minutes to chop a tomato.", 3 * 4),
        (f"It takes him 2 minutes to slice enough mushrooms for one omelet, and he adds mushrooms to each of the {omelets} omelets.", omelets * 2),
        (f"It takes him 3 minutes to dice enough ham for one omelet, and he adds ham to each of the {omelets} omelets.", omelets * 3),
        (f"It takes him 1 minute to rinse a handful of spinach for one omelet, and he adds spinach to each of the {omelets} omelets.", omelets * 1),
        (f"It takes him 2 minutes to crack and whisk the eggs for one omelet, and he does this separately for each of the {omelets} omelets.", omelets * 2),
        ("After he finishes cooking, it takes him 6 minutes to wash the cutting board and knife.", 6),
    ]

    clauses = " ".join(c for c, _ in extras[:k])
    extras_minutes = sum(m for _, m in extras[:k])

    core = (
        "Bill is making omelets for his family's breakfast. "
        "It takes him 3 minutes to chop a pepper, 4 minutes to chop an onion, and 1 minute to grate enough cheese for one omelet. "
        "It takes him 5 minutes to assemble and cook one omelet. "
        f"He needs to chop up {peppers} peppers, chop {onions} onions, and grate cheese for cooking each of {omelets} omelets. "
    )
    q = core + (clauses + " " if clauses else "") + f"How long will he spend preparing for and cooking the {omelets} omelets?"

    return q, base_minutes + extras_minutes


def gen_gardens(k: int) -> Tuple[str, int]:
    """
    Extend the garden-area question by adding up to k extra people with gardens (0..6).
    Returns: (question_text, total_square_feet)
    """
    assert 0 <= k <= 6

    base_people = [
        ("Mancino", 3, 16, 5),
        ("Marquita", 2, 8, 4),
    ]
    base_area = sum(n * l * w for _, n, l, w in base_people)

    extras = [
        ("Luis", 4, 10, 6),
        ("Tasha", 1, 12, 9),
        ("Nia", 2, 7, 7),
        ("Omar", 3, 5, 14),
        ("Chen", 2, 9, 3),
        ("Priya", 1, 20, 4),
    ]
    extra_area = sum(n * l * w for _, n, l, w in extras[:k])

    extra_clauses = " ".join(
        f"{name} is working on {n} garden{'s' if n != 1 else ''} that "
        f"{'each measure' if n != 1 else 'measures'} {l} feet by {w} feet."
        for name, n, l, w in extras[:k]
    )

    core = (
        "Mancino is tending 3 gardens that each measure 16 feet by 5 feet. "
        "His sister, Marquita, is tilling the soil for two gardens that each measure 8 feet by 4 feet. "
    )
    q = core + (extra_clauses + " " if extra_clauses else "") + "How many square feet combined are in all their gardens?"

    return q, base_area + extra_area


def gen_punch(k: int) -> Tuple[str, int]:
    """
    Extend the punch servings question by adding up to k extra ingredients (0..6).
    All ingredient volumes are chosen so the final number of 10-oz servings is an integer.

    Returns: (question_text, num_servings)
    """
    assert 0 <= k <= 6

    base_total_oz = 6 * 12 + 28 + 40  # Mountain Dew + ice + fruit juice

    extras: List[Tuple[str, int]] = [
        ("She also adds 5 8-oz cans of ginger ale.", 5 * 8),
        ("Then she pours in 2 15-oz bottles of orange soda.", 2 * 15),
        ("Next, she adds a 20-oz bottle of cranberry juice.", 20),
        ("She also stirs in 3 scoops of sherbet that weigh 10 oz each.", 3 * 10),
        ("Finally, she adds 2 cups of pineapple chunks that weigh 10 oz per cup.", 2 * 10),
        ("For extra fizz, she adds a 50-oz bottle of sparkling water.", 50),
    ]

    total_oz = base_total_oz + sum(oz for _, oz in extras[:k])
    assert total_oz % 10 == 0

    extra_text = " ".join(txt for txt, _ in extras[:k])
    core = (
        "Carrie is making punch. "
        "She adds 6 12-oz cans of Mountain Dew, 28 oz of ice, and a 40 oz bottle of fruit juice. "
    )
    q = core + (extra_text + " " if extra_text else "") + "How many 10 oz servings of punch does Carrie have?"

    return q, total_oz // 10


def gen_bread_and_treats(k: int) -> Tuple[str, int]:
    """
    Extend the bread-and-treats comparison question by adding up to k extra people (0..6).
    Each added person has treats described by a comparison, and bread described as a comparison to their treats.

    Returns: (question_text, total_pieces_of_bread_and_treats)
    """
    assert 0 <= k <= 6

    # Base relationships
    wanda_bread = 90
    wanda_treats = wanda_bread // 3
    jane_treats = wanda_treats * 2
    jane_bread = int(Fraction(3, 4) * jane_treats)

    treats: Dict[str, int] = {"Wanda": wanda_treats, "Jane": jane_treats}
    bread: Dict[str, int] = {"Wanda": wanda_bread, "Jane": jane_bread}

    # (name, treats_fn(existing_treats)->treats, bread_fn(treats)->bread, clause_text)
    extras_defs = [
        ("Carlos",
         lambda t: int(Fraction(1, 3) * t["Jane"]),
         lambda tr: 2 * tr,
         "Carlos brings one-third as many treats as Jane and twice as many pieces of bread as treats."),
        ("Mia",
         lambda t: t["Wanda"] - 10,
         lambda tr: int(Fraction(1, 2) * tr),
         "Mia brings 10 fewer treats than Wanda and half as many pieces of bread as treats."),
        ("Omar",
         lambda t: int(Fraction(3, 2) * t["Wanda"]),
         lambda tr: int(Fraction(4, 5) * tr),
         "Omar brings one and a half times as many treats as Wanda and four-fifths as many pieces of bread as treats."),
        ("Priya",
         lambda t: 2 * t["Carlos"],
         lambda tr: int(Fraction(1, 4) * tr),
         "Priya brings twice as many treats as Carlos and one-fourth as many pieces of bread as treats."),
        ("Elena",
         lambda t: t["Mia"] + 5,
         lambda tr: 3 * tr,
         "Elena brings 5 more treats than Mia and three times as many pieces of bread as treats."),
        ("Ben",
         lambda t: int(Fraction(1, 2) * t["Jane"]),
         lambda tr: int(Fraction(3, 5) * tr),
         "Ben brings half as many treats as Jane and three-fifths as many pieces of bread as treats."),
    ]

    clauses: List[str] = []
    for i in range(k):
        name, treat_fn, bread_fn, clause = extras_defs[i]
        tcount = int(treat_fn(treats))
        bcount = int(bread_fn(tcount))
        treats[name] = tcount
        bread[name] = bcount
        clauses.append(clause)

    included = ["Wanda", "Jane"] + [extras_defs[i][0] for i in range(k)]
    total_items = sum(treats[p] + bread[p] for p in included)

    extras_text = " ".join(clauses)
    if k == 0:
        tail = "what is the total number of pieces of bread and treats that Wanda and Jane brought to the zoo?"
    else:
        tail = "what is the total number of pieces of bread and treats that everyone mentioned brought to the zoo?"

    core = (
        "Jane brings 75% as many pieces of bread as treats to feed the live pets at the zoo. "
        "Wanda brings half as many treats as Jane and three times as many pieces of bread as treats. "
    )
    q = core + (extras_text + " " if extras_text else "") + "If Wanda brings 90 pieces of bread, " + tail

    return q, total_items


def gen_rehab_centers(k: int) -> Tuple[str, int]:
    """
    Extend the rehabilitation-centers comparison chain by adding up to k extra people (0..6),
    each described using a multiplier and/or offset comparison.

    Returns: (question_text, total_centers_visited)
    """
    assert 0 <= k <= 6

    centers: Dict[str, int] = {}
    centers["Lisa"] = 6
    centers["Jude"] = int(Fraction(1, 2) * centers["Lisa"])   # half as many as Lisa
    centers["Han"] = 2 * centers["Jude"] - 2                   # 2 less than twice Jude
    centers["Jane"] = 2 * centers["Han"] + 6                   # 6 more than twice Han

    extras_defs = [
        ("Omar",  lambda c: c["Jude"] + 3,
         "Omar visited 3 more rehabilitation centers than Jude did."),
        ("Priya", lambda c: 3 * c["Han"] - 1,
         "Priya visited 1 less than three times as many rehabilitation centers as Han did."),
        ("Chen",  lambda c: 2 * c["Lisa"],
         "Chen visited twice as many rehabilitation centers as Lisa did."),
        ("Nia",   lambda c: 2 * c["Jane"] - 4,
         "Nia visited 4 fewer than twice as many rehabilitation centers as Jane did."),
        ("Luis",  lambda c: int(Fraction(1, 2) * c["Nia"]) + 5,
         "Luis visited 5 more than half as many rehabilitation centers as Nia did."),
        ("Tasha", lambda c: 2 * c["Omar"] - 2,
         "Tasha visited 2 fewer than twice as many rehabilitation centers as Omar did."),
    ]

    clauses: List[str] = []
    for i in range(k):
        name, fn, clause = extras_defs[i]
        centers[name] = int(fn(centers))
        clauses.append(clause)

    included = ["Lisa", "Jude", "Han", "Jane"] + [extras_defs[i][0] for i in range(k)]
    total = sum(centers[p] for p in included)

    extras_text = " ".join(clauses)
    core = (
        "When Jane visited Iran, she visited 6 more than twice as many rehabilitation centers as Han. "
        "Han visited 2 less than twice as many rehabilitation centers as Jude did. "
        "Jude visited half as many rehabilitation centers as Lisa did. "
    )
    q = (
        core
        + (extras_text + " " if extras_text else "")
        + "If Lisa visited 6 rehabilitation centers, how many rehabilitation centers did they all visit in total, assuming they are all different?"
    )

    return q, total


def gen_dog_time(k: int) -> Tuple[str, int]:
    """
    Extend the dog-activity time question by adding up to k extra activities (0..6).
    Some activities require simple calculation (fractions of time, distance/speed).

    Returns: (question_text, total_minutes)
    """
    assert 0 <= k <= 6

    bath = 20
    blow_dry = bath // 2
    walk_minutes = int(Fraction(3, 6) * 60)  # 3 miles at 6 mph

    base_minutes = bath + blow_dry + walk_minutes

    extras: List[Tuple[str, int]] = [
        ("After the bath, Marcus spends 8 minutes brushing his dog.", 8),
        ("He then spends half as long trimming her nails as he spent brushing her.", 4),
        ("Before the walk, he drives 9 miles to the trailhead at 30 miles per hour.", int(Fraction(9, 30) * 60)),
        ("At the trailhead, he spends 6 minutes filling her water bowl and giving her a treat.", 6),
        ("After the walk, he jogs 2 miles with her at 8 miles per hour.", int(Fraction(2, 8) * 60)),
        ("When they get home, he spends 10 minutes cleaning up muddy paw prints.", 10),
    ]

    extra_minutes = sum(m for _, m in extras[:k])

    # Keep the story in time order: first up through the trailhead (extras 0..3),
    # then the walk, then any after-walk activities (extras 4..5).
    pre_walk_text = " ".join(txt for txt, _ in extras[:min(k, 4)])
    post_walk_text = " ".join(txt for txt, _ in extras[4:k])

    core1 = "Marcus spends 20 minutes giving his dog a bath and half as long blow-drying her. "
    core2 = "Then he takes her for a walk along a 3-mile trail. If Marcus walks at 6 miles per hour, "
    q = (
        core1
        + (pre_walk_text + " " if pre_walk_text else "")
        + core2
        + (post_walk_text + " " if post_walk_text else "")
        + "how much time does he spend with his dog total?"
    )

    return q, base_minutes + extra_minutes


def gen_crayons(k: int) -> Tuple[str, int]:
    """
    Extend the crayons counting question by adding up to k extra box statements (0..6).
    Returns: (question_text, total_crayons)
    """
    assert 0 <= k <= 6

    base_total = 6 * 8 + 7 * 5 + 11

    extras = [
        ("green", 4, 6),
        ("purple", 2, 9),
        ("yellow", 3, 10),
        ("brown", 5, 4),
        ("black", 1, 12),
        ("pink", 8, 3),
    ]
    extra_total = sum(num * per for _, num, per in extras[:k])

    extra_text = " ".join(
        f"There are {num} box{'es' if num != 1 else ''} of crayons that "
        f"{'hold' if num != 1 else 'holds'} {per} {color} crayon{'s' if per != 1 else ''}."
        for color, num, per in extras[:k]
    )

    core = (
        "There are 6 boxes of crayons that hold 8 orange crayons. "
        "There are 7 boxes of crayons that have 5 blue crayons. "
        "There is 1 box of 11 red crayons. "
    )
    q = core + (extra_text + " " if extra_text else "") + "How many crayons are there in total?"

    return q, base_total + extra_total


def gen_zoo_tarantulas(k: int) -> Tuple[str, int]:
    """
    Extend the zoo-legs question by adding up to k extra animal sightings (0..6).
    The numbers are chosen so the remaining legs are always divisible by 8.

    Returns: (question_text, tarantulas_needed)
    """
    assert 0 <= k <= 6

    goal_legs = 1100
    base_animals = [("chimps", 12, 2), ("lions", 8, 4), ("lizards", 5, 4)]
    extras = [
        ("penguins", 8, 2),
        ("elephants", 6, 4),
        ("parrots", 12, 2),
        ("turtles", 4, 4),
        ("scorpions", 3, 8),
        ("frogs", 8, 4),
    ]

    legs_seen = sum(c * legs for _, c, legs in base_animals) + sum(c * legs for _, c, legs in extras[:k])
    remaining = goal_legs - legs_seen
    assert remaining % 8 == 0
    tarantulas_needed = remaining // 8

    extras_text = " ".join(f"He also saw {c} {name}." for name, c, _ in extras[:k])
    core = (
        "Borgnine wants to see 1100 legs at the zoo. "
        "He has already seen 12 chimps, 8 lions, and 5 lizards. "
    )
    q = core + (extras_text + " " if extras_text else "") + "He is next headed to see the tarantulas. How many tarantulas does he need to see to meet his goal?"

    return q, tarantulas_needed


def gen_stationery_lengths(k: int) -> Tuple[str, int]:
    """
    Extend the stationery-length chain by adding up to k extra stationery items (0..6).
    Returns: (question_text, total_length_cm)
    """
    assert 0 <= k <= 6

    lengths: Dict[str, int] = {}
    lengths["pencil"] = 12
    lengths["pen"] = lengths["pencil"] - 2
    lengths["rubber"] = lengths["pen"] - 3

    extras_defs = [
        ("marker", lambda L: L["pencil"] + 1, "A marker is 1 centimeter longer than the pencil."),
        ("crayon", lambda L: L["rubber"] - 2, "A crayon is 2 centimeters shorter than the rubber."),
        ("ruler", lambda L: 2 * L["pencil"], "A ruler is twice as long as the pencil."),
        ("paintbrush", lambda L: L["pen"] + 3, "A paintbrush is 3 centimeters longer than the pen."),
        ("stylus", lambda L: int(Fraction(1, 2) * L["ruler"]), "A stylus is half as long as the ruler."),
        ("paperclip chain", lambda L: L["stylus"] - 4, "A paperclip chain is 4 centimeters shorter than the stylus."),
    ]

    clauses: List[str] = []
    for i in range(k):
        item, fn, clause = extras_defs[i]
        key = item.replace(" ", "_")
        lengths[key] = int(fn(lengths))
        clauses.append(clause)

    included = ["rubber", "pen", "pencil"] + [extras_defs[i][0].replace(" ", "_") for i in range(k)]
    total_len = sum(lengths[key] for key in included)

    extras_text = " ".join(clauses)
    core = (
        "A pen is longer than the rubber by 3 centimeters, and shorter than the pencil by 2 centimeters. "
        "Considering the pencil is 12 centimeters long, "
    )
    q = core + (extras_text + " " if extras_text else "") + "what is the total length of the rubber, the pen, the pencil, and all the other stationery items mentioned altogether?"

    return q, total_len


def gen_consecutive_ages(k: int) -> Tuple[str, int]:
    """
    Extend the consecutive-ages question by adding up to k additional younger brothers (0..6).
    The sum of the three OLDEST brothers is fixed at 96, and their ages are consecutive integers.

    Returns: (question_text, youngest_age)
    """
    assert 0 <= k <= 6

    sum_three_oldest = 96
    youngest_of_three_oldest = (sum_three_oldest - 3) // 3  # x+(x+1)+(x+2) = 96
    assert 3 * youngest_of_three_oldest + 3 == sum_three_oldest

    younger_names = ["Liam", "Noah", "Owen", "Ethan", "Mason", "Lucas"]

    current_youngest = youngest_of_three_oldest
    clauses: List[str] = []
    for i in range(k):
        name = younger_names[i]
        if i == 0:
            ref = "the youngest of the three oldest brothers"
        else:
            ref = younger_names[i - 1]
        clauses.append(f"They also have a younger brother named {name} who is 1 year younger than {ref}.")
        current_youngest -= 1

    q = (
        "The ages of the three oldest brothers are consecutive integers with the sum of 96. "
        + (" ".join(clauses) + " " if clauses else "")
        + "How old is the youngest brother?"
    )

    return q, current_youngest


def gen_spelling_contest(k: int) -> Tuple[str, int]:
    """
    Extend the spelling-contest question by adding up to k extra competitors (0..6).

    Returns: (question_text, total_questions_asked)

    Interpretation: since each competitor was asked a different set of questions,
    the total number of questions asked in the competition equals the sum of
    (correct + wrong) over all competitors mentioned.
    """
    assert 0 <= k <= 6

    correct: Dict[str, int] = {}
    wrong: Dict[str, int] = {}

    # Base competitors
    correct["Drew"] = 20
    wrong["Drew"] = 6

    correct["Carla"] = 14
    wrong["Carla"] = 2 * wrong["Drew"]  # twice as many wrong as Drew got wrong

    # Extra competitors: (name, compute_fn, clause_text)
    # NOTE: All extra competitors have fewer than 20 correct so Drew still "won the competition".
    extras_defs = [
        (
            "Maya",
            lambda c, w: (c["Drew"] - 3, w["Drew"] + 2),
            "Another competitor, Maya, got 3 fewer questions correct than Drew and 2 more questions wrong than Drew.",
        ),
        (
            "Eli",
            lambda c, w: (c["Carla"] + 5, w["Carla"] // 2),
            "Another competitor, Eli, got 5 more questions correct than Carla and half as many questions wrong as Carla.",
        ),
        (
            "Sasha",
            lambda c, w: (2 * w["Drew"], w["Carla"] - 4),
            "Another competitor, Sasha, got twice as many questions correct as the number of questions Drew got wrong, and she got 4 fewer questions wrong than Carla.",
        ),
        (
            "Noah",
            lambda c, w: (c["Carla"], w["Drew"] - 3),
            "Another competitor, Noah, got the same number of questions correct as Carla and 3 fewer questions wrong than Drew.",
        ),
        (
            "Priya",
            lambda c, w: (c["Eli"] - 2, w["Sasha"] + 1),
            "Another competitor, Priya, got 2 fewer questions correct than Eli and 1 more question wrong than Sasha.",
        ),
        (
            "Jamal",
            lambda c, w: (c["Maya"] - 4, 2 * w["Noah"]),
            "Another competitor, Jamal, got 4 fewer questions correct than Maya and twice as many questions wrong as Noah.",
        ),
    ]

    clauses: List[str] = []
    for i in range(k):
        name, fn, clause = extras_defs[i]
        c_val, w_val = fn(correct, wrong)
        correct[name] = int(c_val)
        wrong[name] = int(w_val)
        clauses.append(clause)

        # sanity checks
        assert correct[name] >= 0 and wrong[name] >= 0
        assert correct[name] < correct["Drew"]  # keep Drew as clear winner

    # total questions asked (each competitor got different questions)
    names = ["Drew", "Carla"] + [extras_defs[i][0] for i in range(k)]
    total_questions = sum(correct[n] + wrong[n] for n in names)

    core = (
        "In a spelling contest held in her school, Drew got 20 questions correct, winning her the competition. "
        "She got six questions wrong, and her competitor Carla got 14 questions correct and twice as many questions wrong "
        "as the number of questions Drew got wrong. "
    )
    extras_text = (" ".join(clauses) + " ") if clauses else ""
    q = core + extras_text + "If each competitor was asked a different question, how many questions were asked in the competition?"

    return q, total_questions


def gen_bike_year_phases(k: int) -> Tuple[str, int]:
    """
    Extend the biking-in-a-year question by adding up to k additional riding phases (0..6).
    The year is treated as 365 days.

    Returns: (question_text, total_miles)
    """
    assert 0 <= k <= 6

    year_days = 365
    first_days = 183
    first_rate = 30

    # Extra phases happen AFTER the first 183 days, before the final "remaining days" phase.
    # (days_in_phase, miles_per_day)
    extras_phases: List[Tuple[int, int]] = [
        (20, 25),
        (30, 40),
        (25, 20),
        (15, 50),
        (10, 33),
        (12, 45),
    ]

    used_extra_days = sum(d for d, _ in extras_phases[:k])
    remaining_days = year_days - first_days - used_extra_days
    assert remaining_days >= 0

    last_rate = 35

    total_miles = first_days * first_rate
    total_miles += sum(d * r for d, r in extras_phases[:k])
    total_miles += remaining_days * last_rate

    clauses = []
    for d, r in extras_phases[:k]:
        clauses.append(f"For the next {d} days of the year, he rides {r} miles per day.")
    extras_text = (" ".join(clauses) + " ") if clauses else ""

    q = (
        f"Tom puts {first_rate} miles per day on his bike for the first {first_days} days of the year. "
        + extras_text
        + f"For the remaining {remaining_days} days of the year he goes {last_rate} miles per day. "
        "How many miles does he ride for the year?"
    )

    return q, total_miles


def gen_two_meals_fruit(k: int) -> Tuple[str, int]:
    """
    Extend the two-meals fruit question by adding up to k new kinds of fruit (0..6).
    Each added fruit includes:
      - how many were eaten last night, and
      - how many are planned for today, described via a comparison (like the original).

    Returns: (question_text, total_fruits_across_two_meals)
    """
    assert 0 <= k <= 6

    last: Dict[str, int] = {"apples": 3, "bananas": 1, "oranges": 4}

    today: Dict[str, int] = {}
    today["apples"] = last["apples"] + 4
    today["bananas"] = 10 * last["bananas"]
    today["oranges"] = 2 * today["apples"]

    # Extra fruit definitions (in order). Each item:
    # (fruit_key, last_night_count, today_compute_fn, clause_text)
    extras_defs = [
        (
            "pears",
            2,
            lambda L, T: T["apples"] - 2,
            "Last night she also had 2 pears. Today she is planning to have 2 fewer pears than the number of apples she'll have today.",
        ),
        (
            "peaches",
            3,
            lambda L, T: 3 * L["peaches"],
            "Last night she also had 3 peaches. Today she is planning to have three times as many peaches as she ate last night.",
        ),
        (
            "grapes",
            8,
            lambda L, T: T["oranges"] // 2,
            "Last night she also ate 8 grapes. Today she is planning to have half as many grapes as the oranges she'll have today.",
        ),
        (
            "plums",
            5,
            lambda L, T: L["bananas"] + 1,
            "Last night she also had 5 plums. Today she is planning to have 1 more plum than the number of bananas she ate last night.",
        ),
        (
            "strawberries",
            10,
            lambda L, T: 2 * T["pears"],
            "Last night she also had 10 strawberries. Today she is planning to have twice as many strawberries as the pears she'll have today.",
        ),
        (
            "kiwis",
            4,
            lambda L, T: (T["apples"] + T["bananas"]) - 3,
            "Last night she also had 4 kiwis. Today she is planning to have 3 fewer kiwis than the total number of apples and bananas she'll have today.",
        ),
    ]

    clauses: List[str] = []
    for i in range(k):
        fruit, last_count, today_fn, clause = extras_defs[i]
        last[fruit] = int(last_count)
        today[fruit] = int(today_fn(last, today))
        clauses.append(clause)

        assert last[fruit] >= 0 and today[fruit] >= 0

    total_last = sum(last.values())
    total_today = sum(today.values())
    total_two_meals = total_last + total_today

    core = (
        "Last night Kannon had 3 apples, a banana, and 4 oranges for dinner. "
        "Today she is planning to have 4 more apples than last night, 10 times as many bananas as she ate last night, "
        "and twice as many oranges as apples she'll have today. "
    )
    extras_text = (" ".join(clauses) + " ") if clauses else ""
    q = core + extras_text + "How many fruits would she have eaten so far in the two meals?"

    return q, total_two_meals


def gen_run_two_arenas(k: int) -> Tuple[str, int]:
    """
    Extend the running-around-two-arenas question by adding up to k extra runners (0..6).

    Returns: (question_text, total_time_hours)
    """
    assert 0 <= k <= 6

    times: Dict[str, int] = {}
    times["Emma"] = 20
    times["Fernando"] = 2 * times["Emma"]

    extras_defs = [
        ("Gina",   lambda t: t["Fernando"] // 2,
         "Gina takes half as long as Fernando to run all around the two arenas."),
        ("Harold", lambda t: t["Emma"] + 10,
         "Harold takes 10 hours more than Emma to run all around the two arenas."),
        ("Ivy",    lambda t: t["Fernando"] - 5,
         "Ivy takes 5 hours less than Fernando to run all around the two arenas."),
        ("Jack",   lambda t: 3 * t["Gina"],
         "Jack takes three times as long as Gina to run all around the two arenas."),
        ("Kara",   lambda t: t["Jack"] // 4,
         "Kara takes one-fourth as long as Jack to run all around the two arenas."),
        ("Luis",   lambda t: t["Kara"] + 2,
         "Luis takes 2 hours more than Kara to run all around the two arenas."),
    ]

    clauses: List[str] = []
    for i in range(k):
        name, fn, clause = extras_defs[i]
        times[name] = int(fn(times))
        clauses.append(clause)
        assert times[name] >= 0

    names = ["Emma", "Fernando"] + [extras_defs[i][0] for i in range(k)]
    total_time = sum(times[n] for n in names)

    extras_text = (" ".join(clauses) + " ") if clauses else ""
    tail = (
        "what's the total time it takes both of them to run all around the two arenas?"
        if k == 0
        else "what's the total time it takes everyone mentioned to run all around the two arenas?"
    )

    q = (
        "Emma can run all-around two arenas in 20 hours. "
        "If it takes Fernando twice as long to run his course, "
        + extras_text
        + tail
    )

    return q, total_time


def gen_fruit_grower_money(k: int) -> Tuple[str, int]:
    """
    Extend the fruit-grower revenue question by adding up to k extra fruit types (0..6).
    All fruits sell for the same $50 per kg.

    Returns: (question_text, total_money_dollars)
    """
    assert 0 <= k <= 6

    kg: Dict[str, int] = {}
    kg["mangoes"] = 400
    kg["apples"] = 2 * kg["mangoes"]
    kg["oranges"] = kg["mangoes"] + 200

    price_per_kg = 50

    extras_defs = [
        ("bananas", lambda w: w["apples"] - 100,
         "He also produced 100 kg fewer bananas than apples."),
        ("peaches", lambda w: w["oranges"] // 2,
         "The total produce of peaches was half the total produce of oranges."),
        ("guavas", lambda w: w["mangoes"] // 4,
         "He harvested one-fourth as many guavas as mangoes."),
        ("lemons", lambda w: w["peaches"] + 50,
         "The total produce of lemons was 50 kg more than the total produce of peaches."),
        ("plums", lambda w: 2 * w["guavas"],
         "He produced twice as many plums as guavas."),
        ("cherries", lambda w: w["oranges"] - 150,
         "He harvested 150 kg fewer cherries than oranges."),
    ]

    clauses: List[str] = []
    for i in range(k):
        fruit, fn, clause = extras_defs[i]
        kg[fruit] = int(fn(kg))
        clauses.append(clause)
        assert kg[fruit] >= 0

    total_kg = sum(kg.values())
    total_money = total_kg * price_per_kg

    extras_text = (" ".join(clauses) + " ") if clauses else ""
    q = (
        "Mr. Sergio is a fruit grower producing apples, mangoes and oranges. "
        "In a particular season, the total produce of apples was twice the total produce of mangoes, "
        "and the total produce of oranges was 200 kg more than that of mangoes. "
        + extras_text +
        "If his farm produced 400 kg of mangoes and he sold the fruits at $50 per kg, "
        "calculate the total amount of money he got in that season."
    )

    return q, total_money


def gen_opal_savings_rounds(k: int) -> Tuple[str, int]:
    """
    Extend the horse-race savings question by adding up to k additional betting rounds (0..6).

    Process:
      - Start with $100 winnings.
      - She saves half of her current money and bets the other half.
      - The bet returns a profit percentage; her "earnings" are bet * (1 + profit).
      - She saves half of those earnings.
      - For any further rounds, she bets the money she did NOT put into savings (the other half of earnings).

    Returns: (question_text, total_saved_dollars)
    """
    assert 0 <= k <= 6

    start = Fraction(100, 1)

    # Base: first bet profit is 60%
    base_profit = Fraction(3, 5)  # 60%
    saved = Fraction(0, 1)

    # Round 1: save half of winnings, bet half
    saved += start / 2
    bet = start / 2

    # Round 2: 60% profit, then save half of earnings
    earnings = bet * (1 + base_profit)
    saved += earnings / 2
    remaining_to_bet = earnings / 2

    extras = [
        (Fraction(1, 4), "25%"),
        (Fraction(1, 5), "20%"),
        (Fraction(1, 1), "100%"),
        (Fraction(3, 5), "60%"),
        (Fraction(1, 2), "50%"),
        (Fraction(1, 3), "33 1/3%"),
    ]

    clauses: List[str] = []
    for i in range(k):
        prof, label = extras[i]
        clauses.append(
            f"She then bet the money she did not put into savings, made a {label} profit, "
            "and again put half of her new earnings into savings."
        )
        earnings_i = remaining_to_bet * (1 + prof)
        saved += earnings_i / 2
        remaining_to_bet = earnings_i / 2

    # We designed profits so saved stays an integer number of dollars
    assert saved.denominator == 1
    total_saved = int(saved)

    extras_text = (" ".join(clauses) + " ") if clauses else ""
    q = (
        "Opal won $100.00 betting on a horse race. "
        "She put half of her winnings into savings and bet the other half of her winnings. "
        "This time, she made a 60% profit and again, put half of her earnings into savings. "
        + extras_text +
        "How much money in total did she put into her savings?"
    )

    return q, total_saved


def gen_country_club_expenditures(k: int) -> Tuple[str, int]:
    """
    Extend the country club cost question by adding up to k extra expenditure types (0..6).

    Returns: (question_text, john_pays_first_year_dollars)
    """
    assert 0 <= k <= 6

    family_members = 4  # John + 3 other family members
    join_fee_per_person = 4000
    monthly_cost_per_person = 1000
    months = 12

    total_cost = family_members * join_fee_per_person + months * family_members * monthly_cost_per_person

    extras_defs = [
        ("There is also an annual maintenance fee of $600 per person.", family_members * 600),
        ("Each person also pays $50 per month for gym access.", family_members * 50 * months),
        ("The family buys 2 guest passes each month at $30 each.", 2 * 30 * months),
        ("There is a one-time processing fee of $500 for the whole family.", 500),
        ("They rent lockers for 3 of the members at $120 per locker for the year.", 3 * 120),
        ("They spend $200 every quarter on club events for the family.", 200 * 4),
    ]

    clauses = [txt for txt, _ in extras_defs[:k]]
    total_cost += sum(cost for _, cost in extras_defs[:k])

    # John pays half the total
    assert total_cost % 2 == 0
    john_pays = total_cost // 2

    extras_text = (" ".join(clauses) + " ") if clauses else ""
    q = (
        "John joins a country club with 3 other members of his family. "
        "The fee to join is $4000 per person. "
        "There is also a monthly cost of $1000 per person. "
        + extras_text +
        "John agrees to pay half the total cost for the whole family. "
        "How much does John pay for the first year?"
    )

    return q, john_pays


def gen_national_parks(k: int) -> Tuple[str, int]:
    """
    Extend the original 2-park question by adding up to k extra TYPES of animals (0..6),
    while keeping the same two parks: Safari National park and Savanna National park.

    The question always asks for the TOTAL number of animals in Savanna National park,
    counting lions/snakes/giraffes plus any added animal types.

    Returns: (question_text, total_animals_in_savanna)
    """
    assert 0 <= k <= 6

    safari: Dict[str, int] = {}
    savanna: Dict[str, int] = {}

    # --- Base animals (from the original problem) ---
    safari["lions"] = 100
    safari["snakes"] = safari["lions"] // 2
    safari["giraffes"] = safari["snakes"] - 10

    savanna["lions"] = 2 * safari["lions"]
    savanna["snakes"] = 3 * safari["snakes"]
    savanna["giraffes"] = safari["giraffes"] + 20

    # --- Extra animal types (add up to k) ---
    # Each entry: (animal_key, safari_count_fn, savanna_count_fn, text_clause)
    extras_defs = [
        (
            "zebras",
            lambda S, V: 2 * S["giraffes"],
            lambda S, V: V["lions"] // 2,
            "Safari National park also has twice as many zebras as giraffes. "
            "Savanna National park has half as many zebras as lions.",
        ),
        (
            "elephants",
            lambda S, V: S["snakes"] - 5,
            lambda S, V: S["elephants"] + 20,
            "Safari National park also has 5 fewer elephants than snakes. "
            "Savanna National park has 20 more elephants than Safari National park.",
        ),
        (
            "monkeys",
            lambda S, V: 3 * S["giraffes"],
            lambda S, V: V["zebras"] - 10,
            "Safari National park also has three times as many monkeys as giraffes. "
            "Savanna National park has 10 fewer monkeys than its zebras.",
        ),
        (
            "crocodiles",
            lambda S, V: S["lions"] // 5,
            lambda S, V: 2 * S["crocodiles"],
            "Safari National park also has one-fifth as many crocodiles as lions. "
            "Savanna National park has twice as many crocodiles as Safari National park.",
        ),
        (
            "flamingos",
            lambda S, V: S["giraffes"] - 10,
            lambda S, V: 3 * S["flamingos"],
            "Safari National park also has 10 fewer flamingos than giraffes. "
            "Savanna National park has three times as many flamingos as Safari National park.",
        ),
        (
            "hippos",
            lambda S, V: S["elephants"],
            lambda S, V: V["crocodiles"] + 5,
            "Safari National park also has as many hippos as elephants. "
            "Savanna National park has 5 more hippos than its crocodiles.",
        ),
    ]

    clauses: List[str] = []
    for i in range(k):
        animal, safari_fn, savanna_fn, clause = extras_defs[i]
        safari[animal] = int(safari_fn(safari, savanna))
        savanna[animal] = int(savanna_fn(safari, savanna))
        clauses.append(clause)

        # basic safety checks (non-negative counts)
        assert safari[animal] >= 0
        assert savanna[animal] >= 0

    extras_text = (" ".join(clauses) + " ") if clauses else ""

    q = (
        "Safari National park has 100 lions, half as many snakes, and 10 fewer giraffes than snakes. "
        "On the other hand, Savanna National park has double as many lions as Safari National park, "
        "triple as many snakes as Safari National park, and 20 more giraffes than Safari National park. "
        + extras_text
        + "How many animals in total does Savanna National park have?"
    )

    total_animals_in_savanna = sum(savanna.values())
    return q, total_animals_in_savanna


def gen_zoo_animals_spend(k: int) -> Tuple[str, int]:
    """
    Extend the zoo-animal purchase question by adding up to k extra animal types (0..6).

    Returns: (question_text, total_spent_dollars)
    """
    assert 0 <= k <= 6

    qty: Dict[str, int] = {}
    cost_each: Dict[str, int] = {}

    # Base animals
    qty["goats"] = 3
    cost_each["goats"] = 400

    qty["llamas"] = 2 * qty["goats"]
    cost_each["llamas"] = int(cost_each["goats"] * 1.5)  # 50% more each

    extras_defs = [
        (
            "alpacas",
            lambda q, c: (4, int(c["llamas"] * 0.75)),  # 25% less than a llama
            "He also buys 4 alpacas that cost 25% less each than a llama.",
        ),
        (
            "zebras",
            lambda q, c: (2, 2 * c["goats"]),
            "He buys 2 zebras that cost twice as much as a goat each.",
        ),
        (
            "ponies",
            lambda q, c: (q["goats"], c["goats"] + 100),
            "He buys as many ponies as goats, and each pony costs $100 more than a goat.",
        ),
        (
            "peacocks",
            lambda q, c: (5, c["ponies"] // 2),
            "He buys 5 peacocks that cost 50% less each than a pony.",
        ),
        (
            "kangaroos",
            lambda q, c: (3, int(c["zebras"] * 1.2)),
            "He buys 3 kangaroos that cost 20% more each than a zebra.",
        ),
        (
            "rabbits",
            lambda q, c: (8, 75),
            "He buys 8 rabbits for $75 each.",
        ),
    ]

    clauses: List[str] = []
    for i in range(k):
        animal, fn, clause = extras_defs[i]
        q_i, c_i = fn(qty, cost_each)
        qty[animal] = int(q_i)
        cost_each[animal] = int(c_i)
        clauses.append(clause)

        assert qty[animal] >= 0 and cost_each[animal] >= 0

    total_spent = sum(qty[a] * cost_each[a] for a in qty.keys())

    extras_text = (" ".join(clauses) + " ") if clauses else ""
    q = (
        "Tim decides to get animals for his zoo. "
        "He buys 3 goats for $400 each. "
        "He gets twice as many llamas which cost 50% more each. "
        + extras_text +
        "How much did he spend?"
    )

    return q, total_spent


def gen_circus_entertainers(k: int) -> Tuple[str, int]:
    """
    Extend the circus entertainers standing-time question by adding up to k extra entertainers (0..6).

    Returns: (question_text, total_minutes)
    """
    assert 0 <= k <= 6

    minutes: Dict[str, int] = {}
    minutes["Pulsar"] = 10
    minutes["Polly"] = 3 * minutes["Pulsar"]
    minutes["Petra"] = minutes["Polly"] // 6

    extras_defs = [
        ("Marvin", lambda m: m["Polly"] // 2,
         "Then, Marvin the mischievous monkey stands on his back legs for half as long as Polly."),
        ("Zelda", lambda m: m["Pulsar"] + 4,
         "Next, Zelda the zany zebra stands on her back legs for 4 minutes more than Pulsar."),
        ("Trixie", lambda m: 2 * m["Petra"],
         "After that, Trixie the tiny tiger stands on her back legs for twice as long as Petra."),
        ("Wally", lambda m: m["Marvin"] // 3,
         "Then, Wally the waddling walrus stands on his back legs for one-third as long as Marvin."),
        ("Greta", lambda m: m["Zelda"] - 2,
         "Next, Greta the graceful giraffe stands on her back legs for 2 minutes less than Zelda."),
        ("Hector", lambda m: 3 * m["Wally"],
         "Finally, Hector the hopping hare stands on his back legs for three times as long as Wally."),
    ]

    clauses: List[str] = []
    for i in range(k):
        name, fn, clause = extras_defs[i]
        minutes[name] = int(fn(minutes))
        clauses.append(clause)
        assert minutes[name] >= 0

    total_minutes = sum(minutes.values())

    extras_text = (" ".join(clauses) + " ") if clauses else ""
    tail = (
        "that the three entertainers stand on their back legs?"
        if k == 0
        else "that all the entertainers mentioned stand on their back legs?"
    )

    q = (
        "Pulsar, the shuffle-dancing bear, Polly, the pink prancing poodle, and Petra, the proud portly pachyderm, "
        "are entertainers at the Big Top Circus. "
        "In one show, Pulsar stands on his two back legs for a total of 10 minutes. "
        "Then, Polly stands on her back legs for three times as long as Pulsar. "
        "And then, finally, Petra stands on his back legs for one-sixth as long as Polly. "
        + extras_text +
        "What is the combined length of time, in minutes, "
        + tail
    )

    return q, total_minutes


def gen_sailboat_sails(k: int) -> Tuple[str, int]:
    """
    Extend the sailboat-distance question by adding up to k extra sail-configuration travel segments (0..6).

    Returns: (question_text, total_distance_in_land_miles)

    Notes:
    - 1 knot = 1 nautical mile per hour
    - 1 nautical mile = 1.15 land miles (= 23/20 land miles)
    - The added segments are chosen so the final land-mile distance is an integer.
    """
    assert 0 <= k <= 6

    # Base speeds (knots)
    speed_one_sail = 25
    speed_two_sails = 50

    # Base travel times (hours)
    base_hours_one = 4
    base_hours_two = 4

    nm_to_land = Fraction(23, 20)  # 1.15 exactly

    base_nm = speed_one_sail * base_hours_one + speed_two_sails * base_hours_two

    # Each extra clause adds (speed * time) nautical miles; chosen so totals stay multiples of 20 nm.
    extras: List[Tuple[str, int]] = [
        (
            "Under the same conditions, the use of three sails brings the ship to the speed of 60 knots, "
            "and she sails for 1 hour using three sails.",
            60 * 1,
        ),
        (
            "Under the same conditions, the use of four sails brings the ship to the speed of 70 knots, "
            "and she sails for 2 hours using four sails.",
            70 * 2,
        ),
        (
            "Under the same conditions, sailing with one reefed sail brings her ship to the speed of 15 knots, "
            "and she sails for 4 hours with a reefed sail.",
            15 * 4,
        ),
        (
            "Under the same conditions, the use of two sails while carrying extra cargo brings her ship to the speed of 40 knots, "
            "and she sails for 1 hour using two sails with the extra cargo.",
            40 * 1,
        ),
        (
            "Under the same conditions, the use of one sail with a favorable current brings her ship to the speed of 30 knots, "
            "and she sails for 2 hours with that setup.",
            30 * 2,
        ),
        (
            "Under the same conditions, if she lowers all sails and just drifts, her ship moves at 5 knots, "
            "and she drifts for 4 hours.",
            5 * 4,
        ),
    ]

    extra_text = " ".join(clause for clause, _ in extras[:k])
    extra_nm = sum(dist_nm for _, dist_nm in extras[:k])

    total_nm = base_nm + extra_nm
    total_land = total_nm * nm_to_land
    assert total_land.denominator == 1  # by construction

    q = (
        "Buffy has a sailboat with a planing hull that she sails in the Gulf of Mexico. "
        "Under optimal wind conditions, the use of two sails brings the ship to the speed of 50 knots, "
        "whereas under the same conditions, the use of one sail brings her ship to the speed of 25 knots. "
        "A knot is a nautical term for speed equal to 1 nautical mile per hour, and one nautical mile equals 1.15 land miles. "
        f"If she travels in her sailboat under optimal wind conditions for {base_hours_one} hours with one sail "
        f"and then for another {base_hours_two} hours with two sails, "
    )
    q += (extra_text + " " if extra_text else "")
    q += "what is the total distance, in land miles, that she will travel?"

    return q, int(total_land)


def gen_market_remaining(k: int) -> Tuple[str, int]:
    """
    Extend Peter's market purchases by adding up to k additional items (0..6).

    Returns: (question_text, remaining_money_dollars)
    """
    assert 0 <= k <= 6

    start_money = 500

    base_items: List[Tuple[str, int]] = [
        ("6 kilos of potatoes for $2 per kilo", 6 * 2),
        ("9 kilos of tomato for $3 per kilo", 9 * 3),
        ("5 kilos of cucumbers for $4 per kilo", 5 * 4),
        ("3 kilos of bananas for $5 per kilo", 3 * 5),
    ]
    base_cost = sum(cost for _, cost in base_items)

    extras: List[Tuple[str, int]] = [
        ("4 kilos of rice for $6 per kilo", 4 * 6),
        ("2 kilos of onions for $3 per kilo", 2 * 3),
        ("5 kilos of apples for $4 per kilo", 5 * 4),
        ("3 kilos of chicken for $8 per kilo", 3 * 8),
        ("10 loaves of bread for $2 each", 10 * 2),
        ("2 bottles of olive oil for $11 each", 2 * 11),
    ]

    extra_cost = sum(cost for _, cost in extras[:k])
    total_cost = base_cost + extra_cost
    remaining = start_money - total_cost
    assert remaining >= 0

    base_text = (
        "Peter carried $500 to the market. "
        "He bought 6 kilos of potatoes for $2 per kilo, 9 kilos of tomato for $3 per kilo, "
        "5 kilos of cucumbers for $4 per kilo, and 3 kilos of bananas for $5 per kilo. "
    )

    extra_text = ""
    if k > 0:
        extra_list = ", ".join(desc for desc, _ in extras[:k])
        extra_text = f"He also bought {extra_list}. "

    q = base_text + extra_text + "How much is Peter's remaining money?"
    return q, remaining


def gen_cafeteria(k: int) -> Tuple[str, int]:
    """
    Extend the cafeteria bill question by adding up to k additional friend/group order clauses (0..6).

    Returns: (question_text, total_cost_dollars)
    """
    assert 0 <= k <= 6

    coffee = 4
    cake = 7
    ice_cream = 3

    # Base orders
    mell_cost = 2 * coffee + 1 * cake
    friend_cost_each = mell_cost + ice_cream
    base_total = mell_cost + 2 * friend_cost_each

    extras: List[Tuple[str, int]] = [
        (
            "One more friend, Nina, ordered the same as Mell, but she also bought 2 bowls of ice cream.",
            mell_cost + 2 * ice_cream,
        ),
        (
            "Two more friends each ordered 1 cup of coffee and 2 bowls of ice cream.",
            2 * (1 * coffee + 2 * ice_cream),
        ),
        (
            "Another friend, Omar, ordered 3 cups of coffee and 1 piece of cake.",
            3 * coffee + cake,
        ),
        (
            "Three of her friends each ordered 1 cup of coffee and 1 piece of cake.",
            3 * (coffee + cake),
        ),
        (
            "One friend, Ava, only bought 4 bowls of ice cream.",
            4 * ice_cream,
        ),
        (
            "Two friends placed one combined order of 5 cups of coffee and 2 pieces of cake.",
            5 * coffee + 2 * cake,
        ),
    ]

    extras_text = " ".join(clause for clause, _ in extras[:k])
    extra_cost = sum(cost for _, cost in extras[:k])
    total_cost = base_total + extra_cost

    q = (
        "Mell went to a cafeteria to spend some time with her friends. "
        "She ordered two cups of coffee and one piece of cake. "
        "Two of her friends ordered the same, but each of them also bought a bowl of ice cream. "
        f"One cup of coffee is ${coffee}, one piece of cake is ${cake}, and a bowl of ice cream is ${ice_cream}. "
    )
    q += (extras_text + " " if extras_text else "")
    q += "How much money did Mell and her friends need to pay at the cafeteria?"

    return q, total_cost


def gen_lasagna(k: int) -> Tuple[str, int]:
    """
    Extend the lasagna pieces question by adding up to k additional people (0..6).
    The answer is the total number of pieces Manny should cut the lasagna into.

    Interpretation used (spelled out by the added clauses):
    - Manny cuts the lasagna into equal pieces.
    - Manny eats exactly 1 whole piece.
    - If someone eats only part of a piece, they still need a piece to be served,
      and someone else is described as eating the leftover part.

    Returns: (question_text, pieces_to_cut)
    """
    assert 0 <= k <= 6

    # Pieces that must exist (served pieces), not necessarily fully eaten by that person.
    pieces_needed: Dict[str, int] = {
        "Manny": 1,     # eats 1 full piece
        "Kai": 2,       # eats 2 pieces (twice Manny)
        "Raphael": 1,   # served 1 piece but eats only half
        "Lisa": 2,      # eats 2 full pieces (plus leftovers from others)
        "Aaron": 0,     # eats no lasagna
    }

    extras_defs: List[Tuple[str, int, str]] = [
        (
            "Nora",
            1,
            "Another friend, Nora, is picky and will only eat one-third of what Manny eats, "
            "and Lisa will eat whatever Nora leaves of her piece.",
        ),
        (
            "Ben",
            2,
            "Another friend, Ben, will eat the same amount of lasagna as Kai.",
        ),
        (
            "Tina",
            3,
            "Another friend, Tina, loves lasagna and will eat 3 pieces.",
        ),
        (
            "Omar",
            1,
            "Another friend, Omar, will eat half as much as Kai.",
        ),
        (
            "Zoe",
            2,
            "Another friend, Zoe, will eat one more piece of lasagna than Manny.",
        ),
        (
            "Gus",
            1,
            "Another friend, Gus, will only taste lasagna and eat one-fourth of what Manny eats, "
            "and Kai will finish whatever Gus leaves of his piece.",
        ),
    ]

    clauses: List[str] = []
    for i in range(k):
        name, pcs, clause = extras_defs[i]
        pieces_needed[name] = pcs
        clauses.append(clause)

    pieces_to_cut = sum(pieces_needed.values())

    extras_text = " ".join(clauses)
    q = (
        "Manny is making lasagna for dinner with his four friends, Lisa, Raphael, Aaron, and Kai. "
        "He needs to know how many pieces to cut the lasagna into to serve it. "
        "Manny only wants one piece. "
        "Aaron doesn't like lasagna much and will probably only eat garlic bread and salad. "
        "Kai is always hungry and will eat twice as much as Manny. "
        "Raphael always eats half the amount Manny does, but his sister Lisa loves lasagna and will eat two pieces, "
        "plus any Raphael has left of his piece. "
    )
    q += (extras_text + " " if extras_text else "")
    q += "How many pieces should Manny cut his lasagna into?"

    return q, pieces_to_cut


def gen_tshirts_sectors(k: int) -> Tuple[str, int]:
    """
    Extend the t-shirt purchase problem by adding up to k additional color sectors (0..6).

    - Women's t-shirts are always $5 cheaper than men's t-shirts of the same color.
    - For a unique, integer answer, the generated question states that EACH sector has 10 men and 10 women.

    Returns: (question_text, total_spent_dollars)
    """
    assert 0 <= k <= 6

    men_prices: Dict[str, int] = {"white": 20, "black": 18}
    extras: List[Tuple[str, int]] = [
        ("blue", 22),
        ("red", 21),
        ("green", 19),
        ("gray", 17),
        ("yellow", 16),
        ("purple", 23),
    ]

    for color, price in extras[:k]:
        men_prices[color] = price

    sector_colors = ["white", "black"] + [c for c, _ in extras[:k]]
    men_per_sector = 10
    women_per_sector = 10
    total_employees = (men_per_sector + women_per_sector) * len(sector_colors)

    total_spent = 0
    for color in sector_colors:
        mp = men_prices[color]
        wp = mp - 5
        total_spent += men_per_sector * mp + women_per_sector * wp

    head = (
        "Juan bought T-shirts for his employees. He bought shirts for men and women. "
        "Women's t-shirts are $5 cheaper than men's t-shirts of the same color. "
    )

    if k == 0:
        sectors_sentence = (
            "His company has 2 sectors, one in white t-shirts and the other in black t-shirts. "
            "He paid $20 for white men's t-shirts and $18 for black men's t-shirts. "
            "The 2 sectors have the same number of men and women, with a total of 40 employees. "
        )
    else:
        sectors_sentence = (
            f"His company has {2 + k} sectors with different t-shirt colors. "
            "Two sectors use white t-shirts and black t-shirts. "
            "He paid $20 for white men's t-shirts and $18 for black men's t-shirts. "
        )
        extra_price_sentences = " ".join(
            f"He paid ${price} for {color} men's t-shirts." for color, price in extras[:k]
        )
        sectors_sentence += extra_price_sentences + " "
        sectors_sentence += (
            f"Each sector has {men_per_sector} men and {women_per_sector} women, "
            f"so there are {total_employees} employees in total. "
        )

    q = head + sectors_sentence + "How much did he spend total on buying t-shirts?"
    return q, total_spent


def gen_aquarium_fish(k: int) -> Tuple[str, int]:
    """
    Extend the aquarium fish comparison chain by adding up to k additional people (0..6).

    Returns: (question_text, total_fish_all_people)
    """
    assert 0 <= k <= 6

    fish: Dict[str, int] = {}
    fish["Billy"] = 10
    fish["Tony"] = 3 * fish["Billy"]
    fish["Sarah"] = fish["Tony"] + 5
    fish["Bobby"] = 2 * fish["Sarah"]

    extras_defs = [
        ("Nina", lambda f: f["Billy"] - 2,
         "Nina has 2 fewer fish in her aquarium than Billy does."),
        ("Omar", lambda f: f["Bobby"] // 2,
         "Omar has half as many fish in his aquarium as Bobby has."),
        ("Priya", lambda f: 4 * f["Nina"],
         "Priya has 4 times as many fish in her aquarium as Nina has."),
        ("Chen", lambda f: f["Sarah"] + 10,
         "Chen has 10 more fish in his aquarium than Sarah has."),
        ("Luis", lambda f: f["Tony"] + f["Billy"],
         "Luis has as many fish in his aquarium as Tony and Billy have combined."),
        ("Ava", lambda f: f["Omar"] - 3,
         "Ava has 3 fewer fish in her aquarium than Omar has."),
    ]

    clauses: List[str] = []
    for i in range(k):
        name, fn, clause = extras_defs[i]
        fish[name] = int(fn(fish))
        clauses.append(clause)
        assert fish[name] >= 0

    names = ["Bobby", "Sarah", "Tony", "Billy"] + [extras_defs[i][0] for i in range(k)]
    total = sum(fish[n] for n in names)

    extras_text = " ".join(clauses)
    tail = (
        "How many fish do all 4 people have put together?"
        if k == 0
        else "How many fish do all the people mentioned have put together?"
    )

    q = (
        "Bobby has an aquarium with twice as many fish as Sarah's has. "
        "Sarah has 5 more fish in her aquarium than Tony does. "
        "Tony has 3 times as many fish in his aquarium as Billy does. "
        "Billy has 10 fish. "
    )
    q += (extras_text + " " if extras_text else "")
    q += tail

    return q, total


def gen_stuffed_animals(k: int) -> Tuple[str, int]:
    """
    Extend the stuffed-animals comparison question by adding up to k additional people (0..6).

    Returns: (question_text, total_stuffed_animals)
    """
    assert 0 <= k <= 6

    sa: Dict[str, int] = {}
    sa["McKenna"] = 34
    sa["Kenley"] = 2 * sa["McKenna"]
    sa["Tenly"] = sa["Kenley"] + 5

    extras_defs = [
        ("Riley", lambda s: s["Tenly"] - 10,
         "Riley has 10 fewer stuffed animals than Tenly."),
        ("Avery", lambda s: s["Kenley"] // 2,
         "Avery has half as many stuffed animals as Kenley."),
        ("Jordan", lambda s: 3 * s["McKenna"],
         "Jordan has 3 times as many stuffed animals as McKenna."),
        ("Casey", lambda s: s["Avery"] + 7,
         "Casey has 7 more stuffed animals than Avery."),
        ("Morgan", lambda s: s["Riley"] + s["Casey"],
         "Morgan has as many stuffed animals as Riley and Casey have combined."),
        ("Peyton", lambda s: s["McKenna"] - 2,
         "Peyton has 2 fewer stuffed animals than McKenna."),
    ]

    clauses: List[str] = []
    for i in range(k):
        name, fn, clause = extras_defs[i]
        sa[name] = int(fn(sa))
        clauses.append(clause)
        assert sa[name] >= 0

    names = ["McKenna", "Kenley", "Tenly"] + [extras_defs[i][0] for i in range(k)]
    total = sum(sa[n] for n in names)

    extras_text = " ".join(clauses)
    tail = (
        "How many stuffed animals do the three girls have in all?"
        if k == 0
        else "How many stuffed animals do all the people mentioned have in all?"
    )

    q = (
        "McKenna has 34 stuffed animals. "
        "Kenley has twice as many as McKenna. "
        "Tenly has 5 more than Kenley. "
    )
    q += (extras_text + " " if extras_text else "")
    q += tail

    return q, total


def _join_names(names: List[str]) -> str:
    """Join names into a readable English list."""
    if len(names) == 0:
        return ""
    if len(names) == 1:
        return names[0]
    if len(names) == 2:
        return f"{names[0]} and {names[1]}"
    return ", ".join(names[:-1]) + f", and {names[-1]}"


def gen_jelly_bracelets(k: int) -> Tuple[str, int]:
    """
    Extend the jelly-bracelets question by adding up to k additional friends (0..6).
    Robin buys one bracelet for each letter in each friend's FIRST name.

    Returns: (question_text, total_cost_dollars)
    """
    assert 0 <= k <= 6

    cost_per = 2
    base_names = ["Jessica", "Tori", "Lily", "Patrice"]
    extra_names = ["Ava", "Monique", "Ben", "Charlotte", "Ian", "Zoe"]

    names = base_names + extra_names[:k]
    total_cost = cost_per * sum(len(n) for n in names)

    q = (
        "Robin wants to buy jelly bracelets for her friends. "
        "She decides to buy one bracelet for each letter of the first name of her friends. "
        f"Her friends are {_join_names(names)}. "
        f"If each jelly bracelet costs ${cost_per}, what is the total she will spend in dollars?"
    )
    return q, total_cost


def gen_leg_collection(k: int) -> Tuple[str, int]:
    """
    Extend Monroe's legs-counting question by adding up to k extra animal types (0..6).
    Each added clause states the number of that animal AND how many legs each one has.

    Returns: (question_text, total_legs)
    """
    assert 0 <= k <= 6

    base_animals: List[Tuple[str, int, int]] = [
        ("spiders", 8, 8),
        ("ants", 12, 6),
    ]

    extras: List[Tuple[str, int, int, str]] = [
        ("beetles", 5, 6, "He also has 5 beetles."),
        ("scorpions", 3, 8, "He also has 3 scorpions."),
        ("crickets", 10, 6, "He also has 10 crickets."),
        ("crabs", 2, 10, "He also has 2 crabs."),
        ("ladybugs", 7, 6, "He also has 7 ladybugs."),
        ("tarantulas", 4, 8, "He also has 4 tarantulas."),
    ]

    total_legs = sum(count * legs for _, count, legs in base_animals)
    total_legs += sum(count * legs for _, count, legs, _ in extras[:k])

    extras_text = " ".join(clause for _, _, _, clause in extras[:k])

    q = (
        "Monroe has a collection of ants and a collection of spiders. "
        "He has 8 spiders and 12 ants. "
    )
    q += (extras_text + " " if extras_text else "")
    q += "What is the total number of legs in his entire collection?"

    return q, total_legs


def gen_gumballs(k: int) -> Tuple[str, int]:
    """
    Extend the gumball machine question by adding up to k extra colors (0..6).

    Returns: (question_text, total_gumballs)
    """
    assert 0 <= k <= 6

    red = 16
    blue = red // 2
    green = 4 * blue

    counts: Dict[str, int] = {"red": red, "blue": blue, "green": green}

    # (color, compute_fn, clause)
    extras_defs = [
        ("yellow", lambda c: c["blue"],
         "The machine has as many yellow gumballs as blue gumballs."),
        ("purple", lambda c: c["red"] - 3,
         "The machine has 3 fewer purple gumballs than red gumballs."),
        ("orange", lambda c: 2 * c["blue"],
         "The machine has twice as many orange gumballs as blue gumballs."),
        ("pink", lambda c: c["green"] // 4,
         "The machine has one-fourth as many pink gumballs as green gumballs."),
        ("white", lambda c: c["yellow"] + c["blue"],
         "The machine has as many white gumballs as the total number of yellow and blue gumballs combined."),
        ("brown", lambda c: c["purple"] + 5,
         "The machine has 5 more brown gumballs than purple gumballs."),
    ]

    clauses: List[str] = []
    for i in range(k):
        color, fn, clause = extras_defs[i]
        counts[color] = int(fn(counts))
        clauses.append(clause)
        assert counts[color] >= 0

    total = sum(counts.values())

    extras_text = " ".join(clauses)
    q = (
        "A gumball machine has red, green, and blue gumballs. "
        "The machine has half as many blue gumballs as red gumballs. "
        "For each blue gumball, the machine has 4 times as many green gumballs. "
    )
    q += (extras_text + " " if extras_text else "")
    q += "If the machine has 16 red gumballs how many gumballs are in the machine?"

    return q, total



In [9]:
questions = []
ground_truths = []

def gen_question(qnum: int, k: int, base_value: int = 20):
    if qnum == 1:
        return gen_sheep(k, base_value)
    elif qnum == 2:
        return gen_dvds(k)
    elif qnum == 3:
        return gen_clothes(k)
    elif qnum == 4:
        return gen_rainwater(k)
    elif qnum == 5:
        return gen_carlotta(k)
    elif qnum == 6:
        return gen_pills(k)
    elif qnum == 7:
        return gen_snackpacks(k)
    elif qnum == 8:
        return gen_shells(k)
    elif qnum == 9:
        return gen_lunch(k)
    elif qnum == 10:
        return gen_pinata(k)
    elif qnum == 11:
        return gen_paint(k)
    elif qnum == 12:
        return gen_club(k)
    elif qnum == 13:
        return gen_wine(k)
    elif qnum == 14:
        return gen_carpets(k)
    elif qnum == 15:
        return gen_safari(k)
    elif qnum == 16:
        return gen_fish(k)
    elif qnum == 17:
        return gen_basketball(k)
    elif qnum == 18:
        return gen_jewelry(k)
    elif qnum == 19:
        return gen_marbles(k)
    elif qnum == 20:
        return gen_cycling(k)
    elif qnum == 21:
        return gen_omelets(k)
    elif qnum == 22:
        return gen_gardens(k)
    elif qnum == 23:
        return gen_punch(k)
    elif qnum == 24:
        return gen_bread_and_treats(k)
    elif qnum == 25:
        return gen_rehab_centers(k)
    elif qnum == 26:
        return gen_dog_time(k)
    elif qnum == 27:
        return gen_crayons(k)
    elif qnum == 28:
        return gen_zoo_tarantulas(k)
    elif qnum == 29:
        return gen_stationery_lengths(k)
    elif qnum == 30:
        return gen_consecutive_ages(k)
    elif qnum == 31:
        return gen_spelling_contest(k)
    elif qnum == 32:
        return gen_bike_year_phases(k)
    elif qnum == 33:
        return gen_two_meals_fruit(k)
    elif qnum == 34:
        return gen_run_two_arenas(k)
    elif qnum == 35:
        return gen_fruit_grower_money(k)
    elif qnum == 36:
        return gen_opal_savings_rounds(k)
    elif qnum == 37:
        return gen_country_club_expenditures(k)
    elif qnum == 38:
        return gen_national_parks(k)
    elif qnum == 39:
        return gen_zoo_animals_spend(k)
    elif qnum == 40:
        return gen_circus_entertainers(k)
    elif qnum == 41:
        return gen_sailboat_sails(k)
    elif qnum == 42:
        return gen_market_remaining(k)
    elif qnum == 43:
        return gen_cafeteria(k)
    elif qnum == 44:
        return gen_lasagna(k)
    elif qnum == 45:
        return gen_tshirts_sectors(k)
    elif qnum == 46:
        return gen_aquarium_fish(k)
    elif qnum == 47:
        return gen_stuffed_animals(k)
    elif qnum == 48:
        return gen_jelly_bracelets(k)
    elif qnum == 49:
        return gen_leg_collection(k)
    elif qnum == 50:
        return gen_gumballs(k)
    else:        raise ValueError("qnum must be 1..50")

# ============================
# BUILD DATASET: 50 QUESTIONS PER k
# ============================

NUM_Q = 50
K_VALUES = list(range(1, 7))  # k = 1..6

dataset_by_k = {k: [] for k in K_VALUES}

for qnum in range(1, NUM_Q + 1):
    for k in K_VALUES:
        question, correct = gen_question(qnum, k)
        dataset_by_k[k].append({
            "qnum": qnum,               # 1..50
            "k": k,                     # 1..6
            "question": question,
            "ground_truth": correct,
        })

assert all(len(dataset_by_k[k]) == NUM_Q for k in K_VALUES), "Each k must have exactly 50 questions"


In [10]:
# ============================
# MULTI-MODEL EVAL + CSV EXPORT
# ============================

import gc
import time
import pandas as pd
import re
from transformers import AutoModelForCausalLM, AutoTokenizer

# Make sure your cache dir exists
os.makedirs(model_path, exist_ok=True)
os.makedirs(wd, exist_ok=True)

# 1) Map your "friendly names" to HF repo IDs (+ optional settings)
MODEL_REGISTRY = {
    "pythia-7b":        {"hf_repo": "EleutherAI/pythia-6.9b-deduped", "revision": "step3000", "trust_remote_code": False},
    "gemma2-2b":        {"hf_repo": "google/gemma-2-2b",              "trust_remote_code": False},
    "mistral-7b":       {"hf_repo": "mistralai/Mistral-7B-Instruct-v0.3", "trust_remote_code": False},
    "gptJ-6b":          {"hf_repo": "EleutherAI/gpt-j-6B",            "trust_remote_code": False},
    "gpt2-medium":      {"hf_repo": "gpt2-medium",                   "trust_remote_code": False},
    "gemma2-9b":        {"hf_repo": "google/gemma-2-9b",              "trust_remote_code": False},
    "llama2-7b":        {"hf_repo": "meta-llama/Llama-2-7b-chat-hf",  "trust_remote_code": False},
    "llama3-8b":        {"hf_repo": "meta-llama/Meta-Llama-3-8B",     "trust_remote_code": False},
    "llama3-8B-Instruct":{"hf_repo": "meta-llama/Meta-Llama-3-8B-Instruct", "trust_remote_code": False},
    "gemma2-9b-it":     {"hf_repo": "google/gemma-2-9b-it",           "trust_remote_code": False},
    "qwen-1.5B":        {"hf_repo": "Qwen/Qwen2.5-1.5B-Instruct",     "trust_remote_code": True},
    "qwen-3B":          {"hf_repo": "Qwen/Qwen2.5-3B-Instruct",       "trust_remote_code": True},
    "qwen-7B":          {"hf_repo": "Qwen/Qwen2.5-7B-Instruct",       "trust_remote_code": True},
    "llama3.1-8b":      {"hf_repo": "meta-llama/Meta-Llama-3.1-8B-Instruct", "trust_remote_code": False},
    "qwen2.5-7b":       {"hf_repo": "Qwen/Qwen2.5-Math-7B-Instruct",  "trust_remote_code": True},
    "mathstral-7b":     {"hf_repo": "mistralai/Mathstral-7B-v0.1",    "trust_remote_code": False},
    "deepseek-7b":      {"hf_repo": "deepseek-ai/deepseek-math-7b-instruct", "trust_remote_code": True},
}

def _maybe_token_id(tokenizer, token_str: str):
    try:
        tok_id = tokenizer.convert_tokens_to_ids(token_str)
    except Exception:
        return None
    if tok_id is None:
        return None
    try:
        if tokenizer.convert_ids_to_tokens(tok_id) == token_str:
            return tok_id
    except Exception:
        return None   # <-- FIXED (was Nonexx)
    return None

def get_eos_token_ids(tokenizer):
    eos_ids = []
    if tokenizer.eos_token_id is not None:
        eos_ids.append(tokenizer.eos_token_id)

    # Common end-of-turn / end tokens across chat models
    for tok in ["<|eot_id|>", "<|end_of_text|>", "<|im_end|>", "</s>"]:
        tid = _maybe_token_id(tokenizer, tok)
        if tid is not None and tid not in eos_ids:
            eos_ids.append(tid)

    # Fallback if somehow empty
    if len(eos_ids) == 0 and tokenizer.pad_token_id is not None:
        eos_ids = [tokenizer.pad_token_id]
    return eos_ids

# 3) Answer extraction helpers
def extract_number_after_hashes(text: str):
    # grabs the LAST occurrence after ####
    m = re.findall(r"(?:####)\s*([^\n\r]+)", text)
    return m[-1].strip() if m else None

# Matches either:
#  - comma-separated thousands: 15,362 or 1,234,567.89
#  - or normal numbers: 15362 or 15362.89
NUM_RE = re.compile(r"[-+]?\d{1,3}(?:,\d{3})+(?:\.\d+)?|[-+]?\d+(?:\.\d+)?")

def extract_last_number(text: str):
    matches = NUM_RE.findall(text)
    return matches[-1] if matches else None

def to_float(x):
    if x is None:
        return None
    x = str(x).strip().replace(",", "")
    m = re.search(r"[-+]?\d+(?:\.\d+)?", x)  # allow optional +
    if not m:
        return None
    try:
        return float(m.group(0))
    except:
        return None

def is_correct(pred_val, gt_val, tol=1e-6):
    if pred_val is None:
        return False
    try:
        return abs(float(pred_val) - float(gt_val)) <= tol
    except Exception:
        return False

# 4) Prompt builder: tries tokenizer chat template; falls back to plain prompt.
#    Also handles Gemma-it models that don't support "system" role by retrying without it.
def build_prompt_for_model(q: str, tokenizer, model_alias: str):
    instruction = "Solve the problem. Give the final answer as: #### <number>"

    # Your special handling for Qwen-family (keeps your “show reasoning” style)
    if model_alias.lower().startswith("qwen"):
        user_text = (
            "Solve the problem step by step. Show your reasoning.\n"
            "At the end, give the final answer on its own line in the format:\n"
            "#### <number>\n\n"
            f"{q}"
        )
        messages = [{"role": "user", "content": user_text}]
    else:
        messages = [
            {"role": "system", "content": instruction},
            {"role": "user", "content": q},
        ]

    # If chat template exists, prefer it
    if getattr(tokenizer, "chat_template", None):
        try:
            prompt_text = tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
            return prompt_text, False  # add_special_tokens=False recommended for chat templates
        except Exception:
            # Retry WITHOUT system role (Gemma-it often throws "System role not supported")
            try:
                fallback_messages = [{"role": "user", "content": instruction + "\n\n" + q}]
                prompt_text = tokenizer.apply_chat_template(
                    fallback_messages, tokenize=False, add_generation_prompt=True
                )
                return prompt_text, False
            except Exception:
                pass

    # Plain fallback (non-chat models like GPT-2 / GPT-J / base Gemma, etc.)
    plain = f"{instruction}\n\nProblem:\n{q}\n\nAnswer:\n#### "
    return plain, True  # add_special_tokens=True is fine for plain prompts

# 5) Run a single model
def evaluate_one_model(model_id: int, model_alias: str, batch_size: int = 1):
    cfg = MODEL_REGISTRY.get(model_alias, None)

    rows = []
    summary = {
        "model_id": model_id,
        "model_alias": model_alias,
        "hf_repo": None,
        "status": "ok",
        "num_questions": len(questions),
        "num_correct": 0,
        "accuracy": 0.0,
        "seconds": None,
        "error": "",
    }

    if cfg is None:
        summary["status"] = "missing_registry_entry"
        summary["error"] = f"No HF repo mapping for model_alias='{model_alias}'"
        return rows, summary

    hf_repo = cfg["hf_repo"]
    revision = cfg.get("revision", None)
    trust_remote_code = cfg.get("trust_remote_code", False)
    summary["hf_repo"] = hf_repo

    start_t = time.time()

    try:
        model = AutoModelForCausalLM.from_pretrained(
            hf_repo,
            revision=revision,
            cache_dir=model_path,
            device_map="auto",
            torch_dtype="auto",
            trust_remote_code=trust_remote_code,
        )
        tokenizer = AutoTokenizer.from_pretrained(
            hf_repo,
            revision=revision,
            cache_dir=model_path,
            padding_side="left",
            trust_remote_code=trust_remote_code,
        )

        # pad token handling
        if tokenizer.pad_token_id is None:
            tokenizer.pad_token = tokenizer.eos_token
        model.generation_config.pad_token_id = tokenizer.pad_token_id

        model.eval()

        eos_ids = get_eos_token_ids(tokenizer)
        eos_for_generate = eos_ids if len(eos_ids) > 1 else eos_ids[0]

        max_new_tokens = 1024 if model_alias.lower().startswith("qwen") else 512

        # Build all prompts (for this model) *after* tokenizer is loaded
        prompts = []
        add_special_tokens_flag = None
        for q in questions:
            ptxt, add_spec = build_prompt_for_model(q, tokenizer, model_alias)
            prompts.append(ptxt)
            if add_special_tokens_flag is None:
                add_special_tokens_flag = add_spec

        # Generate in small batches to reduce OOM risk
        num_correct = 0
        for s in range(0, len(prompts), batch_size):
            batch_prompts = prompts[s:s+batch_size]
            batch_questions = questions[s:s+batch_size]
            batch_gts = ground_truths[s:s+batch_size]

            enc = tokenizer(
                batch_prompts,
                return_tensors="pt",
                padding=True,
                add_special_tokens=add_special_tokens_flag,
            ).to(model.device)

            with torch.no_grad():
                out = model.generate(
                    **enc,
                    max_new_tokens=max_new_tokens,
                    do_sample=False,
                    eos_token_id=eos_for_generate,
                )

            # Decode only the generated continuation (avoids picking numbers from the prompt)
            gen_tokens = out[:, enc["input_ids"].shape[1]:]
            responses = tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)

            for i, (q, gt, resp) in enumerate(zip(batch_questions, batch_gts, responses)):
                # Prefer #### ... if present, else last number
                pred_str = extract_number_after_hashes(resp)
                if pred_str is None:
                    pred_str = extract_last_number(resp)

                pred_val = to_float(pred_str)
                correct = is_correct(pred_val, gt)

                if correct:
                    num_correct += 1

                rows.append({
                    "model_id": model_id,
                    "model_alias": model_alias,
                    "hf_repo": hf_repo,
                    "question_idx": s + i,
                    "question": q,
                    "ground_truth": gt,
                    "response_text": resp,
                    "predicted_str": pred_str,
                    "predicted_val": pred_val,
                    "is_correct": correct,
                    "error": "",
                })

        summary["num_correct"] = num_correct
        summary["accuracy"] = num_correct / len(questions)

    except RuntimeError as e:
        # Catch CUDA OOM etc.
        summary["status"] = "runtime_error"
        summary["error"] = repr(e)
        # still write placeholder rows so CSV has consistent shape
        for qi, (q, gt) in enumerate(zip(questions, ground_truths)):
            rows.append({
                "model_id": model_id,
                "model_alias": model_alias,
                "hf_repo": hf_repo,
                "question_idx": qi,
                "question": q,
                "ground_truth": gt,
                "response_text": "",
                "predicted_str": None,
                "predicted_val": None,
                "is_correct": False,
                "error": f"runtime_error: {repr(e)}",
            })

    except Exception as e:
        summary["status"] = "error"
        summary["error"] = repr(e)
        for qi, (q, gt) in enumerate(zip(questions, ground_truths)):
            rows.append({
                "model_id": model_id,
                "model_alias": model_alias,
                "hf_repo": hf_repo,
                "question_idx": qi,
                "question": q,
                "ground_truth": gt,
                "response_text": "",
                "predicted_str": None,
                "predicted_val": None,
                "is_correct": False,
                "error": f"error: {repr(e)}",
            })

    finally:
        # Cleanup to avoid GPU memory accumulation
        try:
            del model
            del tokenizer
        except Exception:
            pass
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()

        summary["seconds"] = time.time() - start_t

    return rows, summary



In [None]:
import os
import re
import gc
import time
import pandas as pd
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# Make sure your cache dir exists
os.makedirs(model_path, exist_ok=True)
os.makedirs(wd, exist_ok=True)

def safe_filename(name: str) -> str:
    # safe for filenames across platforms
    return re.sub(r"[^A-Za-z0-9_.-]+", "_", name).strip("_")

def load_model_and_tokenizer(model_alias: str):
    cfg = MODEL_REGISTRY.get(model_alias, None)
    if cfg is None:
        raise KeyError(f"No HF repo mapping for model_alias='{model_alias}'")

    hf_repo = cfg["hf_repo"]
    revision = cfg.get("revision", None)
    trust_remote_code = cfg.get("trust_remote_code", False)

    model = AutoModelForCausalLM.from_pretrained(
        hf_repo,
        revision=revision,
        cache_dir=model_path,
        device_map="auto",
        torch_dtype="auto",
        trust_remote_code=trust_remote_code,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        hf_repo,
        revision=revision,
        cache_dir=model_path,
        padding_side="left",
        trust_remote_code=trust_remote_code,
    )

    # pad token handling
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token = tokenizer.eos_token
    model.generation_config.pad_token_id = tokenizer.pad_token_id
    model.eval()

    return model, tokenizer, hf_repo

def evaluate_loaded_model_on_k(
    model,
    tokenizer,
    model_id: int,
    model_alias: str,
    hf_repo: str,
    k: int,
    dataset_k: list,
    batch_size: int = 1,
):
    rows = []
    summary = {
        "model_id": model_id,
        "model_alias": model_alias,
        "hf_repo": hf_repo,
        "k": k,
        "status": "ok",
        "num_questions": len(dataset_k),
        "num_correct": 0,
        "accuracy": 0.0,
        "seconds": None,
        "error": "",
        "csv_path": "",
    }

    start_t = time.time()

    try:
        eos_ids = get_eos_token_ids(tokenizer)
        eos_for_generate = eos_ids if len(eos_ids) > 1 else eos_ids[0]
        max_new_tokens = 2048 if model_alias.lower().startswith("qwen") else 1024

        # Build prompts for this k
        prompts = []
        add_special_tokens_flag = None
        for item in dataset_k:
            ptxt, add_spec = build_prompt_for_model(item["question"], tokenizer, model_alias)
            prompts.append(ptxt)
            if add_special_tokens_flag is None:
                add_special_tokens_flag = add_spec
            elif add_special_tokens_flag != add_spec:
                # Rare edge-case: mix of chat-template vs plain fallback.
                # Use the "safer" option if any prompt needs specials.
                add_special_tokens_flag = add_special_tokens_flag or add_spec

        num_correct = 0

        for s in range(0, len(prompts), batch_size):
            batch_prompts = prompts[s:s + batch_size]
            batch_items = dataset_k[s:s + batch_size]

            enc = tokenizer(
                batch_prompts,
                return_tensors="pt",
                padding=True,
                add_special_tokens=add_special_tokens_flag,
            ).to(model.device)

            with torch.no_grad():
                out = model.generate(
                    **enc,
                    max_new_tokens=max_new_tokens,
                    do_sample=False,
                    eos_token_id=eos_for_generate,
                )

            # Decode only the generated continuation
            gen_tokens = out[:, enc["input_ids"].shape[1]:]
            responses = tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)

            for i, (item, resp) in enumerate(zip(batch_items, responses)):
                q = item["question"]
                gt = item["ground_truth"]
                qnum = item["qnum"]
                idx_in_k = s + i  # 0..49

                pred_str = extract_number_after_hashes(resp)
                if pred_str is None:
                    pred_str = extract_last_number(resp)

                pred_val = to_float(pred_str)
                correct = is_correct(pred_val, gt)

                if correct:
                    num_correct += 1

                rows.append({
                    "model_id": model_id,
                    "model_alias": model_alias,
                    "hf_repo": hf_repo,
                    "k": k,
                    "qnum": qnum,                 # 1..50
                    "question_idx_in_k": idx_in_k, # 0..49
                    "question": q,
                    "ground_truth": gt,
                    "response_text": resp,         # <-- RAW MODEL OUTPUT (continuation)
                    "predicted_str": pred_str,
                    "predicted_val": pred_val,
                    "is_correct": correct,
                    "error": "",
                })

        summary["num_correct"] = num_correct
        summary["accuracy"] = num_correct / max(1, len(dataset_k))

    except RuntimeError as e:
        summary["status"] = "runtime_error"
        summary["error"] = repr(e)

        for idx_in_k, item in enumerate(dataset_k):
            rows.append({
                "model_id": model_id,
                "model_alias": model_alias,
                "hf_repo": hf_repo,
                "k": k,
                "qnum": item["qnum"],
                "question_idx_in_k": idx_in_k,
                "question": item["question"],
                "ground_truth": item["ground_truth"],
                "response_text": "",
                "predicted_str": None,
                "predicted_val": None,
                "is_correct": False,
                "error": f"runtime_error: {repr(e)}",
            })

    except Exception as e:
        summary["status"] = "error"
        summary["error"] = repr(e)

        for idx_in_k, item in enumerate(dataset_k):
            rows.append({
                "model_id": model_id,
                "model_alias": model_alias,
                "hf_repo": hf_repo,
                "k": k,
                "qnum": item["qnum"],
                "question_idx_in_k": idx_in_k,
                "question": item["question"],
                "ground_truth": item["ground_truth"],
                "response_text": "",
                "predicted_str": None,
                "predicted_val": None,
                "is_correct": False,
                "error": f"error: {repr(e)}",
            })

    finally:
        summary["seconds"] = time.time() - start_t

        # Clear per-k generation cache (model stays loaded, but clears transient CUDA stuff)
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()

    return rows, summary


# =======================================
# MULTI-MODEL EVAL: per k, per model CSVs
# =======================================

all_summary_rows = []

for model_id in sorted(model_dict.keys()):
    model_alias = model_dict[model_id]
    print(f"\n===== Loading model {model_id}: {model_alias} =====")

    cfg = MODEL_REGISTRY.get(model_alias, None)
    if cfg is None:
        # Missing registry entry: write placeholder summaries + files
        err = f"No HF repo mapping for model_alias='{model_alias}'"
        for k in K_VALUES:
            out_dir = os.path.join(wd, f"k_{k}")
            os.makedirs(out_dir, exist_ok=True)

            out_path = os.path.join(out_dir, f"{model_id:02d}_{safe_filename(model_alias)}.csv")
            placeholder_rows = []
            for idx_in_k, item in enumerate(dataset_by_k[k]):
                placeholder_rows.append({
                    "model_id": model_id,
                    "model_alias": model_alias,
                    "hf_repo": "",
                    "k": k,
                    "qnum": item["qnum"],
                    "question_idx_in_k": idx_in_k,
                    "question": item["question"],
                    "ground_truth": item["ground_truth"],
                    "response_text": "",
                    "predicted_str": None,
                    "predicted_val": None,
                    "is_correct": False,
                    "error": err,
                })
            pd.DataFrame(placeholder_rows).to_csv(out_path, index=False)

            all_summary_rows.append({
                "model_id": model_id,
                "model_alias": model_alias,
                "hf_repo": "",
                "k": k,
                "status": "missing_registry_entry",
                "num_questions": len(dataset_by_k[k]),
                "num_correct": 0,
                "accuracy": 0.0,
                "seconds": 0.0,
                "error": err,
                "csv_path": out_path,
            })
        continue

    model = tokenizer = None
    hf_repo = None
    try:
        model, tokenizer, hf_repo = load_model_and_tokenizer(model_alias)

        for k in K_VALUES:
            print(f"  -> Evaluating k={k} on 50 questions")
            out_dir = os.path.join(wd, f"k_{k}")
            os.makedirs(out_dir, exist_ok=True)

            rows, summ = evaluate_loaded_model_on_k(
                model=model,
                tokenizer=tokenizer,
                model_id=model_id,
                model_alias=model_alias,
                hf_repo=hf_repo,
                k=k,
                dataset_k=dataset_by_k[k],
                batch_size=1,
            )

            out_path = os.path.join(out_dir, f"{model_id:02d}_{safe_filename(model_alias)}.csv")
            pd.DataFrame(rows).to_csv(out_path, index=False)

            summ["csv_path"] = out_path
            all_summary_rows.append(summ)

            print(f"     status={summ['status']}, acc={summ['accuracy']:.3f}, correct={summ['num_correct']}/{summ['num_questions']}, seconds={summ['seconds']:.1f}")

    except Exception as e:
        # If model load fails: create placeholder files for all k
        err = f"load_error: {repr(e)}"
        for k in K_VALUES:
            out_dir = os.path.join(wd, f"k_{k}")
            os.makedirs(out_dir, exist_ok=True)

            out_path = os.path.join(out_dir, f"{model_id:02d}_{safe_filename(model_alias)}.csv")
            placeholder_rows = []
            for idx_in_k, item in enumerate(dataset_by_k[k]):
                placeholder_rows.append({
                    "model_id": model_id,
                    "model_alias": model_alias,
                    "hf_repo": hf_repo or "",
                    "k": k,
                    "qnum": item["qnum"],
                    "question_idx_in_k": idx_in_k,
                    "question": item["question"],
                    "ground_truth": item["ground_truth"],
                    "response_text": "",
                    "predicted_str": None,
                    "predicted_val": None,
                    "is_correct": False,
                    "error": err,
                })
            pd.DataFrame(placeholder_rows).to_csv(out_path, index=False)

            all_summary_rows.append({
                "model_id": model_id,
                "model_alias": model_alias,
                "hf_repo": hf_repo or "",
                "k": k,
                "status": "load_error",
                "num_questions": len(dataset_by_k[k]),
                "num_correct": 0,
                "accuracy": 0.0,
                "seconds": 0.0,
                "error": err,
                "csv_path": out_path,
            })

    finally:
        # Cleanup model after finishing all k
        try:
            del model
            del tokenizer
        except Exception:
            pass
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()


# ============================
# SAVE GLOBAL SUMMARY
# ============================

summary_df = pd.DataFrame(all_summary_rows)
summary_csv_path = os.path.join(wd, "all_model_summary_by_k.csv")
summary_df.to_csv(summary_csv_path, index=False)

# Optional: a pivoted "accuracy table" (models x k)
acc_pivot = summary_df.pivot_table(
    index=["model_id", "model_alias"],
    columns="k",
    values="accuracy",
    aggfunc="first",
)
acc_pivot_csv_path = os.path.join(wd, "accuracy_pivot_by_k.csv")
acc_pivot.to_csv(acc_pivot_csv_path)

print("\nSaved:")
print("Summary (long):", summary_csv_path)
print("Accuracy pivot :", acc_pivot_csv_path)

# Optional: also save per-k model summary inside each k directory
for k in K_VALUES:
    k_dir = os.path.join(wd, f"k_{k}")
    k_summary_path = os.path.join(k_dir, "summary.csv")
    summary_df[summary_df["k"] == k].to_csv(k_summary_path, index=False)



===== Loading model 1: mistral-7b =====


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

  -> Evaluating k=1 on 50 questions
     status=ok, acc=0.440, correct=22/50, seconds=542.9
  -> Evaluating k=2 on 50 questions
     status=ok, acc=0.300, correct=15/50, seconds=606.7
  -> Evaluating k=3 on 50 questions
     status=ok, acc=0.100, correct=5/50, seconds=734.8
  -> Evaluating k=4 on 50 questions
     status=ok, acc=0.040, correct=2/50, seconds=725.6
  -> Evaluating k=5 on 50 questions
     status=ok, acc=0.020, correct=1/50, seconds=804.8
  -> Evaluating k=6 on 50 questions
     status=ok, acc=0.060, correct=3/50, seconds=846.9

===== Loading model 2: llama2-7b =====


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


  -> Evaluating k=1 on 50 questions
     status=ok, acc=0.140, correct=7/50, seconds=491.6
  -> Evaluating k=2 on 50 questions
     status=ok, acc=0.080, correct=4/50, seconds=534.9
  -> Evaluating k=3 on 50 questions
     status=ok, acc=0.020, correct=1/50, seconds=619.8
  -> Evaluating k=4 on 50 questions
     status=ok, acc=0.000, correct=0/50, seconds=695.5
  -> Evaluating k=5 on 50 questions
     status=ok, acc=0.000, correct=0/50, seconds=764.5
  -> Evaluating k=6 on 50 questions
     status=ok, acc=0.000, correct=0/50, seconds=778.1

===== Loading model 3: llama3-8b =====


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

  -> Evaluating k=1 on 50 questions
     status=ok, acc=0.120, correct=6/50, seconds=948.5
  -> Evaluating k=2 on 50 questions
     status=ok, acc=0.120, correct=6/50, seconds=1049.5
  -> Evaluating k=3 on 50 questions
     status=ok, acc=0.100, correct=5/50, seconds=1259.3
  -> Evaluating k=4 on 50 questions
     status=ok, acc=0.120, correct=6/50, seconds=1231.5
  -> Evaluating k=5 on 50 questions
     status=ok, acc=0.060, correct=3/50, seconds=1142.8
  -> Evaluating k=6 on 50 questions
     status=ok, acc=0.000, correct=0/50, seconds=1124.3

===== Loading model 4: llama3-8B-Instruct =====


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

  -> Evaluating k=1 on 50 questions
     status=ok, acc=0.760, correct=38/50, seconds=250.9
  -> Evaluating k=2 on 50 questions
     status=ok, acc=0.560, correct=28/50, seconds=315.5
  -> Evaluating k=3 on 50 questions
     status=ok, acc=0.520, correct=26/50, seconds=359.7
  -> Evaluating k=4 on 50 questions
     status=ok, acc=0.400, correct=20/50, seconds=399.7
  -> Evaluating k=5 on 50 questions
     status=ok, acc=0.320, correct=16/50, seconds=448.1
  -> Evaluating k=6 on 50 questions
     status=ok, acc=0.180, correct=9/50, seconds=504.3

===== Loading model 5: gemma2-9b-it =====


Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/3.67G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

  -> Evaluating k=1 on 50 questions


In [9]:
# # ============================
# # MULTI-MODEL EVAL + CSV EXPORT
# # ============================

# import gc
# import time
# import pandas as pd
# import re
# from transformers import AutoModelForCausalLM, AutoTokenizer

# # Make sure your cache dir exists
# os.makedirs(model_path, exist_ok=True)
# os.makedirs(wd, exist_ok=True)

# # 1) Map your "friendly names" to HF repo IDs (+ optional settings)
# MODEL_REGISTRY = {
#     "pythia-7b":        {"hf_repo": "EleutherAI/pythia-6.9b-deduped", "revision": "step3000", "trust_remote_code": False},
#     "gemma2-2b":        {"hf_repo": "google/gemma-2-2b",              "trust_remote_code": False},
#     "mistral-7b":       {"hf_repo": "mistralai/Mistral-7B-Instruct-v0.3", "trust_remote_code": False},
#     "gptJ-6b":          {"hf_repo": "EleutherAI/gpt-j-6B",            "trust_remote_code": False},
#     "gpt2-medium":      {"hf_repo": "gpt2-medium",                   "trust_remote_code": False},
#     "gemma2-9b":        {"hf_repo": "google/gemma-2-9b",              "trust_remote_code": False},
#     "llama2-7b":        {"hf_repo": "meta-llama/Llama-2-7b-chat-hf",  "trust_remote_code": False},
#     "llama3-8b":        {"hf_repo": "meta-llama/Meta-Llama-3-8B",     "trust_remote_code": False},
#     "llama3-8B-Instruct":{"hf_repo": "meta-llama/Meta-Llama-3-8B-Instruct", "trust_remote_code": False},
#     "gemma2-9b-it":     {"hf_repo": "google/gemma-2-9b-it",           "trust_remote_code": False},
#     "qwen-1.5B":        {"hf_repo": "Qwen/Qwen2.5-1.5B-Instruct",     "trust_remote_code": True},
#     "qwen-3B":          {"hf_repo": "Qwen/Qwen2.5-3B-Instruct",       "trust_remote_code": True},
#     "qwen-7B":          {"hf_repo": "Qwen/Qwen2.5-7B-Instruct",       "trust_remote_code": True},
#     "llama3.1-8b":      {"hf_repo": "meta-llama/Meta-Llama-3.1-8B-Instruct", "trust_remote_code": False},
#     "qwen2.5-7b":       {"hf_repo": "Qwen/Qwen2.5-Math-7B-Instruct",  "trust_remote_code": True},
#     "mathstral-7b":     {"hf_repo": "mistralai/Mathstral-7B-v0.1",    "trust_remote_code": False},
#     "deepseek-7b":      {"hf_repo": "deepseek-ai/deepseek-math-7b-instruct", "trust_remote_code": True},
# }

# # 2) Robust EOS collection (only include tokens that really exist for that tokenizer)
# def _maybe_token_id(tokenizer, token_str: str):
#     try:
#         tok_id = tokenizer.convert_tokens_to_ids(token_str)
#     except Exception:
#         return None
#     if tok_id is None:
#         return None
#     # Verify round-trip: id -> token should match exactly (avoids "UNK id" false positives)
#     try:
#         if tokenizer.convert_ids_to_tokens(tok_id) == token_str:
#             return tok_id
#     except Exception:
#         return Nonexx
#     return None

# def get_eos_token_ids(tokenizer):
#     eos_ids = []
#     if tokenizer.eos_token_id is not None:
#         eos_ids.append(tokenizer.eos_token_id)

#     # Common end-of-turn / end tokens across chat models
#     for tok in ["<|eot_id|>", "<|end_of_text|>", "<|im_end|>", "</s>"]:
#         tid = _maybe_token_id(tokenizer, tok)
#         if tid is not None and tid not in eos_ids:
#             eos_ids.append(tid)

#     # Fallback if somehow empty
#     if len(eos_ids) == 0 and tokenizer.pad_token_id is not None:
#         eos_ids = [tokenizer.pad_token_id]
#     return eos_ids

# # 3) Answer extraction helpers
# def extract_number_after_hashes(text: str):
#     # grabs the LAST occurrence after ####
#     m = re.findall(r"(?:####)\s*([^\n\r]+)", text)
#     return m[-1].strip() if m else None

# def extract_last_number(text: str):
#     nums = re.findall(r"-?\d+(?:\.\d+)?", text)
#     return nums[-1] if nums else None

# def to_float(x):
#     if x is None:
#         return None
#     x = str(x).strip().replace(",", "")
#     m = re.search(r"-?\d+(?:\.\d+)?", x)
#     if not m:
#         return None
#     try:
#         return float(m.group(0))
#     except Exception:
#         return None

# def is_correct(pred_val, gt_val, tol=1e-6):
#     if pred_val is None:
#         return False
#     try:
#         return abs(float(pred_val) - float(gt_val)) <= tol
#     except Exception:
#         return False

# # 4) Prompt builder: tries tokenizer chat template; falls back to plain prompt.
# #    Also handles Gemma-it models that don't support "system" role by retrying without it.
# def build_prompt_for_model(q: str, tokenizer, model_alias: str):
#     instruction = "Solve the problem. Give the final answer as: #### <number>"

#     # Your special handling for Qwen-family (keeps your “show reasoning” style)
#     if model_alias.lower().startswith("qwen"):
#         user_text = (
#             "Solve the problem step by step. Show your reasoning.\n"
#             "At the end, give the final answer on its own line in the format:\n"
#             "#### <number>\n\n"
#             f"{q}"
#         )
#         messages = [{"role": "user", "content": user_text}]
#     else:
#         messages = [
#             {"role": "system", "content": instruction},
#             {"role": "user", "content": q},
#         ]

#     # If chat template exists, prefer it
#     if getattr(tokenizer, "chat_template", None):
#         try:
#             prompt_text = tokenizer.apply_chat_template(
#                 messages, tokenize=False, add_generation_prompt=True
#             )
#             return prompt_text, False  # add_special_tokens=False recommended for chat templates
#         except Exception:
#             # Retry WITHOUT system role (Gemma-it often throws "System role not supported")
#             try:
#                 fallback_messages = [{"role": "user", "content": instruction + "\n\n" + q}]
#                 prompt_text = tokenizer.apply_chat_template(
#                     fallback_messages, tokenize=False, add_generation_prompt=True
#                 )
#                 return prompt_text, False
#             except Exception:
#                 pass

#     # Plain fallback (non-chat models like GPT-2 / GPT-J / base Gemma, etc.)
#     plain = f"{instruction}\n\nProblem:\n{q}\n\nAnswer:\n#### "
#     return plain, True  # add_special_tokens=True is fine for plain prompts

# # 5) Run a single model
# def evaluate_one_model(model_id: int, model_alias: str, batch_size: int = 1):
#     cfg = MODEL_REGISTRY.get(model_alias, None)

#     rows = []
#     summary = {
#         "model_id": model_id,
#         "model_alias": model_alias,
#         "hf_repo": None,
#         "status": "ok",
#         "num_questions": len(questions),
#         "num_correct": 0,
#         "accuracy": 0.0,
#         "seconds": None,
#         "error": "",
#     }

#     if cfg is None:
#         summary["status"] = "missing_registry_entry"
#         summary["error"] = f"No HF repo mapping for model_alias='{model_alias}'"
#         return rows, summary

#     hf_repo = cfg["hf_repo"]
#     revision = cfg.get("revision", None)
#     trust_remote_code = cfg.get("trust_remote_code", False)
#     summary["hf_repo"] = hf_repo

#     start_t = time.time()

#     try:
#         model = AutoModelForCausalLM.from_pretrained(
#             hf_repo,
#             revision=revision,
#             cache_dir=model_path,
#             device_map="auto",
#             torch_dtype="auto",
#             trust_remote_code=trust_remote_code,
#         )
#         tokenizer = AutoTokenizer.from_pretrained(
#             hf_repo,
#             revision=revision,
#             cache_dir=model_path,
#             padding_side="left",
#             trust_remote_code=trust_remote_code,
#         )

#         # pad token handling
#         if tokenizer.pad_token_id is None:
#             tokenizer.pad_token = tokenizer.eos_token
#         model.generation_config.pad_token_id = tokenizer.pad_token_id

#         model.eval()

#         eos_ids = get_eos_token_ids(tokenizer)
#         eos_for_generate = eos_ids if len(eos_ids) > 1 else eos_ids[0]

#         max_new_tokens = 1024 if model_alias.lower().startswith("qwen") else 512

#         # Build all prompts (for this model) *after* tokenizer is loaded
#         prompts = []
#         add_special_tokens_flag = None
#         for q in questions:
#             ptxt, add_spec = build_prompt_for_model(q, tokenizer, model_alias)
#             prompts.append(ptxt)
#             if add_special_tokens_flag is None:
#                 add_special_tokens_flag = add_spec

#         # Generate in small batches to reduce OOM risk
#         num_correct = 0
#         for s in range(0, len(prompts), batch_size):
#             batch_prompts = prompts[s:s+batch_size]
#             batch_questions = questions[s:s+batch_size]
#             batch_gts = ground_truths[s:s+batch_size]

#             enc = tokenizer(
#                 batch_prompts,
#                 return_tensors="pt",
#                 padding=True,
#                 add_special_tokens=add_special_tokens_flag,
#             ).to(model.device)

#             with torch.no_grad():
#                 out = model.generate(
#                     **enc,
#                     max_new_tokens=max_new_tokens,
#                     do_sample=False,
#                     eos_token_id=eos_for_generate,
#                 )

#             # Decode only the generated continuation (avoids picking numbers from the prompt)
#             gen_tokens = out[:, enc["input_ids"].shape[1]:]
#             responses = tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)

#             for i, (q, gt, resp) in enumerate(zip(batch_questions, batch_gts, responses)):
#                 # Prefer #### ... if present, else last number
#                 pred_str = extract_number_after_hashes(resp)
#                 if pred_str is None:
#                     pred_str = extract_last_number(resp)

#                 pred_val = to_float(pred_str)
#                 correct = is_correct(pred_val, gt)

#                 if correct:
#                     num_correct += 1

#                 rows.append({
#                     "model_id": model_id,
#                     "model_alias": model_alias,
#                     "hf_repo": hf_repo,
#                     "question_idx": s + i,
#                     "question": q,
#                     "ground_truth": gt,
#                     "response_text": resp,
#                     "predicted_str": pred_str,
#                     "predicted_val": pred_val,
#                     "is_correct": correct,
#                     "error": "",
#                 })

#         summary["num_correct"] = num_correct
#         summary["accuracy"] = num_correct / len(questions)

#     except RuntimeError as e:
#         # Catch CUDA OOM etc.
#         summary["status"] = "runtime_error"
#         summary["error"] = repr(e)
#         # still write placeholder rows so CSV has consistent shape
#         for qi, (q, gt) in enumerate(zip(questions, ground_truths)):
#             rows.append({
#                 "model_id": model_id,
#                 "model_alias": model_alias,
#                 "hf_repo": hf_repo,
#                 "question_idx": qi,
#                 "question": q,
#                 "ground_truth": gt,
#                 "response_text": "",
#                 "predicted_str": None,
#                 "predicted_val": None,
#                 "is_correct": False,
#                 "error": f"runtime_error: {repr(e)}",
#             })

#     except Exception as e:
#         summary["status"] = "error"
#         summary["error"] = repr(e)
#         for qi, (q, gt) in enumerate(zip(questions, ground_truths)):
#             rows.append({
#                 "model_id": model_id,
#                 "model_alias": model_alias,
#                 "hf_repo": hf_repo,
#                 "question_idx": qi,
#                 "question": q,
#                 "ground_truth": gt,
#                 "response_text": "",
#                 "predicted_str": None,
#                 "predicted_val": None,
#                 "is_correct": False,
#                 "error": f"error: {repr(e)}",
#             })

#     finally:
#         # Cleanup to avoid GPU memory accumulation
#         try:
#             del model
#             del tokenizer
#         except Exception:
#             pass
#         if torch.cuda.is_available():
#             torch.cuda.empty_cache()
#         gc.collect()

#         summary["seconds"] = time.time() - start_t

#     return rows, summary

# # 6) Evaluate ALL models and save CSVs
# all_rows = []
# summary_rows = []

# for model_id in sorted(model_dict.keys()):
#     model_alias = model_dict[model_id]
#     print(f"\n===== Evaluating {model_id}: {model_alias} =====")
#     rows, summ = evaluate_one_model(model_id, model_alias, batch_size=1)
#     all_rows.extend(rows)
#     summary_rows.append(summ)
#     print(f" -> status={summ['status']}, acc={summ['accuracy']:.3f}, correct={summ['num_correct']}/{summ['num_questions']}, seconds={summ['seconds']:.1f}")

# df = pd.DataFrame(all_rows)
# summary_df = pd.DataFrame(summary_rows)

# detailed_csv_path = os.path.join(wd, "all_model_outputs.csv")
# summary_csv_path  = os.path.join(wd, "all_model_summary.csv")

# df.to_csv(detailed_csv_path, index=False)
# summary_df.to_csv(summary_csv_path, index=False)

# print("\nSaved:")
# print("Detailed:", detailed_csv_path)
# print("Summary :", summary_csv_path)

# # Optional: show summary in notebook
# summary_df.sort_values(["status", "accuracy"], ascending=[True, False])



===== Evaluating 1: pythia-7b =====


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

 -> status=ok, acc=0.000, correct=0/18, seconds=528.1

===== Evaluating 2: gemma2-2b =====
 -> status=error, acc=0.000, correct=0/18, seconds=1.1

===== Evaluating 3: mistral-7b =====


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

 -> status=ok, acc=0.333, correct=6/18, seconds=449.6

===== Evaluating 4: gptJ-6b =====


Some weights of the model checkpoint at EleutherAI/gpt-j-6B were not used when initializing GPTJForCausalLM: ['transformer.h.0.attn.bias', 'transformer.h.0.attn.masked_bias', 'transformer.h.1.attn.bias', 'transformer.h.1.attn.masked_bias', 'transformer.h.10.attn.bias', 'transformer.h.10.attn.masked_bias', 'transformer.h.11.attn.bias', 'transformer.h.11.attn.masked_bias', 'transformer.h.12.attn.bias', 'transformer.h.12.attn.masked_bias', 'transformer.h.13.attn.bias', 'transformer.h.13.attn.masked_bias', 'transformer.h.14.attn.bias', 'transformer.h.14.attn.masked_bias', 'transformer.h.15.attn.bias', 'transformer.h.15.attn.masked_bias', 'transformer.h.16.attn.bias', 'transformer.h.16.attn.masked_bias', 'transformer.h.17.attn.bias', 'transformer.h.17.attn.masked_bias', 'transformer.h.18.attn.bias', 'transformer.h.18.attn.masked_bias', 'transformer.h.19.attn.bias', 'transformer.h.19.attn.masked_bias', 'transformer.h.2.attn.bias', 'transformer.h.2.attn.masked_bias', 'transformer.h.20.attn.bi

 -> status=ok, acc=0.000, correct=0/18, seconds=663.4

===== Evaluating 5: gpt2-medium =====
 -> status=ok, acc=0.000, correct=0/18, seconds=207.8

===== Evaluating 6: gemma2-9b =====
 -> status=error, acc=0.000, correct=0/18, seconds=1.1

===== Evaluating 7: llama2-7b =====


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


 -> status=ok, acc=0.167, correct=3/18, seconds=365.2

===== Evaluating 8: llama3-8b =====


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

 -> status=ok, acc=0.222, correct=4/18, seconds=509.9

===== Evaluating 10: llama3-8B-Instruct =====


Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

 -> status=ok, acc=0.556, correct=10/18, seconds=340.3

===== Evaluating 11: gemma2-9b-it =====
 -> status=error, acc=0.000, correct=0/18, seconds=1.1

===== Evaluating 12: qwen-1.5B =====


The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


 -> status=ok, acc=0.500, correct=9/18, seconds=356.3

===== Evaluating 13: qwen-3B =====


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

 -> status=ok, acc=0.722, correct=13/18, seconds=497.2

===== Evaluating 14: qwen-7B =====


config.json:   0%|          | 0.00/663 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/3.95G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/3.56G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/243 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

 -> status=ok, acc=0.889, correct=16/18, seconds=337.7

===== Evaluating 15: llama3.1-8b =====


config.json:   0%|          | 0.00/855 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/184 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/55.4k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

 -> status=ok, acc=0.778, correct=14/18, seconds=213.9

===== Evaluating 16: qwen2.5-7b =====


config.json:   0%|          | 0.00/658 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/3.95G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/3.56G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/161 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

 -> status=ok, acc=0.944, correct=17/18, seconds=403.2

===== Evaluating 17: mathstral-7b =====


config.json:   0%|          | 0.00/624 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 6 files:   0%|          | 0/6 [00:00<?, ?it/s]

model-00001-of-00006.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00006-of-00006.safetensors:   0%|          | 0.00/4.26G [00:00<?, ?B/s]

model-00004-of-00006.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00005-of-00006.safetensors:   0%|          | 0.00/4.83G [00:00<?, ?B/s]

model-00003-of-00006.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00002-of-00006.safetensors:   0%|          | 0.00/4.90G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/588k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

 -> status=ok, acc=0.611, correct=11/18, seconds=390.3

===== Evaluating 18: deepseek-7b =====


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

 -> status=ok, acc=0.611, correct=11/18, seconds=347.7

Saved:
Detailed: drive/MyDrive/gsm_test/all_model_outputs.csv
Summary : drive/MyDrive/gsm_test/all_model_summary.csv


Unnamed: 0,model_id,model_alias,hf_repo,status,num_questions,num_correct,accuracy,seconds,error
1,2,gemma2-2b,google/gemma-2-2b,error,18,0,0.0,1.071481,OSError('You are trying to access a gated repo...
5,6,gemma2-9b,google/gemma-2-9b,error,18,0,0.0,1.135303,OSError('You are trying to access a gated repo...
9,11,gemma2-9b-it,google/gemma-2-9b-it,error,18,0,0.0,1.055158,OSError('You are trying to access a gated repo...
14,16,qwen2.5-7b,Qwen/Qwen2.5-Math-7B-Instruct,ok,18,17,0.944444,403.208078,
12,14,qwen-7B,Qwen/Qwen2.5-7B-Instruct,ok,18,16,0.888889,337.674092,
13,15,llama3.1-8b,meta-llama/Meta-Llama-3.1-8B-Instruct,ok,18,14,0.777778,213.857952,
11,13,qwen-3B,Qwen/Qwen2.5-3B-Instruct,ok,18,13,0.722222,497.189481,
15,17,mathstral-7b,mistralai/Mathstral-7B-v0.1,ok,18,11,0.611111,390.278368,
16,18,deepseek-7b,deepseek-ai/deepseek-math-7b-instruct,ok,18,11,0.611111,347.657365,
8,10,llama3-8B-Instruct,meta-llama/Meta-Llama-3-8B-Instruct,ok,18,10,0.555556,340.280829,
