In [1]:
import adaptive_consistency
import pandas as pd
from adaptive_consistency import BetaStoppingCriteria
import consol
import langchain_openai
import langchain_core
import os


  from .autonotebook import tqdm as notebook_tqdm


In [4]:
df = pd.read_parquet("hf://datasets/Maxwell-Jia/AIME_2024/aime_2024_problems.parquet")
df = df.rename(columns={'Problem': 'input', 'Answer': 'target'})
df = df[['input', 'target']]

In [25]:
llm = langchain_openai.ChatOpenAI(
    model="o3-mini",
    openai_api_key=os.getenv("OPENAI_API_KEY"),
    reasoning_effort="low"
)

In [14]:
import abc
import pydantic
import typing

class AbstractOutput(pydantic.BaseModel, abc.ABC):
    answer: str

class ReasonedMixin(abc.ABC):
    reasons: typing.List[str]

class FloatOutput(AbstractOutput):
    answer: float

class ABCDEOutput(AbstractOutput):
    answer: typing.Literal["A", "B", "C", "D", "E"]
    

llm_with_structured_output = llm.with_structured_output(FloatOutput, include_raw=True)

In [97]:
beta = BetaStoppingCriteria(0.95)

In [None]:
import tqdm
stopped = {}
answers = {}
logs = {}
response = {}
for i, problem in enumerate(df.input):
    stopped[i] = False
    answers[i] = []
    logs[i] = []
    response[i] = []
N = len(df.input)

for k in tqdm.tqdm(range(40)):
    if all(stopped): break

    need_to_idx = [i for i in range(N) if not stopped[i]]
    messages = [langchain_core.messages.HumanMessage(df.input[x]) for x in need_to_idx]    
    output = llm_with_structured_output.batch([str(x.content) for x in messages])

    for i in range(len(need_to_idx)):
        try:
            answers[need_to_idx[i]].append(output[i]['parsed'].answer)
            logs[need_to_idx[i]].append(output[i])
        except Exception as e:
            print(f"error at {need_to_idx[i]} at {k}-th iteration")
            print(output[i])
            continue

    for i in range(N):
        if stopped[i]: continue
        if len(answers[i]) <= 0: continue
        if beta.should_stop(answers[i])['stop']:
            response[i] = beta.should_stop(answers[i])['most_common']
            stopped[i] = True

  8%|████████▊                                                                                                             | 3/40 [10:27<1:46:06, 172.06s/it]

