In [7]:
import os
import yaml
from itertools import product
import math

In [8]:
def map_vars_to_config(batch_size=0, k_input=0, temperature = 1, n_rescores = 1):

    max_pipeline_iterations = math.ceil(k_input / batch_size) * n_rescores

    var_config_mapping = {
        'data': {
            'dl19_bm25': {'data': {
                'dataloader_class': 'PyseriniLoader',
                'index': 'msmarco-v1-passage',
                'run_path': 'data/dl19-passage/run.msmarco-v1-passage.bm25-default.dl19_sorted.txt',
                'topics': 'dl19-passage'
            }},
            'trec-covid_bm25': {'data': {
                'dataloader_class': 'PyseriniLoader',
                'index': 'beir-v1.0.0-trec-covid.flat',
                'run_path': 'data/beir-v1.0.0-trec-covid-test/run.bm25.trec-covid_sorted.txt',
                'topics': 'beir-v1.0.0-trec-covid-test'
            }},
            'dl19_splade': {'data': {
                'dataloader_class': 'PyseriniLoader',
                'index': 'msmarco-v1-passage',
                'run_path': 'data/dl19-passage/run.msmarco-v1-passage.splade-pp-ed-pytorch.dl19.txt',
                'topics': 'dl19-passage'
            }},
            'trec-covid_splade': {'data': {
                'dataloader_class': 'PyseriniLoader',
                'index': 'beir-v1.0.0-trec-covid.flat',
                'run_path': 'data/beir-v1.0.0-trec-covid-test/run.beir.splade-pp-ed.trec-covid.txt',
                'topics': 'beir-v1.0.0-trec-covid-test'
            }},
            'pluk' : {'data': {
                'dataloader_class': 'PLLoader',
                'run_path': 'data/pl_test_set_2/run.txt',
                'queries_path': 'data/pl_test_set_2/queries.tsv',
                'passages_path': 'data/pl_test_set_2/passages.json'
            }},
            'toy_city_weather'  : {'data': {
                'dataloader_class': 'TestLoader',
                'run_path': 'data/toy_city_weather/run.txt',
                'queries_path': 'data/toy_city_weather/queries.json',
                'passages_path': 'data/toy_city_weather/passages.json'
                }}
        },
        'strategy': {
            'pw': {
                'agent': {
                    'policy_steps': [
                        {'component': 'AgentLogic', 'method': 'get_next_batch'},
                        {'component': 'Prompter', 'method': 'pw_rerank'}
                    ],
                    'postprocessing': {'component': 'AgentLogic', 'method': 'pw_postprocess'}
                },
                'templates': {'pw_rerank': 'pw_rerank.jinja2'}
            },
            'lw_simple': {
                'agent': {
                    'policy_steps': [
                        {'component': 'AgentLogic', 'method': 'get_next_batch'},
                        {'component': 'Prompter', 'method': 'lw_rerank'}
                    ],
                    'postprocessing': {'component': 'AgentLogic', 'method': 'lw_simple_postprocess'}
                },
                'templates': {'lw_rerank': 'lw_rerank.jinja2'},
                'rerank': {'lw_padding': True}
            },
            'lw_agg_kemeny_young': {
                'agent': {
                    'policy_steps': [
                        {'component': 'AgentLogic', 'method': 'get_next_batch'},
                        {'component': 'Prompter', 'method': 'lw_rerank'}
                    ],
                    'postprocessing': {'component': 'AgentLogic', 'method': 'lw_rank_agg_postprocess'}
                },
                'templates': {'lw_rerank': 'lw_rerank.jinja2'},
                'rerank': {'lw_padding': True,
                          'lw_rank_agg_func': 'kemeny_young'},
            }
        },
        'label_instructions': {
            'RAG_eval_zero_ten_instr': {'templates': {'label_macro_name': 'RAG_eval_zero_ten_instr'}, 'n_labels': 11},
            '0_3': {'templates': {'label_macro_name': 'zero_three_instr'}, 'n_labels': 4},
            'umb': {'templates': {'label_macro_name': 'umbrella_instr'}, 'n_labels': 4}
        },
        'llm': {
            '4o': {'llm': {
                'model_class': 'LLMFactoryLLM',
                'model_name': 'gpt-4o',
                'num_retries': 3
            }},
            'sonnet': {'llm': {
                'model_class': 'ClaudeLLM',
                'model_name': 'anthropic.claude-3-sonnet-20240229-v1:0',
                'num_retries': 3
            }},
            'nova-pro': {'llm': {
                'model_class': 'NovaLLM',
                'model_name': 'amazon.nova-pro-v1:0',
                'num_retries': 3
            }},
            'flash-1.5': {'llm': {
                'model_class': 'GeminiLLM',
                'model_name': 'gemini-1.5-flash',
                'num_retries': 20
            }}
        },
        'temperature' : {
            temperature: {
                            'llm': {'temperature': temperature},
            }
        },
        'batch_size': {
            batch_size: {
                        'rerank': {'batch_size': batch_size},
                        'agent': {'max_pipeline_iterations': max_pipeline_iterations}
                         },
        },
        'batching' : {
            'seq': {'agent': {'preprocessing': {'component': 'AgentLogic', 'method': 'preprocess_batch_sequential'}},
                           'rerank': {'n_rescores': n_rescores}},
            'intra': {'agent': {'preprocessing': {'component': 'AgentLogic', 'method': 'preprocess_batch_sequential_in_batch_shuffle'}},
                                            'rerank': {'n_rescores': n_rescores}},
            'inter': {'agent': {'preprocessing': {'component': 'AgentLogic', 'method': 'preprocess_batch_inter_batch_shuffle'}},
                                    'rerank': {'n_rescores': n_rescores}}                                
        },
        'score_agg' : {
            'amean' : {'agent': {'score_agg_func': 'amean'}}
        },
        'k_input': {
            k_input: {'data': {'k_input': k_input}}
        }
    }
    return var_config_mapping

