# Diversify MATH (Notebook)

This notebook mirrors the original pipeline but introduces a mock mode so we can debug the data flow without calling the OpenAI API. Flip `USE_MOCK_API` to `False` when you are ready to run real completions.


In [1]:
import dataclasses
import json
import os
import random
import time
from typing import Any, Dict, Iterable, Optional, Tuple

from datasets import load_dataset

try:
    from openai import OpenAI
except Exception:  # pragma: no cover - dev environments might not have openai installed
    OpenAI = None  # type: ignore

DEFAULT_MAX_RETRIES = 5
DEFAULT_RETRY_BASE_DELAY_S = 2.0

USE_MOCK_API = False  # Set to False to hit real OpenAI endpoints
DATASET_SUBSET = "algebra"  # Any valid hendrycks_math subset or None
DATASET_SPLIT = "test"
DATASET_LIMIT = 1  # Only process the very first sample for quick debugging
OUTPUT_PATH = "demo_diversified_math.jsonl"
SEED = 42


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def _ensure_openai_client() -> Any:
    if OpenAI is None:
        raise RuntimeError(
            "openai package not found. Install `openai>=1.0.0` and set OPENAI_API_KEY."
        )
    return OpenAI()


def _exponential_backoff(attempt: int, base_delay_s: float = DEFAULT_RETRY_BASE_DELAY_S) -> float:
    expo = base_delay_s * (2 ** max(0, attempt - 1))
    jitter = random.uniform(0, base_delay_s)
    return min(60.0, expo + jitter)


def _safe_json_loads(text: str) -> Dict[str, Any]:
    text = text.strip()
    try:
        return json.loads(text)
    except Exception:
        pass
    start = text.find("{")
    end = text.rfind("}")
    if start != -1 and end != -1 and end > start:
        snippet = text[start : end + 1]
        try:
            return json.loads(snippet)
        except Exception:
            pass
    raise ValueError("Failed to parse JSON from model response.")


In [3]:
@dataclasses.dataclass
class Diversification:
    diversified_problem: str
    transformation_type: str = "rewrite"
    diversified_solution: Optional[str] = None
    change_description: Optional[str] = None
    changed_quantity_before: Optional[str] = None
    changed_quantity_after: Optional[str] = None
    notes: Optional[str] = None


@dataclasses.dataclass
class Verification:
    verdict: str
    reason: str
    consistency_checks: Dict[str, Any]


In [4]:
class DiversifierAgent:
    def __init__(self, model: str, use_mock: bool = False):
        self.model = model
        self.use_mock = use_mock
        self.client = None if use_mock else _ensure_openai_client()

    def diversify(
        self,
        problem: str,
        solution: str,
        max_retries: int = DEFAULT_MAX_RETRIES,
    ) -> Tuple[str, Diversification]:
        if self.use_mock:
            return self._mock_diversification(problem, solution)

        system_msg = (
            "You are a math problem rewriter for dataset augmentation.\n\n"
            "Goal:\n"
            "Given an original problem and its official solution, create a NEW problem that:\n"
            "- Uses the SAME underlying reasoning steps and mathematical structure.\n"
            "- Has a DIFFERENT surface form (phrasing, variable names, ordering, numbers).\n"
            "- Is well-posed and solvable.\n\n"
            "Requirements:\n"
            "- Preserve:\n"
            "  - The solution method (e.g., system of equations, induction, casework, inequality manipulation, etc.).\n"
            "  - The qualitative structure of the argument (same key intermediate relationships).\n"
            "  - The answer type (integer vs fraction vs expression, etc.).\n"
            "- You MAY:\n"
            "  - Change numerical values.\n"
            "  - Rename variables and symbols.\n"
            "  - Reorder how information is presented.\n"
            "  - Slightly change the style (more algebraic/abstract vs more 'wordy').\n"
            "- You MUST NOT:\n"
            "  - Change the core topic or method (e.g., algebra -> combinatorics).\n"
            "  - Make the problem trivial or drastically more difficult.\n"
            "  - Introduce contradictions or underdetermined systems.\n"
            "  - Output any solution or hint.\n\n"
            "Output only the new problem, as a single self-contained statement, with no explanation."
        )
        user_msg = (
            "Original problem:\n---\n"
            f"{problem}\n---\n\n"
            "Official solution (for understanding the reasoning steps):\n---\n"
            f"{solution}\n---\n\n"
            "Task: Rewrite the problem as a NEW PROBLEM that satisfies the system instructions.\n"
            "Do NOT output any solution or explanation.\n\n"
            "NEW PROBLEM:"
        )

        last_error: Optional[str] = None
        for attempt in range(1, max_retries + 1):
            try:
                response = self.client.chat.completions.create(
                    model=self.model,
                    messages=[
                        {"role": "system", "content": system_msg},
                        {"role": "user", "content": user_msg},
                    ],
                    temperature=0.2,
                )
                content = response.choices[0].message.content or ""
                new_problem = self._extract_problem_statement(content)
                if not new_problem:
                    raise ValueError("Diversifier returned empty problem text.")
                diversification = Diversification(
                    diversified_problem=new_problem,
                    transformation_type="rewrite",
                    diversified_solution=None,
                )
                return content, diversification
            except Exception as exc:
                last_error = f"{type(exc).__name__}: {exc}"
                if attempt == max_retries:
                    raise
                time.sleep(_exponential_backoff(attempt))
        raise RuntimeError(f"Diversifier failed after retries: {last_error}")

    def _extract_problem_statement(self, content: str) -> str:
        text = content.strip()
        marker = "NEW PROBLEM:"
        if marker in text:
            text = text.split(marker, 1)[1].strip()
        if text.startswith("```"):
            parts = text.split("```")
            if len(parts) >= 2:
                text = parts[1].strip()
            else:
                text = text.strip("`")
        return text.strip()

    def _mock_diversification(self, problem: str, solution: str) -> Tuple[str, Diversification]:
        mock_problem = f"[Mock rewrite] {problem.strip()}"
        diversification = Diversification(
            diversified_problem=mock_problem,
            transformation_type="rewrite",
            diversified_solution=None,
            change_description=None,
            changed_quantity_before=None,
            changed_quantity_after=None,
            notes="Generated locally without hitting the API.",
        )
        return mock_problem, diversification


