In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys

sys.path.append("..")

import os
import numpy as np
import pandas as pd
import itertools
from multiprocessing import Pool

from src.utils import metrics, load_decompositions

datasets = sorted(["WikiText", "GitHub", "OpenWebText"])
topics = sorted([d + "_topics" for d in datasets])
models = sorted(["BERT", "GPT2", "BLOOM", "Llama2"])
index = pd.MultiIndex.from_product([models, datasets])

proj_dir = "out/"
topics_dir = "out-topics/"

dataset_models_comb = itertools.product(datasets, models)

In [None]:
# RUN ONLY ONCE
def save_metrics(dataset, model):
    dataset, model = dataset.lower(), model.lower()
    out_dir = os.path.join(proj_dir, f"{dataset}-{model}")
    pos, cvec, global_mean = load_decompositions.load_pos_cvec_global_mean(
        dataset, model, proj_dir
    )
    results = metrics.ranks_and_explained_ratios_and_relative_norm(
        pos, cvec, global_mean
    )

    np.save(os.path.join(out_dir, "metrics.npy"), results)
    print(f"Finishing {dataset}+{model}")


with Pool() as p:
    p.starmap(save_metrics, dataset_models_comb)

In [None]:
# RUN ONLY ONCE
topics_models_comb = list(itertools.product(topics, models))


def save_metrics(topic, model):
    topic, model = topic.lower(), model.lower()
    data_dir = os.path.join(topics_dir, f"{topic}-{model}")
    pos = np.load(f"{data_dir}/pos_id1-128_{model}.npy")
    L, T, C = pos.shape
    context_basis = np.memmap(
        f"{data_dir}/cbasis_id1-128_{model}.npy",
        mode="r",
        dtype=np.float32,
        shape=(L, 64 * 4, C),
    )

    B = metrics.avg_similarity_between_NoPCA(context_basis)
    W = metrics.avg_similarity_within_NoPCA(context_basis)

    np.save(f"{data_dir}/between_similarity_NoPCA.npy", B)
    np.save(f"{data_dir}/within_similarity_NoPCA.npy", W)
    print(f"Finishing {topic}+{model}")


with Pool() as p:
    p.starmap(save_metrics, topics_models_comb)

### Table 1 (screeNOT)

In [3]:
d = []
for model in models:
    for dataset in datasets:
        dataset, model = dataset.lower(), model.lower()
        out_dir = os.path.join(proj_dir, f"{dataset}-{model}")
        results = np.load(os.path.join(out_dir, "metrics.npy"))
        c = results[:, 0]
        if len(c) > 13:
            c = c[np.linspace(0, len(c) - 1, num=13, dtype=int)]
        d.append(c)

columns = [f"Layer {i}" for i in range(13)]
pd.DataFrame(np.array(d), index=index, columns=columns).astype(int)

Unnamed: 0,Unnamed: 1,Layer 0,Layer 1,Layer 2,Layer 3,Layer 4,Layer 5,Layer 6,Layer 7,Layer 8,Layer 9,Layer 10,Layer 11,Layer 12
BERT,GitHub,15,16,16,16,14,11,11,9,10,10,11,11,12
BERT,OpenWebText,15,16,18,16,11,11,9,9,11,11,11,11,13
BERT,WikiText,15,16,18,16,12,11,9,9,11,11,11,12,12
BLOOM,GitHub,8,9,9,8,9,10,10,11,10,10,10,10,10
BLOOM,OpenWebText,6,10,10,11,11,10,11,11,11,11,10,10,11
BLOOM,WikiText,6,8,10,10,9,11,11,11,10,11,10,11,11
GPT2,GitHub,15,14,13,12,12,11,11,10,10,10,11,11,10
GPT2,OpenWebText,15,13,14,12,13,11,10,10,10,10,9,9,12
GPT2,WikiText,15,14,14,13,11,11,11,11,11,10,10,10,11
Llama2,GitHub,6,10,9,8,10,8,8,9,9,9,9,8,10


### Table 1 (stable rank)

In [4]:
d = []
for model in models:
    for dataset in datasets:
        dataset, model = dataset.lower(), model.lower()
        out_dir = os.path.join(proj_dir, f"{dataset}-{model}")
        results = np.load(os.path.join(out_dir, "metrics.npy"))
        c = results[:, 1]
        if len(c) > 13:
            c = c[np.linspace(0, len(c) - 1, num=13, dtype=int)]
        d.append(c)

columns = [f"Layer {i}" for i in range(13)]
pd.DataFrame(np.array(d), index=index, columns=columns).round(2)

