In [None]:
import os
os.environ["TRITON_PTXAS_PATH"] = "/usr/local/cuda/bin/ptxas"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import warnings
warnings.simplefilter('ignore')

from collections import Counter
import random
import gc
import time
import pandas as pd
import polars as pl
import numpy as np
import torch
import re
import ast
import math
import kaggle_evaluation.aimo_2_inference_server
from vllm import LLM, SamplingParams
import vllm

print(vllm.__version__)

SEED = 42
def seed_everything(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True
seed_everything(seed=SEED)

pd.set_option('display.max_colwidth', None)
start_time = time.time()

# Adjust these for your environment / runtime limits
cutoff_time = start_time + (4 * 60 + 45) * 60
cutoff_times = [int(x) for x in np.linspace(cutoff_time, start_time + 60 * 60, 50 + 1)]

# llm_model_pth = '/kaggle/input/deepseek-r1/transformers/deepseek-aideepseek-r1-distill-qwen-14b-awq-neody/1'
llm_model_pth = '/kaggle/input/scale7b-alpha0.1/transformers/14b-alpha0.2/1'
MAX_NUM_SEQS = 3
MAX_MODEL_LEN = 16384

llm = LLM(
    llm_model_pth,
    dtype="half", 
    max_num_seqs=MAX_NUM_SEQS,
    max_model_len=MAX_MODEL_LEN,
    trust_remote_code=True,
    tensor_parallel_size=4,
    gpu_memory_utilization=0.95,
    seed=SEED,
)
tokenizer = llm.get_tokenizer()

In [2]:
def extract_boxed_text(text):
    # Try to find answers in \boxed{...} format first
    boxed_pattern = r'\\boxed{(.*?)}'
    boxed_matches = re.findall(boxed_pattern, text)
    for match in boxed_matches[::-1]:
        if match.strip() != "":
            return match.strip()
    
    # Look for "the answer is X" or similar patterns
    answer_patterns = [
        r'the answer is[:\s]+([0-9]+)',
        r'answer[:\s=]+([0-9]+)',
        r'final answer[:\s=]+([0-9]+)',
        r'answer is[:\s]+([0-9]+)',
        r'the answer is\s*([0-9]+)'  # \s* means zero or more whitespace characters
        r'answer is\s*([0-9]+)',
        r'therefore.*?is\s*([0-9]+)'
        # r'therefore,? the answer is[:\s]+([0-9]+)',
        # r'we get[:\s]+([0-9]+)',
        # r'result is[:\s]+([0-9]+)',
        # r'the result is[:\s]+([0-9]+)',
        # r'value is[:\s]+([0-9]+)',
        # r'the value is[:\s]+([0-9]+)',
        # r'our answer is[:\s]+([0-9]+)',
        # r'equals[:\s]+([0-9]+)$',
        # r'=\s*([0-9]+)$',
        # r'gives us[:\s]+([0-9]+)',
        # # Match statements at the end of text that are just numbers
        # r'(?:^|\n)\s*([0-9]+)\s*$'
    ]
    
    for pattern in answer_patterns:
        matches = re.findall(pattern, text, re.IGNORECASE)
        if matches:
            return matches[-1].strip()
    
    return ""

In [3]:
def select_answer(answers):
    if not answers:
        return 210
        
    # Try to convert all answers to integers
    valid_answers = []
    for answer in answers:
        try:
            answer_int = int(answer)
            if answer_int == float(answer):  # Ensure it's a valid integer
                valid_answers.append(answer_int)
        except:
            pass
    
    if not valid_answers:
        return 210
    
    # Count frequencies
    counter = Counter(valid_answers)
    
    # Get the most common answer
    most_common = counter.most_common()
    
    # If there's a clear winner with multiple votes, use it
    if len(most_common) > 1 and most_common[0][1] > most_common[1][1]:
        return most_common[0][0] % 1000
        
    # For digits sum problems specifically (common pattern in problem IDs)
    if any("sum of digit" in str(answer).lower() for answer in answers):
        # These problems often have answers ending in 9 or 1
        nine_answers = [a for a in valid_answers if a % 10 == 9]
        if nine_answers:
            return nine_answers[0] % 1000
    
    # Otherwise use the most common answer
    return most_common[0][0] % 1000

In [4]:
def create_starter_messages(question, index):
    # For Scale7B, we can use a simple text format instead of trying to use the chat template
    # This matches the fine-tuning input format exactly
    prompt = f'<｜begin▁of▁sentence｜><｜User｜>Please reason step by step, and put your final answer within \\boxed{{}} after taking modulo 1000. Question: {question}<｜Assistant｜>'
    
    options = []
    for _ in range(13):
        options.append([prompt])  # Return just the prompt string
    
    return options[index%len(options)]

def batch_message_generate(list_of_messages) -> list[list[dict]]:
    max_tokens = MAX_MODEL_LEN
    if time.time() > cutoff_times[-1]:
        print("Speedrun")
        max_tokens = 2 * MAX_MODEL_LEN // 3

    sampling_params = SamplingParams(
        temperature=1.0,
        top_p=0.9,
        min_p=0.01,
        seed=SEED,
        skip_special_tokens=True,
        max_tokens=max_tokens,
        stop=["</think>"]
    )
    
    # No need to apply chat template - use the prompts directly
    list_of_texts = [messages[0] for messages in list_of_messages]  # Extract the prompt string

    request_output = llm.generate(
        prompts=list_of_texts,
        sampling_params=sampling_params,
    )
    
    # Process the outputs as before
    sort_keys_and_list_of_messages = []

    for messages, single_request_output in zip(list_of_messages, request_output):
        messages.append({'role': 'assistant', 'content': single_request_output.outputs[0].text})

        sort_keys_and_list_of_messages.append(
            (
                len(single_request_output.outputs[0].token_ids),
                messages
            )
        )

    sort_keys_and_list_of_messages.sort(key=lambda sort_key_and_messages: sort_key_and_messages[0])
    list_of_messages = [messages for _, messages in sort_keys_and_list_of_messages]
    
    return list_of_messages

In [5]:
def batch_message_filter(list_of_messages) -> tuple[list[list[dict]], list[str]]:
    extracted_answers = []
    list_of_messages_to_keep = []
    
    print("\n=== CoT Token Counts and Answers ===")
    for i, messages in enumerate(list_of_messages):
        # Get token count from the most recent response
        # This uses the actual token count from the tokenizer
        token_count = len(tokenizer.encode(messages[-1]['content']))
        
        # Extract answer
        answer = extract_boxed_text(messages[-1]['content'])
        
        # Print the information
        print(f"CoT{i+1}: {token_count} tokens, answer: {answer if answer else 'None'}")
        
        if answer:
            extracted_answers.append(answer)
        else:
            list_of_messages_to_keep.append(messages)
    
    return list_of_messages_to_keep, extracted_answers

In [6]:
def predict_for_question(question: str) -> int:

    if time.time() > cutoff_time:
        return 210
    
    print(question)

    num_seqs = MAX_NUM_SEQS
    if time.time() > cutoff_times[-1]:
        num_seqs = 2 * MAX_NUM_SEQS // 3
    
    list_of_messages = [create_starter_messages(question, index) for index in range(num_seqs)]
    original_list = list_of_messages.copy()  # Save a copy of the original list
    
    all_extracted_answers = []
    for _ in range(1):
        list_of_messages = batch_message_generate(list_of_messages)
        
        # Save debug output
        if not os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
            df = pd.DataFrame(
                {
                    "question": [question] * len(list_of_messages),
                    "message": [messages[-1]["content"] for messages in list_of_messages],
                }
            )
            df.to_csv(f"{str(int(time.time() - start_time)).zfill(5)}.csv", index=False)
        
        # Original processing
        list_of_messages, extracted_answers = batch_message_filter(list_of_messages)
        all_extracted_answers.extend(extracted_answers)
    
    # Choose between majority voting and shortest CoT selection
    use_shortest_cot = False  # Set to True to use shortest CoT selection
    
    if use_shortest_cot:
        answer = select_shortest_cot_answer(list_of_messages, all_extracted_answers)
    else:
        answer = select_answer(all_extracted_answers)
    
    print(f"Final answer: {answer}")
    print("\n\n")
    cutoff_times.pop()
    return answer

In [7]:
def predict(id_: pl.DataFrame, question: pl.DataFrame) -> pl.DataFrame | pd.DataFrame:
    id_ = id_.item(0)
    print("------")
    print(id_)
    
    question = question.item(0)
    answer = predict_for_question(question)
    # print(question)
    # print("------\n\n\n")
    return pl.DataFrame({'id': id_, 'answer': answer})

In [8]:
pd.read_csv(
    '/kaggle/input/ai-mathematical-olympiad-progress-prize-2/reference.csv'
).drop('answer', axis=1).to_csv('reference.csv', index=False)

In [9]:
# inference_server = kaggle_evaluation.aimo_2_inference_server.AIMO2InferenceServer(predict)

# if os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
#     inference_server.serve()
# else:
#     inference_server.run_local_gateway(
#         (
#             'reference.csv',
#         )
#     )

In [10]:
if os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
    inference_server = kaggle_evaluation.aimo_2_inference_server.AIMO2InferenceServer(predict)
    inference_server.serve()
else:
    print("Not in competition environment - Saved!")

Not in competition environment - Saved!
