In [1]:
from tensorboard.backend.event_processing import event_accumulator

def tb_get(path, tag, step_cutoff):
    # read tensorboard scalar data

    ea = event_accumulator.EventAccumulator(path, size_guidance={"scalar": 0})
    ea.Reload()
    scalar_tensor = ea.Scalars(tag)
    target = min(step_cutoff, len(scalar_tensor) - 1)

    return scalar_tensor[target].value

def check_metrics(tag, folder, key, step, retriever_key="data/retrieved",total=100):
    asr_cnt = 0
    retrieved_cnt = 0
    optimize_total = 0
    optimize_idxs = []
    positions = []
    llm_loss = [0.0] * total
    retriever_loss = [0.0] * total
    llm_success = []
    retriever_success = []
    llm_fail = []
    retriever_fail = []

    for i in range(total):
        try:
            current_llm_success = tb_get(f"{folder}/{tag}.cluster_{i}", key, step)
            current_retriever_success = tb_get(f"{folder}/{tag}.cluster_{i}", retriever_key, step)
            current_position = tb_get(f"{folder}/{tag}.cluster_{i}", "data/mean_position", step)
            positions.append(current_position)
            if current_llm_success > 0:
                llm_success.append(i)
                asr_cnt += current_llm_success
            else:
                llm_fail.append(i)
            if current_retriever_success > 0:
                retriever_success.append(i)
                retrieved_cnt += current_retriever_success
            else:
                retriever_fail.append(i)
            try:
                retriever_loss[i] += tb_get(f"{folder}/{tag}.cluster_{i}", "loss/encoder", step)
            except Exception as e:
                pass
            try:
                llm_loss[i] += tb_get(f"{folder}/{tag}.cluster_{i}", "loss/decoder", step)
            except Exception as e:
                pass
            if retriever_loss[i] > 0 or llm_loss[i] > 0:
                optimize_total += 1
                optimize_idxs.append(i)
        except Exception as e:
            # print(e)
            pass
    return {
        "asr_cnt": asr_cnt,
        "retrieved_cnt": retrieved_cnt,
        "optimize_total": optimize_total,
        "optimize_idxs": optimize_idxs,
        "llm_loss": llm_loss,
        "retriever_loss": retriever_loss,
        "llm_success": llm_success,
        "retriever_success": retriever_success,
        "llm_fail": llm_fail,
        "retriever_fail": retriever_fail,
        "positions": positions,
        "mean_position": sum(positions) / len(positions)
    }

def result_to_csv(results, key, tag):
    csv_line = f"{tag}"
    for result in results:
        csv_line += f",{result[key]}"
    return csv_line

def common_idxs_result_to_csv(common_idx, results, key, tag):
    csv_line = f"{tag}-{len(common_idx)}"
    for result in results:
        optimize_set = set(result[key]) & common_idx
        csv_line += f",{len(optimize_set)}"
    return csv_line

