# PROMETHEUS v2 + Reflexion Prompting

## Key Improvements from Top 27/50 Notebook:
1. **5 Different System Prompts** - Diverse reasoning strategies
2. **Reflexion Prompting** - Follow-up verification ("Are you sure?")
3. **Log-Weighted Voting** - Penalize trivial small answers
4. **Early Stopping** - Stop when confidence is high
5. **Streaming + Early Exit** - Stop generation when \boxed{} found

**Model**: Qwen-72B-Math-INT4

In [None]:
# =============================================================================
# CELL 1: ENVIRONMENT SETUP
# =============================================================================
import os
import sys
import subprocess
import gc

# Purge TensorFlow conflicts
subprocess.run([sys.executable, "-m", "pip", "uninstall", "-y", 
                "tensorflow", "keras", "scikit-learn", "matplotlib"], capture_output=True)

os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
os.environ["TRANSFORMERS_NO_TF"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Install vLLM from wheels
WHEEL_DIR = "/kaggle/input/vllm-085-wheels"
AUX_DIR = "/kaggle/input/aimo3-offline-wheels"

if os.path.exists(WHEEL_DIR):
    result = subprocess.run([
        sys.executable, "-m", "pip", "install", "vllm",
        "--no-index", f"--find-links={WHEEL_DIR}", f"--find-links={AUX_DIR}"
    ], capture_output=True)
    print("vLLM installed" if result.returncode == 0 else f"vLLM failed: {result.stderr[-200:]}")

In [None]:
# =============================================================================
# CELL 2: IMPORTS & TIMING
# =============================================================================
import torch
import time
import math
import random
import re
import tempfile
import statistics
from typing import Optional, List, Tuple, Dict, Any
from collections import Counter, defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
import numpy as np
import polars as pl

# Seed
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# Timing (AIMO3 Rule: GPU <= 5 hours)
START_TIME = time.time()
TOTAL_BUDGET = (4 * 60 + 40) * 60  # 4h40m safety
CUTOFF_TIME = START_TIME + TOTAL_BUDGET

# Answer bounds
ANSWER_MIN, ANSWER_MAX = 0, 99999

print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'}")
print(f"Budget: {TOTAL_BUDGET//3600}h {(TOTAL_BUDGET%3600)//60}m")

In [None]:
# =============================================================================
# CELL 3: MODEL LOADING (Qwen-72B-Math-INT4)
# =============================================================================
from vllm import LLM, SamplingParams

MODEL_PATHS = [
    "/kaggle/input/qwen-72b-math-int4",
    "/kaggle/input/d/ryancardwell/qwen-72b-math-int4",
    "/kaggle/input/qwen-72b-math-nf4",
]

def find_model():
    import glob
    for p in MODEL_PATHS:
        if os.path.exists(p) and os.path.exists(os.path.join(p, "config.json")):
            return p
        configs = glob.glob(f"{p}/**/config.json", recursive=True)
        if configs:
            return os.path.dirname(configs[0])
    return None

MODEL_PATH = find_model()
print(f"Model: {MODEL_PATH}")

LLM_MODEL = None
if MODEL_PATH:
    LLM_MODEL = LLM(
        model=MODEL_PATH,
        tensor_parallel_size=1,
        gpu_memory_utilization=0.92,
        trust_remote_code=True,
        max_model_len=8192,
        enforce_eager=True,
        seed=SEED,
    )
    print("Model loaded!")

In [None]:
# =============================================================================
# CELL 4: REFLEXION SYSTEM PROMPTS (from top 27/50 notebook)
# =============================================================================

SYSTEM_PROMPTS = [
    # Rigorous multi-method
    """You are solving a national/international-level mathematics olympiad problem. 
You must rigorously define all variables, explore multiple solution strategies before committing, 
perform full case analysis where required, justify every nontrivial step, 
explicitly check boundary cases and hidden assumptions, 
and verify the final result using at least one independent method. 
Return only the final numerical answer inside \\boxed{}. 
The answer must be an integer in [0, 99999]. Never guess.""",

    # Self-refutation
    """Solve the problem with full rigor. After obtaining a candidate solution, 
actively attempt to refute your own answer by searching for counterexamples, 
re-running the logic from a different viewpoint, and stress-testing edge cases. 
Only after the answer survives refutation, return it in \\boxed{}. 
The answer must be an integer in [0, 99999]. Never guess.""",

    # IMO-style efficiency
    """Solve this problem as if under IMO-level time pressure: 
identify the key invariant, symmetry, or extremal principle early, 
avoid brute force unless strictly justified, compress reasoning without sacrificing correctness, 
and perform at least one final arithmetic verification pass. 
Return only the final integer answer in \\boxed{}, with 0 ≤ answer ≤ 99999. Never guess.""",

    # Dual approach
    """You must attempt at least two fundamentally different solution approaches 
(e.g., algebraic vs geometric, combinatorial vs number-theoretic). 
Proceed with the more rigorous one and use the other as a verification tool. 
Return only the verified final answer in \\boxed{}, where the answer is an integer in [0, 99999]. Never guess.""",

    # Restart on inconsistency
    """Solve the problem rigorously. If at any point a step relies on an unproven assumption, 
a jump in logic is detected, or the computation becomes inconsistent, 
you must restart the solution from first principles. 
Return only the final verified integer answer inside \\boxed{}, with 0 ≤ answer ≤ 99999. Never guess."""
]

print(f"Loaded {len(SYSTEM_PROMPTS)} system prompts")

In [None]:
# =============================================================================
# CELL 5: ANSWER EXTRACTION & VALIDATION
# =============================================================================

def extract_boxed(text: str) -> Optional[str]:
    """Extract text inside \\boxed{}."""
    pattern = r"oxed{(.*?)}"
    matches = re.findall(pattern, text)
    if not matches:
        return None
    # Return last non-empty match
    for m in reversed(matches):
        if m.strip():
            return m.strip()
    return None

def is_valid_answer(text: str) -> bool:
    """Check if answer is valid integer in [0, 99999]."""
    try:
        val = int(text)
        return 0 <= val <= 99999
    except:
        return False

def normalize_answer(raw: Any) -> Optional[int]:
    """Convert raw output to integer answer."""
    if raw is None:
        return None
    try:
        s = str(raw).replace(',', '').strip()
        s = re.sub(r'[^0-9.-]', '', s)
        if not s:
            return None
        val = int(float(s))
        if 0 <= val <= 99999:
            return val
    except:
        pass
    return None

print("Answer utilities loaded")

In [None]:
# =============================================================================
# CELL 6: LOG-WEIGHTED VOTING (from top 27/50 notebook)
# =============================================================================

def vote_answer(counter: Counter, force_answer: bool = False) -> Tuple[Optional[int], bool]:
    """
    Log-weighted voting from top 27/50 notebook.
    
    Key insight: Small answers (0, 1, 2) are often wrong guesses.
    Weight by log(1.25 + |value|) to penalize trivial answers.
    
    Returns: (answer, is_confident)
    """
    if not counter:
        return (12453 if force_answer else None, False)
    
    # Compute log-weighted scores
    modified_counter = Counter()
    for value, count in counter.items():
        # log(1.25 + |value|) penalizes small answers
        weight = math.log(1.25 + abs(value)) * count
        modified_counter[value] = weight
    
    total_score = sum(modified_counter.values())
    score_list = sorted(
        [(score, counter[value], value) for value, score in modified_counter.items()],
        key=lambda x: -x[0]
    )
    
    best_score, best_count, best_value = score_list[0]
    
    # Confidence check from top notebook:
    # Stop when best > total / (2 + log(1 + total))
    threshold = total_score / (2 + math.log(1 + total_score))
    is_confident = best_score > max(3, threshold)
    
    # Additional confidence: single answer or clear winner
    if len(score_list) == 1:
        is_confident = True
    elif len(score_list) > 1 and best_score - score_list[1][0] > 1:
        is_confident = True
    
    if force_answer:
        print(f"  Vote: {[(v, s, c) for s, c, v in score_list[:5]]}")
    
    return (best_value, is_confident)

print("Log-weighted voting loaded")

In [None]:
# =============================================================================
# CELL 7: GENERATION WITH REFLEXION
# =============================================================================

STOP_TOKENS = ["```output", "```\nOutput", "\n\n\n", "<|im_end|>"]

def generate_with_reflexion(
    question: str,
    system_prompt: str,
    counter: Counter,
    completed: set,
    question_id: str = "",
    max_iterations: int = 2
) -> Optional[int]:
    """
    Generate answer with reflexion prompting.
    
    Reflexion: After initial answer, follow up with:
    - "Are you sure?" for small answers
    - "Have you verified?" for quick answers
    - "Put in \\boxed{}" if no boxed answer
    """
    if question_id in completed:
        return None
    if time.time() >= CUTOFF_TIME:
        return None
    
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": question}
    ]
    
    for iteration in range(max_iterations):
        # Generate
        prompt = format_messages(messages)
        params = SamplingParams(
            temperature=0.7 if iteration == 0 else 0.5,
            top_p=0.9,
            max_tokens=4096,
            stop=STOP_TOKENS,
        )
        
        outputs = LLM_MODEL.generate([prompt], sampling_params=params)
        response = outputs[0].outputs[0].text
        
        # Check for early stop
        if question_id in completed:
            break
        if time.time() >= CUTOFF_TIME:
            break
        
        messages.append({"role": "assistant", "content": response})
        
        # Extract answer
        boxed = extract_boxed(response)
        
        # Reflexion logic
        if not boxed or not is_valid_answer(boxed):
            # No valid answer - ask for boxed
            follow_up = "The answer is expected to be an integer between 0 and 99999 inclusive. Place your final answer in \\boxed{}. Do not guess."
            messages.append({"role": "user", "content": follow_up})
        elif int(boxed) <= 10:
            # Small answer - ask for verification
            follow_up = "Are you sure that is the answer? Double-check your work."
            messages.append({"role": "user", "content": follow_up})
        elif iteration == 0 and len(response) < 1000:
            # Quick answer - ask for verification
            follow_up = "Have you verified your answer using an alternative method?"
            messages.append({"role": "user", "content": follow_up})
        else:
            # Good answer - break
            break
    
    # Final extraction
    full_text = "\n".join([m["content"] for m in messages if m["role"] == "assistant"])
    boxed = extract_boxed(full_text)
    
    if boxed and is_valid_answer(boxed):
        answer = int(boxed)
        counter[answer] += 1
        return answer
    
    return None

