In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["TRITON_PTXAS_PATH"] = "/usr/local/cuda/bin/ptxas"
import re
import time
import random
import warnings
from collections import Counter
import numpy as np, pandas as pd, polars as pl

import torch
import vllm
from vllm import LLM, SamplingParams

import kaggle_evaluation.aimo_2_inference_server

warnings.simplefilter('ignore')

INFO 03-30 09:21:58 __init__.py:183] Automatically detected platform cuda.


In [2]:
def seed_everything(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True
seed_everything(seed=5)

start_time = time.time()
cutoff_time = start_time + (4 * 60 + 50) * 60
cutoff_times = [int(x) for x in np.linspace(cutoff_time, start_time + 60 * 60, 50 + 1)]

In [3]:
if os.getenv('KAGGLE_KERNEL_RUN_TYPE') or os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
    llm_model_pth = '/kaggle/input/m/shelterw/deepseek-r1/transformers/light-r1-7b-ds-awq/1'
else:
    llm_model_pth = '/kaggle/input/deepseek-r1/transformers/deepseek-r1-distill-qwen-7b-awq-casperhansen/1'

MAX_NUM_SEQS = 75
MAX_MODEL_LEN = 31000

llm = LLM(
    llm_model_pth,
    #dtype="half",                 # The data type for the model weights and activations
    max_num_seqs=MAX_NUM_SEQS,    # Maximum number of sequences per iteration. Default is 256
    max_model_len=MAX_MODEL_LEN,  # Model context length
    trust_remote_code=True,       # Trust remote code (e.g., from HuggingFace) when downloading the model and tokenizer
    tensor_parallel_size=4,       # The number of GPUs to use for distributed execution with tensor parallelism
    gpu_memory_utilization=0.95,  # The ratio (between 0 and 1) of GPU memory to reserve for the model
    seed=2025,
)

tokenizer = llm.get_tokenizer()

INFO 03-30 09:22:28 config.py:526] This model supports multiple tasks: {'embed', 'score', 'reward', 'classify', 'generate'}. Defaulting to 'generate'.
INFO 03-30 09:22:31 awq_marlin.py:109] The model is convertible to awq_marlin during runtime. Using awq_marlin kernel.
INFO 03-30 09:22:32 config.py:1383] Defaulting to use mp for distributed inference
INFO 03-30 09:22:32 llm_engine.py:232] Initializing a V0 LLM engine (v0.7.1) with config: model='/kaggle/input/m/shelterw/deepseek-r1/transformers/light-r1-7b-ds-awq/1', speculative_config=None, tokenizer='/kaggle/input/m/shelterw/deepseek-r1/transformers/light-r1-7b-ds-awq/1', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=31000, download_dir=None, load_format=auto, tensor_parallel_size=4, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=awq_marlin, enforce_eager=False, kv_cache_dtype=auto,

Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]


