In [1]:
from sae_vis import parse_feature_data, SaeVisConfig, SaeVisLayoutConfig, Column, ActsHistogramConfig, LogitsTableConfig, LogitsHistogramConfig, SequencesConfig
from datasets import load_dataset
from modelling_sae import normalize_dict
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from modelling_sae import BasicSAE, RICA
import tqdm.notebook as tqdm

In [2]:
dataset = load_dataset("Elriggs/openwebtext-100k", split="train[-10000:]")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-410m-deduped")
tokenizer.pad_token_id = 0

MAX_LEN = 64
NUM_FEATS = 64

dataset = dataset.map(
    lambda x: tokenizer(x["text"], padding="max_length", truncation=True, max_length=MAX_LEN),
    batched=True,
)
dataset.set_format(type="torch", columns=["input_ids"])

data = torch.tensor(dataset["input_ids"]).cuda()

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
  data = torch.tensor(dataset["input_ids"]).cuda()


In [3]:
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-410m").cuda()

sae = BasicSAE(1024, 8192, 0).cuda()
sae_state_dict = torch.load("models/epoch_9/basic_sae_1.05e-04.pt")
sae.load_state_dict(sae_state_dict)

model.eval()

saved_feat_acts = []
saved_model_acts = []
saved_final_hidden_states = []

is_nonnegative = True

with torch.no_grad():
    for i in tqdm.tqdm(range(0, data.shape[0], 128)):
        batch = data[i:i+128]

        hidden_states = model(batch, output_hidden_states=True).hidden_states
        model_acts = hidden_states[8]
        final_hidden_states = hidden_states[-1]
        all_feat_acts = sae(model_acts.reshape(-1, 1024))[2].reshape(-1, MAX_LEN, 8192)[:, :, :NUM_FEATS]
        
        if not is_nonnegative:
            all_feat_acts = torch.cat([all_feat_acts, -all_feat_acts], dim=-1)
            all_feat_acts = torch.clamp(all_feat_acts, min=0)
        
        saved_feat_acts.append(all_feat_acts)
        saved_model_acts.append(model_acts)
        saved_final_hidden_states.append(final_hidden_states)

model_acts = torch.cat(saved_model_acts, dim=0)
all_feat_acts = torch.cat(saved_feat_acts, dim=0)
final_hidden_states = torch.cat(saved_final_hidden_states, dim=0)

feature_idxs = list(range(NUM_FEATS))

if not is_nonnegative:
    feature_idxs = list(range(NUM_FEATS * 2))

if is_nonnegative:
    feature_resid_dir = normalize_dict(sae.unembed)[:NUM_FEATS]
elif not is_nonnegative:
    feature_resid_dir = normalize_dict(sae.embed)[:NUM_FEATS]
    feature_resid_dir = torch.cat([feature_resid_dir, -feature_resid_dir], dim=0)

output_embed = model.embed_out.weight.data.T

feat_tables_cfg = SaeVisLayoutConfig(
    columns = [
        Column(ActsHistogramConfig(), SequencesConfig(stack_mode='stack-none')),
    ],
    height=750
)
cfg = SaeVisConfig(
    feature_centric_layout=feat_tables_cfg,
)

  0%|          | 0/79 [00:00<?, ?it/s]

In [4]:
sae_vis_data = parse_feature_data(
    tokens=data,
    feature_indices=feature_idxs,
    all_feat_acts=all_feat_acts,
    feature_resid_dir=feature_resid_dir,
    # feature_out_dir=feature_resid_dir,
    all_resid_post=final_hidden_states,
    W_U=output_embed,
    cfg=cfg,
)[0]

class PointlessStupidity:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

sae_vis_data.model = PointlessStupidity(tokenizer)

In [5]:
sae_vis_data.save_feature_centric_vis("sae_vis.html")

In [6]:
print(data.shape)

torch.Size([10000, 64])
