In [1]:
from typing import List, Dict, Tuple, Optional, Union, Any

import pandas as pd

from prompt_template_collection import PromptTemplate
from doraemon import Doraemon
from relaxed_fda import RelaxedFDA
from metrics_collection import MetricsHelper

task_name = "GSM8K"
hp_m = None # -> SoTs 4,6,8,10(default),12
hp_k = 7 # -> cluster 1,3,4,5,7
# Set up logger
logger = Doraemon.get_logger(name=task_name, logfile=f"relaxed-fda_on_{task_name}.log")


df=pd.read_pickle('/kaggle/input/building-gsm8k-sots-dataset/sots.pkl')
df=df.rename(columns={'question':'query', 'reason':'r_s', 'ground_truth':'g_t','temperature':'t_p'})
if hp_m:
    df = df.groupby('query').head(hp_m).reset_index(drop=True)
df.to_pickle('sots_df.pkl')
grouped_data = RelaxedFDA.prepare_dataset(file_path='/kaggle/working/sots_df.pkl', mode='r_s', logger=logger).head(1200)

D: List[Dict] = PromptTemplate.sot_construct_inter_gsm8k()[:1]
op_system_prompt = RelaxedFDA.get_optimize_system_prompt(task_name)
encoder=RelaxedFDA.get_encoder()

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.4/44.4 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m124.9/124.9 kB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m210.7/210.7 kB[0m [31m11.6 MB/s[0m eta [36m0:00:00[0m
[?25h

2025-07-04 01:11:59.990494: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1751591520.213195      13 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1751591520.275347      13 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-07-04 01:12:15,803 INFO Loaded dataset from /kaggle/working/sots_df.pkl with shape (11003, 7)
2025-07-04 01:12:15,804 INFO In r_s mode.
2025-07-04 01:12:15,874 INFO Dataset preparation completed with 1249 entries.


In [2]:
result_pd: pd.DataFrame = RelaxedFDA.evaluate(
    grouped_data,
    logger,
    op_system_prompt,
    encoder,
    D,
    enable_logger_rs=False,
    K=hp_k
)

  0%|          | 0/1200 [00:00<?, ?it/s]2025-07-04 01:12:21,919 INFO Cluster 1: |C_0| = 1, P(r_0|do(X)) ≈ 0.11
2025-07-04 01:12:26,058 INFO Cluster 2: |C_1| = 2, P(r_1|do(X)) ≈ 0.22
2025-07-04 01:12:29,160 INFO Cluster 3: |C_2| = 1, P(r_2|do(X)) ≈ 0.11
2025-07-04 01:12:39,228 INFO Cluster 4: |C_3| = 1, P(r_3|do(X)) ≈ 0.11
2025-07-04 01:12:53,981 INFO Cluster 5: |C_4| = 1, P(r_4|do(X)) ≈ 0.11
2025-07-04 01:13:00,652 INFO Cluster 6: |C_5| = 2, P(r_5|do(X)) ≈ 0.22
2025-07-04 01:13:04,982 INFO Cluster 7: |C_6| = 1, P(r_6|do(X)) ≈ 0.11
2025-07-04 01:13:08,239 INFO Aggregated candidate votes (weighted): {'18': 1.0}
2025-07-04 01:13:08,241 INFO Final aggregated answer is 18 and weight 1.0
2025-07-04 01:13:08,243 INFO Overall estimated probability (aggregated): 1.00
  return self.fit(X, sample_weight=sample_weight).labels_
2025-07-04 01:13:08,659 INFO Cluster 1: |C_0| = 1, P(r_0|do(X)) ≈ 0.11
2025-07-04 01:13:11,251 INFO Cluster 2: |C_1| = 1, P(r_1|do(X)) ≈ 0.11
2025-07-04 01:13:13,736 INFO Cl

In [3]:
MetricsHelper.evaluate(result_pd, logger, log=True)

2025-07-04 10:35:49,870 INFO Evaluated 1200 examples
2025-07-04 10:35:49,871 INFO Exact Match: 978/1200 = 81.50%
2025-07-04 10:35:49,872 INFO Average F1 Score: 81.64%
2025-07-04 10:35:49,872 INFO Accuracy: 81.50%
