In [1]:
import os

# set custom HF_HOME location

# Set a custom cache directory (relative to current working directory)
os.environ["HF_HOME"] = "./.hf_cache"
os.environ["HF_HUB_CACHE"] = "./.hf_cache/hub"

from huggingface_hub import login

# login to Huggingface and make sure you have be granted access to the Mistral-7B-v0.1 model.
login(
    token=os.getenv("HUGGINGFACE_TOKEN")
)

In [2]:
import numpy as np
import pandas as pd
import pickle

from tqdm import tqdm
import torch
from sae_lens import SAE
from transformers import AutoTokenizer

from itertools import islice

# from huggingface_hub import hf_hub_download, notebook_login

In [3]:
def get_gpt2_sae(layer):
    return SAE.from_pretrained(
        release="gpt2-small-res-jb",  # see other options in sae_lens/pretrained_saes.yaml
        sae_id=f"blocks.{layer}.hook_resid_pre",  # won't always be a hook point
        device=DEVICE
    )[0]

def get_mistral_sae(layer):
    from sae_lens import SAE

    return SAE.from_pretrained(
        release="mistral-7b-res-wg",  # see other options in sae_lens/pretrained_saes.yaml
        sae_id=f"blocks.{layer}.hook_resid_pre",  # won't always be a hook point
        device=DEVICE
    )[0]

def get_cluster_activations_gpt2(sparse_sae_activations, sae_neurons_in_cluster, decoder_vecs):
    current_token = None
    all_activations = []
    all_token_indices = []
    updated = False
    for sae_value, sae_index, token_index in tqdm(zip(
        sparse_sae_activations["sparse_sae_values"],
        sparse_sae_activations["sparse_sae_indices"],
        sparse_sae_activations["all_token_indices"],
    ), total = len(sparse_sae_activations["sparse_sae_values"]), disable=True):
        if current_token == None:
            current_token = token_index
            current_activations = np.zeros(768)
        if token_index != current_token:
            if updated:
                all_activations.append(current_activations)
                all_token_indices.append(token_index)
            updated = False
            current_token = token_index
            current_activations = np.zeros(768)
        if sae_index in sae_neurons_in_cluster:
            updated = True
            current_activations += sae_value * decoder_vecs[sae_index]

    return np.stack(all_activations), all_token_indices


def get_cluster_activations_mistral(
    sparse_sae_activations,
    sae_neurons_in_cluster,
    decoder_vecs,
    sample_limit,
    max_indices=1e8,
):
    max_indices = int(max_indices)
    current_token = None
    all_activations = []
    all_token_indices = []
    updated = False
    for sae_value, sae_index, token_index in tqdm(
        islice(
            zip(
                sparse_sae_activations["sparse_sae_values"],
                sparse_sae_activations["sparse_sae_indices"],
                sparse_sae_activations["all_token_indices"],
            ),
            0,
            max_indices,
        ),
        total=max_indices,
        disable=False,
    ):
        if current_token == None:
            current_token = token_index
            current_activations = np.zeros(4096)
        if token_index != current_token:
            if updated:
                all_activations.append(current_activations)
                all_token_indices.append(token_index - 1)  # FIXED OFF-BY-ONE ERROR
                if len(all_activations) >= sample_limit:
                    break
            updated = False
            current_token = token_index
            current_activations = np.zeros(4096)
        if sae_index in sae_neurons_in_cluster:
            updated = True
            current_activations += sae_value * decoder_vecs[sae_index]

    return np.stack(all_activations), all_token_indices

In [4]:
DEVICE = 'cuda' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

Using device: mps


In [5]:
activations_file = "data/sae_activations_big_layer-7.npz"
layer = 7
sample_limit = 20_000

In [6]:
ae = get_gpt2_sae(layer=layer)
decoder_vecs = ae.W_dec.data.cpu().numpy()
tokenizer = AutoTokenizer.from_pretrained("gpt2")
sparse_activations = np.load(activations_file)

This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


In [7]:
cluster_years = [1052, 2753, 4427, 6382, 8314, 9576, 9606, 13551, 19734, 20349]
reconstructions_years, token_indices_years = get_cluster_activations_gpt2(sparse_activations, set(cluster_years), decoder_vecs)
reconstructions_years, token_indices_years = reconstructions_years[:sample_limit], token_indices_years[:sample_limit]
token_strs_years = tokenizer.batch_decode(sparse_activations['all_tokens'])

In [8]:
contexts_years = []
for token_index_years in token_indices_years:
    contexts_years.append(token_strs_years[max(0, token_index_years-10):token_index_years]) # thought it should be :token_index+1, but seems like there's an off-by-one error in Josh's script, so compensating here.

In [9]:
# subselect tokens corresponding to years

