In [2]:
import litellm
import pandas as pd
from tqdm import tqdm
from dotenv import load_dotenv
from judge import Judge
import json
import os

load_dotenv()


#load the dataset
df_100 = pd.read_csv("../output_data/100_data_math500.csv")
# Prompts
COT_PROMPT = """
You are a math tutor helping a student. Solve the following math problem step-by-step.

Question:
{question}

Let's think step-by-step.
"""

REFLECTION_PROMPT = """
Now reflect on the solution above. Was the reasoning and final answer correct? If there was a mistake, explain it and provide a corrected solution.
"""
judge = Judge(model='gemini/gemini-2.0-flash')

df_100.head(2)


Unnamed: 0.1,Unnamed: 0,problem,solution,answer,subject,level,unique_id
0,0,"Convert the point $(0,3)$ in rectangular coord...",We have that $r = \sqrt{0^2 + 3^2} = 3.$ Also...,"\left( 3, \frac{\pi}{2} \right)",Precalculus,2,test/precalculus/807.json
1,1,Define\n\[p = \sum_{k = 1}^\infty \frac{1}{k^2...,We count the number of times $\frac{1}{n^3}$ a...,p - q,Intermediate Algebra,5,test/intermediate_algebra/1994.json


In [5]:
def generate_training_messages(df, start_idx:int, num_rows:int):
    end_idx = min(start_idx + num_rows, len(df))
    messages = [{
                'role': 'system',
                'content': '''You are a math tutor helping a student. Solve the following math problem step-by-step.
                  Question: <QUESTION>
                  Let's think step-by-step.
                ''',
            }]
    for idx in range(start_idx, end_idx):
        row = df.iloc[idx]
        
        # Base messages that are common for all examples
        messages.extend([
            {
                'role': 'user',
                'content': row['problem']
            },
            {
                'role': 'assistant',
                'content': f'solution : {row['solution']}' + '\n\n' + f'answer : {row['answer']}'
            }
        ])
        
    return messages

In [6]:
def metatuning(start_idx:int, num_rows:int,model,df_result):
    training_message = generate_training_messages(df_100, start_idx=start_idx, num_rows=num_rows)
    # LLM Call
    def call_litellm(model="openai/gpt-4o",messages=[]):
        response = litellm.completion(
            model=model,
            messages=messages,
            temperature=0
        )
        return response['choices'][0]['message']['content'].strip()

    temp_json_path = f"../output_data/temp_metatuning_rows.json"
    if os.path.exists(temp_json_path):
        with open(temp_json_path, 'r') as f:
            rows = json.load(f)
    else:
        rows = []
    already_done = len(rows)
    df_sliced = df_100.iloc[start_idx+num_rows+already_done:]
    index = 0
    for _,item in tqdm(df_sliced.iterrows()):
        index += 1
        print(index)
        q = item["problem"]
        truth = item["answer"]

        # Chain of Thought
        cot_prompt = COT_PROMPT.format(question=q)
        messages = training_message
        messages.append({"role": "user", "content": cot_prompt})
        cot_response = call_litellm(model=model,messages=messages)

        # Self Reflection
        reflection_prompt = cot_response + "\n\n" + REFLECTION_PROMPT
        reflection_response = call_litellm(model=model,messages=[{"role": "user", "content": reflection_prompt}])

        result = judge.prediction(query=q,answer1=reflection_response,answer2=truth).correct

        rows.append({
            "question": q,
            "ground_truth": truth,
            f"cot_answer_{num_rows}": cot_response,
            f"reflected_answer_{num_rows}": reflection_response,
            f"result_{num_rows}": result
        })
        # Save rows to temporary JSON file
        with open(temp_json_path, 'w') as f:
            json.dump(rows, f)

    # Create a DataFrame
    df_curr_result = pd.DataFrame(rows)
    #delete the temp json file
    os.remove(temp_json_path)

    #merge
    df_result = pd.merge(df_result,df_curr_result,on=["question","ground_truth"],how="left")
    return df_result

In [7]:
result_df_gpt4o = pd.read_csv("../output_data/math500_cot_reflection_output_gpt_4o.csv")
result_df_gemini = pd.read_csv("../output_data/math500_cot_reflection_output_gemini_1_5_flash.csv")

In [24]:
import time
df_result = pd.read_csv("../output_data/math500_metatuning_cot_reflection_output_gpt_4o.csv")
# df_result = metatuning(0,5,"openai/gpt-4o",result_df_gpt4o)
# df_result = metatuning(0,10,"openai/gpt-4o",df_result)
# df_result = metatuning(0,20,"openai/gpt-4o",df_result)
# df_result = metatuning(0,30,"openai/gpt-4o",df_result)
a = True
while a:
  try:
    df_result = metatuning(0,40,"openai/gpt-4o",df_result)
    a = False
  except Exception as e:
    time.sleep(61)
    a = True

0it [00:00, ?it/s]

1


1it [00:10, 10.64s/it]

2


2it [00:35, 19.11s/it]

3


3it [00:46, 15.27s/it]

4


4it [01:16, 21.23s/it]

5


5it [01:29, 18.24s/it]

6


6it [01:37, 14.67s/it]

7


7it [02:08, 20.14s/it]

8


8it [02:34, 21.81s/it]

9


9it [02:57, 22.33s/it]

10


10it [03:52, 32.31s/it]

11


11it [04:11, 28.24s/it]

12


12it [04:28, 24.73s/it]

13


12it [04:38, 23.17s/it]


[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new[0m
LiteLLM.Info: If you need to debug this error, use `litellm._turn_on_debug()'.




0it [00:00, ?it/s]

1


1it [00:12, 12.29s/it]

2


2it [00:23, 11.62s/it]

3


3it [00:37, 12.54s/it]

4


4it [00:51, 13.26s/it]

5


4it [00:56, 14.12s/it]


[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new[0m
LiteLLM.Info: If you need to debug this error, use `litellm._turn_on_debug()'.




0it [00:00, ?it/s]

1


1it [00:11, 11.45s/it]

2


2it [00:20, 10.02s/it]

3


3it [00:30,  9.99s/it]

4


4it [00:41, 10.52s/it]

5


5it [01:16, 19.10s/it]

6


6it [01:36, 19.40s/it]

7


7it [02:08, 23.66s/it]

8


7it [02:23, 20.47s/it]


[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new[0m
LiteLLM.Info: If you need to debug this error, use `litellm._turn_on_debug()'.




0it [00:00, ?it/s]

1


1it [00:15, 15.73s/it]

2


2it [00:28, 13.92s/it]

3


3it [00:37, 11.64s/it]

4


4it [00:47, 11.26s/it]

5


5it [01:12, 16.03s/it]

6


5it [01:26, 17.35s/it]


[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new[0m
LiteLLM.Info: If you need to debug this error, use `litellm._turn_on_debug()'.




0it [00:00, ?it/s]

1


1it [00:09,  9.58s/it]

2


2it [00:23, 12.22s/it]

3


3it [00:38, 13.62s/it]

4


4it [00:50, 12.81s/it]

5


4it [01:00, 15.21s/it]


[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new[0m
LiteLLM.Info: If you need to debug this error, use `litellm._turn_on_debug()'.




0it [00:00, ?it/s]

1


1it [00:10, 10.11s/it]

2


2it [00:22, 11.22s/it]

3


3it [00:38, 13.44s/it]

4


4it [00:54, 14.62s/it]

5


4it [01:02, 15.68s/it]


[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new[0m
LiteLLM.Info: If you need to debug this error, use `litellm._turn_on_debug()'.




0it [00:00, ?it/s]

1


1it [00:09,  9.17s/it]

2


2it [00:15,  7.75s/it]

3


3it [00:29, 10.22s/it]

4


3it [00:34, 11.64s/it]


[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new[0m
LiteLLM.Info: If you need to debug this error, use `litellm._turn_on_debug()'.




0it [00:00, ?it/s]

1


1it [00:09,  9.51s/it]

2


2it [00:22, 11.71s/it]

3


3it [00:31, 10.37s/it]

4


3it [00:38, 13.00s/it]


[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new[0m
LiteLLM.Info: If you need to debug this error, use `litellm._turn_on_debug()'.




0it [00:00, ?it/s]

1


1it [00:13, 13.44s/it]

2


2it [00:23, 11.55s/it]

3


3it [00:37, 12.47s/it]

4


4it [00:47, 11.77s/it]

5


4it [01:03, 15.90s/it]


[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new[0m
LiteLLM.Info: If you need to debug this error, use `litellm._turn_on_debug()'.




0it [00:00, ?it/s]

1


1it [00:08,  8.69s/it]

2


2it [00:21, 10.89s/it]

3


3it [00:33, 11.76s/it]

4


4it [00:47, 12.56s/it]

5


4it [01:01, 15.28s/it]


[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new[0m
LiteLLM.Info: If you need to debug this error, use `litellm._turn_on_debug()'.




0it [00:00, ?it/s]

1


1it [00:07,  7.58s/it]

2


2it [00:17,  9.00s/it]

3


3it [00:31, 11.12s/it]

4


4it [00:50, 14.25s/it]

5


4it [01:01, 15.48s/it]


[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new[0m
LiteLLM.Info: If you need to debug this error, use `litellm._turn_on_debug()'.




0it [00:00, ?it/s]

1


1it [00:11, 11.99s/it]

2


2it [00:23, 11.93s/it]

3


3it [00:43, 15.39s/it]

4


4it [00:55, 14.29s/it]

5


4it [01:04, 16.05s/it]


[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new[0m
LiteLLM.Info: If you need to debug this error, use `litellm._turn_on_debug()'.




0it [00:00, ?it/s]

1


1it [00:13, 13.17s/it]

2


2it [00:41, 20.87s/it]


In [25]:
df_result.to_csv("../output_data/math500_metatuning_cot_reflection_output_gpt_4o.csv",index=False)

In [28]:
import time
for val in [20,30,40]:
  print(f"********************{val}*****************")
  a = True
  while a:
    try:
      df_result = metatuning(0,val,"gemini/gemini-1.5-flash",df_result)
      a = False
      df_result.to_csv("../output_data/math500_metatuning_cot_reflection_output_gemini_1.5.csv",index=False)
    except Exception as e:
      time.sleep(61)
      a = True

********************20*****************


0it [00:00, ?it/s]

1


1it [00:06,  6.09s/it]

2


2it [00:15,  7.92s/it]

3


3it [00:32, 12.10s/it]

4


4it [00:50, 14.45s/it]

5


5it [01:07, 15.35s/it]

6


6it [01:28, 17.35s/it]

7


7it [01:53, 19.94s/it]

8


8it [02:19, 21.79s/it]

9


9it [02:50, 24.75s/it]

10


10it [03:18, 25.49s/it]

11


11it [03:43, 25.62s/it]

12


12it [04:09, 25.52s/it]

13


13it [04:34, 25.56s/it]

14


14it [05:01, 25.76s/it]

15


15it [05:38, 29.27s/it]

16


16it [06:05, 28.64s/it]

17


17it [06:39, 30.13s/it]

18


18it [07:09, 30.20s/it]

19


19it [07:43, 31.28s/it]

20


20it [08:49, 41.81s/it]

21


21it [09:24, 39.54s/it]

22


22it [10:02, 39.36s/it]

23


23it [10:38, 38.07s/it]

24


24it [11:09, 35.99s/it]

25


25it [11:41, 34.82s/it]

26


26it [12:20, 36.00s/it]

27


27it [13:00, 37.25s/it]

28


28it [13:37, 37.15s/it]

29


29it [14:14, 37.14s/it]

30


30it [14:50, 36.93s/it]

31


31it [15:47, 43.04s/it]

32


32it [16:37, 44.97s/it]

33


32it [17:22, 32.58s/it]


[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new[0m
LiteLLM.Info: If you need to debug this error, use `litellm._turn_on_debug()'.




0it [00:00, ?it/s]

1


1it [00:07,  7.66s/it]

2


2it [00:17,  8.79s/it]

3


3it [00:27,  9.34s/it]

4


4it [00:38, 10.07s/it]

5


5it [00:50, 10.88s/it]

6


6it [01:03, 11.67s/it]

7


7it [01:15, 11.53s/it]

8


8it [01:31, 13.10s/it]

9


9it [01:56, 16.81s/it]

10


10it [02:22, 19.75s/it]

11


11it [02:51, 22.55s/it]

12


12it [03:18, 23.85s/it]

13


13it [03:47, 25.43s/it]

14


14it [04:15, 26.27s/it]

15


15it [04:55, 30.25s/it]

16


16it [05:27, 30.78s/it]

17


17it [06:00, 31.41s/it]

18


18it [06:39, 33.70s/it]

19


19it [07:19, 35.56s/it]

20


20it [07:58, 36.79s/it]

21


21it [08:44, 39.54s/it]

22


22it [09:36, 43.09s/it]

23


23it [10:26, 45.22s/it]

24


24it [10:55, 40.33s/it]

25


25it [11:43, 42.54s/it]

26


26it [12:29, 43.73s/it]

27


27it [13:26, 47.84s/it]

28


28it [14:17, 48.72s/it]

29


29it [15:05, 48.43s/it]

30


30it [15:58, 49.75s/it]

31


31it [16:59, 53.07s/it]

32


32it [17:50, 52.56s/it]

33


33it [18:51, 55.10s/it]

34


34it [19:45, 54.64s/it]

35


35it [20:45, 56.31s/it]

36


36it [21:44, 57.15s/it]

37


36it [31:54, 53.18s/it]
0it [00:00, ?it/s]

1


1it [00:14, 14.68s/it]

2


2it [00:27, 13.82s/it]

3


3it [00:43, 14.73s/it]

4


4it [01:02, 16.26s/it]

5


5it [01:18, 16.29s/it]

6


6it [01:34, 16.21s/it]

7


7it [01:50, 15.98s/it]

8


8it [02:16, 19.18s/it]

9


9it [02:40, 20.81s/it]

10


10it [03:02, 21.05s/it]

11


11it [03:30, 23.21s/it]

12


12it [03:52, 19.41s/it]


********************30*****************


0it [00:00, ?it/s]

1


1it [00:09,  9.08s/it]

2


2it [00:19,  9.91s/it]

3


3it [00:32, 11.35s/it]

4


4it [00:48, 13.31s/it]

5


5it [01:04, 14.00s/it]

6


6it [01:21, 15.30s/it]

7


7it [01:39, 15.88s/it]

8


8it [02:02, 18.36s/it]

9


9it [02:24, 19.45s/it]

10


10it [02:46, 20.30s/it]

11


11it [03:09, 21.14s/it]

12


12it [03:38, 23.36s/it]

13


13it [04:10, 26.06s/it]

14


14it [04:39, 26.91s/it]

15


15it [05:10, 28.29s/it]

16


16it [05:42, 29.29s/it]

17


17it [06:15, 30.46s/it]

18


18it [06:50, 31.71s/it]

19


19it [07:19, 30.88s/it]

20


20it [07:53, 31.77s/it]

21


21it [08:29, 33.21s/it]

22


22it [09:07, 34.70s/it]

23


23it [09:54, 38.24s/it]

24


24it [10:44, 41.92s/it]

25


25it [11:30, 43.05s/it]

26


26it [12:20, 45.02s/it]

27


27it [13:00, 43.76s/it]

28


28it [13:43, 43.31s/it]

29


29it [14:28, 43.81s/it]

30


30it [15:21, 46.78s/it]

31


31it [16:26, 52.01s/it]

32


32it [17:32, 56.28s/it]

33


33it [18:25, 55.48s/it]

34


34it [19:24, 56.52s/it]

35


35it [20:21, 56.46s/it]

36


36it [21:20, 57.20s/it]

37


37it [22:19, 57.86s/it]

38


38it [23:23, 59.82s/it]

39


39it [24:29, 61.54s/it]

40


40it [25:35, 62.91s/it]

41


41it [26:40, 63.59s/it]

42


42it [27:47, 64.65s/it]

43


43it [28:56, 65.69s/it]

44


44it [30:04, 66.46s/it]

45


45it [31:11, 66.71s/it]

46


46it [32:19, 67.04s/it]

47


47it [33:31, 68.52s/it]

48


48it [34:35, 67.10s/it]

49


49it [35:43, 67.36s/it]

50


50it [36:52, 67.90s/it]

51


51it [38:00, 68.01s/it]

52


52it [39:05, 67.17s/it]

53


53it [40:12, 66.93s/it]

54


54it [41:21, 67.55s/it]

55


55it [42:29, 67.85s/it]

56


56it [43:52, 72.26s/it]

57


57it [44:58, 70.36s/it]

58


58it [46:07, 69.96s/it]

59


59it [47:15, 69.54s/it]

60


60it [48:22, 68.66s/it]

61


61it [49:30, 68.50s/it]

62


62it [50:40, 69.09s/it]

63


63it [51:50, 69.13s/it]

64


64it [52:54, 67.78s/it]

65


65it [54:05, 68.59s/it]

66


66it [55:15, 68.93s/it]

67


67it [56:22, 68.61s/it]

68


68it [57:30, 68.38s/it]

69


69it [58:53, 72.84s/it]

70


70it [59:59, 51.43s/it]


********************40*****************


0it [00:00, ?it/s]

1


1it [00:08,  8.31s/it]

2


2it [00:20, 10.51s/it]

3


3it [00:29, 10.03s/it]

4


4it [00:46, 12.61s/it]

5


5it [01:02, 13.79s/it]

6


6it [01:17, 14.45s/it]

7


7it [01:35, 15.47s/it]

8


8it [01:53, 16.17s/it]

9


9it [02:12, 17.29s/it]

10


10it [02:31, 17.71s/it]

11


11it [02:55, 19.60s/it]

12


12it [03:18, 20.58s/it]

13


13it [03:40, 21.07s/it]

14


14it [04:02, 21.40s/it]

15


15it [04:28, 22.72s/it]

16


16it [04:50, 22.52s/it]

17


17it [05:11, 22.18s/it]

18


18it [05:37, 23.09s/it]

19


19it [05:57, 22.21s/it]

20


20it [06:21, 22.84s/it]

21


21it [06:51, 24.86s/it]

22


22it [07:18, 25.47s/it]

23


23it [07:55, 29.11s/it]

24


24it [08:25, 29.38s/it]

25


25it [08:49, 27.66s/it]

26


26it [09:16, 27.43s/it]

27


27it [09:58, 31.78s/it]

28


28it [10:37, 33.99s/it]

29


29it [11:15, 35.18s/it]

30


30it [11:53, 36.23s/it]

31


31it [12:28, 35.84s/it]

32


32it [13:07, 36.64s/it]

33


33it [13:30, 32.64s/it]

34


34it [14:11, 34.97s/it]

35


35it [14:54, 37.53s/it]

36


36it [15:40, 40.09s/it]

37


37it [16:28, 42.57s/it]

38


38it [17:12, 42.79s/it]

39


39it [17:59, 44.07s/it]

40


40it [18:49, 45.79s/it]

41


41it [19:28, 43.81s/it]

42


42it [20:24, 47.49s/it]

43


43it [20:50, 41.07s/it]

44


44it [21:20, 37.85s/it]

45


45it [21:53, 36.37s/it]

46


46it [22:51, 42.62s/it]

47


47it [23:30, 41.65s/it]

48


48it [24:02, 38.83s/it]

49


49it [24:45, 40.19s/it]

50


50it [25:16, 37.32s/it]

51


51it [25:56, 38.20s/it]

52


52it [26:37, 39.00s/it]

53


53it [27:17, 39.22s/it]

54


54it [27:58, 39.68s/it]

55


55it [29:05, 48.01s/it]

56


56it [30:10, 52.91s/it]

57


57it [31:17, 57.41s/it]

58


58it [32:25, 60.61s/it]

59


59it [33:34, 62.97s/it]

60


60it [34:15, 34.26s/it]


In [29]:
df_result.to_csv("../output_data/math500_metatuning_cot_reflection_output_gemini_1.5_flash.csv",index=False)

In [3]:
df_result_gpt4o = pd.read_csv("../output_data/math500_metatuning_cot_reflection_output_gpt_4o.csv")
df_result_gemini = pd.read_csv("../output_data/math500_metatuning_cot_reflection_output_gemini.csv")

In [4]:
df_result_gpt4o.columns

Index(['question', 'ground_truth', 'cot_answer', 'reflected_answer', 'result',
       'cot_answer_5', 'reflected_answer_5', 'result_5', 'cot_answer_10',
       'reflected_answer_10', 'result_10', 'cot_answer_20',
       'reflected_answer_20', 'result_20', 'cot_answer_30',
       'reflected_answer_30', 'result_30', 'cot_answer_40',
       'reflected_answer_40', 'result_40'],
      dtype='object')

In [7]:
# Function to analyze result pairs
def analyze_result_pair(df, base_col, compare_col):
    # Create a mask for non-blank values in both columns
    mask = df[base_col].notna() & df[compare_col].notna()
    
    # Get the filtered dataframe
    filtered_df = df[mask]
    
    print(f"\nAnalyzing {base_col} vs {compare_col}")
    print(f"Total rows after removing blanks: {len(filtered_df)}")
    
    # Get value counts for both columns
    print(f"\n{base_col} value counts:")
    print(filtered_df[base_col].value_counts())
    
    print(f"\n{compare_col} value counts:")
    print(filtered_df[compare_col].value_counts())
    
    
    return filtered_df

# Analyze each pair
result_pairs = [
    ('result', 'result_5'),
    ('result', 'result_10'),
    ('result', 'result_20'),
    ('result', 'result_30'),
    ('result', 'result_40')
]


In [8]:
for base_col, compare_col in result_pairs:
    filtered_df = analyze_result_pair(df_result_gpt4o, base_col, compare_col)
    print("-" * 50)


Analyzing result vs result_5
Total rows after removing blanks: 95

result value counts:
result
True     75
False    20
Name: count, dtype: int64

result_5 value counts:
result_5
True     79
False    16
Name: count, dtype: int64
--------------------------------------------------

Analyzing result vs result_10
Total rows after removing blanks: 90

result value counts:
result
True     71
False    19
Name: count, dtype: int64

result_10 value counts:
result_10
True     72
False    18
Name: count, dtype: int64
--------------------------------------------------

Analyzing result vs result_20
Total rows after removing blanks: 80

result value counts:
result
True     67
False    13
Name: count, dtype: int64

result_20 value counts:
result_20
True     68
False    12
Name: count, dtype: int64
--------------------------------------------------

Analyzing result vs result_30
Total rows after removing blanks: 70

result value counts:
result
True     59
False    11
Name: count, dtype: int64

result

In [9]:
for base_col, compare_col in result_pairs:
    filtered_df = analyze_result_pair(df_result_gemini, base_col, compare_col)
    print("-" * 50)


Analyzing result vs result_5
Total rows after removing blanks: 95

result value counts:
result
True     75
False    20
Name: count, dtype: int64

result_5 value counts:
result_5
False    62
True     33
Name: count, dtype: int64
--------------------------------------------------

Analyzing result vs result_10
Total rows after removing blanks: 90

result value counts:
result
True     71
False    19
Name: count, dtype: int64

result_10 value counts:
result_10
False    63
True     27
Name: count, dtype: int64
--------------------------------------------------

Analyzing result vs result_20
Total rows after removing blanks: 80

result value counts:
result
True     67
False    13
Name: count, dtype: int64

result_20 value counts:
result_20
False    42
True     38
Name: count, dtype: int64
--------------------------------------------------

Analyzing result vs result_30
Total rows after removing blanks: 70

result value counts:
result
True     59
False    11
Name: count, dtype: int64

result