In [None]:
import argparse
import gc
import json
import math
import os
from collections import defaultdict

import torch as t
import numpy as np
from einops import rearrange
from tqdm import tqdm, trange
from nnsight import LanguageModel

from dictionary_learning import AutoEncoder

from activation_utils import SparseAct
from attribution import patching_effect, jvp
from circuit_plotting import plot_circuit, plot_circuit_posaligned
from loading_utils import load_examples, load_examples_nopair
import histogram_aggregator as ha


%load_ext autoreload
%autoreload 2

In [None]:
dp = t.load('./circuits/NeelNanda_pile-10k_dict10_node0.1_edge0.01_n9990_aggnone_threshTrue_methodig_prunefirst-layer-sink_modelEleutherAI_pythia-70m-deduped.hist.pt')
dg = t.load('./circuits/NeelNanda_pile-10k_dictgpt2_node8e-06_edge8e-06_n9990_aggnone_threshFalse_methodig_prunefirst-layer-sink_modelgpt2.hist.pt')

In [None]:
def get_activation_stats(model_str, layer, component, dset, seq_len=64, batch_size=5, simulate=False):

    model = LanguageModel(model_str, device_map='cuda', dispatch=True)
    if model_str == 'gpt2':
        # Load GPT2 model

        # load SAE
        if component == 'resid':
            n_feats = '32k'
            loc = 'post'
            ext = '.pt'
        else:
            n_feats = '128k'
            loc = 'out'
            ext = ''


        repo = f"jbloom/GPT2-Small-OAI-v5-{n_feats}-{component}-{loc}-SAEs"
        sae = AutoEncoder.from_hf(repo, f"v5_{n_feats}_layer_{layer}{ext}/sae_weights.safetensors", device="cuda")
        for i, t_layer in enumerate(model.transformer.h):
            if i == layer:
                if component == 'resid':
                    submodule = t_layer
                else:
                    submodule = getattr(t_layer, component)
                break
    else:
        sae = AutoEncoder.from_pretrained(
                f'dictionaries/pythia-70m-deduped/{component}_out_layer{layer}/10_32768/ae.pt',
                device='cuda'
            )

        for i, t_layer in enumerate(model.gpt_neox.layers):
            if i == layer:
                if component == 'resid':
                    submodule = t_layer
                elif component == 'attn':
                    submodule = t_layer.attention
                else:
                    submodule = t_layer.mlp
                break

    feat_size = sae.encoder.weight.shape[0]

    with model.trace("_"):
        output_submod = submodule.output.save()
    is_tuple = isinstance(output_submod.value, tuple)

    total_valid = 0
    n_bins = 200
    min_power = -10
    max_power = 6
    hist = t.zeros(n_bins).to('cuda')
    nnz_hist = t.zeros(n_bins).to('cuda')  # will range from log(1) to log(feat_size)

    for i in trange(0, 1000, batch_size):
        entries = dset[i:i+batch_size]
        valid_entries = []
        for e in entries:
            encoded = model.tokenizer(e, return_tensors='pt', max_length=seq_len, truncation=True).to('cuda')['input_ids']
            if encoded.shape[1] == seq_len:
                valid_entries.append(encoded)
        if len(valid_entries) == 0:
            continue
        batch = t.cat(valid_entries, dim=0)
        if simulate:
            total_valid += len(valid_entries)
            continue

        with model.trace(batch), t.no_grad():
            x = submodule.output
            if is_tuple:
                x = x[0]
            f = sae.encode(x).save()

        if f.ndim == 2:
            f = f.unsqueeze(0)

        f_late = f[:, seq_len//2:, :]
        nnz = (f_late != 0).sum(dim=2).flatten()   # [N, seq_len//2].flatten()
        abs_f = abs(f_late)
        nnz_hist += t.histc(t.log10(nnz), bins=n_bins, min=np.log10(1), max=np.log10(feat_size))
        hist += t.histc(t.log10(abs_f[abs_f != 0]), bins=n_bins, min=min_power, max=max_power)

    hist = hist.cpu().numpy()

    return hist, nnz_hist.cpu().numpy(), feat_size

In [None]:
save_path = 'distribs-64-nnz-log10.pkl'
import pickle
if os.path.exists(save_path):
    with open(save_path, 'rb') as f:
        results = pickle.load(f)
else:
    results = {}

In [None]:
from datasets import load_dataset

dset = load_dataset("NeelNanda/pile-10k")['train']['text']

for model in ['gpt2', 'EleutherAI/pythia-70m-deduped']:
    for component in ['attn', 'resid', 'mlp']:
        for layer in range(12 if model == 'gpt2' else 6):
            if (model, component, layer) in results:
                continue
            print(model, component, layer)
            results[model, component, layer] = get_activation_stats(model, layer, component, dset, batch_size=16)
            t.cuda.empty_cache()
            gc.collect()
            with open(save_path, 'wb') as f:
                pickle.dump(results, f)

In [None]:
# plot results for GPT2, do a 12 x 3 plot where column 1 is resid, column 2 is attn, column 3 is mlp

import matplotlib.pyplot as plt
import numpy as np

def plot_hist(hist, bins, ax, xlabel, title):
    value_hist_color = 'blue'
    ax.set_xlabel(xlabel, color=value_hist_color)
    ax.set_ylabel('Frequency', color=value_hist_color)
    ax.plot(bins, hist, color=value_hist_color)
    ax.tick_params(axis='x', colors=value_hist_color)
    ax.tick_params(axis='y', colors=value_hist_color)
    # ax.set_xlim(min(min_nnz, min_val), max(max_nnz, max_val))
    # compute median value of activations
    median_idx = (hist.cumsum() >= hist.sum() / 2).nonzero()[0][0]
    median_val = bins[median_idx]
    # compute variance of activations
    total = hist.sum()
    mean = (bins * hist).sum() / total
    std = np.sqrt(((bins - mean)**2 * hist).sum() / total)
    ax.set_title(f'{title} : (log10(total) = {np.log10(total):.2f})')
    # vertical line at mean
    ax.axvline(median_val, color='r', linestyle='--')
    # add text with mean
    ax.text(median_val+0.5, hist.max(), f'{median_val:.2f} +- {std:.2f}', color='r')

def get_hist_settings(hist, n_feats, hist_or_nnz='hist', thresh=None, as_sparsity=False):
    if hist_or_nnz == 'hist':
        min_val = -10
        max_val = 6
        xlabel = 'log10(Activation magnitude)'
        bins = np.linspace(min_val, max_val, 200)
        if thresh is not None:
            if not as_sparsity:
                thresh_loc = np.searchsorted(bins, np.log(thresh))
            else:
                percentile_hist = np.cumsum(hist) / hist.sum()
                thresh_loc = np.searchsorted(percentile_hist, 1-thresh)
            hist = hist.copy()
            hist[:thresh_loc-1] = 0
    else:
        if thresh is not None:
            raise ValueError("thresh can only be computed for hist")
        min_val = 0
        xlabel = 'NNZ'
        max_val = np.log10(n_feats)
        bins = 10 ** (np.linspace(min_val, max_val, 200))
        max_index = np.nonzero(hist)[0].max()
        max_val = bins[max_index]
        bins = bins[:max_index+1]
        hist = hist[:max_index+1]
    return hist, bins, xlabel

def get_hist_and_nfeats_activations(layer, component, model_str, hist_or_nnz, results):
    if hist_or_nnz == 'hist':
        hist = results[model_str, component, layer][0]
    else:
        hist = results[model_str, component, layer][1]
    feat_size = results[model_str, component, layer][2]
    return hist, feat_size

def get_hist_for_node_effect(layer, component, model_str, hist_or_nnz, results):
    hist_type = 'acts' if hist_or_nnz == 'hist' else 'nnz'
    key = f'node_{hist_type}'
    if model_str == 'gpt2':
        mod_name = f'.transformer.h.{layer}'
    else:
        mod_name = f'.gpt_neox.layers.{layer}'

    match component:
        case 'resid':
            mod_name += ''
        case 'attn':
            mod_name += '.attention'
        case 'mlp':
            mod_name += '.mlp'
    hist = results[key][mod_name]
    if model_str == 'gpt2':
        feat_size = 32768 if component == 'resid' else 131072

    feat_size = results[model_str, component, layer][2]
    return hist, feat_size


def plot_model_hists(model_str, n_layers, hist_getter, hist_or_nnz='hist', thresh=None, as_sparsity=False):
    fig, axs = plt.subplots(n_layers, 3, figsize=(15, 3.6*n_layers))

    for layer in range(n_layers):
        for i, component in enumerate(['resid', 'attn', 'mlp']):
            hist, feat_size = hist_getter(layer, component, model_str, hist_or_nnz)
            hist, bins, xlabel = get_hist_settings(hist, feat_size, hist_or_nnz, thresh, as_sparsity)
            plot_hist(hist, bins, axs[layer, i], xlabel, f'{model_str} {component} layer {layer}')

    plt.tight_layout()
    plt.show()


def plot_model_edge_hists(model_str, result_dict, hist_or_nnz, thresh=None, as_sparsity=False):
    hist_type = 'acts' if hist_or_nnz == 'hist' else 'nnz'
    key = f'edge_{hist_type}'
    results = result_dict[key]
    n_edges = len(results)
    n_cols = 3
    n_rows = math.ceil(n_edges / n_cols)
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 3.6*n_rows))

    for edge_name, hist in results.items():
        hist, bins, xlabel = get_hist_settings(hist, feat_size, hist_or_nnz, thresh, as_sparsity)
        plot_hist(hist, bins, axs[layer, component], xlabel, f'{model_str} {component} layer {layer}')



