In [None]:
import os
import gc
import torch
import numpy as np
import time
import warnings
from typing import Optional
import pandas as pd
import polars as pl
import kaggle_evaluation.aimo_3_inference_server
import re
import keyword
from collections import Counter
import random
from math import ceil
from vllm import LLM, SamplingParams
from prompts import step1_prompt, self_improvement_prompt, correction_prompt, verification_system_prompt, verification_remider

warnings.simplefilter('ignore')
pd.set_option('display.max_colwidth', None)

os.environ["TRANSFORMERS_NO_TF"] = "1"
os.environ["TRANSFORMERS_NO_FLAX"] = "1"
os.environ["TRITON_PTXAS_PATH"] = "/usr/local/cuda/bin/ptxas"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [None]:
start_time = time.time()
cutoff_time = start_time + (4 * 60 + 53) * 60
cutoff_times = [int(x) for x in np.linspace(cutoff_time, time.time() + 500, 50)]

In [None]:
llm_model_pth = '/kaggle/input/deepseek-r1/transformers/deepseek-r1-distill-llama-70b-awq/1'

MAX_NUM_SEQS = 256
MAX_MODEL_LEN = 32768

llm = LLM(
    llm_model_pth,
    dtype="bfloat16",            # The data type for the model weights and activations
    max_num_seqs=MAX_NUM_SEQS,   # Maximum number of sequences per iteration. Default is 256
    max_model_len=MAX_MODEL_LEN, # Model context length
    trust_remote_code=True,      # Trust remote code (e.g., from HuggingFace) when downloading the model and tokenizer
    tensor_parallel_size=1,      # The number of GPUs to use for distributed execution with tensor parallelism
    gpu_memory_utilization=0.98, # The ratio (between 0 and 1) of GPU memory to reserve for the model
    seed=391,
    enable_prefix_caching=True,
)

tokenizer = llm.get_tokenizer()

In [None]:
def extract_boxed_text(text):
    pattern = r'oxed{(.*?)}'
    matches = re.findall(pattern, text)
    if not matches:
        return ""
    for match in matches[::-1]:
        if match != "":
            return match
    return ""
    
def select_answer(answers):
    counter = Counter()
    for answer in answers:
        try:
            if int(answer) == float(answer):
                counter[int(answer)] += 1 + random.random() / 1000
        except:
            pass
    if not counter:
        return 3
    _, answer = sorted([(v,k) for k,v in counter.items()], reverse=True)[0]
    return answer%100000

In [None]:
def extract_detailed_solution(solution, marker='Detailed Solution', after=True):
    """
    从解决方案字符串中提取 '### Detailed Solution ###' 之后或之前的文本。
    """
    idx = solution.find(marker)
    if idx == -1:
        return ''
    if after:
        return solution[idx + len(marker):].strip()
    else:
        return solution[:idx].strip()

In [None]:
def create_starter_messages(question, index):
    options = []
    for _ in range(3):
        options.append(
            [
                {"role": "system", "content": step1_prompt},
                {"role": "user", "content": question + ' Return final answer within \\boxed{}, after taking modulo 100000.'},
            ]
        )
    
    return options[index%len(options)]

In [None]:
def verify_solution(problem_statement, solution, verbose=True):
    dsol = extract_detailed_solution(solution)
    newst = f"""
======================================================================
### Problem ###

{problem_statement}

======================================================================
### Solution ###

{dsol}

{verification_remider}
"""
    if verbose:
        print(">>>>>>> Start verification.")
    
    contents1 = [{"role": "user", "parts": [{"text": newst}]}]
    out = call_gemini_api(
        system_instruction=verification_system_prompt, 
        contents=contents1,
        verbose=verbose
    )
    if not out:
        print(">>>>>>> Verification call failed.")
        return "", "no"

    if verbose:
        print(">>>>>>> Verification results:")
        print(out)

    check_correctness = f'Response in "yes" or "no". Is the following statement saying the solution is correct, or does not contain critical error or a major justification gap?\n\n{out}'
    contents2 = [{"role": "user", "parts": [{"text": check_correctness}]}]
    o = call_gemini_api(
        system_instruction=None,
        contents=contents2,
        verbose=verbose
    )
    if not o:
        print(">>>>>>> Verification check call failed.")
        return "", "no"

    if verbose:
        print(">>>>>>> Is verification good?")
        print(o)
        
    bug_report = ""
    if "yes" not in o.lower():
        bug_report = extract_detailed_solution(out, "Detailed Verification", False)

    if verbose:
        print(">>>>>>> Bug report:")
        print(bug_report)
    
    return bug_report, o

[{'role': 'system', 'content': '\n### Core Instructions ###\n\n*   **Rigor is Paramount:** Your primary goal is to produce a complete and rigorously justified solution. Every step in your solution must be logically sound and clearly explained. A correct final answer derived from flawed or incomplete reasoning is considered a failure.\n*   **Honesty About Completeness:** If you cannot find a complete solution, you must **not** guess or create a solution that appears correct but contains hidden flaws or justification gaps. Instead, you should present only significant partial results that you can rigorously prove. A partial result is considered significant if it represents a substantial advancement toward a full solution. Examples include:\n    *   Proving a key lemma.\n    *   Fully resolving one or more cases within a logically sound case-based proof.\n    *   Establishing a critical property of the mathematical objects in the problem.\n    *   For an optimization problem, proving an up