class VerifierAgent:
    def __init__(self, model: str, use_mock: bool = False):
        self.model = model
        self.use_mock = use_mock
        self.client = None if use_mock else _ensure_openai_client()

    def verify(
        self,
        original_problem: str,
        original_solution: str,
        diversification: Diversification,
        max_retries: int = DEFAULT_MAX_RETRIES,
    ) -> Tuple[str, Verification]:
        if self.use_mock:
            return self._mock_verification(diversification)

        system_msg = (
            "You are a math solution verifier.\n\n"
            "Your job:\n"
            "- Check whether a PROPOSED SOLUTION for a NEW PROBLEM is mathematically correct.\n"
            "- You are also given the ORIGINAL PROBLEM and its OFFICIAL SOLUTION as reference to understand the intended reasoning pattern. "
            "The original pair is context only; judge the NEW PROBLEM and NEW SOLUTION.\n\n"
            "Requirements:\n"
            "1. Focus of evaluation: verify the final answer and core reasoning, ensuring no fatal algebraic or logical errors.\n"
            "2. Use of original problem/solution: you MAY compare structures, but MUST base correctness on the new pair.\n"
            "3. Scope: do NOT rewrite the solution or generate a full new one; be strict about mathematical mistakes.\n"
            "4. Output format:\n"
            "   REASON: brief justification (1–5 sentences).\n"
            "   VERDICT: CORRECT or INCORRECT"
            
        )
        user_msg = (
            "ORIGINAL PROBLEM:\n---\n"
            f"{original_problem}\n---\n\n"
            "OFFICIAL SOLUTION:\n---\n"
            f"{original_solution}\n---\n\n"
            "NEW PROBLEM:\n---\n"
            f"{diversification.diversified_problem}\n---\n\n"
            "PROPOSED SOLUTION:\n---\n"
            f"{diversification.diversified_solution}\n---\n\n"
            "Decide if the proposed solution correctly solves the new problem."
        )

        last_error: Optional[str] = None
        for attempt in range(1, max_retries + 1):
            try:
                response = self.client.chat.completions.create(
                    model=self.model,
                    messages=[
                        {"role": "system", "content": system_msg},
                        {"role": "user", "content": user_msg},
                    ],
                    temperature=0.0 if self.model != "o3-mini" else None,
                )
                content = response.choices[0].message.content or ""
                verdict_word, reason = self._parse_verifier_response(content)
                if verdict_word not in {"correct", "incorrect"}:
                    raise ValueError("Verifier must output VERDICT: CORRECT/INCORRECT.")
                verdict = "pass" if verdict_word == "correct" else "fail"
                if not reason:
                    reason = "No reason provided by verifier."
                verification = Verification(
                    verdict=verdict,
                    reason=reason,
                    consistency_checks={},
                )
                return content, verification
            except Exception as exc:
                last_error = f"{type(exc).__name__}: {exc}"
                if attempt == max_retries:
                    raise
                time.sleep(_exponential_backoff(attempt))
        raise RuntimeError(f"Verifier failed after retries: {last_error}")

    def _parse_verifier_response(self, content: str) -> Tuple[str, str]:
        verdict_word: Optional[str] = None
        reason_text: Optional[str] = None
        for line in content.splitlines():
            line_stripped = line.strip()
            if line_stripped.upper().startswith("VERDICT:"):
                verdict_word = line_stripped.split(":", 1)[1].strip().lower()
            elif line_stripped.upper().startswith("REASON:"):
                reason_text = line_stripped.split(":", 1)[1].strip()
        return verdict_word or "", reason_text or ""

    def _mock_verification(self, diversification: Diversification) -> Tuple[str, Verification]:
        mock_response = "VERDICT: CORRECT\nREASON: Mock verifier assumes the proposed solution is sound."
        verification = Verification(
            verdict="pass",
            reason="Mock verifier assumes the proposed solution is sound.",
            consistency_checks={},
        )
        return mock_response, verification