INFO 03-30 09:23:13 model_runner.py:1116] Loading model weights took 1.3375 GB
[1;36m(VllmWorkerProcess pid=351)[0;0m [1;36m(VllmWorkerProcess pid=354)[0;0m INFO 03-30 09:23:13 model_runner.py:1116] Loading model weights took 1.3375 GB
INFO 03-30 09:23:13 model_runner.py:1116] Loading model weights took 1.3375 GB
[1;36m(VllmWorkerProcess pid=359)[0;0m INFO 03-30 09:23:13 model_runner.py:1116] Loading model weights took 1.3375 GB
[1;36m(VllmWorkerProcess pid=354)[0;0m [1;36m(VllmWorkerProcess pid=359)[0;0m [1;36m(VllmWorkerProcess pid=351)[0;0m INFO 03-30 09:23:35 worker.py:266] Memory profiling takes 21.23 seconds
INFO 03-30 09:23:35 worker.py:266] Memory profiling takes 21.23 seconds
[1;36m(VllmWorkerProcess pid=354)[0;0m [1;36m(VllmWorkerProcess pid=359)[0;0m [1;36m(VllmWorkerProcess pid=351)[0;0m INFO 03-30 09:23:35 worker.py:266] the current vLLM instance can use total_gpu_memory (22.28GiB) x gpu_memory_utilization (0.95) = 21.16GiB
INFO 03-30 09:23:35 worker.p

Capturing CUDA graph shapes:   0%|          | 0/13 [00:00<?, ?it/s]

INFO 03-30 09:23:41 model_runner.py:1435] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.


Capturing CUDA graph shapes:  92%|█████████▏| 12/13 [00:11<00:00,  1.09it/s]

[1;36m(VllmWorkerProcess pid=354)[0;0m 

Capturing CUDA graph shapes: 100%|██████████| 13/13 [00:12<00:00,  1.03it/s]

INFO 03-30 09:23:53 model_runner.py:1563] Graph capturing finished in 13 secs, took 0.16 GiB
[1;36m(VllmWorkerProcess pid=351)[0;0m INFO 03-30 09:23:53 model_runner.py:1563] Graph capturing finished in 13 secs, took 0.16 GiB
INFO 03-30 09:23:53 model_runner.py:1563] Graph capturing finished in 13 secs, took 0.16 GiB
[1;36m(VllmWorkerProcess pid=359)[0;0m INFO 03-30 09:23:53 model_runner.py:1563] Graph capturing finished in 13 secs, took 0.16 GiB
INFO 03-30 09:23:53 llm_engine.py:429] init engine (profile, create kv cache, warmup model) took 39.85 seconds





In [4]:
#base_system_prompt = """You are a helpful and harmless math assistant. You should learn transferable problem-solving strategies from the given "Example", rather than memorizing specific instances. Break this question into parts. Take quality time to Verify every case and condition in the intermediate step. Iterate over every value of n from the least to the highest value without any assumptions or generalizations. CRUCIAL: Your response MUST END with exactly one \\boxed{} containing the final answer after modulo 1000. Don't stop generating until you reach this \\boxed{} answer.\n\n"""

TOPIC_PROMPTS = {
    "Number Theory": [
        {
            "role": "user",
            "content": "Break this question down first. Carefully Analyze and follow the problem statement exactly . Avoid: (- mis-calculations, - overly rough approximations, -imprecise handling of the radical–axis relations, - misapplication of an alternate counting method that over‐counted by introducing unnecessary factors, - incorrect generalization of the sum‐of‐digits formula over the whole range). **IMPORTANT**: Arrive at early answers. Before using an alternative method, Put possible answer within \\boxed{} after taking modulo 1000."
        },
        {
            "role": "user",
            "content": "You MUST IGNORE  theoretical reasoning. You MUST TAKE QUALITY TIME TO PERFORM A real-time exhaustive verification of ALL CASES using Sagemath or Sympy. **IMPORTANT**: Arrive at early answers. Before using an alternative method, Put ALL possible answers within \\boxed{} after taking modulo 1000."
        },
        {
            "role": "user",
            "content": "Don't solve, ONLY take time to verify all search cases to the 7th case then STOP QUICKLY. Summarize with verified steps. NOTE: There are other cases that meet the condition after the 7th case, but because of time constraints and computational limitations, stop at the 7th case and find a pattern that satisfies other cases. You MUST AVOID initial rough search that might hint at a potential candidate, a more rigorous and detailed analysis is needed. DON'T hypothesize. **IMPORTANT**: Arrive at early answers. Before using an alternative method, Put ALL possible answers within \\boxed{} after taking modulo 1000."
        }
        
    ],
    
    "Geometry": [
        {
            "role": "user",
            "content": "Break this question down first. Carefully Analyze and follow the problem statement exactly . Avoid: (- mis-calculations, - overly rough approximations, -imprecise handling of the radical–axis relations, - misapplication of an alternate counting method that over‐counted by introducing unnecessary factors, - incorrect generalization of the sum‐of‐digits formula over the whole range). **IMPORTANT**: Arrive at early answers. Before using an alternative method, Put possible answer within \\boxed{} after taking modulo 1000."
        },
        {
            "role": "user",
            "content": "Break this question down first. Analyze Verify  Conditions. Solve. No mis-calculations. No overly rough approximations. No imprecise handling of the radical–axis relations. Take modulo 1000 of final answer."
        },
        {
            "role": "user",
            "content": "Break this question down first, then analyze and verify every condition and solve it. Construct geometric diagrams and apply geometric principles, Don't make any assumptions. Avoid miscalculations when applying Theorems, and overly rough approximations and an imprecise handling of the radical–axis relations. **IMPORTANT**: Arrive at early answers. Before using an alternative method, Put possible answer within \\boxed{} after taking modulo 1000."
        }
    ],
    
    "Combinatorics": [
        {
            "role": "user",
            "content": "Break this question down first. Carefully Analyze and follow the problem statement exactly . Avoid: (- mis-calculations, - overly rough approximations, -imprecise handling of the radical–axis relations, - misapplication of an alternate counting method that over‐counted by introducing unnecessary factors, - incorrect generalization of the sum‐of‐digits formula over the whole range). **IMPORTANT**: Arrive at early answers. Before using an alternative method, Put possible answer within \\boxed{} after taking modulo 1000."
        },
        {
            "role": "user",
            "content": "You must Take quality time to Verify every small case for optimal accuracy with Exhaustive CAS Checks without making any assumptions in the question you will receive. If you don't verify every small case accurately, don't solve anything else. There should be NO OVERSIGHT OR ASSUMPTION DUE TO GENERALIZATION. Summarize what you have done so far making sure every calculation is VERY ACCURATE!!! Complete the solution and arrive at the ONLY accurate answer after taking modulo 1000. Put your final answer within \\boxed{}."
        },
        {
            "role": "user",
            "content": "Solve this question step by step. Recheck every critical step for optimal accuracy. Do not mis‐simplify any expression. Summarize what you have done so far. Complete the solution. Take modulo 1000 of final answer. **IMPORTANT**: Put your final answer within \\boxed{}."
        }
    ],
    
    "Algebra/Modular Arithmetic": [
        {
            "role": "user",
            "content": "Break this question down first. Carefully Analyze and follow the problem statement exactly . Avoid: (- mis-calculations, - overly rough approximations, -imprecise handling of the radical–axis relations, - misapplication of an alternate counting method that over‐counted by introducing unnecessary factors, - incorrect generalization of the sum‐of‐digits formula over the whole range). **IMPORTANT**: Arrive at early answers. Before using an alternative method, Put possible answer within \\boxed{} after taking modulo 1000."
        },
        {
            "role": "user",
            "content": "Solve this question step by step. Recheck every critical step for optimal accuracy. Do not mis‐simplify any expression. Summarize what you have done so far. Complete the solution. Take modulo 1000 of final answer. **IMPORTANT**: Put your final answer within \\boxed{}."
        },
        {
            "role": "user",
            "content": "You must Take quality time to Verify every small case for optimal accuracy with Exhaustive CAS Checks without making any assumptions in the question you will receive. If you don't verify every small case accurately, don't solve anything else. There should be NO OVERSIGHT OR ASSUMPTION DUE TO GENERALIZATION. Summarize what you have done so far making sure every calculation is VERY ACCURATE!!! Complete the solution and arrive at the ONLY accurate answer after taking modulo 1000. Put your final answer within \\boxed{}."
        }
        
    ]
}


In [5]:
import re


def detect_topic(question):
    question = question.lower()
    
    topic_patterns = {
        "Number Theory": {
            "keywords": {"gcd", "lcm", "mod", "divisible", "prime", "congruence", 
                         "remainder", "factor", "multiple", "coprime", "composite",
                         "diophantine", "euclidean", "bezout", "chinese remainder",
                         "greatest common divisors", "sum of digits", "digit sum", 
                         "s(n)", "base 10", "digits"},
            "weight": 1.2,
            "regex": r"mod(ulo)?\s*\d+|prime\s*power|≡\s*\d+\s*\(mod|sum\s+of\s+digits"
        },
        "Combinatorics": {
            "keywords": {"arrange", "combination", "permutation", "probability", 
                        "graph", "tree", "path", "subset", "binomial", "pigeonhole",
                        "inclusion-exclusion", "recurrence", "generating function"},
            "weight": 1.1,
            "regex": r"\d+\s*ways|arrang(e|ing).* (not|never)"
        },
        "Geometry": {
            "keywords": {"triangle", "circle", "coordinate", "volume", "area", 
                        "circumradius", "altitude", "hypotenuse", "parabola", 
                        "ellipse", "theorem", "bisector", "bisects", 
                        "circumcircle of triangle", "orthocenter"},
            "weight": 1.0,
            "regex": r"\b(r|d)\s*=\s*\d+|foot\s+of\s+perpendicular"
        }
    }

    special_cases = {
        "Number Theory": [
            r"\bs\(n\)\b",
            r"sum\s+of\s+digits",
            r"base\s+10",
            r"\d+day",
            r"polynomial\s+congruence",
            r"exponential\s+mod",
            r"remainder\s+when\s+divided"
        ],
        "Geometry": [
            r"circumradius\s*=\s*\d+",
            r"triangle.*side\s+lengths",
            r"volume\s+of\s+cylinder"
        ],
        "Combinatorics": [
            r"at\s+least\s+\d+\s*heads",
            r"probability\s+of\s+winning",
            r"vertices\s+and\s+edges"
        ]
    }

    scores = {topic: 0 for topic in topic_patterns}
    
    # Base scoring
    for topic, data in topic_patterns.items():
        # Keyword matches
        kw_matches = len([kw for kw in data["keywords"] if kw in question])
        scores[topic] += kw_matches * data["weight"]
        
        # Regex matches
        if "regex" in data:
            scores[topic] += len(re.findall(data["regex"], question)) * 2
    
    # Special case boosts
    for topic, patterns in special_cases.items():
        for pattern in patterns:
            if re.search(pattern, question, re.IGNORECASE):
                scores[topic] += 3
    
    # Priority tiebreaker
    priority_order = [
        "Number Theory",
        "Combinatorics", 
        "Geometry"
    ]
    
    max_score = max(scores.values())
    candidates = [t for t, s in scores.items() if s == max_score]
    
    for topic in priority_order:
        if topic in candidates:
            return topic
    
    return "Number Theory"  # Fallback to highest priority

In [6]:
def extract_boxed_text(text):
    pattern = r'\\boxed{([^}]*)}'
    matches = re.findall(pattern, text)
    if not matches:
        return ""
    for match in matches[::-1]:
        if match != "":
            return match
    return ""

def batch_message_filter(list_of_messages) -> tuple[list[list[dict]], list[str]]:
    extracted_answers = []
    list_of_messages_to_keep = []
    for messages in list_of_messages:
        answer = extract_boxed_text(messages[-1]['content'])
        if answer:
            extracted_answers.append(answer)
        else:
            list_of_messages_to_keep.append(messages)
    return list_of_messages_to_keep, extracted_answers

def select_answer(answers):
    counter = Counter()
    for answer in answers:
        try:
            if int(answer) == float(answer):
                counter[int(answer)] += 1 + random.random() / 1_000
        except:
            pass
    if not counter:
        return 210
    _, answer = sorted([(v,k) for k,v in counter.items()], reverse=True)[0]
    return answer%1000

In [7]:
def batch_message_generate(list_of_messages) -> list[list[dict]]:
    # Enforce single-sequence processing for iterative refinement
    assert len(list_of_messages) == 1, "Use sequential mode for iterative prompts"
    
    # Dynamic token limits based on remaining time
    max_tokens = MAX_MODEL_LEN
    

    # Configure sampling parameters
    sampling_params = SamplingParams(
        temperature=0.6,
        top_p=0.95,
        min_p=0.05,
        skip_special_tokens=True,
        max_tokens=max_tokens,
        stop=["</think>"],
        seed=777,
    )

    # Convert message chain to prompt text
    list_of_texts = [
        tokenizer.apply_chat_template(
            conversation=messages,
            tokenize=False,
            add_generation_prompt=True
        )
        for messages in list_of_messages
    ]

    # Generate response using LLM
    request_output = llm.generate(
        prompts=list_of_texts,
        sampling_params=sampling_params,
    )

    # Update message chain with generated response
    updated_messages = []
    for messages, single_request_output in zip(list_of_messages, request_output):
        # Append assistant's response to message history
        messages.append({
            'role': 'assistant', 
            'content': single_request_output.outputs[0].text
        })
        
        # Store token count for sorting (maintains compatibility)
        updated_messages.append((
            len(single_request_output.outputs[0].token_ids),
            messages
        ))

    # Sort by generated response length (ascending)
    updated_messages.sort(key=lambda x: x[0])
    
    # Return only the message chain (drop length metadata)
    return [messages for _, messages in updated_messages]

In [8]:
#prompts = [#"You must Take quality time to Verify every small cases upto the 7th value before assuming a pattern. There should be NO OVERSIGHT OR ASSUMPTION DUE TO GENERALIZATION. Summarize what you have done so far making sure every calculation is VERY ACCURATE!!! Complete the solution and arrive at the ONLY accurate answer after taking modulo 1000. Put your final answer within \\boxed{}.",
           #"Solve this question step by step. Recheck every critical step for optimal accuracy. Do not mis‐simplify any expression. Summarize what you have done so far. Complete the solution. Take modulo 1000 of final answer. **IMPORTANT**: Put your final answer within \\boxed{}.",
           #"Break this question down first. Carefully Analyze and follow the problem statement exactly . Avoid: (- mis-calculations, - overly rough approximations, -imprecise handling of the radical–axis relations, - misapplication of an alternate counting method that over‐counted by introducing unnecessary factors, - incorrect generalization of the sum‐of‐digits formula over the whole range). **IMPORTANT**: Arrive at early answers. Before using an alternative method, Put possible answer within \\boxed{} after taking modulo 1000.",
           #"You MUST IGNORE  theoretical reasoning. You MUST TAKE QUALITY TIME TO PERFORM A real-time exhaustive verification of ALL CASEs using Sagemath or Sympy. **IMPORTANT**: Arrive at early answers. Before using an alternative method, Put possible answer within \\boxed{} after taking modulo 1000."
          #]



In [9]:
def predict_for_question(question: str) -> int:
    if time.time() > cutoff_time:
        return 210
    

    question += " When solving a problem, YOU MUST follow the problem statement exactly and avoid switching to alternate methods that introduce extra factors or over-counting.. YOU MUST put ALL the suspected answers in \\boxed{} before verifying further"
    

    topic = detect_topic(question)
    prompts = TOPIC_PROMPTS.get(topic, TOPIC_PROMPTS["Number Theory"])
    all_extracted_answers = []

    # Create 3 INDEPENDENT message threads
    for i in range(3):
        # Fresh message history for each prompt
        messages = [
            {"role": "system", "content": ""},
            {"role": "user", "content": f"{prompts[i]}\n\n{question}"}
        ]
        
        # Generate and print response
        messages = batch_message_generate([messages])[0]
        
        print(f"\n===== PROMPT {i+1} RESPONSE =====")
        print(f"Prompt Config: {prompts[i][:100]}...")  # Show first 100 chars of prompt
        print(f"Generated Response: {messages[-1]['content'][-500:]}...")# Show last 500 chars of response
        print("-"*70)

        # Extract and store answer
        _, extracted = batch_message_filter([messages])
        all_extracted_answers.extend(extracted)
        print(f"Extracted answers: {all_extracted_answers}")
    
    
        
        # Reset context while keeping system prompt
        messages = [messages[0]] + [messages[-1]]  # Carry final response forward if needed

    answer = select_answer(all_extracted_answers)
    print(f"Answer: {answer}")
    return answer % 1000

In [10]:
def predict(id_: pl.DataFrame, question: pl.DataFrame) -> pl.DataFrame | pd.DataFrame:
    id_ = id_.item(0)
    print("------")
    print(id_)
    question = question.item(0)
    answer = predict_for_question(question)
    print(question)
    print("------\n\n")
    return pl.DataFrame({'id': id_, 'answer': answer})

In [11]:
pd.read_csv(
    '/kaggle/input/ai-mathematical-olympiad-progress-prize-2/reference.csv'
).drop('answer', axis=1).to_csv('reference.csv', index=False)

In [12]:
inference_server = kaggle_evaluation.aimo_2_inference_server.AIMO2InferenceServer(predict)
if os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
    inference_server.serve()
else:
    inference_server.run_local_gateway(
        (
#            '/kaggle/input/ai-mathematical-olympiad-progress-prize-2/test.csv',
            'reference.csv',
        )
    )

------
a1d40b


Processed prompts: 100%|██████████| 1/1 [01:11<00:00, 71.82s/it, est. speed input: 4.05 toks/s, output: 103.81 toks/s]



===== PROMPT 1 RESPONSE =====


GatewayRuntimeError: (<GatewayRuntimeErrorType.SERVER_RAISED_EXCEPTION: 3>, "unhashable type: 'slice'")