years = []
mask_years = []
for context in contexts_years:
    token = context[-1]
    if token.strip().isdigit():
        if 1900 <= int(token) <= 1999:
            mask_years.append(True)
            years.append(int(token.strip()))
        else:
            mask_years.append(False)
    else:
        mask_years.append(False)
mask_years = np.array(mask_years)
years = np.array(years)

In [10]:
X_years = reconstructions_years[mask_years, :]

In [11]:
# X is in a 10-dimensional subspace of the 768-dimensional activation space so we reduce it to 10
# dimensions using QR decomposition for space-efficiently.

D_years = decoder_vecs[np.array(cluster_years), :]
Q, _ = np.linalg.qr(D_years.T)
X_years = X_years @ Q

In [12]:
# convert to pandas dataframe
X_years = pd.DataFrame(X_years)

In [13]:
X_years.to_csv(
    "representations/years_reprs.csv",
    index=False,
    header=False,
)

In [25]:
# save labels

pd.DataFrame({
    "label": years
}).to_csv(
    "representations/years_labels.csv",
    index=False,
    header=False
)

Mistral

In [15]:
activations_file = "data/sae_activations_big_layer-8.npz"

In [16]:
sae = get_mistral_sae(layer=8)
decoder_vecs = sae.W_dec.detach().cpu().numpy()
sparse_activations = np.load(activations_file)

tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
tokenizer.pad_token = tokenizer.eos_token
token_strs = tokenizer.convert_ids_to_tokens(sparse_activations["all_tokens"])

In [17]:
cluster_days = [398, 35234, 31484, 54166, 52464, 23936, 20629]
cluster_months = [7411, 33259, 49189, 46031, 9117, 57916, 26027, 16820, 41121, 23434, 39714, 59285,
                  47182, 22809, 17555, 52568, 8934, 16406, 63163, 15477, 54144]

In [18]:
reconstructions_days, token_indices_days = get_cluster_activations_mistral(
    sparse_activations, set(cluster_days), decoder_vecs, sample_limit=4_000, max_indices=1e9
)
reconstructions_months, token_indices_months = get_cluster_activations_mistral(
    sparse_activations, set(cluster_months), decoder_vecs, sample_limit=4_000, max_indices=1e9
)

 33%|███▎      | 325405029/1000000000 [01:28<03:04, 3660752.00it/s]
  5%|▌         | 53230130/1000000000 [00:11<03:19, 4753569.91it/s]


In [19]:
days_of_week = {
    "monday": 0,
    "mondays": 0,
    "tuesday": 1,
    "tuesdays": 1,
    "wednesday": 2,
    "wednesdays": 2,
    "thursday": 3,
    "thursdays": 3,
    "friday": 4,
    "fridays": 4,
    "saturday": 5,
    "saturdays": 5,
    "sunday": 6,
    "sundays": 6,
}

months_of_year = {
    "january": 0,
    "february": 1,
    "march": 2,
    "april": 3,
    "may": 4,
    "june": 5,
    "july": 6,
    "august": 7,
    "september": 8,
    "october": 9,
    "november": 10,
    "december": 11,
}

days = []
mask_days = []
for i, token_i in enumerate(token_indices_days):
    token = token_strs[token_i].replace("▁", "").replace("▁", "").lower().strip()
    if token in days_of_week:
        mask_days.append(True)
        days.append(token)
    else:
        mask_days.append(False)
days = days
mask_days = np.array(mask_days)

months = []
mask_months = []
for i, token_i in enumerate(token_indices_months):
    token = token_strs[token_i].replace("▁", "").replace("▁", "").lower().strip()
    if token in months_of_year:
        mask_months.append(True)
        months.append(token)
    else:
        mask_months.append(False)
months = np.array(months)
mask_months = np.array(mask_months)


In [20]:
X_days = reconstructions_days[mask_days, :]

D_days = decoder_vecs[np.array(cluster_days), :]
Q_days, _ = np.linalg.qr(D_days.T)
X_days = X_days @ Q_days

In [27]:
X_days = pd.DataFrame(X_days)
X_days.to_csv(
    "representations/days_reprs.csv",
    index=False,
    header=False,
)
# save labels
pd.DataFrame({
    "label": days
}).to_csv(
    "representations/days_labels.csv",
    index=False,
    header=False
)

In [22]:
X_months = reconstructions_months[mask_months, :]

# remove outlier
outlier_id = 696
X_months = np.delete(X_months, outlier_id, axis=0)
months = np.delete(months, outlier_id)

D_months = decoder_vecs[np.array(cluster_months), :]
Q_months, _ = np.linalg.qr(D_months.T)
X_months = X_months @ Q_months

In [26]:
X_months = pd.DataFrame(X_months)
X_months.to_csv(
    "representations/months_reprs.csv",
    index=False,
    header=False,
)
# save labels
pd.DataFrame({
    "label": months
}).to_csv(
    "representations/months_labels.csv",
    index=False,
    header=False
)