# AIMO3 - GPT-OSS-120B + Reflexion

**Model**: GPT-OSS-120B (OpenAI's open-weight model)
**Attach**: `danielhanchen/gpt-oss-120b` or `gpt-oss-120b`

Key techniques from top 38/50 solution:
- vLLM OpenAI-compatible server
- 5 diverse system prompts
- Reflexion follow-ups
- Log-weighted voting

In [None]:
import subprocess
import os
import sys

# Purge conflicts
subprocess.run(["pip", "uninstall", "--yes", "tensorflow", "matplotlib", "keras", "scikit-learn"])

# Pre-cache model files for faster loading
print("Caching model files...")
subprocess.run(["find", "/kaggle/usr/lib", "-type", "f", "-print0"], capture_output=True)

In [None]:
import os
import time
import torch
import numpy as np
import math
import re
from collections import Counter
from typing import Optional, List, Dict, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed
import polars as pl

# Timing
START_TIME = time.time()
TOTAL_BUDGET = (4 * 60 + 45) * 60  # 4h45m
CUTOFF_TIME = START_TIME + TOTAL_BUDGET

ANSWER_MIN, ANSWER_MAX = 0, 99999

print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'}")
print(f"Budget: {TOTAL_BUDGET//3600}h {(TOTAL_BUDGET%3600)//60}m")

In [None]:
# Find GPT-OSS-120B model
import glob

MODEL_PATHS = [
    "/kaggle/input/gpt-oss-120b/transformers/default/1",
    "/kaggle/input/gpt-oss-120b",
    "/kaggle/input/danielhanchen/gpt-oss-120b",
    "/kaggle/input/d/danielhanchen/gpt-oss-120b",
]

MODEL_PATH = None
for p in MODEL_PATHS:
    if os.path.exists(p):
        configs = glob.glob(f"{p}/**/config.json", recursive=True)
        if configs:
            MODEL_PATH = os.path.dirname(configs[0])
            break
        elif os.path.exists(os.path.join(p, "config.json")):
            MODEL_PATH = p
            break

if MODEL_PATH is None:
    # Search all inputs
    for root, dirs, files in os.walk("/kaggle/input"):
        if "config.json" in files and "gpt" in root.lower():
            MODEL_PATH = root
            break

print(f"Model path: {MODEL_PATH}")

In [None]:
# Cache model weights into memory
def cache_model(path, num_workers=16, chunk_mb=1024):
    import multiprocessing
    from concurrent.futures import ThreadPoolExecutor, as_completed
    
    exts = (".bin", ".pt", ".safetensors")
    
    def warmup_file(fpath):
        chunk_size = chunk_mb * 1024 * 1024
        total = 0
        with open(fpath, "rb") as f:
            while True:
                data = f.read(chunk_size)
                if not data:
                    break
                total += len(data)
        return fpath, total
    
    files = [
        os.path.join(root, name)
        for root, _, names in os.walk(path)
        for name in names
        if name.endswith(exts)
    ]
    
    if not files:
        print(f"No model files found in {path}")
        return 0
    
    print(f"Caching {len(files)} files...")
    t0 = time.time()
    total_bytes = 0
    
    with ThreadPoolExecutor(max_workers=num_workers) as pool:
        futures = {pool.submit(warmup_file, f): f for f in files}
        for i, fut in enumerate(as_completed(futures), 1):
            fpath, n = fut.result()
            total_bytes += n
            if i % 10 == 0:
                print(f"  [{i}/{len(files)}] cached")
    
    elapsed = time.time() - t0
    gb = total_bytes / 1024**3
    print(f"Cached {gb:.2f} GB in {elapsed:.1f}s ({gb/elapsed:.2f} GB/s)")
    return total_bytes

if MODEL_PATH:
    cache_model(MODEL_PATH)

In [None]:
# Start vLLM server
import subprocess

os.environ["TRANSFORMERS_NO_TF"] = "1"
os.environ["TRANSFORMERS_NO_FLAX"] = "1"
os.environ["VLLM_ATTENTION_BACKEND"] = "TRITON_ATTN"
os.environ["TRITON_PTXAS_PATH"] = "/usr/local/cuda/bin/ptxas"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

SEQUENCE_LENGTH = 65536

command = [
    "python", "-m", "vllm.entrypoints.openai.api_server",
    "--model", MODEL_PATH,
    "--served-model-name", "vllm-model",
    "--tensor-parallel-size", "1",
    "--max-num-seqs", "4",
    "--gpu-memory-utilization", "0.96",
    "--host", "0.0.0.0",
    "--port", "8000",
    "--dtype", "auto",
    "--max-model-len", str(SEQUENCE_LENGTH),
]

with open("/kaggle/working/vllm.log", "w") as logfile:
    vllm_process = subprocess.Popen(
        command, stdout=logfile, stderr=subprocess.STDOUT, start_new_session=True
    )

print(f"vLLM server starting (PID: {vllm_process.pid})")
print("Logs: /kaggle/working/vllm.log")

In [None]:
# Setup OpenAI client
from openai import OpenAI

os.environ["OPENAI_API_BASE"] = "http://127.0.0.1:8000/v1"
os.environ["OPENAI_API_KEY"] = "sk-local"

client = OpenAI(
    base_url="http://127.0.0.1:8000/v1",
    api_key="sk-local",
)

# GPT-OSS stop tokens
STOP_TOKEN_IDS = [
    token_id for token_id in range(200_000, 201_088)
    if token_id not in [200005, 200006, 200007, 200008]
]

# Wait for server
def await_server(max_wait=900):
    for _ in range(max_wait):
        time.sleep(1)
        try:
            client.models.list()
            print("Server ready!")
            return True
        except:
            continue
    print("Server timeout!")
    return False

await_server()

In [None]:
# System prompts (from top solutions)
SYSTEM_PROMPTS = [
    """You are solving a national/international-level mathematics olympiad problem. 
Rigorously define all variables, explore multiple solution strategies, 
perform full case analysis, justify every step, check boundary cases, 
and verify using at least one independent method. 
Return only the final numerical answer inside \\boxed{}. 
Answer must be an integer in [0, 99999]. Never guess.""",

    """Solve with full rigor. After obtaining a candidate solution, 
actively attempt to refute it by searching for counterexamples, 
re-running logic from a different viewpoint, and stress-testing edge cases. 
Only after the answer survives refutation, return it in \\boxed{}. 
Answer must be an integer in [0, 99999]. Never guess.""",

    """Solve as if under IMO time pressure: identify the key invariant, 
symmetry, or extremal principle early. Avoid brute force unless justified. 
Compress reasoning without sacrificing correctness. 
Perform at least one arithmetic verification pass. 
Return only the final integer in \\boxed{}, 0 ≤ answer ≤ 99999. Never guess.""",

    """Attempt at least two fundamentally different solution approaches 
(algebraic vs geometric, combinatorial vs number-theoretic). 
Proceed with the more rigorous one and use the other for verification. 
Return only the verified answer in \\boxed{}, integer in [0, 99999]. Never guess.""",

    """Solve rigorously. If any step relies on an unproven assumption, 
has a logic jump, or becomes inconsistent, restart from first principles. 
Return only the verified integer in \\boxed{}, 0 ≤ answer ≤ 99999. Never guess."""
]

print(f"{len(SYSTEM_PROMPTS)} system prompts loaded")

In [None]:
# Answer extraction
def extract_boxed(text: str) -> Optional[str]:
    pattern = r"oxed{(.*?)}"
    matches = re.findall(pattern, text)
    for m in reversed(matches):
        if m.strip():
            return m.strip()
    return None

def is_valid_answer(text: str) -> bool:
    try:
        val = int(text)
        return 0 <= val <= 99999
    except:
        return False

# Log-weighted voting (from top 27/50)
def vote_answer(counter: Counter, force: bool = False) -> Tuple[Optional[int], bool]:
    if not counter:
        return (12453 if force else None, False)
    
    scores = {v: math.log(1.25 + abs(v)) * c for v, c in counter.items()}
    total = sum(scores.values())
    ranked = sorted(scores.items(), key=lambda x: -x[1])
    
    best_val, best_score = ranked[0]
    threshold = total / (2 + math.log(1 + total))
    confident = best_score > max(3, threshold)
    
    if len(ranked) == 1 or (len(ranked) > 1 and best_score - ranked[1][1] > 1):
        confident = True
    
    if force:
        print(f"  Vote: {ranked[:5]}")
    
    return (best_val, confident)

print("Utilities loaded")

In [None]:
# Generation with reflexion
COMPLETED_IDS = set()

def generate_solution(question: str, question_id: str, idx: int, system_prompt: str, counter: Counter) -> Optional[int]:
    if question_id in COMPLETED_IDS or time.time() >= CUTOFF_TIME:
        return None
    
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": question},
    ]
    
    full_response = ""
    
    for iteration in range(2):
        if question_id in COMPLETED_IDS or time.time() >= CUTOFF_TIME:
            break
        
        try:
            resp = client.chat.completions.create(
                model="vllm-model",
                messages=messages,
                max_tokens=4096,
                temperature=0.7 if iteration == 0 else 0.5,
                extra_body=dict(
                    min_p=0.02,
                    stop_token_ids=STOP_TOKEN_IDS,
                    chat_template_kwargs=dict(enable_thinking=True)
                ),
            )
            
            response = resp.choices[0].message.content or ""
            if hasattr(resp.choices[0].message, 'reasoning_content') and resp.choices[0].message.reasoning_content:
                response = resp.choices[0].message.reasoning_content + "\n" + response
            
        except Exception as e:
            print(f"  Gen error: {e}")
            break
        
        messages.append({"role": "assistant", "content": response})
        full_response += response
        
        boxed = extract_boxed(response)
        
        # Reflexion logic
        if not boxed or not is_valid_answer(boxed):
            messages.append({"role": "user", "content": "Place your final integer answer in \\boxed{}. Answer must be 0-99999."})
        elif int(boxed) <= 10:
            messages.append({"role": "user", "content": "Are you sure? Double-check your work."})
        elif iteration == 0 and len(response) < 1000:
            messages.append({"role": "user", "content": "Have you verified your answer?"})
        else:
            break
    
    boxed = extract_boxed(full_response)
    if boxed and is_valid_answer(boxed):
        ans = int(boxed)
        counter[ans] += 1
        return ans
    
    return None