def format_messages(messages: List[Dict]) -> str:
    """Format messages for Qwen chat template."""
    parts = []
    for m in messages:
        role = m["role"]
        content = m["content"]
        parts.append(f"<|im_start|>{role}\n{content}<|im_end|>")
    parts.append("<|im_start|>assistant\n")
    return "\n".join(parts)

print("Reflexion generation loaded")

In [None]:
# =============================================================================
# CELL 8: PARALLEL SOLVER WITH EARLY STOPPING
# =============================================================================

PROBLEM_COUNTER = Counter()
COMPLETED_IDS = set()

def solve(question: str, question_id: str = "") -> int:
    """
    Solve with parallel generation + early stopping.
    
    1. Launch 4 parallel generations with different system prompts
    2. Stop early when confidence is high
    3. Vote on final answer with log-weighting
    """
    if not LLM_MODEL:
        print("  No model - returning default")
        return 12453
    
    if time.time() >= CUTOFF_TIME:
        print("  Timeout - returning default")
        return 12453
    
    counter = Counter()
    completed = set()
    
    num_generations = min(4, len(SYSTEM_PROMPTS))
    
    # Parallel generation with different system prompts
    with ThreadPoolExecutor(max_workers=4) as executor:
        futures = [
            executor.submit(
                generate_with_reflexion,
                question,
                SYSTEM_PROMPTS[i % len(SYSTEM_PROMPTS)],
                counter,
                completed,
                question_id
            )
            for i in range(num_generations)
        ]
        
        for future in as_completed(futures):
            try:
                result = future.result()
                if result is not None:
                    # Check for early stopping
                    answer, confident = vote_answer(counter)
                    if confident:
                        completed.add(question_id)
                        print(f"  Early stop: {answer} (confident)")
            except Exception as e:
                print(f"  Generation error: {e}")
    
    # Final vote
    answer, _ = vote_answer(counter, force_answer=True)
    
    if answer is None:
        # Fallback: extract numbers from question
        nums = [int(x) for x in re.findall(r'\b\d+\b', question) if 0 < int(x) < 100000]
        answer = nums[0] if nums else 12453
    
    # Clamp to valid range
    answer = max(ANSWER_MIN, min(ANSWER_MAX, int(answer)))
    
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    return answer