In [None]:
        # self.node_nnz = data['node_nnz']
        # self.node_acts = data['node_acts']
        # self.edge_nnz = data['edge_nnz']
        # self.edge_acts = data['edge_acts']
        # self.nnz_max = data['nnz_max']
        # self.act_min, self.act_max = data['act_min_max']

In [None]:
def update_dict(d):
    d_new = {}
    for k, v in d.items():
        for path, hist in v.items():
            if k not in d_new:
                d_new[k] = {}
            if isinstance(path, tuple):
                up, down = ha.normalize_path(path[0]), ha.normalize_path(path[1])
                if up not in d_new[k]:
                    d_new[k][up] = {}
                d_new[k][up][down] = hist
            else:
                d_new[k][ha.normalize_path(path)] = hist
    return d_new


def get_n_feats(model_str, component):
    if model_str == 'gpt2':
        return 32768 if component == 'resid' else 131072
    return 32768

def build_nnz_max(d_new, model_str):
    d_new['nnz_max'] = {}
    for node in d_new['node_acts']:
        comp = node.split('_')[0]
        n_feats = get_n_feats(model_str, comp)
        d_new['nnz_max'][node] = np.log10(n_feats)
    d_new['act_min_max'] = (-10, 10)
    d_new['model_str'] = model_str