In [5]:
class SolverAgent:
    def __init__(self, model: str, use_mock: bool = False):
        self.model = model
        self.use_mock = use_mock
        self.client = None if use_mock else _ensure_openai_client()

    def solve(
        self,
        original_problem: str,
        original_solution: str,
        diversified_problem: str,
        max_retries: int = DEFAULT_MAX_RETRIES,
    ) -> Tuple[str, str]:
        if self.use_mock:
            return self._mock_solve(reference_solution=original_solution)

        system_msg = (
            "You are a math solution writer.\n\n"
            "Your job:\n"
            "- Given an ORIGINAL PROBLEM and its OFFICIAL SOLUTION, and a NEW PROBLEM (a perturbed variant),\n"
            "- Use the original solution as a reference to understand the intended reasoning pattern,\n"
            "- Then produce a complete, correct solution to the NEW PROBLEM.\n\n"
            "Requirements:\n"
            "1. Primary target: The final answer must correctly solve the NEW PROBLEM; adapt all steps and calculations accordingly.\n"
            "2. Reasoning and rigor: Show clear, step-by-step reasoning with valid algebra and logic adapted to the new statements.\n"
            "3. Consistency: Ensure every step and the final answer match the NEW PROBLEM; if it is unsolvable, explain why.\n"
            "4. Output format: Write a coherent solution and end with 'Final answer: ...' without mentioning the original problem."
        )
        user_msg = (
            "Original problem:\n---\n"
            f"{original_problem}\n---\n\n"
            "Original official solution (for reference only):\n---\n"
            f"{original_solution}\n---\n\n"
            "New problem:\n---\n"
            f"{diversified_problem}\n---\n\n"
            "Task: Using the system instructions, solve the NEW PROBLEM. Provide a clear step-by-step solution and end with 'Final answer: ...'"
        )

        last_error: Optional[str] = None
        for attempt in range(1, max_retries + 1):
            try:
                response = self.client.chat.completions.create(
                    model=self.model,
                    messages=[
                        {"role": "system", "content": system_msg},
                        {"role": "user", "content": user_msg},
                    ],
                    temperature=1,
                )
                content = response.choices[0].message.content or ""
                idx = content.find("Final answer:")
                final_solution = content[idx:].strip()
                return content, final_solution
            except Exception as exc:
                last_error = f"{type(exc).__name__}: {exc}"
                if attempt == max_retries:
                    raise
                time.sleep(_exponential_backoff(attempt))
        raise RuntimeError(f"Solver failed after retries: {last_error}")

    def _mock_solve(self, reference_solution: str) -> Tuple[str, str]:
        mock_solution_body = reference_solution.strip() or "(mock) solution unavailable"
        mock_solution = f"{mock_solution_body}\n\nFinal answer: {mock_solution_body.splitlines()[-1]}"
        return mock_solution, mock_solution



In [6]:
def _load_hendrycks_math(subset: Optional[str], split: str) -> Any:
    if subset:
        return load_dataset("EleutherAI/hendrycks_math", subset, split=split)
    return load_dataset("EleutherAI/hendrycks_math", split=split)


def _iter_samples(dataset: Any, limit: Optional[int]) -> Iterable[Dict[str, Any]]:
    count = 0
    for row in dataset:
        yield row
        count += 1
        if limit is not None and count >= limit:
            break


