In [None]:
from typing import List, Dict, Tuple, Optional, Union, Any
import pandas as pd
import asyncio

from prompt_template_collection import PromptTemplate
from doraemon import Doraemon
from relaxed_fda import RelaxedFDA
from metrics_collection import MetricsHelper

task_name = "FEVER"
hp_m = None # -> SoTs 4,6,8,10,12
hp_k = 4 # -> cluster 1,3,4,5,7
conc = True
logger = Doraemon.get_logger(name=__name__, logfile="relaxed_FDA_on_fever.log")
df=pd.read_pickle('/kaggle/input/building-sot-on-fever/sots_df.pkl')
df=df.rename(columns={'question':'query', 'reason':'r_s', 'evidence':'context', 'ground_truth':'g_t','temperature':'t_p'})
if hp_m:
    df = df.groupby('query').head(hp_m).reset_index(drop=True)
df.to_pickle('sots_df.pkl')

grouped_data = RelaxedFDA.prepare_dataset(file_path='/kaggle/working/sots_df.pkl', mode='r_s', logger=logger)

encoder=RelaxedFDA.get_encoder()
 
D: List[Dict] = PromptTemplate.sot_construct_inter_fever(number=5)
op_system_prompt = RelaxedFDA.get_optimize_system_prompt(task_name)

In [None]:
if conc:
    tasks=grouped_data.to_dict(orient='records')

    batch_size = 20
    
    async def process_batches(tasks, *args, **kwargs):
        results = []
        for i in range(0, len(tasks), batch_size):
            batch = tasks[i:i + batch_size]
            result_pd = await RelaxedFDA.async_evaluate(
                batch,
                *args,
                **kwargs
            )
            results.append(result_pd)
            await asyncio.sleep(5)
        # Concatenate all DataFrames into one
        return pd.concat(results, ignore_index=True)
    

    result_pd: pd.DataFrame = await process_batches(
        tasks,
        logger,
        op_system_prompt,
        encoder,
        D,
        enable_logger_rs=False,
        ablation='all',
        K=hp_k,
        T=3,
        max_tokens=300,
        max_concurrent=2
    )

else:    
    result_pd: pd.DataFrame = RelaxedFDA.evaluate(
        grouped_data,
        logger,
        op_system_prompt,
        encoder,
        D,
        enable_logger_rs=False,
        ablation='wo_wt',
        K=hp_k,
        T=2,
        max_tokens=300,
    )

In [None]:
MetricsHelper.evaluate(result_pd, logger)