print("Parallel solver loaded")

In [None]:
# =============================================================================
# CELL 9: KAGGLE API INTERFACE
# =============================================================================

SOLVED = 0
TOTAL = 0

def predict(id_: pl.Series, problem: pl.Series) -> pl.DataFrame:
    """AIMO3 API predict function."""
    global SOLVED, TOTAL
    TOTAL += 1
    
    question_id = id_.item(0)
    question = problem.item(0)
    
    time_left = (CUTOFF_TIME - time.time()) / 60
    print(f"\n{'='*60}")
    print(f"Problem {TOTAL} | ID: {question_id} | Time: {time_left:.1f}m")
    print(f"Q: {question[:80]}...")
    
    answer = solve(question, question_id)
    
    print(f"ANSWER: {answer}")
    SOLVED += 1
    
    return pl.DataFrame({"id": id_, "answer": answer})

print("API interface ready")

In [None]:
# =============================================================================
# CELL 10: RUN
# =============================================================================
import kaggle_evaluation.aimo_3_inference_server

print("="*70)
print("PROMETHEUS v2 + Reflexion")
print(f"Model: {MODEL_PATH}")
print(f"Budget: {TOTAL_BUDGET//3600}h {(TOTAL_BUDGET%3600)//60}m")
print("="*70)

inference_server = kaggle_evaluation.aimo_3_inference_server.AIMO3InferenceServer(predict)

if os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
    inference_server.serve()
else:
    inference_server.run_local_gateway(
        ('/kaggle/input/ai-mathematical-olympiad-progress-prize-3/test.csv',)
    )

print(f"\nDone in {(time.time() - START_TIME)/60:.1f}m | Solved: {SOLVED}/{TOTAL}")