def diversify_math_dataset(
    model_diversifier: str,
    model_verifier: str,
    subset: Optional[str],
    split: str,
    limit: Optional[int],
    seed: int,
    output_path: str,
    include_failed: bool = False,
    model_solver: Optional[str] = None,
    use_mock_diversifier: bool = False,
    use_mock_verifier: bool = False,
    use_mock_solver: bool = False,
) -> Tuple[int, int]:
    random.seed(seed)

    dataset = _load_hendrycks_math(subset=subset, split=split)
    diversifier = DiversifierAgent(model=model_diversifier, use_mock=use_mock_diversifier)
    verifier = VerifierAgent(model=model_verifier, use_mock=use_mock_verifier)
    solver = (
        SolverAgent(model=model_solver, use_mock=use_mock_solver)
        if model_solver
        else None
    )

    attempted = 0
    accepted = 0
    os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
    with open(output_path, "w", encoding="utf-8") as f:
        for row in _iter_samples(dataset, limit=limit):
            attempted += 1
            original_problem: str = str(row.get("problem", "")).strip()
            original_solution: str = str(row.get("solution", "")).strip()
            level = row.get("level")
            problem_type = row.get("type")

            if not original_problem or not original_solution:
                continue

            raw_diversifier_response: Optional[str] = None
            raw_verifier_response: Optional[str] = None
            raw_solver_response: Optional[str] = None
            solver_solution: Optional[str] = None

            try:
                raw_diversifier_response, diversification = diversifier.diversify(
                    problem=original_problem,
                    solution=original_solution,
                )
                if solver:
                    try:
                        raw_solver_response, solver_solution = solver.solve(
                            original_problem=original_problem,
                            original_solution=original_solution,
                            diversified_problem=diversification.diversified_problem,
                        )
                    except Exception as solver_exc:  # don't fail pipeline just because solver failed
                        raw_solver_response = f"Solver error: {type(solver_exc).__name__}: {solver_exc}"
                        solver_solution = None
                if solver_solution:
                    diversification = dataclasses.replace(
                        diversification, diversified_solution=solver_solution
                    )
                raw_verifier_response, verification = verifier.verify(
                    original_problem=original_problem,
                    original_solution=original_solution,
                    diversification=diversification,
                )
                passed = verification.verdict == "pass"
            except Exception as exc:
                diversification = None
                verification = Verification(
                    verdict="fail",
                    reason=f"Pipeline error: {type(exc).__name__}: {exc}",
                    consistency_checks={},
                )
                passed = False

            if passed or include_failed:
                record = {
                    "subset": subset,
                    "split": split,
                    "level": level,
                    "type": problem_type,
                    "original_problem": original_problem,
                    "original_solution": original_solution,
                    "diversified_problem": getattr(diversification, "diversified_problem", None)
                    if diversification
                    else None,
                    "diversified_solution": getattr(diversification, "diversified_solution", None)
                    if diversification
                    else None,
                    "transformation_type": getattr(diversification, "transformation_type", None)
                    if diversification
                    else None,
                    "change_description": getattr(diversification, "change_description", None)
                    if diversification
                    else None,
                    "changed_quantity_before": getattr(diversification, "changed_quantity_before", None)
                    if diversification
                    else None,
                    "changed_quantity_after": getattr(diversification, "changed_quantity_after", None)
                    if diversification
                    else None,
                    "notes": getattr(diversification, "notes", None) if diversification else None,
                    "diversifier_raw_response": raw_diversifier_response,
                    "verifier_raw_response": raw_verifier_response,
                    "solver_raw_response": raw_solver_response,
                    "solver_solution": solver_solution,
                    "verification": dataclasses.asdict(verification),
                }
                f.write(json.dumps(record, ensure_ascii=False) + "\n")
                if passed:
                    accepted += 1
    return attempted, accepted


In [7]:
# attempted, accepted = diversify_math_dataset(
#     model_diversifier="gpt-4o-mini",
#     model_verifier="gpt-4o-mini",
#     subset=DATASET_SUBSET,
#     split=DATASET_SPLIT,
#     limit=DATASET_LIMIT,
#     seed=SEED,
#     output_path=OUTPUT_PATH,
#     include_failed=True,
#     model_solver="gpt-4o-mini",
#     use_mock_diversifier=USE_MOCK_API,
#     use_mock_verifier=USE_MOCK_API,
#     use_mock_solver=USE_MOCK_API,
# )

