In [1]:
from typing import List, Dict, Tuple, Optional, Union, Any
import pandas as pd
import asyncio

from prompt_template_collection import PromptTemplate
from doraemon import Doraemon
from relaxed_fda import RelaxedFDA
from metrics_collection import MetricsHelper

task_name = "FEVER"
hp_m = None # -> SoTs 4,6,8,10,12
hp_k = 4 # -> cluster 1,3,4,5,7
conc = True
logger = Doraemon.get_logger(name=__name__, logfile="relaxed_FDA_on_fever.log")
df=pd.read_pickle('/kaggle/input/building-sot-on-fever/sots_df.pkl')
df=df.rename(columns={'question':'query', 'reason':'r_s', 'evidence':'context', 'ground_truth':'g_t','temperature':'t_p'})
if hp_m:
    df = df.groupby('query').head(hp_m).reset_index(drop=True)
df.to_pickle('sots_df.pkl')

grouped_data = RelaxedFDA.prepare_dataset(file_path='/kaggle/working/sots_df.pkl', mode='r_s', logger=logger)

encoder=RelaxedFDA.get_encoder()
 
D: List[Dict] = PromptTemplate.sot_construct_inter_fever(number=5)
op_system_prompt = RelaxedFDA.get_optimize_system_prompt(task_name)

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.4/44.4 kB[0m [31m852.3 kB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m124.9/124.9 kB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m210.7/210.7 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[?25h

2025-07-25 07:49:48.691400: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1753429789.000449      13 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1753429789.089791      13 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-07-25 07:50:10,303 INFO Loaded dataset from /kaggle/working/sots_df.pkl with shape (10607, 6)
2025-07-25 07:50:10,304 INFO In r_s mode.
2025-07-25 07:50:10,374 INFO Dataset preparation completed with 1184 entries.


In [2]:
if conc:
    tasks=grouped_data.to_dict(orient='records')

    batch_size = 20
    
    async def process_batches(tasks, *args, **kwargs):
        results = []
        for i in range(0, len(tasks), batch_size):
            batch = tasks[i:i + batch_size]
            result_pd = await RelaxedFDA.async_evaluate(
                batch,
                *args,
                **kwargs
            )
            results.append(result_pd)
            await asyncio.sleep(5)
        # Concatenate all DataFrames into one
        return pd.concat(results, ignore_index=True)
    

    result_pd: pd.DataFrame = await process_batches(
        tasks,
        logger,
        op_system_prompt,
        encoder,
        D,
        enable_logger_rs=False,
        ablation='all',
        K=hp_k,
        T=3,
        max_tokens=300,
        max_concurrent=2
    )

else:    
    result_pd: pd.DataFrame = RelaxedFDA.evaluate(
        grouped_data,
        logger,
        op_system_prompt,
        encoder,
        D,
        enable_logger_rs=False,
        ablation='wo_wt',
        K=hp_k,
        T=2,
        max_tokens=300,
    )

Evaluating rows:   0%|          | 0/20 [00:00<?, ?it/s]2025-07-25 07:50:25,449 INFO Initialized client for provider gpt3
Evaluating rows: 100%|██████████| 20/20 [01:16<00:00,  3.82s/it]
  return self.fit(X, sample_weight=sample_weight).labels_
Evaluating rows:   5%|▌         | 1/20 [00:51<16:19, 51.57s/it]2025-07-25 07:52:30,073 INFO Cluster 4: No items found
Evaluating rows: 100%|██████████| 20/20 [01:13<00:00,  3.65s/it]
Evaluating rows:   0%|          | 0/20 [00:00<?, ?it/s]2025-07-25 07:53:47,994 ERROR Inference failed for index 0, retrying in 2s: Error code: 429 - {'error': {'code': '429', 'message': 'Requests to the ChatCompletions_Create Operation under Azure OpenAI API version 2025-01-01-preview have exceeded token rate limit of your current AIServices S0 pricing tier. Please retry after 1 second. Please go here: https://aka.ms/oai/quotaincrease if you would like to further increase the default rate limit. For Free Account customers, upgrade to Pay as you Go here: https://aka.m

In [3]:
MetricsHelper.evaluate(result_pd, logger)

2025-07-25 09:16:49,440 INFO Exact Match: 79.48%
2025-07-25 09:16:49,442 INFO Average F1 Score: 79.48%
2025-07-25 09:16:49,443 INFO Accuracy: 79.48%


In [4]:
result_pd.to_pickle('result_pd.pkl')