In [9]:
def deep_update(original, update):
    for key, value in update.items():
        if isinstance(value, dict):
            original[key] = deep_update(original.get(key, {}), value)
        else:
            original[key] = value
    return original

def save_config(config, directory, experiment_name, pw=False):
    full_path = os.path.join(directory, experiment_name)
    os.makedirs(full_path, exist_ok=True)
    with open(os.path.join(full_path, "config.yaml"), 'w') as f:
        yaml.dump(config, f)

    measures = []
    for i in range(1, 176):
        measures.append(f"ndcg_cut_{i}")
    for i in range(1, 176):
        measures.append(f"map_cut_{i}")
    for i in range(1, 176):
        measures.append(f"P_{i}")
    for i in range(1, 176):
        measures.append(f"recall_{i}")

    eval_params = {
        'measures': measures,
        'qrels_path' : os.path.join(os.path.dirname(config['data']['run_path']),'qrels.txt'),
        'logging': {'level': 'DEBUG'},
        'min_pw_rel_score': 3,
        'run_pw_anal': pw
    }

    with open(os.path.join(full_path, "eval_config.yaml"), 'w') as f:
        yaml.dump(eval_params, f)

In [27]:
EXP_DIR = '../trials/main_exp/lw_shuf_1_score/nova_90_cand'
if not os.path.exists(EXP_DIR):
    os.makedirs(EXP_DIR)

BASE_CONFIG_PATH = '../configs/base_config.yaml'

In [28]:
#set the parmeters
#options/explanations of params:
'''
param_grid = {
    'data': ['dl19_splade','dl19_bm25','trec-covid_bm25','trec-covid_splade','pluk'], #which dataset
    'strategy': ['pw','lw_simple', 'lw_agg_kemeny_young'], #lw_simple is a single lw rerank operation, 'lw_agg_kemeny_young' agregates multiple ranked lists using a linear program
    'llm': ['sonnet',flash-1.5,'4o','nova-pro'], #which llm
    'temperature': [1],
    'k_input': [30], #number of passages in initial ranked list
    'batch_size': [30], #number of passages in each batch
    'batching': ['seq', 'intra','inter'], #see readme
    'n_rescores': [1], #how times each passage is scored
    'score_agg': ['mean'], #how scores are aggregated, e.g. amean: arithmetic mean
    #'label_instructions':  ['0_3', 'umb', 'RAG_eval_zero_ten_instr'] # for pw only, which prompt is used? 0_3 is similar to the PL prompt, RAG eval is similar to likert paper, umbrella is jimmy lins bing reproduction
}
'''

# Define the parameters
param_grid = {
    'data': ['pluk','dl19_bm25', 'trec-covid_bm25'],
    'llm': ['nova-pro'], #'sonnet','nova-pro'],
    'temperature': [1], #['flash-1.5'], #['4o'], #, 
    'batching': ['intra'], #'seq','intra','inter'
    'k_input': [90],
    'strategy': ['lw_simple'],
    'n_rescores': [1],
    'batch_size': [90]#,
    #'label_instructions':  ['umb'],
    #'score_agg': ['amean']
}

#must be careful when defining n_rescores as it directly effects the number of pipeline iterations

if 'flash-1.5' in param_grid['llm'] and 'pluk' in param_grid['data']:
    param_grid = None
    print('ERROR: Gemini models are currently not compatible with the PL test set since public API keys are currently used which is not secure')

# Run cell below to make experiment batch

In [29]:
# Load the base config file
with open(BASE_CONFIG_PATH, 'r') as f:
    base_config = yaml.safe_load(f)

#for file nameing, don't use these params:
PARAM_NAMES_TO_OMMIT = {'n_rescores', 'label_instructions', 'label_instructions', 'temperature', 'batching', 'strategy'} #

# Generate and save config files for each combination
for param_values in product(*param_grid.values()):
    param_values_dict = dict(zip(param_grid.keys(), param_values))
    experiment_name = '_'.join([f"{value}" for param, value in param_values_dict.items() if param not in PARAM_NAMES_TO_OMMIT])
    updated_config = yaml.safe_load(yaml.dump(base_config))  # deep copy

    # Apply updates to the config based on the current parameter values
    var_config_mapping = map_vars_to_config(
        batch_size=param_values_dict.get('batch_size', 0),
        k_input=param_values_dict.get('k_input', 0),
        temperature = param_values_dict.get('temperature', 1),
        n_rescores = param_values_dict.get('n_rescores', 1)
    )
    for param, value in param_values_dict.items():
        if param in var_config_mapping and value in var_config_mapping[param]:
            deep_update(updated_config, var_config_mapping[param][value])
        else:
            print(f"No mapping found for parameter '{param}' with value '{value}'")

    #check if pw strategy for positional confusion analysis (not supported for lw)
    if param_values_dict.get('strategy') == 'pw':
        pw = True
    else:
        pw = False

    save_config(updated_config, EXP_DIR, experiment_name, pw = pw)


No mapping found for parameter 'n_rescores' with value '1'
No mapping found for parameter 'n_rescores' with value '1'
No mapping found for parameter 'n_rescores' with value '1'
