In [1]:
import torch
import os
import json

In [17]:
def load(model_name):
    dir = os.path.join("out", model_name)

    diagonal = json.load(open(os.path.join(dir, "diagonal.json"), "r"))
    subspace = json.load(
        open(os.path.join(dir, "subspace_IH_PTH_K40_largest.json"), "r")
    )

    return diagonal, subspace


def select(model_name, score_cutoff=2, head_count_max=10):
    diagonal, subspace = load(model_name)

    # select by diagonal
    ih, pth = [], []
    # for d in diagonal:
    #     if d["score"] > score_cutoff:
    #         ih.append(tuple(d["IH"]))
    #         pth.append(tuple(d["PTH"]))
    idx = 0
    while (len(ih) < head_count_max or len(pth) < head_count_max) and idx < len(
        diagonal
    ):
        d = diagonal[idx]
        if d["score"] < score_cutoff:
            idx += 1
            continue
        if len(ih) < head_count_max and d["IH"] not in ih:
            ih.append(d["IH"])

        if len(pth) < head_count_max and d["PTH"] not in pth:
            pth.append(d["PTH"])

        idx += 1

    # ih = list(set(ih))[:head_count_max]
    # pth = list(set(pth))[:head_count_max]

    print(model_name)
    print(f"DIAGONAL\nIH: {ih}\nPTH: {pth}\n" + "-" * 50)
    save_dir = f"checkpoints/{model_name}"
    torch.save(ih, f"{save_dir}/IH_subset_diagonal.pt")
    torch.save(pth, f"{save_dir}/PTH_subset_diagonal.pt")

    # select by subspace
    ih, pth = [], []
    idx = 1
    while len(ih) < head_count_max or len(pth) < head_count_max:
        s = subspace[idx]
        if len(ih) < head_count_max and s["LH0"] not in ih:
            ih.append(s["LH0"])

        if len(pth) < head_count_max and s["LH1"] not in pth:
            pth.append(s["LH1"])

        idx += 1

    print(f"SUBSPACE\nIH: {ih}\nPTH: {pth}\n" + "-" * 50)
    save_dir = f"checkpoints/{model_name}"
    torch.save(ih, f"{save_dir}/IH_subset_subspace.pt")
    torch.save(pth, f"{save_dir}/PTH_subset_subspace.pt")

In [18]:
for model_name in [
    "gpt2",
    "gpt2-xl",
    "llama2-7b",
    "gemma-7b",
    "falcon-7b",
    "mistral-7b",
    "olmo-7b",
]:
    select(model_name, score_cutoff=1.5)

gpt2
DIAGONAL
IH: [[5, 1], [6, 9], [5, 5], [7, 2], [7, 10], [5, 8], [5, 0], [8, 1], [7, 11], [9, 6]]
PTH: [[4, 11], [5, 6], [8, 7], [6, 8], [6, 0], [9, 3], [3, 3], [7, 0], [5, 2], [1, 0]]
--------------------------------------------------
SUBSPACE
IH: [[5, 1], [5, 0], [7, 11], [7, 1], [6, 9], [8, 1], [5, 8], [7, 10], [7, 2], [7, 7]]
PTH: [[4, 11], [6, 0], [6, 8], [5, 6], [4, 6], [4, 3], [6, 5], [5, 2], [7, 0], [10, 9]]
--------------------------------------------------
gpt2-xl
DIAGONAL
IH: [[17, 6], [16, 21], [16, 3], [13, 0], [18, 0], [17, 14], [20, 0], [19, 18], [22, 20], [21, 3]]
PTH: [[15, 19], [12, 21], [13, 20], [14, 12], [16, 5], [11, 2], [9, 7], [14, 20], [10, 15], [13, 12]]
--------------------------------------------------
SUBSPACE
IH: [[17, 6], [16, 21], [13, 0], [17, 1], [21, 3], [20, 0], [25, 18], [21, 20], [16, 3], [28, 11]]
PTH: [[15, 19], [13, 20], [16, 5], [12, 21], [14, 12], [14, 20], [17, 12], [6, 1], [9, 7], [11, 2]]
-------------------------------------------------