In [1]:
minimum = 1e-4

def harmonic_mean(numbers):
    if not numbers:
        raise ValueError("List is empty, cannot compute harmonic mean.")
    
    reciprocal_sum = sum(1/x if x != 0 else 1/minimum for x in numbers )
    harmonic_mean = len(numbers) / reciprocal_sum
    
    return harmonic_mean

In [2]:
from pathlib import Path
import json
from tqdm import tqdm

RESULTS_DIR = Path("results")

record_ckpt = [2**i for i in range(11)]

metrics = {
    alg_name:{
        "step":[[] for _ in range(3)],
        "ES":[[] for _ in range(3)],
        "GS":[[] for _ in range(3)], 
        "LS":[[] for _ in range(3)], 
        "ERS":[[] for _ in range(3)], 
        "ORS":[[] for _ in range(3)], 
        "S":[[] for _ in range(3)],
    } 
    for alg_name in ["KE", "KN", "MEND", "ROME", "MEMIT", "WilKE"]
}

for alg_name in metrics.keys():
    dir_name = alg_name
    # For three runs
    for run_round in range(3):
        # Determine run directory    
        alg_dir = RESULTS_DIR / dir_name
        run_dir = RESULTS_DIR / dir_name / f"run_{str(run_round).zfill(3)}"
        
        print(f"Current proecss folder: {run_dir}")
        files = list(run_dir.glob("edit_*.json"))
        files.sort(key=lambda x: int(str(x).split("_")[-4]))
        
        # Collect all results of the current algorithm
        results = []
        for case_file in tqdm(files):
            try:
                with open(case_file, "r") as f:
                    data = json.load(f)
            except json.JSONDecodeError:
                print(f"Could not decode {case_file} due to format error; skipping.")
            results.append(data)
        
        for ckpt in record_ckpt:
            metrics[alg_name]["step"][run_round].append(ckpt)
            
            ES ,GS, LS = 0, 0, 0
            for i in range(ckpt):
                ES += sum(results[i]["post"]["rewrite_prompts_correct"])/len(results[i]["post"]["rewrite_prompts_correct"])
                GS += sum(results[i]["post"]["paraphrase_prompts_correct"])/len(results[i]["post"]["paraphrase_prompts_correct"])
                LS += sum(results[i]["post"]["neighborhood_prompts_correct"])/len(results[i]["post"]["neighborhood_prompts_correct"])
            
            ERS, ORS = 0, 0
            try:
                with open(run_dir / f"retention_of_edit_{ckpt}.json", "r") as f:
                    retention = json.load(f)
            except json.JSONDecodeError:
                print(f"Could not decode {case_file} due to format error; skipping.")
            
            ERS = retention["edit_retention"]/retention["edit_length"]
            ORS = retention["orig_retention"]/retention["orig_length"]
            
            metrics[alg_name]["ES"][run_round].append(ES/ckpt)
            metrics[alg_name]["GS"][run_round].append(GS/ckpt)
            metrics[alg_name]["LS"][run_round].append(LS/ckpt)
            metrics[alg_name]["ERS"][run_round].append(ERS)
            metrics[alg_name]["ORS"][run_round].append(ORS)
            metrics[alg_name]["S"][run_round].append(harmonic_mean([ES/ckpt, GS/ckpt, LS/ckpt, ERS, ORS]))

Current proecss folder: results(gpt2-xl)/KE/run_000(seed=9)


100%|██████████| 1025/1025 [00:00<00:00, 16888.53it/s]


Current proecss folder: results(gpt2-xl)/KE/run_001(seed=9)


100%|██████████| 1025/1025 [00:00<00:00, 13911.07it/s]


Current proecss folder: results(gpt2-xl)/KE/run_002(seed=9)


100%|██████████| 1025/1025 [00:00<00:00, 19030.33it/s]


Current proecss folder: results(gpt2-xl)/KN/run_000(seed=9)


100%|██████████| 1028/1028 [00:00<00:00, 18966.31it/s]