Unnamed: 0,Unnamed: 1,Layer 0,Layer 1,Layer 2,Layer 3,Layer 4,Layer 5,Layer 6,Layer 7,Layer 8,Layer 9,Layer 10,Layer 11,Layer 12
BERT,GitHub,9.19,7.79,5.26,4.73,4.34,3.84,3.48,3.2,2.7,2.45,2.04,1.84,1.91
BERT,OpenWebText,9.19,7.63,5.25,4.73,4.1,3.53,3.16,2.84,2.46,2.3,2.18,2.22,2.15
BERT,WikiText,9.19,7.77,5.05,4.6,4.01,3.49,3.15,2.83,2.43,2.28,2.14,2.17,2.13
BLOOM,GitHub,8.39,1.25,1.2,1.21,1.21,1.23,1.29,1.29,1.28,1.25,1.21,1.02,1.0
BLOOM,OpenWebText,8.33,1.27,1.3,1.24,1.24,1.27,1.32,1.34,1.33,1.26,1.16,1.01,1.0
BLOOM,WikiText,7.75,1.27,1.28,1.28,1.28,1.32,1.39,1.4,1.39,1.31,1.2,1.01,1.0
GPT2,GitHub,2.05,1.92,1.91,1.89,1.9,1.9,1.92,1.94,1.98,2.03,2.05,1.7,1.11
GPT2,OpenWebText,2.05,1.92,1.91,1.89,1.88,1.88,1.88,1.9,1.91,1.96,2.02,2.24,1.49
GPT2,WikiText,2.05,1.92,1.91,1.89,1.88,1.88,1.88,1.9,1.91,1.97,2.03,2.2,1.57
Llama2,GitHub,24.87,1.0,1.0,1.0,1.0,1.0,1.0,1.01,1.01,1.01,1.02,1.03,1.17


### Table 1 (Relative norm)

In [5]:
indices = pd.MultiIndex.from_product([models, datasets])

d = []
for model in models:
    for dataset in datasets:
        dataset, model = dataset.lower(), model.lower()
        out_dir = os.path.join(proj_dir, f"{dataset}-{model}")
        results = np.load(os.path.join(out_dir, "metrics.npy"))
        c = results[:, 3]
        if len(c) > 13:
            c = c[np.linspace(0, len(c) - 1, num=13, dtype=int)]
        d.append(c)

columns = [f"Layer {i}" for i in range(13)]
pd.DataFrame(np.array(d), index=indices, columns=columns).round(3)

Unnamed: 0,Unnamed: 1,Layer 0,Layer 1,Layer 2,Layer 3,Layer 4,Layer 5,Layer 6,Layer 7,Layer 8,Layer 9,Layer 10,Layer 11,Layer 12
BERT,GitHub,0.445,0.483,0.569,0.616,0.648,0.707,0.764,0.786,0.768,0.686,0.631,0.562,0.473
BERT,OpenWebText,0.465,0.546,0.66,0.759,0.877,0.977,0.973,0.967,0.953,0.901,0.777,0.658,0.596
BERT,WikiText,0.454,0.502,0.626,0.695,0.798,0.916,0.968,0.965,0.949,0.887,0.756,0.682,0.627
BLOOM,GitHub,0.013,0.123,0.232,0.279,0.343,0.385,0.343,0.306,0.301,0.306,0.325,0.219,0.181
BLOOM,OpenWebText,0.012,0.138,0.194,0.264,0.315,0.342,0.297,0.267,0.264,0.3,0.392,0.575,0.589
BLOOM,WikiText,0.013,0.149,0.222,0.287,0.328,0.352,0.336,0.301,0.295,0.325,0.407,0.494,0.491
GPT2,GitHub,0.999,0.994,0.996,0.972,0.812,0.762,0.672,0.6,0.489,0.442,0.386,0.303,0.123
GPT2,OpenWebText,1.0,0.996,0.99,0.989,0.986,0.983,0.981,0.979,0.974,0.841,0.631,0.48,0.075
GPT2,WikiText,1.0,0.995,0.991,0.989,0.987,0.986,0.984,0.984,0.981,0.933,0.68,0.555,0.11
Llama2,GitHub,0.029,0.221,0.221,0.222,0.222,0.222,0.223,0.223,0.224,0.225,0.226,0.227,0.183


### Table 1 (Inter & intra similarity)

In [7]:
row = []
for model in models:
    for topic in topics:
        topic, model = topic.lower(), model.lower()
        data_dir = os.path.join(topics_dir, f"{topic}-{model}")
        B = np.load(f"{data_dir}/between_similarity_NoPCA.npy")
        W = np.load(f"{data_dir}/within_similarity_NoPCA.npy")
        row.append([np.mean(B), np.std(B), np.mean(W), np.std(W)])

pd.DataFrame(
    np.array(row), index=index, columns=["B mean", "B std", "W mean", "W std"]
).round(2)

Unnamed: 0,Unnamed: 1,B mean,B std,W mean,W std
BERT,GitHub,0.17,0.04,0.31,0.05
BERT,OpenWebText,0.13,0.04,0.26,0.04
BERT,WikiText,0.12,0.04,0.28,0.06
BLOOM,GitHub,0.2,0.14,0.4,0.08
BLOOM,OpenWebText,0.15,0.14,0.32,0.09
BLOOM,WikiText,0.14,0.13,0.31,0.09
GPT2,GitHub,0.28,0.06,0.42,0.04
GPT2,OpenWebText,0.1,0.01,0.44,0.04
GPT2,WikiText,0.1,0.01,0.41,0.04
Llama2,GitHub,0.17,0.1,0.4,0.07