error at 5 at 2-th iteration
{'raw': AIMessage(content='', additional_kwargs={'parsed': None, 'refusal': '{\n  "answer": 14\n} \n\n/* Explanation:\n\nWe will explain briefly one way to “see” the answer.\n\nDefine\n\u2003\u2003f(x) = ||x| – ½|\u2003\u2003and\u2003\u2003g(x) = ||x| – ¼|.\nThen the two curves may be written as\n\n\u2003\u2003y = 4 g(f( sin(2πx) ))\n\u2003\u2003x = 4 g(f( cos(3πy) )).\n\nA short (but non‐trivial) analysis shows that, in every “period‐cell” (in fact, one may prove that the range of both functions is [0,1] so that both graphs lie in the unit square) the “sin–curve”\n\u2003\u2003C₁ = { (x, y) : y = 4 g(f( sin(2πx) )),  x ∈ ℝ }\noscillates in a “double‐W” pattern (its “W” arises because f and g are “V–shaped”) while the “cos–curve”\n\u2003\u2003C₂ = { (x, y) : x = 4 g(f( cos(3πy) )),  y ∈ ℝ }\noscillates in a “triple–W” pattern. A careful (piece‐wise linear) count then shows that the curves meet in exactly 14 points.\n\nFor example, one may show that after “un

 10%|███████████▊                                                                                                          | 4/40 [13:06<1:40:01, 166.71s/it]

error at 4 at 3-th iteration
{'raw': AIMessage(content='', additional_kwargs={'parsed': None, 'refusal': '{\n  "answer": 1763\n}'}, response_metadata={'token_usage': {'completion_tokens': 16923, 'prompt_tokens': 132, 'total_tokens': 17055, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 16896, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'o3-mini-2025-01-31', 'system_fingerprint': 'fp_42bfad963b', 'finish_reason': 'stop', 'logprobs': None}, id='run-4e4a2cae-ac0e-43fc-a809-21987d8d8f33-0', usage_metadata={'input_tokens': 132, 'output_tokens': 16923, 'total_tokens': 17055, 'input_token_details': {'audio': 0, 'cache_read': 0}, 'output_token_details': {'audio': 0, 'reasoning': 16896}}), 'parsing_error': OpenAIRefusalError('{\n  "answer": 1763\n}'), 'parsed': None}


 42%|█████████████████████████████████████████████████▋                                                                   | 17/40 [43:53<1:00:18, 157.31s/it]

error at 5 at 16-th iteration
{'raw': AIMessage(content='', additional_kwargs={'parsed': None, 'refusal': '{\n  "answer": 48\n} \n\n/* Explanation:\n\nWe will show that the two‐curve system\n       y = 4·g(f(sin(2πx)))\u2003\u2003and\u2003\u2003x = 4·g(f(cos(3πy)))\n(where\n       f(t)= ||t| – ½|\u2003\u2003and\u2003\u2003g(t)= ||t| – ¼| )\nhas exactly 48 solutions (pairs (x,y) in ℝ²). \n\nA very brief outline of one rigorous approach is as follows.\n\nStep\xa01. Notice that because f and g are “V–shaped” (even functions) and because\n       |sin(2πx)| ≤ 1  and  |cos(3πy)| ≤ 1,\none may show that \n       4g(f(sin(2πx))) ∈ [0,1]  and  4g(f(cos(3πy))) ∈ [0,1].\nThus any solution (x,y) must have x,y ∈ [0,1].\n\nStep\xa02. A short computation shows that\n       f(sin(2πx)) = ||sin(2πx)| – ½|\nand then\n       g(f(sin(2πx))) = | ||sin(2πx)| – ½| – ¼ |.\nThus writing\n       y = 4| ||sin(2πx)| – ½| – ¼ |,\none may verify (by “book‐keeping” the critical values – namely the numbers\n0, ¼, ½, 

 60%|███████████████████████████████████████████████████████████████████████▍                                               | 24/40 [58:28<35:14, 132.18s/it]

In [73]:
stopped

{0: True,
 1: True,
 2: True,
 3: True,
 4: True,
 5: True,
 6: False,
 7: True,
 8: True,
 9: True,
 10: True,
 11: True,
 12: True,
 13: True,
 14: True,
 15: True,
 16: True,
 17: False,
 18: True,
 19: True,
 20: True,
 21: True,
 22: True,
 23: True,
 24: True,
 25: True,
 26: True,
 27: False,
 28: True,
 29: True}

In [75]:
for i in range(30):
    print(i, len(answers[i]))

0 4
1 4
2 4
3 4
4 4
5 16
6 40
7 4
8 4
9 4
10 4
11 4
12 4
13 4
14 4
15 7
16 4
17 40
18 9
19 4
20 4
21 8
22 7
23 4
24 4
25 31
26 4
27 40
28 15
29 4


In [63]:
beta.should_stop(answers[1])

{'most_common': 23.0, 'prob': 0.75, 'stop': False}

In [76]:
response

{0: 33.0,
 1: 23.0,
 2: 116.0,
 3: 809.0,
 4: 197.0,
 5: 24.0,
 6: [],
 7: 601.0,
 8: 25.0,
 9: 55.0,
 10: 540.0,
 11: 45.0,
 12: 204.0,
 13: 699.0,
 14: 294.0,
 15: 110.0,
 16: 721.0,
 17: [],
 18: 468.0,
 19: 902.0,
 20: 211.0,
 21: 80.0,
 22: 480.0,
 23: 236.0,
 24: 73.0,
 25: 113.0,
 26: 127.0,
 27: [],
 28: 104.0,
 29: 321.0}

In [77]:
beta.should_stop(answers[6])

{'most_common': 371.0, 'prob': 0.5, 'stop': False}

In [78]:
response[6] = 371.0

In [79]:
beta.should_stop(answers[17])

{'most_common': 30.0, 'prob': 0.7382664680480958, 'stop': False}

In [80]:
response[17] = 30.0

In [81]:
beta.should_stop(answers[27])

{'most_common': 104.0, 'prob': 0.7789658308029176, 'stop': False}

In [82]:
response[27] = 104.0

In [83]:
response

{0: 33.0,
 1: 23.0,
 2: 116.0,
 3: 809.0,
 4: 197.0,
 5: 24.0,
 6: 371.0,
 7: 601.0,
 8: 25.0,
 9: 55.0,
 10: 540.0,
 11: 45.0,
 12: 204.0,
 13: 699.0,
 14: 294.0,
 15: 110.0,
 16: 721.0,
 17: 30.0,
 18: 468.0,
 19: 902.0,
 20: 211.0,
 21: 80.0,
 22: 480.0,
 23: 236.0,
 24: 73.0,
 25: 113.0,
 26: 127.0,
 27: 104.0,
 28: 104.0,
 29: 321.0}

In [89]:
for i, row in df.iterrows():
    is_correct = row['target'] == response[i]
    if not is_correct:
        print(i, row['target'], response[i])

5 385 24.0
17 315 30.0


In [90]:
answers[5]

[12.0,
 8.0,
 24.0,
 20.0,
 24.0,
 6.0,
 0.0,
 24.0,
 8.0,
 24.0,
 24.0,
 16.0,
 24.0,
 96.0,
 24.0,
 24.0]

In [92]:
pd.Series(answers[17]).value_counts()

30.0     12
245.0     9
54.0      3
45.0      3
975.0     2
153.0     2
85.0      1
69.0      1
39.0      1
81.0      1
585.0     1
15.0      1
75.0      1
36.0      1
55.0      1
Name: count, dtype: int64

In [96]:
beta.conf_thresh

0.95

In [99]:
bayesian_posterior = consol.confidence_models.BayesianPosteriorConfidenceModel()

In [102]:
bayesian_posterior.test(4,0)

True

In [103]:
bayesian_posterior.test(1,0)

False

In [106]:
bayesian_posterior.test(4,1)

False

In [107]:
sprt_posterior = consol.confidence_models.SprtConfidenceModel()

In [110]:
sprt_posterior.test(4,0)

False

In [111]:
sprt_posterior.test(5,0)

False

In [114]:
sprt_posterior.test(16,0)

True

In [115]:
mSprt_posterior = consol.confidence_models.MsprtConfidenceModel()

In [117]:
sprt_posterior.test(6,0)

False

In [125]:
for i in range(0,40):
    for j in range(0,i+1):
        if i + j == 0: 
            break
        ac = beta.should_stop([10] * i + [20] * j)
        ours = bayesian_posterior.test(i,j) 

        if ac['stop'] != ours:
            print(i, j, ac['stop'], ours)