In [1]:
import json
import os
import numpy as np
import re
import matplotlib.pyplot as plt

out_dir = "out"

In [2]:
def result_to_acc(result):
    prob_err = {}
    for exp_id, exp_res in result.items():
        prob_err[exp_id] = {
            "prob": np.mean(exp_res["prob"]),
            "err": np.mean(exp_res["err"])
        }
    return prob_err

def fetch_shuffle_acc(model_name, method="diagonal"):
    path_to_dir = os.path.join(out_dir, model_name)
    files = os.listdir(path_to_dir)
    
    qks = []
    ovs = []
    for file in files:
        if file.endswith(f"{method}.json"):
            if file.startswith("shuffle_result_QK"):
                qks.append(file)
            if file.startswith("shuffle_result_OV"):
                ovs.append(file)
    pattern = re.compile(r'_(\d+)_')
    
    qk = sorted(qks, key=lambda x: int(pattern.search(x).group(1)), reverse=True)[0]
    ov = sorted(ovs, key=lambda x: int(pattern.search(x).group(1)), reverse=True)[0]

    qk_acc = result_to_acc(json.load(open(os.path.join(path_to_dir, qk), "r")))
    ov_acc = result_to_acc(json.load(open(os.path.join(path_to_dir, ov), "r")))
    
    return {"qk": qk_acc, "ov": ov_acc}


def fetch_project_acc(model_name, method="diagonal"):
    path_to_dir = os.path.join(out_dir, model_name)
    files = os.listdir(path_to_dir)
    
    ovs_true = []
    ovs_false = []
    qks_true = []
    qks_false = []
    
    for file in files:
        if file.endswith(f"{method}.json"):
            if file.startswith("proj_QK_proj_True"):
                qks_true.append(file)
            if file.startswith("proj_QK_proj_False"):
                qks_false.append(file)
            if file.startswith("proj_OV_proj_True"):
                ovs_true.append(file)
            if file.startswith("proj_OV_proj_False"):
                ovs_false.append(file)

    pattern = re.compile(r'_(\d+)_')
    
    qk_true = sorted(qks_true, key=lambda x: int(pattern.search(x).group(1)), reverse=True)[0]
    qk_false = sorted(qks_false, key=lambda x: int(pattern.search(x).group(1)), reverse=True)[0]
    ov_true = sorted(ovs_true, key=lambda x: int(pattern.search(x).group(1)), reverse=True)[0]
    ov_false = sorted(ovs_false, key=lambda x: int(pattern.search(x).group(1)), reverse=True)[0]

    return  {
        "qk_true": result_to_acc(json.load(open(os.path.join(path_to_dir, qk_true), "r"))),
        "qk_false": result_to_acc(json.load(open(os.path.join(path_to_dir, qk_false), "r"))),
        "ov_true": result_to_acc(json.load(open(os.path.join(path_to_dir, ov_true), "r"))),
        "ov_false": result_to_acc(json.load(open(os.path.join(path_to_dir, ov_false), "r"))),
    }


In [3]:
model_name = "gpt2-xl"
shuffle = fetch_shuffle_acc(model_name)
project = fetch_project_acc(model_name)

shuffle

{'qk': {'original': {'prob': 0.9068408282501417, 'err': 0.025333333333333333},
  'random baseline 1': {'prob': 0.1843292725704617,
   'err': 0.37777777777777777},
  'random baseline 2': {'prob': 0.3045043077655994, 'err': 0.2702222222222222},
  'random baseline 3': {'prob': 0.27058592548950267,
   'err': 0.2653333333333333},
  'shuffle 1': {'prob': 0.9004752803680995, 'err': 0.025333333333333333},
  'shuffle 2': {'prob': 0.8798525613874286, 'err': 0.028444444444444446},
  'shuffle 3': {'prob': 0.8707555131726133, 'err': 0.029777777777777778}},
 'ov': {'original': {'prob': 0.9056114201986378, 'err': 0.018666666666666668},
  'random baseline 1': {'prob': 0.6447968763294653, 'err': 0.1},
  'random baseline 2': {'prob': 0.6826967181023396,
   'err': 0.07866666666666666},
  'random baseline 3': {'prob': 0.6430495136492055,
   'err': 0.10577777777777778},
  'shuffle 1': {'prob': 0.9032856373908493, 'err': 0.01911111111111111},
  'shuffle 2': {'prob': 0.9062313407344259, 'err': 0.017333333333

In [64]:
project

{'qk_true': {'0': {'prob': 0.9068408282501417, 'err': 0.025333333333333333},
  '10': {'prob': 0.6923194868956376, 'err': 0.08},
  '20': {'prob': 0.5486441263000108, 'err': 0.1817777777777778},
  '30': {'prob': 0.33133813925013783, 'err': 0.36666666666666664},
  '40': {'prob': 0.17304037297671904, 'err': 0.576},
  '50': {'prob': 0.06761172218498018, 'err': 0.7644444444444445},
  '60': {'prob': 0.023969261593632597, 'err': 0.8951111111111111},
  '70': {'prob': 0.011312580076886933, 'err': 0.9475555555555556},
  '80': {'prob': 0.0092375478582913, 'err': 0.9564444444444444},
  '90': {'prob': 0.008500026670178393, 'err': 0.9697777777777777},
  '100': {'prob': 0.008057461318671633, 'err': 0.9666666666666667},
  '110': {'prob': 0.008138159123751413, 'err': 0.9635555555555556},
  '120': {'prob': 0.006689199761588347, 'err': 0.9702222222222222},
  '130': {'prob': 0.006096469190708211, 'err': 0.9737777777777777},
  '140': {'prob': 0.0053475987185499195, 'err': 0.9746666666666667},
  '150': {'pro