In [1]:
from typing import List, Dict, Tuple, Optional, Union, Any

import pandas as pd

from prompt_template_collection import PromptTemplate
from doraemon import Doraemon
from relaxed_fda import RelaxedFDA
from metrics_collection import MetricsHelper

task_name = "MATH"
hp_m = None
hp_k = 4
# Set up logger
logger = Doraemon.get_logger(name=task_name, logfile=f"relaxed_fda_on_{task_name}.log")

# Prepare dataset with the logger passed as a parameter
file_path = '/kaggle/input/gpt35-building-sots-on-math/sots_df.pkl'
df=pd.read_pickle(file_path)
if hp_m:
    df = df.groupby('query').head(hp_m).reset_index(drop=True)
df.to_pickle('sots_df.pkl')

grouped_data = RelaxedFDA.prepare_dataset(file_path='/kaggle/working/sots_df.pkl', logger=logger).head(1200)

D: List[Dict] = PromptTemplate.sot_construct_inter_math()
op_system_prompt = RelaxedFDA.get_optimize_system_prompt(task_name)
encoder=RelaxedFDA.get_encoder()

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.4/44.4 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m124.9/124.9 kB[0m [31m4.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m210.7/210.7 kB[0m [31m10.3 MB/s[0m eta [36m0:00:00[0m
[?25h

2025-07-04 07:01:35.419901: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1751612495.711551      13 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1751612495.795195      13 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-07-04 07:01:55,835 INFO Loaded dataset from /kaggle/working/sots_df.pkl with shape (10800, 8)
2025-07-04 07:01:55,836 INFO In r_s mode.
2025-07-04 07:01:55,924 INFO Dataset preparation completed with 1200 entries.


In [2]:
result_pd: pd.DataFrame = RelaxedFDA.evaluate(
    grouped_data,
    logger,
    op_system_prompt,
    encoder,
    D,
    enable_logger_rs=False,
    K=hp_k
)

  0%|          | 0/1200 [00:00<?, ?it/s]2025-07-04 07:02:03,615 INFO Cluster 1: |C_0| = 6, P(r_0|do(X)) ≈ 0.67
2025-07-04 07:02:07,317 INFO Cluster 2: |C_1| = 1, P(r_1|do(X)) ≈ 0.11
2025-07-04 07:02:10,152 INFO Cluster 3: |C_2| = 1, P(r_2|do(X)) ≈ 0.11
2025-07-04 07:02:13,082 INFO Cluster 4: |C_3| = 1, P(r_3|do(X)) ≈ 0.11
2025-07-04 07:02:16,290 INFO Aggregated votes (weighted): {'2': 1.0} | Final answer: 2 (weight: 1.0) | Overall prob: 1.00
  0%|          | 1/1200 [00:14<4:58:01, 14.91s/it]2025-07-04 07:02:16,816 INFO Cluster 1: |C_0| = 3, P(r_0|do(X)) ≈ 0.33
2025-07-04 07:02:18,093 INFO Cluster 2: |C_1| = 2, P(r_1|do(X)) ≈ 0.22
2025-07-04 07:02:19,462 INFO Cluster 3: |C_2| = 1, P(r_2|do(X)) ≈ 0.11
2025-07-04 07:02:20,845 INFO Cluster 4: |C_3| = 3, P(r_3|do(X)) ≈ 0.33
2025-07-04 07:02:22,451 INFO Aggregated votes (weighted): {'1': 0.3333333333333333, '6': 0.2222222222222222, '\\frac{1': 0.1111111111111111, '3\\sqrt{2': 0.3333333333333333} | Final answer: 1 (weight: 0.3333333333333333)

In [3]:
MetricsHelper.evaluate(result_pd, logger, log=True)

2025-07-04 13:48:21,485 INFO Evaluated 1200 examples
2025-07-04 13:48:21,486 INFO Exact Match: 580/1200 = 48.33%
2025-07-04 13:48:21,487 INFO Average F1 Score: 49.35%
2025-07-04 13:48:21,488 INFO Accuracy: 48.33%
