In [None]:
from typing import List, Dict, Tuple, Optional, Union, Any

import pandas as pd

from prompt_template_collection import PromptTemplate
from doraemon import Doraemon
from relaxed_fda import RelaxedFDA
from metrics_collection import MetricsHelper

task_name = "CommonsenseQA"
hp_m = 8 # -> SoTs 4,6,8,10,12
hp_k = 4 # -> cluster 1,3,4,5,7
logger = Doraemon.get_logger(name=task_name, logfile=f"relaxed_fda_on_{task_name}.log")
file_path = '/kaggle/input/llama-3-building-sots-on-commonsenseqa/sots_df.pkl'
df=pd.read_pickle(file_path)
df=df.rename(columns={'question':'query', 'reason':'r_s', 'ground_truth':'g_t','temperature':'t_p'})
if hp_m:
    df = df.groupby('query').head(hp_m).reset_index(drop=True)
df.to_pickle('sots_df.pkl')

grouped_data = RelaxedFDA.prepare_dataset(file_path='/kaggle/working/sots_df.pkl', logger=logger)

D: List[Dict] = PromptTemplate.sot_construct_inter_commonsenseqa()
op_system_prompt = RelaxedFDA.get_optimize_system_prompt(task_name)
encoder=RelaxedFDA.get_encoder()

In [None]:
result_pd: pd.DataFrame = RelaxedFDA.evaluate(
    grouped_data,
    logger,
    op_system_prompt,
    encoder,
    D,
    enable_logger_rs=False,
    ablation='all',
    K=hp_k
)

In [None]:
MetricsHelper.evaluate(result_pd, logger)