In [None]:
import argparse
import gc
import json
import math
import os
from collections import defaultdict

import torch as t
import numpy as np
from einops import rearrange
from tqdm import tqdm, trange

from activation_utils import SparseAct
from attribution import patching_effect, jvp
from circuit_plotting import plot_circuit, plot_circuit_posaligned
from dictionary_learning import AutoEncoder
from loading_utils import load_examples, load_examples_nopair
from nnsight import LanguageModel

%load_ext autoreload
%autoreload 2

In [None]:
def get_activation_stats(model_str, layer, component, dset, seq_len=64, batch_size=5, simulate=False):

    model = LanguageModel(model_str, device_map='cuda', dispatch=True)
    if model_str == 'gpt2':
        # Load GPT2 model

        # load SAE
        if component == 'resid':
            n_feats = '32k'
            loc = 'post'
            ext = '.pt'
        else:
            n_feats = '128k'
            loc = 'out'
            ext = ''


        repo = f"jbloom/GPT2-Small-OAI-v5-{n_feats}-{component}-{loc}-SAEs"
        sae = AutoEncoder.from_hf(repo, f"v5_{n_feats}_layer_{layer}{ext}/sae_weights.safetensors", device="cuda")
        for i, t_layer in enumerate(model.transformer.h):
            if i == layer:
                if component == 'resid':
                    submodule = t_layer
                else:
                    submodule = getattr(t_layer, component)
                break
    else:
        sae = AutoEncoder.from_pretrained(
                f'dictionaries/pythia-70m-deduped/{component}_out_layer{layer}/10_32768/ae.pt',
                device='cuda'
            )

        for i, t_layer in enumerate(model.gpt_neox.layers):
            if i == layer:
                if component == 'resid':
                    submodule = t_layer
                elif component == 'attn':
                    submodule = t_layer.attention
                else:
                    submodule = t_layer.mlp
                break

    feat_size = sae.encoder.weight.shape[0]

    with model.trace("_"):
        output_submod = submodule.output.save()
    is_tuple = isinstance(output_submod.value, tuple)

    total_valid = 0
    n_bins = 200
    min_power = -10
    max_power = 6
    hist = t.zeros(n_bins).to('cuda')
    nnz_hist = t.zeros(n_bins).to('cuda')  # will range from log(1) to log(feat_size)

    for i in trange(0, 1000, batch_size):
        entries = dset[i:i+batch_size]
        valid_entries = []
        for e in entries:
            encoded = model.tokenizer(e, return_tensors='pt', max_length=seq_len, truncation=True).to('cuda')['input_ids']
            if encoded.shape[1] == seq_len:
                valid_entries.append(encoded)
        if len(valid_entries) == 0:
            continue
        batch = t.cat(valid_entries, dim=0)
        if simulate:
            total_valid += len(valid_entries)
            continue

        with model.trace(batch), t.no_grad():
            x = submodule.output
            if is_tuple:
                x = x[0]
            f = sae.encode(x).save()

        if f.ndim == 2:
            f = f.unsqueeze(0)

        f_late = f[:, seq_len//2:, :]
        nnz = (f_late != 0).sum(dim=2).flatten()   # [N, seq_len//2].flatten()
        abs_f = abs(f_late)
        nnz_hist += t.histc(t.log10(nnz), bins=n_bins, min=np.log10(1), max=np.log10(feat_size))
        hist += t.histc(t.log10(abs_f[abs_f != 0]), bins=n_bins, min=min_power, max=max_power)

    hist = hist.cpu().numpy()

    return hist, nnz_hist.cpu().numpy(), feat_size

In [None]:
save_path = 'distribs-64-nnz-log10.pkl'
import pickle
if os.path.exists(save_path):
    with open(save_path, 'rb') as f:
        results = pickle.load(f)
else:
    results = {}

In [None]:
from datasets import load_dataset

dset = load_dataset("NeelNanda/pile-10k")['train']['text']

for model in ['gpt2', 'EleutherAI/pythia-70m-deduped']:
    for component in ['attn', 'resid', 'mlp']:
        for layer in range(12 if model == 'gpt2' else 6):
            if (model, component, layer) in results:
                continue
            print(model, component, layer)
            results[model, component, layer] = get_activation_stats(model, layer, component, dset, batch_size=16)
            t.cuda.empty_cache()
            gc.collect()
            with open(save_path, 'wb') as f:
                pickle.dump(results, f)

In [None]:
# plot results for GPT2, do a 12 x 3 plot where column 1 is resid, column 2 is attn, column 3 is mlp

import matplotlib.pyplot as plt
import numpy as np

def plot_model_hists(model_str, n_layers, results, hist_or_nnz='hist', thresh=None, as_sparsity=False):
    fig, axs = plt.subplots(n_layers, 3, figsize=(15, 3.6*n_layers))

    for layer in range(n_layers):
        for i, component in enumerate(['resid', 'attn', 'mlp']):
            if hist_or_nnz == 'hist':
                min_val = -10
                max_val = 6
                xlabel = 'log10(Activation magnitude)'
                hist = results[model_str, component, layer][0]
                bins = np.linspace(min_val, max_val, 200)
                if thresh is not None:
                    if not as_sparsity:
                        thresh_loc = np.searchsorted(bins, np.log(thresh))
                    else:
                        percentile_hist = np.cumsum(hist) / hist.sum()
                        thresh_loc = np.searchsorted(percentile_hist, 1-thresh)
                    hist = hist.copy()
                    hist[:thresh_loc-1] = 0
            else:
                if thresh is not None:
                    raise ValueError("thresh can only be computed for hist")
                min_val = 0
                xlabel = 'NNZ'
                hist = results[model_str, component, layer][1]
                max_val = np.log10(results[model_str, component, layer][2])
                bins = 10 ** (np.linspace(min_val, max_val, 200))
                max_index = np.nonzero(hist)[0].max()
                max_val = bins[max_index]
                bins = bins[:max_index+1]
                hist = hist[:max_index+1]

            value_hist_color = 'blue'
            ax = axs[layer, i]
            ax.set_xlabel(xlabel, color=value_hist_color)
            ax.set_ylabel('Frequency', color=value_hist_color)
            ax.plot(bins, hist, color=value_hist_color)
            ax.tick_params(axis='x', colors=value_hist_color)
            ax.tick_params(axis='y', colors=value_hist_color)
            # ax.set_xlim(min(min_nnz, min_val), max(max_nnz, max_val))
            # compute median value of activations
            median_idx = (hist.cumsum() >= hist.sum() / 2).nonzero()[0][0]
            median_val = bins[median_idx]
            # compute variance of activations
            total = hist.sum()
            mean = (bins * hist).sum() / total
            std = np.sqrt(((bins - mean)**2 * hist).sum() / total)
            ax.set_title(f'{component} layer {layer} (log(total) = {np.log(total):.2f})')
            # vertical line at mean
            ax.axvline(median_val, color='r', linestyle='--')
            # add text with mean
            ax.text(median_val+0.5, hist.max(), f'{median_val:.2f} +- {std:.2f}', color='r')

    plt.tight_layout()
    plt.show()

In [None]:
plot_model_hists('gpt2', 12, results, 'hist', thresh=8e-6, as_sparsity=True)

In [None]:
plot_model_hists('gpt2', 12, results, 'nnz')

In [None]:
plot_model_hists('EleutherAI/pythia-70m-deduped', 6, results, 'hist')

In [None]:
plot_model_hists('EleutherAI/pythia-70m-deduped', 6, results, 'nnz')

In [None]:
(2.71828 ** 15) / (64 * 958)

In [None]:
import matplotlib.pyplot as plt
import networkx as nx

# Create a directed graph
G = nx.DiGraph()

# Define the edges based on the visual structure from the provided image
nodes = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z"]

edges = {   # key corresponds to height of node
    10: [("A", "B"), ("A", "C")],
    9: [("B", "D"), ("B", "E")],
    8: [("C", "F"), ("I", "H")],
    4: [("G", "J")],
    7: [("F", "G"), ("H", "G")],
    6: [("T", "U"), ("T", "R"), ("S", "R"), ("J", "K"), ("J", "L")],
    5: [("U", "V"), ("U", "W"), ("K", "M"), ("K", "N"), ("L", "O"), ("L", "P")]
}
all_edges = []
for v in edges.values():
    all_edges.extend(v)
G.add_edges_from(all_edges)
G.add_node("Q")

positions = {
    "A": (0, 10),
    "B": (-1, 9),
    "C": (1, 9),
    "D": (-2, 8),
    "E": (-1, 8),
    "F": (0, 8),
    "G": (1, 7),
    "H": (2, 8),
    "I": (3, 9),
    "J": (1, 6),
    "K": (0, 5),
    "L": (3, 5),
    "M": (-1, 4),
    "N": (0, 4),
    "O": (2, 4),
    "P": (4, 4),
    "Q": (-4, 7),
    "R": (-2, 6),
    "S": (-1, 7),
    "T": (-2, 7),
    "U": (-3, 6),
    "V": (-5, 5),
    "W": (-4, 5),
}


nx.draw(G, pos=positions, with_labels=True, node_size=200, node_color='skyblue', font_size=10, font_weight='bold', font_color='black', edge_color='black', width=2, alpha=0.5)
plt.show()


In [None]:
from copy import copy


# %%
# Filter graph functionality.
def filter_graph(
    graph, leaf_nodes: list | None = None
):
    """Filter a directed graph down to a source-to-sink subgraph."""

    # `leaf_nodes` starts off with all nodes without outgoing edges.
    if leaf_nodes is None:
        leaf_nodes = []

        for node in graph.nodes:
            if not list(graph.successors(node)):
                leaf_nodes.append(node)
    print(leaf_nodes)
    # Prunes out all the non-final-layer leaf nodes and append upstream
    # relevant nodes.
    for node in copy(leaf_nodes):
        if node != f"A":
            upstream_nodes = list(graph.predecessors(node))
            leaf_nodes.extend(upstream_nodes)

            graph.remove_node(node)
        leaf_nodes.remove(node)

    # Recurses if necessary.
    if leaf_nodes:
        leaf_nodes: list = list(set(leaf_nodes))
        graph = filter_graph(graph, leaf_nodes)

    return graph
prune_G = filter_graph(G)
nx.draw(prune_G, pos=positions, with_labels=True, node_size=200, node_color='skyblue', font_size=10, font_weight='bold', font_color='black', edge_color='black', width=2, alpha=0.5)

In [None]:
from huggingface_hub import hf_hub_download

for layer in range(12):
    tens = hf_hub_download('jbloom/GPT2-Small-OAI-v5-32k-resid-post-SAEs',
                    f'v5_32k_layer_{layer}.pt/sae_weights.safetensors')

In [None]:
from safetensors import safe_open
tensors = {}
with safe_open(tens, 'pt') as f:
    for k in f.keys():
        tensors[k] = f.get_tensor(k)

In [None]:
tensors['W_enc'].shape

In [None]:
"""Read from and write to the module-level interface."""


import csv
import gc
import os
from pathlib import Path
from textwrap import dedent

import yaml
import numpy as np
import torch as t
from transformers import PreTrainedModel


def parse_slice(slice_string: str) -> slice:
    """Parse any valid slice string into its slice object."""

    start = stop = step = None
    slice_parts: list = slice_string.split(":")

    if not 0 <= len(slice_parts) <= 3:
        raise ValueError(
            dedent(
                f"""
                Slice string {slice_string} is not well-formed.
                """
            )
        )

    # Remember that Python evaluates empty strings as falsy.
    if slice_parts[0]:
        start = int(slice_parts[0])
    if len(slice_parts) > 1 and slice_parts[1]:
        stop = int(slice_parts[1])
    if len(slice_parts) == 3 and slice_parts[2]:
        step = int(slice_parts[2])

    layers_slice = slice(start, stop, step)

    if layers_slice.start is not None and layers_slice.stop is not None:
        assert start < stop, dedent(
            f"""
            Slice start ({layers_slice.start}) must be less than stop
            ({layers_slice.stop})
            """
        )

    return layers_slice


def validate_slice(model: PreTrainedModel, layers_slice: slice) -> None:
    """
    See whether the layers slice fits in the model's layers.

    Note that this is unnecessary when the slice is preprocessed with
    `slice_to_seq`; only use this when you need to validate the _slice_ object,
    not the corresponding range.
    """

    if layers_slice.stop is None:
        return

    # num_hidden_layers is not inclusive.
    last_layer: int = model.config.num_hidden_layers - 1

    # slice.stop is not inclusive.
    if last_layer < layers_slice.stop - 1:
        raise ValueError(
            dedent(
                f"""
                The layers slice {layers_slice} is out of bounds for the
                model's layer count.
                """
            )
        )

    return


def sanitize_model_name(model_name: str) -> str:
    """Sanitize model names for saving and loading."""

    return model_name.replace("/", "_")


def cache_layer_tensor(
    layer_tensor: t.Tensor,
    layer_idx: int,
    save_append: str,
    base_file: str,
    model_name: str,
) -> None:
    """
    Cache per layer tensors in appropriate subdirectories.

    Base file is `__file__` in the calling module. Save append should be _just_
    the file name and extension, not any additional path. Model name will be
    sanitized, so HF hub names are kosher.
    """

    assert isinstance(
        layer_idx, int
    ), f"Layer index {layer_idx} is not an int."
    # Python bools are an int subclass.
    assert not isinstance(
        layer_idx, bool
    ), f"Layer index {layer_idx} is a bool, not an int."

    save_dir_path: str = save_paths(base_file, "")
    safe_model_name = sanitize_model_name(model_name)

    # Subdirectory structure in the save directory is
    # data/models/layers/tensor.pt.
    save_subdir_path: str = save_dir_path + f"/{safe_model_name}/{layer_idx}"

    os.makedirs(save_subdir_path, exist_ok=True)
    t.save(layer_tensor, save_subdir_path + f"/{save_append}")


def slice_to_range(model: PreTrainedModel, input_slice: slice) -> range:
    """Build a range corresponding to an input slice."""

    if input_slice.start is None:
        start = 0
    elif input_slice.start < 0:
        start: int = model.config.num_hidden_layers + input_slice.start
    else:
        start: int = input_slice.start

    if input_slice.stop is None:
        stop = model.config.num_hidden_layers
    elif input_slice.stop < 0:
        stop: int = model.config.num_hidden_layers + input_slice.stop
    else:
        stop: int = input_slice.stop

    step: int = 1 if input_slice.step is None else input_slice.step

    # Truncate final ranges to the model's size.
    output_range = range(
        max(start, 0),
        min(stop, model.config.num_hidden_layers),
        step,
    )

    return output_range


def load_input_token_ids(prompt_ids_path: str) -> list[list[int]]:
    """
    Load input ids.

    These are constant across layers, making this a simpler job.
    """
    prompts_ids: np.ndarray = np.load(prompt_ids_path, allow_pickle=True)
    prompts_ids_list = prompts_ids.tolist()
    unpacked_ids: list[list[int]] = [
        elem for question_list in prompts_ids_list for elem in question_list
    ]

    return unpacked_ids


def load_yaml_constants(base_file):
    """Load config files."""

    current_dir = Path(base_file).parent
    hf_access_file: str = "config/hf_access.yaml"
    central_config_file: str = "config/central_config.yaml"

    if current_dir.name == "sparse_coding":
        hf_access_path = current_dir / hf_access_file
        central_config_path = current_dir / central_config_file

    elif current_dir.name in ("interp_tools", "rasp"):
        hf_access_path = current_dir.parent / hf_access_file
        central_config_path = current_dir.parent / central_config_file

    else:
        raise ValueError(
            dedent(
                f"""
                Trying to access config files from an unfamiliar working
                directory: {current_dir}
                """
            )
        )

    try:
        with open(hf_access_path, "r", encoding="utf-8") as f:
            access = yaml.safe_load(f)
    except FileNotFoundError:
        print("hf_access.yaml not found. Creating it now.")
        with open(hf_access_path, "w", encoding="utf-8") as w:
            w.write('HF_ACCESS_TOKEN: ""\n')
        access = {}
    except yaml.YAMLError as e:
        print(e)

    with open(central_config_path, "r", encoding="utf-8") as f:
        try:
            config = yaml.safe_load(f)
        except yaml.YAMLError as e:
            print(e)

    return access, config


def save_paths(base_file, save_append: str) -> str:
    """Route to save paths from the current working directory."""

    assert isinstance(
        save_append, str
    ), f"`save_append` must be a string: {save_append}."

    current_dir = Path(base_file).parent

    save_path = current_dir / "data" / save_append
    return str(save_path)



def load_layer_tensors(
    model_dir: str,
    layer_idx: int,
    encoder_file: str,
    biases_file: str,
    base_file: str,
) -> t.Tensor:
    """
    Return the autoencoder, bias tensors for a model layer.

    `base_file should be __file__ in the calling module.
    """

    encoder = t.load(
        save_paths(
            base_file,
            (
                sanitize_model_name(model_dir)
                + "/"
                + str(layer_idx)
                + "/"
                + encoder_file
            ),
        )
    )

    bias = t.load(
        save_paths(
            base_file,
            (
                sanitize_model_name(model_dir)
                + "/"
                + str(layer_idx)
                + "/"
                + biases_file
            ),
        )
    )

    return encoder, bias


def load_layer_feature_indices(
    model_dir: str,
    layer_idx: int,
    top_k_info_file: str,
    base_file: str,
) -> list[int]:
    """
    Return the meaningful feature indices for a model layer.

    `base_file` should be `__file__` in the calling module.
    """

    indices = []

    with open(
        save_paths(
            base_file,
            (
                sanitize_model_name(model_dir)
                + "/"
                + str(layer_idx)
                + "/"
                + top_k_info_file
            ),
        ),
        mode="r",
        encoding="utf-8",
    ) as file:
        reader = csv.reader(file)
        # Skip the header.
        next(reader)

        for row in reader:
            indices.append(int(row[0]))

    return indices


def load_layer_feature_labels(
    model_dir: str,
    layer_idx: int,
    feature_idx: int,
    top_k_info_file: str,
    base_file: str,
) -> tuple[list[str], list[list[float]]]:
    """
    Return the top-k input token labels for an encoder layer feature.

    `base_file` should be `__file__` in the calling module.
    """

    with open(
        save_paths(
            base_file,
            (
                sanitize_model_name(model_dir)
                + "/"
                + str(layer_idx)
                + "/"
                + top_k_info_file
            ),
        ),
        mode="r",
        encoding="utf-8",
    ) as file:
        reader = csv.reader(file)
        # Skip the header.
        next(reader)

        for row in reader:
            if int(row[0]) == feature_idx:
                context_ints = []
                context_str = row[1]
                context_sublists = context_str.split("], ")

                for list_str in context_sublists:
                    list_int = []
                    list_str = list_str.replace("[", "").replace("]", "")
                    list_str = list_str.split(", ")
                    for integer in list_str:
                        integer = int(integer.strip("'"))
                        list_int.append(integer)
                    context_ints.append(list_int)

                act_floats = []
                acts_str = row[-1]
                acts_sublists = acts_str.split("], ")

                for list_str in acts_sublists:
                    list_flt = []
                    list_str = list_str.replace("[", "").replace("]", "")
                    list_str = list_str.split(", ")
                    for flt in list_str:
                        flt = float(flt)
                        list_flt.append(flt)
                    act_floats.append(list_flt)

                return (context_ints, act_floats)

        raise ValueError(
            dedent(
                f"""
                Feature index {feature_idx} not found in layer {layer_idx}
                autoencoder.
                """
            )
        )


def pad_activations(
    tensor: t.Tensor, max_length: int, accelerator
) -> t.Tensor:
    """Pad activation tensors to a given sequence length."""

    complement_length: int = max_length - tensor.size(1)
    padding: t.Tensor = t.zeros(
        tensor.size(0), complement_length, tensor.size(2)
    ).to(tensor.device)
    padding = accelerator.prepare(padding)
    try:
        return t.cat([tensor, padding], dim=1)
    except RuntimeError:
        gc.collect()
        return t.cat([tensor, padding], dim=1)


In [None]:
import torch as t
from huggingface_hub import hf_hub_download
from safetensors import safe_open


def load_sublayer_autoencoder(
    autoencoder_repo: str,
    encoder_file: str,
    enc_biases_file: str,
    decoder_file: str,
    dec_biases_file: str,
    model_dir: str,
    acts_layers_range: range,
    base_file: str,
):
    """
    Load sublayer autoencoders from HuggingFace.

    The HF Hub interface for these is rather different, so they are getting
    their own import function.
    """

    filename: str = "sae_weights.safetensors"

    for idx in acts_layers_range:
        subfolder: str = f"v5_128k_layer_{idx}"
        safe_model_name = sanitize_model_name(model_dir)

        file_url = hf_hub_download(
            repo_id=autoencoder_repo,
            filename=filename,
            subfolder=subfolder,
        )

        tensors_dict: dict = {}
        with safe_open(file_url, "pt") as f:
            for k in f.keys():
                tensors_dict[k] = f.get_tensor(k)

        encoder = tensors_dict["W_enc"]
        enc_biases = tensors_dict["b_enc"]
        decoder = tensors_dict["W_dec"]
        dec_biases = tensors_dict["b_dec"]

        t.save(
            encoder,
            save_paths(base_file, f"{safe_model_name}/{idx}/{encoder_file}"),
        )
        t.save(
            enc_biases,
            save_paths(
                base_file, f"{safe_model_name}/{idx}/{enc_biases_file}"
            ),
        )
        t.save(
            decoder,
            save_paths(base_file, f"{safe_model_name}/{idx}/{decoder_file}"),
        )
        t.save(
            dec_biases,
            save_paths(
                base_file, f"{safe_model_name}/{idx}/{dec_biases_file}"
            ),
        )

In [None]:
ATTN_REPO: str = "jbloom/GPT2-Small-OAI-v5-128k-attn-out-SAEs"
ATTN_ENCODER_FILE = "attn_encoder.pt"
ATTN_ENC_BIASES_FILE = "attn_enc_biases.pt"
ATTN_DECODER_FILE = "attn_decoder.pt"
ATTN_DEC_BIASES_FILE = "attn_dec_biases.pt"
MODEL_DIR = "openai-community/gpt2"
ACTS_LAYERS_RANGE = range(0, 12)

load_sublayer_autoencoder(
    ATTN_REPO,
    ATTN_ENCODER_FILE,
    ATTN_ENC_BIASES_FILE,
    ATTN_DECODER_FILE,
    ATTN_DEC_BIASES_FILE,
    MODEL_DIR,
    ACTS_LAYERS_RANGE,
    '/home/jack/feature-circuits/dictionaries/dummy.py',
)

In [None]:
import sae_lens
t = sae_lens.SAE.from_pretrained("jbloom/GPT2-Small-OAI-v5-128k-attn-out-SAEs")

In [None]:
import torch
torch.random.manual_seed(1)
x = torch.randn(2, 20, 300)
numel_per_batch = x.shape[1] * x.shape[2]
numel_per_batch_seq = x.shape[2]
ind = x.flatten().topk(5).indices
ind2 = torch.cat([ind // numel_per_batch, (ind % numel_per_batch) // numel_per_batch_seq, ind % numel_per_batch_seq], dim=0).reshape(3, -1).T
ind = torch.stack(torch.unravel_index(ind, x.shape), dim=1)
ind, ind2

In [None]:
import torch
vjv_MR = torch.load('vjv_save_MR.pt')
vjv_AR = torch.load('vjv_save_AR.pt')
vjv_AM = torch.load('vjv_save_AM.pt')

In [None]:
# plot histogram of agg, with log scale
import matplotlib.pyplot as plt
def plot_hist(ax, agg, title):
    x = torch.cat(agg, dim=0).flatten()
    nz_only = x[x > 0]
    ax.hist(nz_only.cpu().numpy(), bins=1000)
    ax.set_xscale("log")
    ax.set_yscale("log")
    ax.set_title(f"{title} (max={nz_only.max()})")

fig, axs = plt.subplots(1, 3, figsize=(15, 5))
plot_hist(axs[0], vjv_MR, "MR")
plot_hist(axs[1], vjv_AR, "AR")
plot_hist(axs[2], vjv_AM, "AM")

In [None]:
eleu_model = LanguageModel('EleutherAI/pythia-70m-deduped', device_map='cuda', dispatch=True)

In [None]:
eleu_model

In [None]:
model.lm_head

In [None]:
model.transformer.h

In [None]:
import loading_utils
import torch
from transformer_lens import HookedTransformer
from sae_lens import SAE
from functools import partial

saes = []
for i in range(12):
    saes.append(SAE.from_pretrained(
        release = "gpt2-small-res-jb", # see other options in sae_lens/pretrained_saes.yaml
        sae_id = f"blocks.{i}.hook_resid_pre", # won't always be a hook point
        device = 'cuda'
    ))  # returns SAE, config, sparsity

In [None]:
attn_saes = []
for i in range(12):
    saes.append(SAE.from_pretrained(
        release = "gpt2-small-hook-z-kk",
        sae_id = f"blocks.{i}.hook_z", # won't always be a hook point
        device = 'cuda'
    ))  # returns SAE, config, sparsity

In [None]:
from huggingface_hub import hf_hub_download
from safetensors import safe_open
import matplotlib.pyplot as plt

for i in range(12):
    repo = "jbloom/GPT2-Small-OAI-v5-128k-attn-out-SAEs"
    filename = f"v5_128k_layer_{i}/sparsity.safetensors"

    path = hf_hub_download(repo, filename)

    tensor_dict = dict()

    with safe_open(path, 'pt') as f:
        for k in f.keys():
            tensor_dict[k] = f.get_tensor(k)

    # plot sparsity values
    plt.title(f"Sparsity histogram for layer {i}")
    plt.hist(tensor_dict['sparsity'].flatten().cpu().numpy(), bins=100)
    plt.show()

In [None]:
mlp_saes = []
for i in range(12):
    saes.append(SAE.from_pretrained(
        release = "gpt2-small-mlp-tm",
        sae_id = f"blocks.{i}.hook_mlp_out", # won't always be a hook point
        device = 'cuda'
    ))  # returns SAE, config, sparsity

In [None]:
saes[0][0].W_enc.data.shape