# print(f"attempted\t{attempted}")
# print(f"accepted\t{accepted}")
# print(f"output\t{os.path.abspath(OUTPUT_PATH)}")

dataset = _load_hendrycks_math(subset=DATASET_SUBSET, split=DATASET_SPLIT)
diversifier = DiversifierAgent(model="gpt-4.1", use_mock=USE_MOCK_API)

solver = SolverAgent(model="o3", use_mock=USE_MOCK_API)
with open(OUTPUT_PATH, "w", encoding="utf-8") as f:
    for row in _iter_samples(dataset, limit=1):
        original_problem: str = str(row.get("problem", "")).strip()
        original_solution: str = str(row.get("solution", "")).strip()
        level = row.get("level")
        problem_type = row.get("type")

        if not original_problem or not original_solution:
            continue

        raw_response, diversification = diversifier.diversify(
            problem=original_problem,
            solution=original_solution,
        )
        raw_solver_response, solver_solution = solver.solve(
            original_problem=original_problem,
            original_solution=original_solution,
            diversified_problem=diversification.diversified_problem,
        )
        if solver_solution:
            diversification = dataclasses.replace(
                diversification, diversified_solution=solver_solution
            )
        
print("diversifier response: ", raw_response)

print("solver response: ", raw_solver_response)




diversifier response:  How many vertical asymptotes are there in the graph of the function \( y = \frac{5}{x^2 - 4x - 12} \)?
solver response:  To locate vertical asymptotes of a rational function, find the real zeros of the denominator that are not canceled by corresponding zeros in the numerator.

1. Factor the denominator.
   x² − 4x − 12 = (x − 6)(x + 2)

2. Check the numerator.
   The numerator is the constant 5, which is never zero, so no factor cancels with the denominator.

3. Identify the x-values where the denominator is zero.
   (x − 6)(x + 2) = 0  ⇒  x = 6 or x = −2

At these two x-values the function is undefined, and because there is no cancellation, the graph has vertical asymptotes there.

Therefore, the graph has 2 vertical asymptotes.

Final answer: 2


In [8]:
# left = raw_response.find("{")
# right = raw_response.rfind("}")

# print(raw_response[left:right+1])
# res = json.loads(raw_response[left:right+1])
# print(res)
print(diversification)


Diversification(diversified_problem='How many vertical asymptotes are there in the graph of the function \\( y = \\frac{5}{x^2 - 4x - 12} \\)?', transformation_type='rewrite', diversified_solution='Final answer: 2', change_description=None, changed_quantity_before=None, changed_quantity_after=None, notes=None)


In [9]:
verifier = VerifierAgent(model="o3-mini", use_mock=USE_MOCK_API)
raw_verifier_response, verification = verifier.verify(
    original_problem=original_problem,
    original_solution=original_solution,
    diversification=diversification,
)
print(raw_verifier_response)
print(verification)


REASON: The denominator factors as (x - 6)(x + 2), so the function has vertical asymptotes at x = 6 and x = -2. Thus, there are 2 vertical asymptotes, which matches the proposed answer.
VERDICT: CORRECT
Verification(verdict='pass', reason='The denominator factors as (x - 6)(x + 2), so the function has vertical asymptotes at x = 6 and x = -2. Thus, there are 2 vertical asymptotes, which matches the proposed answer.', consistency_checks={})


In [10]:
print(original_problem)
print(original_solution)
print(diversification)
print(verification)
print(raw_verifier_response)




How many vertical asymptotes does the graph of $y=\frac{2}{x^2+x-6}$ have?
The denominator of the rational function factors into $x^2+x-6=(x-2)(x+3)$. Since the numerator is always nonzero, there is a vertical asymptote whenever the denominator is $0$, which occurs for $x = 2$ and $x = -3$.  Therefore, the graph has $\boxed{2}$ vertical asymptotes.
Diversification(diversified_problem='How many vertical asymptotes are there in the graph of the function \\( y = \\frac{5}{x^2 - 4x - 12} \\)?', transformation_type='rewrite', diversified_solution='Final answer: 2', change_description=None, changed_quantity_before=None, changed_quantity_after=None, notes=None)
Verification(verdict='pass', reason='The denominator factors as (x - 6)(x + 2), so the function has vertical asymptotes at x = 6 and x = -2. Thus, there are 2 vertical asymptotes, which matches the proposed answer.', consistency_checks={})
REASON: The denominator factors as (x - 6)(x + 2), so the function has vertical asymptotes at x =