In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["TRITON_PTXAS_PATH"] = "/usr/local/cuda/bin/ptxas"
import re
import time
import random
import warnings
from collections import Counter
import numpy as np, pandas as pd, polars as pl

import torch
import vllm
from vllm import LLM, SamplingParams

import kaggle_evaluation.aimo_2_inference_server

warnings.simplefilter('ignore')

In [None]:
def seed_everything(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True
seed_everything(seed=5)

start_time = time.time()
cutoff_time = start_time + (4 * 60 + 50) * 60
cutoff_times = [int(x) for x in np.linspace(cutoff_time, start_time + 60 * 60, 50 + 1)]

In [None]:
if os.getenv('KAGGLE_KERNEL_RUN_TYPE') or os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
    llm_model_pth = '/kaggle/input/m/shelterw/deepseek-r1/transformers/light-r1-7b-ds-awq/1'
else:
    llm_model_pth = '/kaggle/input/deepseek-r1/transformers/deepseek-r1-distill-qwen-7b-awq-casperhansen/1'

MAX_NUM_SEQS = 75
MAX_MODEL_LEN = 30000

llm = LLM(
    llm_model_pth,
    #dtype="half",                 # The data type for the model weights and activations
    max_num_seqs=MAX_NUM_SEQS,    # Maximum number of sequences per iteration. Default is 256
    max_model_len=MAX_MODEL_LEN,  # Model context length
    trust_remote_code=True,       # Trust remote code (e.g., from HuggingFace) when downloading the model and tokenizer
    tensor_parallel_size=4,       # The number of GPUs to use for distributed execution with tensor parallelism
    gpu_memory_utilization=0.95,  # The ratio (between 0 and 1) of GPU memory to reserve for the model
    seed=2025,
)

tokenizer = llm.get_tokenizer()

In [None]:
def extract_boxed_text(text):
    pattern = r'\\boxed{([^}]*)}'
    matches = re.findall(pattern, text)
    if not matches:
        return ""
    for match in matches[::-1]:
        if match != "":
            return match
    return ""

def batch_message_filter(list_of_messages) -> tuple[list[list[dict]], list[str]]:
    extracted_answers = []
    list_of_messages_to_keep = []
    for messages in list_of_messages:
        answer = extract_boxed_text(messages[-1]['content'])
        if answer:
            extracted_answers.append(answer)
        else:
            list_of_messages_to_keep.append(messages)
    return list_of_messages_to_keep, extracted_answers

def select_answer(answers):
    counter = Counter()
    for answer in answers:
        try:
            if int(answer) == float(answer):
                counter[int(answer)] += 1 + random.random() / 1_000
        except:
            pass
    if not counter:
        return 210
    _, answer = sorted([(v,k) for k,v in counter.items()], reverse=True)[0]
    return answer%1000

def batch_message_generate(list_of_messages) -> list[list[dict]]:
    max_tokens = MAX_MODEL_LEN -2000
    if time.time() > cutoff_times[-1]:
        print("Speedrun")
        max_tokens = 2 * (MAX_MODEL_LEN - 2000 ) // 3

    sampling_params = SamplingParams(
        temperature=0.5,               # Randomness of the sampling
        top_p=0.95,                    # Cumulative probability of the top tokens to consider
        min_p=0.05,                    # Minimum probability for a token to be considered
        #frequency_penalty=0.4,         # Balanced formula repetition control
        #presence_penalty=0.3,
        skip_special_tokens=True,      # Whether to skip special tokens in the output
        max_tokens=max_tokens,         # Maximum number of tokens to generate
        stop=["</think>"],             # List of strings that stop the generation
        seed=777,
    )
    
    list_of_texts = [
        tokenizer.apply_chat_template(
            conversation=messages,
            tokenize=False,
            add_generation_prompt=True
        )
        for messages in list_of_messages
    ]

    request_output = llm.generate(
        prompts=list_of_texts,
        sampling_params=sampling_params,
    )
    print([len(single_request_output.outputs[0].token_ids) for single_request_output in request_output])

    sort_keys_and_list_of_messages = []
    for messages, single_request_output in zip(list_of_messages, request_output):
        #print()
        #print(single_request_output.outputs[0].text)
        #print()
        messages.append({'role': 'assistant', 'content': single_request_output.outputs[0].text})

        sort_keys_and_list_of_messages.append(
            (
                len(single_request_output.outputs[0].token_ids),
                messages
            )
        )
    print(f"First sort : {[sort_key for sort_key, _ in sort_keys_and_list_of_messages]}")
    sort_keys_and_list_of_messages.sort(key=lambda sort_key_and_messages: sort_key_and_messages[0])
    print(f"Second sort: {[sort_key for sort_key, _ in sort_keys_and_list_of_messages]}")
    
    list_of_messages = [messages for _, messages in sort_keys_and_list_of_messages]
    return list_of_messages

In [None]:
def batch_message_generate(list_of_messages) -> list[list[dict]]:
    # Enforce single-sequence processing for iterative refinement
    assert len(list_of_messages) == 1, "Use sequential mode for iterative prompts"
    
    # Dynamic token limits based on remaining time
    max_tokens = MAX_MODEL_LEN
    

    # Configure sampling parameters
    sampling_params = SamplingParams(
        temperature=0.5,
        top_p=0.95,
        min_p=0.05,
        skip_special_tokens=True,
        max_tokens=max_tokens,
        stop=["</think>"],
        seed=777,
    )

    # Convert message chain to prompt text
    list_of_texts = [
        tokenizer.apply_chat_template(
            conversation=messages,
            tokenize=False,
            add_generation_prompt=True
        )
        for messages in list_of_messages
    ]

    # Generate response using LLM
    request_output = llm.generate(
        prompts=list_of_texts,
        sampling_params=sampling_params,
    )

    # Update message chain with generated response
    updated_messages = []
    for messages, single_request_output in zip(list_of_messages, request_output):
        # Append assistant's response to message history
        messages.append({
            'role': 'assistant', 
            'content': single_request_output.outputs[0].text
        })
        
        # Store token count for sorting (maintains compatibility)
        updated_messages.append((
            len(single_request_output.outputs[0].token_ids),
            messages
        ))

    # Sort by generated response length (ascending)
    updated_messages.sort(key=lambda x: x[0])
    
    # Return only the message chain (drop length metadata)
    return [messages for _, messages in updated_messages]

In [None]:
prompts = [#"You must Take quality time to Verify every small cases upto the 7th value before assuming a pattern. There should be NO OVERSIGHT OR ASSUMPTION DUE TO GENERALIZATION. Summarize what you have done so far making sure every calculation is VERY ACCURATE!!! Complete the solution and arrive at the ONLY accurate answer after taking modulo 1000. Put your final answer within \\boxed{}.",
           #"Solve this question step by step. Recheck every critical step for optimal accuracy. Do not mis‐simplify any expression. Summarize what you have done so far. Complete the solution. Take modulo 1000 of final answer. **IMPORTANT**: Put your final answer within \\boxed{}.",
           "Break this question down first. Carefully Analyze and follow the problem statement exactly . Avoid: (- mis-calculations, - overly rough approximations, -imprecise handling of the radical–axis relations, - misapplication of an alternate counting method that over‐counted by introducing unnecessary factors, - incorrect generalization of the sum‐of‐digits formula over the whole range). **IMPORTANT**: Arrive at early answers. Before using an alternative method, Put possible answer within \\boxed{} after taking modulo 1000.",
           "You MUST IGNORE  theoretical reasoning. You MUST TAKE QUALITY TIME TO PERFORM A real-time exhaustive verification of ALL CASEs using Sagemath or Sympy. **IMPORTANT**: Arrive at early answers. Before using an alternative method, Put possible answer within \\boxed{} after taking modulo 1000."
          ]



In [None]:
def predict_for_question(question: str) -> int:
    if time.time() > cutoff_time:
        return 210

    question += " Please read the question again to understand it very well. Avoid off-by-one calculation"
    
    messages = [{"role": "system", "content": ""}]
    all_extracted_answers = []

    for prompt_idx in range(2):
        # Append current prompt configuration
        messages.append({
            "role": "user", 
            "content": f"{prompts[prompt_idx]}\n\n{question}"
        })

        # Generate and print response
        messages = batch_message_generate([messages])[0]
        
        print(f"\n===== PROMPT {prompt_idx+1} RESPONSE =====")
        print(f"Prompt Config: {prompts[prompt_idx][:100]}...")  # Show first 100 chars of prompt
        print(f"Generated Response: {messages[-1]['content'][: -300]}...")  # Show first 500 chars of response
        print("-"*70)

        # Extract and store answer
        _, extracted = batch_message_filter([messages])
        all_extracted_answers.extend(extracted)
        print(f"Extracted answers: {all_extracted_answers}")
    
    
        
        # Reset context while keeping system prompt
        messages = [messages[0]] + [messages[-1]]  # Carry final response forward if needed

    answer = select_answer(all_extracted_answers)
    print(f"Answer: {answer}")
    return answer % 1000

In [None]:
def predict(id_: pl.DataFrame, question: pl.DataFrame) -> pl.DataFrame | pd.DataFrame:
    id_ = id_.item(0)
    print("------")
    print(id_)
    question = question.item(0)
    answer = predict_for_question(question)
    print(question)
    print("------\n\n")
    return pl.DataFrame({'id': id_, 'answer': answer})

In [None]:
pd.read_csv(
    '/kaggle/input/ai-mathematical-olympiad-progress-prize-2/reference.csv'
).drop('answer', axis=1).to_csv('reference.csv', index=False)

In [None]:
inference_server = kaggle_evaluation.aimo_2_inference_server.AIMO2InferenceServer(predict)
if os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
    inference_server.serve()
else:
    inference_server.run_local_gateway(
        (
#            '/kaggle/input/ai-mathematical-olympiad-progress-prize-2/test.csv',
            'reference.csv',
        )
    )