In [None]:
# Full Evaluation Example
CHECK_STEPS = [4, 8, 16, 32, 64]
for llm in ["llama", "qwen"]:
    for retriever in ["contriever", "bge"]:
        for dataset in ["msmarco", "nq", "hotpotqa"]:
            all_results_baseline = []
            all_results_k = []
            all_results_v2 = []
            for step in CHECK_STEPS:
                result_baseline = check_metrics("rag_base", f"logs/{llm}-{retriever}/{dataset}/base/tb", "data/poisioned", step // 4)
                result_k = check_metrics("rag_k", f"logs/{llm}-{retriever}/{dataset}/liar/tb",  "data/poisioned", step // 4)
                result_v2 = check_metrics("rag_v2", f"logs/{llm}-{retriever}/{dataset}/v2/tb",  "data/poisioned", step // 4)
                all_results_baseline.append(result_baseline)
                all_results_k.append(result_k)
                all_results_v2.append(result_v2)
            print(f"{llm}-{retriever}-{dataset}")
            optimize_idxs = set(all_results_baseline[0]['optimize_idxs']) & set(all_results_k[0]['optimize_idxs']) & set(all_results_v2[0]['optimize_idxs'])
            print("mean_position," + ",".join(map(str, CHECK_STEPS)))
            print(result_to_csv(all_results_baseline, "mean_position", "baseline"))
            print(result_to_csv(all_results_k, "mean_position", "k"))
            print(result_to_csv(all_results_v2, "mean_position", "v2"))
            print("asr_cnt," + ",".join(map(str, CHECK_STEPS)))
            print(result_to_csv(all_results_baseline, "asr_cnt", "baseline"))
            print(result_to_csv(all_results_k, "asr_cnt", "k"))
            print(result_to_csv(all_results_v2, "asr_cnt", "v2"))
            print("retrieved_cnt,")
            print(result_to_csv(all_results_baseline, "retrieved_cnt", "baseline"))
            print(result_to_csv(all_results_k, "retrieved_cnt", "k"))
            print(result_to_csv(all_results_v2, "retrieved_cnt", "v2"))
            print("optimize_success,")
            print(common_idxs_result_to_csv(optimize_idxs, all_results_baseline, "llm_success", "baseline"))
            print(common_idxs_result_to_csv(optimize_idxs, all_results_k, "llm_success", "k"))
            print(common_idxs_result_to_csv(optimize_idxs, all_results_v2, "llm_success", "v2"))
            print("")


In [None]:
# Extra Step Evaluation Example
EXTRA_CHECK_STEPS = [96, 128]
for llm in ["llama", "qwen"]:
    for retriever in ["contriever", "bge"]:
        for dataset in ["msmarco", "nq", "hotpotqa"]:
            all_results_baseline = []
            all_results_k = []
            all_results_v2 = []
            for step in EXTRA_CHECK_STEPS:
                result_baseline = check_metrics("rag_base", f"logs/{llm}-{retriever}/{dataset}/base/tb", "data/poisioned", step // 4)
                result_k = check_metrics("rag_k", f"logs/{llm}-{retriever}/{dataset}/liar/tb",  "data/poisioned", step // 4)
                result_v2 = check_metrics("rag_v2", f"logs/{llm}-{retriever}/{dataset}/v2/tb",  "data/poisioned", step // 4)
                all_results_baseline.append(result_baseline)
                all_results_k.append(result_k)
                all_results_v2.append(result_v2)
            print(f"{llm}-{retriever}-{dataset}")
            optimize_idxs = set(all_results_baseline[0]['optimize_idxs']) & set(all_results_k[0]['optimize_idxs']) & set(all_results_v2[0]['optimize_idxs'])
            print("mean_position," + ",".join(map(str, CHECK_STEPS)))
            print(result_to_csv(all_results_baseline, "mean_position", "baseline"))
            print(result_to_csv(all_results_k, "mean_position", "k"))
            print(result_to_csv(all_results_v2, "mean_position", "v2"))
            print("asr_cnt," + ",".join(map(str, EXTRA_CHECK_STEPS)))
            print(result_to_csv(all_results_baseline, "asr_cnt", "baseline"))
            print(result_to_csv(all_results_k, "asr_cnt", "k"))
            print(result_to_csv(all_results_v2, "asr_cnt", "v2"))
            print("retrieved_cnt,")
            print(result_to_csv(all_results_baseline, "retrieved_cnt", "baseline"))
            print(result_to_csv(all_results_k, "retrieved_cnt", "k"))
            print(result_to_csv(all_results_v2, "retrieved_cnt", "v2"))
            print("optimize_success,")
            print(common_idxs_result_to_csv(optimize_idxs, all_results_baseline, "llm_success", "baseline"))
            print(common_idxs_result_to_csv(optimize_idxs, all_results_k, "llm_success", "k"))
            print(common_idxs_result_to_csv(optimize_idxs, all_results_v2, "llm_success", "v2"))
            print("")


In [None]:
# Joint Only Evaluation Example
CHECK_STEPS = [4, 8, 16, 32, 64]
dataset = "msmarco"
for llm in ["llama", "qwen"]:
    for retriever in ["contriever", "bge"]:
        all_results_v2 = []
        for step in CHECK_STEPS:
            result_v2 = check_metrics("rag_v2", f"/path/to/tb",  "data/poisioned", step // 4)
            all_results_v2.append(result_v2)
        print(f"{llm}-{retriever}-{dataset}")
        optimize_idxs = set(all_results_v2[0]['optimize_idxs'])
        print("mean_position," + ",".join(map(str, CHECK_STEPS)))
        print(result_to_csv(all_results_v2, "mean_position", "v2"))
        print("asr_cnt," + ",".join(map(str, CHECK_STEPS)))
        print(result_to_csv(all_results_v2, "asr_cnt", "v2"))
        print("retrieved_cnt,")
        print(result_to_csv(all_results_v2, "retrieved_cnt", "v2"))
        print("")