dg_new = update_dict(dg)
dp_new = update_dict(dp)

build_nnz_max(dg_new, 'gpt2')
build_nnz_max(dp_new, 'EleutherAI/pythia-70m-deduped')

In [None]:
gpt_hist = ha.HistAggregator()
gpt_hist.load(dg_new)
pythia_hist = ha.HistAggregator()
pythia_hist.load(dp_new)

In [None]:
pythia_hist.node_acts['resid_0'][720:780]

In [None]:
pythia_hist.plot(6, 'EleutherAI/pythia-70m-deduped', 'nodes', 'acts')

In [None]:
10**(np.linspace(-10, 10, 1500)[750])

In [None]:
min(gpt_hist.node_acts, key=lambda x: gpt_hist.node_acts[x].nonzero()[0].min())

In [None]:
np.linspace(-10, 10, 1500)[750]

In [None]:
gpt_hist.plot(12, 'nodes', 'acts')

In [None]:
gpt_hist.plot(12, 'edges', 'acts')

In [None]:
pythia_hist.plot(6, 'edges', 'nnz')

In [None]:
new_pythia_hist = ha.HistAggregator()
new_pythia_hist.load('./circuits/NeelNanda_pile-10k_dict10_node0.1_edge0.01_n9990_aggnone_threshTrue_methodig_prunefirst-layer-sink_modelEleutherAI_pythia-70m-deduped.hist.pt')

In [None]:
gpt_hist = ha.HistAggregator()
gpt_hist.load('./circuits/NeelNanda_pile-10k_dictgpt2_node8e-06_edge8e-06_n9990_aggnone_threshFalse_methodig_prunefirst-layer-sink_modelgpt2.hist.pt')

In [None]:
gpt_hist.plot(12, 'nodes', 'acts')

In [None]:
new_pythia_hist.plot(6, 'nodes', 'nnz')

In [None]:
plot_model_hists('gpt2', 12, results, 'hist', thresh=8e-6, as_sparsity=True)

In [None]:
plot_model_hists('gpt2', 12, results, 'nnz')

In [None]:
plot_model_hists('EleutherAI/pythia-70m-deduped', 6, results, 'hist')

In [None]:
plot_model_hists('EleutherAI/pythia-70m-deduped', 6, results, 'nnz')

In [None]:
from huggingface_hub import hf_hub_download
from safetensors import safe_open
import matplotlib.pyplot as plt

for i in range(12):
    repo = "jbloom/GPT2-Small-OAI-v5-128k-attn-out-SAEs"
    filename = f"v5_128k_layer_{i}/sparsity.safetensors"

    path = hf_hub_download(repo, filename)

    tensor_dict = dict()

    with safe_open(path, 'pt') as f:
        for k in f.keys():
            tensor_dict[k] = f.get_tensor(k)

    # plot sparsity values
    plt.title(f"Sparsity histogram for layer {i}")
    plt.hist(tensor_dict['sparsity'].flatten().cpu().numpy(), bins=100)
    plt.show()