Current proecss folder: results(gpt2-xl)/KN/run_001(seed=9)


100%|██████████| 1028/1028 [00:00<00:00, 14743.17it/s]


Current proecss folder: results(gpt2-xl)/KN/run_002(seed=9)


100%|██████████| 1028/1028 [00:00<00:00, 18026.67it/s]


Current proecss folder: results(gpt2-xl)/MEND/run_000(seed=9)


100%|██████████| 1025/1025 [00:00<00:00, 18831.19it/s]


Current proecss folder: results(gpt2-xl)/MEND/run_001(seed=9)


100%|██████████| 1025/1025 [00:00<00:00, 18839.94it/s]


Current proecss folder: results(gpt2-xl)/MEND/run_002(seed=9)


100%|██████████| 1025/1025 [00:00<00:00, 13060.73it/s]


Current proecss folder: results(gpt2-xl)/ROME/run_000(seed=9)


100%|██████████| 1025/1025 [00:00<00:00, 18957.08it/s]


Current proecss folder: results(gpt2-xl)/ROME/run_001(seed=9)


100%|██████████| 1025/1025 [00:00<00:00, 18960.00it/s]


Current proecss folder: results(gpt2-xl)/ROME/run_002(seed=9)


100%|██████████| 1025/1025 [00:00<00:00, 14868.74it/s]


Current proecss folder: results(gpt2-xl)/MEMIT/run_000(seed=9)


100%|██████████| 1025/1025 [00:00<00:00, 19242.51it/s]


Current proecss folder: results(gpt2-xl)/MEMIT/run_001(seed=9)


100%|██████████| 1025/1025 [00:00<00:00, 19012.83it/s]


Current proecss folder: results(gpt2-xl)/MEMIT/run_002(seed=9)


100%|██████████| 1025/1025 [00:00<00:00, 18994.18it/s]


Current proecss folder: results(gpt2-xl)/WilKE/run_000(seed=9)


100%|██████████| 1025/1025 [00:00<00:00, 14570.86it/s]


Current proecss folder: results(gpt2-xl)/WilKE/run_001(seed=9)


100%|██████████| 1025/1025 [00:00<00:00, 18487.76it/s]


Current proecss folder: results(gpt2-xl)/WilKE/run_002(seed=9)


100%|██████████| 1025/1025 [00:00<00:00, 18537.26it/s]


In [3]:
import numpy as np

def calculate_average_and_std(metric):
    average_result = np.mean(metric, axis=0)
    std_result = np.std(metric, axis=0)
    return average_result, std_result

In [4]:
for alg_name in metrics.keys():
    print(f"=============> For Method {alg_name} <=============")
    for metric in list(metrics[alg_name].keys())[1:]:
        print(f"Metric {metric}: {calculate_average_and_std(metrics[alg_name][metric])}")

Metric ES: (array([0.        , 0.5       , 0.5       , 0.25      , 0.125     ,
       0.0625    , 0.03125   , 0.015625  , 0.0078125 , 0.00390625,
       0.00195312]), array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]))
Metric GS: (array([0.        , 0.25      , 0.375     , 0.1875    , 0.09375   ,
       0.046875  , 0.0234375 , 0.01171875, 0.00585938, 0.00292969,
       0.00146484]), array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]))
Metric LS: (array([0.4       , 0.2       , 0.1       , 0.0875    , 0.04375   ,
       0.021875  , 0.0109375 , 0.00546875, 0.00273437, 0.00136719,
       0.00068359]), array([5.55111512e-17, 2.77555756e-17, 1.38777878e-17, 1.38777878e-17,
       6.93889390e-18, 3.46944695e-18, 1.73472348e-18, 8.67361738e-19,
       4.33680869e-19, 2.16840434e-19, 1.08420217e-19]))
Metric ERS: (array([1. , 0.5, 0.5, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ]), array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]))
Metric ORS: (array([0.91699219, 0.29980469, 0.14160156, 0.000976