In [1]:
import torch
import os
import json
import numpy as np
from scipy.optimize import linear_sum_assignment

In [2]:
def load_diagonal(model_name):
    dir = os.path.join("out", model_name)

    diagonal = json.load(open(os.path.join(dir, "diagonal.json"), "r"))
    # subspace = json.load(
    #     open(os.path.join(dir, "subspace_IH_PTH_K40_largest.json"), "r")
    # )

    return diagonal  # , subspace


def find_most_similar_group(M, max_group_size):
    # Apply the Hungarian algorithm to find the maximum weight matching
    row_ind, col_ind = linear_sum_assignment(
        -M
    )  # Use -M to maximize instead of minimize

    # Extract the matched pairs and their similarities
    matched_pairs = [(i, j, M[i, j]) for i, j in zip(row_ind, col_ind)]

    # Sort the matched pairs by similarity in descending order
    matched_pairs.sort(key=lambda x: x[2], reverse=True)

    # Select the top pairs up to the desired group size
    most_similar_group = matched_pairs[:max_group_size]

    return most_similar_group

In [3]:
def select(model_name, diagonal_cutoff=2, subspace_multipler_cutoff=1.5):
    diagonal = load_diagonal(model_name)

    # select by diagonal
    ih, pth = [], []
    for idx in range(len(diagonal)):
        d = diagonal[idx]
        if d["score"] < diagonal_cutoff:
            continue
        if d["IH"] not in ih:
            ih.append(d["IH"])

        if d["PTH"] not in pth:
            pth.append(d["PTH"])

    print(model_name)
    print(f"DIAGONAL\nIH[{len(ih)}]:{ih}\nPTH[{len(pth)}]: {pth}\n" + "-" * 50)
    save_dir = f"checkpoints/{model_name}"
    torch.save(ih, f"{save_dir}/IH_diagonal.pt")
    torch.save(pth, f"{save_dir}/PTH_diagonal.pt")

    # # select by diagonal
    # ih, pth = [], []
    # baseline = subspace[0]["baseline"]
    # subspace_cutoff = baseline * subspace_multipler_cutoff
    # for idx in range(1, len(subspace)):
    #     d = subspace[idx]
    #     if d["score"] < subspace_cutoff:
    #         continue
    #     if d["LH0"] not in ih:
    #         ih.append(d["LH0"])
    #     if d["LH1"] not in pth:
    #         pth.append(d["LH1"])

    # print(f"SUBSPACE\nIH[{len(ih)}]: {ih}\nPTH[{len(pth)}]: {pth}\n" + "-" * 50)
    # save_dir = f"checkpoints/{model_name}"
    # torch.save(ih, f"{save_dir}/IH_subspace.pt")
    # torch.save(pth, f"{save_dir}/PTH_subspace.pt")

In [5]:
for model_name in [
    "gpt2",
    "gpt2-xl",
    "llama2-7b",
    "gemma-7b",
    "gemma2-9b",
    "falcon-7b",
    "mistral-7b",
    "olmo-7b",
    "llama3-8b",
    "pythia-7b",
]:
    select(model_name, diagonal_cutoff=2.3, subspace_multipler_cutoff=5)

gpt2
DIAGONAL
IH[23]:[[5, 1], [6, 9], [5, 5], [7, 2], [7, 10], [5, 8], [5, 0], [8, 1], [7, 11], [9, 6], [7, 1], [9, 4], [8, 10], [9, 9], [10, 1], [10, 2], [10, 11], [10, 10], [11, 9], [10, 6], [8, 3], [7, 7], [11, 7]]
PTH[13]: [[4, 11], [5, 6], [8, 7], [6, 8], [6, 0], [9, 3], [3, 3], [7, 0], [5, 2], [1, 0], [2, 2], [4, 3], [3, 7]]
--------------------------------------------------
gpt2-xl
DIAGONAL
IH[18]:[[17, 6], [16, 21], [16, 3], [13, 0], [18, 0], [17, 14], [20, 0], [19, 18], [22, 20], [21, 3], [17, 1], [21, 20], [25, 19], [27, 23], [27, 1], [26, 20], [29, 14], [25, 16]]
PTH[16]: [[15, 19], [12, 21], [13, 20], [14, 12], [16, 5], [11, 2], [9, 7], [14, 20], [10, 15], [13, 12], [7, 12], [16, 20], [26, 6], [17, 12], [10, 10], [25, 7]]
--------------------------------------------------
llama2-7b
DIAGONAL
IH[16]:[[6, 9], [6, 30], [7, 4], [8, 26], [7, 12], [7, 13], [6, 11], [8, 31], [6, 16], [11, 15], [7, 10], [7, 28], [12, 2], [11, 2], [16, 19], [12, 26]]
PTH[4]: [[5, 15], [6, 5], [10, 3]

In [21]:
# manually insert
head_subset = {
    "llama2-7b": {
        "IH": [
            [11, 15],
            [7, 4],
            [7, 12],
            [17, 22],
            [8, 31],
            [7, 13],
            [7, 10],
            [13, 23],
        ],
        "PTH": [
            [5, 15],
            [6, 5],
            [10, 25],
            [5, 16],
        ],
    },
    "falcon-7b": {
        "IH": [
            [5, 41],
            [5, 2],
            [5, 18],
            [5, 52],
            [5, 65],
            [5, 10],
            [5, 1],
            [5, 69],
            [5, 13],
            [5, 43],
            [5, 14],
            [5, 59],
            [5, 39],
            [5, 66],
            [5, 49],
            [5, 63],
            [5, 33],
            [5, 7],
            [5, 70],
        ],
        "PTH": [[3, 38], [2, 40], [2, 21]],
    },
    "olmo-7b": {
        "IH": [[27, 14], [26, 17], [15, 15], [24, 7], [30, 13]],
        "PTH": [[26, 25], [24, 3], [14, 18], [25, 6], [28, 13], [26, 30]],
    },
    "mistral-7b": {
        "IH": [
            [18, 0],
            [18, 2],
            [18, 3],
            [12, 4],
            [18, 1],
            [26, 6],
            [20, 29],
            [20, 30],
            [20, 28],
            [22, 4],
            [26, 5],
            [21, 1],
        ],
        "PTH": [[11, 17], [17, 22], [24, 8], [29, 4], [6, 24]],
    },
}

In [23]:
for model_name, heads in head_subset.items():
    ih = heads["IH"]
    pth = heads["PTH"]
    save_dir = f"checkpoints/{model_name}"
    torch.save(ih, f"{save_dir}/IH_subset.pt")
    torch.save(pth, f"{save_dir}/PTH_subset.pt")