print("Generation with reflexion loaded")

In [None]:
# Main solver
def solve(question: str, question_id: str) -> int:
    global COMPLETED_IDS
    
    if time.time() >= CUTOFF_TIME:
        return 12453
    
    counter = Counter()
    num_gens = 4
    
    with ThreadPoolExecutor(max_workers=4) as executor:
        futures = [
            executor.submit(
                generate_solution,
                question, question_id, i,
                SYSTEM_PROMPTS[i % len(SYSTEM_PROMPTS)],
                counter
            )
            for i in range(num_gens)
        ]
        
        for fut in as_completed(futures):
            try:
                fut.result()
                ans, confident = vote_answer(counter)
                if confident and ans is not None:
                    COMPLETED_IDS.add(question_id)
                    print(f"  Early stop: {ans}")
            except Exception as e:
                print(f"  Error: {e}")
    
    answer, _ = vote_answer(counter, force=True)
    
    if answer is None:
        nums = [int(x) for x in re.findall(r'\b\d+\b', question) if 0 < int(x) < 100000]
        answer = nums[0] if nums else 12453
    
    return max(0, min(99999, int(answer)))

print("Solver loaded")

In [None]:
# API
SOLVED = 0

def predict(id_: pl.Series, problem: pl.Series) -> pl.DataFrame:
    global SOLVED
    SOLVED += 1
    
    qid = id_.item(0)
    question = problem.item(0)
    
    time_left = (CUTOFF_TIME - time.time()) / 60
    print(f"\n{'='*50}")
    print(f"Problem {SOLVED} | {qid} | {time_left:.1f}m left")
    print(f"Q: {question[:60]}...")
    
    answer = solve(question, qid)
    COMPLETED_IDS.add(qid)
    
    print(f"ANSWER: {answer}")
    return pl.DataFrame({"id": id_, "answer": answer})

print("API ready")

In [None]:
# Run
import kaggle_evaluation.aimo_3_inference_server

print("="*60)
print("AIMO3 - GPT-OSS-120B + Reflexion")
print(f"Model: {MODEL_PATH}")
print("="*60)

server = kaggle_evaluation.aimo_3_inference_server.AIMO3InferenceServer(predict)

if os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
    server.serve()
else:
    server.run_local_gateway(
        ('/kaggle/input/ai-mathematical-olympiad-progress-prize-3/test.csv',)
    )

print(f"\nDone in {(time.time()-START_TIME)/60:.1f}m | Solved: {SOLVED}")