In [12]:
'''
import torch
import blobfile as bf
import transformer_lens
from sparse_model import Autoencoder
from tqdm import tqdm
import pickle as pkl

device = "mps"

# Load the autoencoder
layer_index = 0  # in range(12)
autoencoder_input = ["mlp_post_act", "resid_delta_mlp"][0]
filename = f"az://openaipublic/sparse-autoencoder/gpt2-small/{autoencoder_input}/autoencoders/{layer_index}.pt"

with bf.BlobFile(filename, mode="rb") as f:
    print("Inside the blobfile")
    print()
    state_dict = torch.load(f)
    with open("state_dict.pkl", "wb") as f:
        pkl.dump(state_dict, f)
    autoencoder = Autoencoder.from_state_dict(state_dict)



# Extract neuron activations with transformer_lens
model = transformer_lens.HookedTransformer.from_pretrained("gpt2", center_writing_weights=False)
prompt = "This is an example of a prompt that"
print(prompt)
print()
tokens = model.to_tokens(prompt)  # (1, n_tokens)
print(model.to_str_tokens(tokens))
with torch.no_grad():
    print("Inside grad")
    logits, activation_cache = model.run_with_cache(tokens, remove_batch_dim=True)
if autoencoder_input == "mlp_post_act":
    input_tensor = activation_cache[f"blocks.{layer_index}.mlp.hook_post"]  # (n_tokens, n_neurons)
elif autoencoder_input == "resid_delta_mlp":
    input_tensor = activation_cache[f"blocks.{layer_index}.hook_mlp_out"]  # (n_tokens, n_residual_channels)

# Encode neuron activations with the autoencoder
device = next(model.parameters()).device
autoencoder.to(device)
with torch.no_grad():
    print(f"The input tensor is {input_tensor}")
    print()
    latent_activations = autoencoder.encode(input_tensor)  # (n_tokens, n_latents)
    print(f"The latent activation is {latent_activations}")
    print()
    original_activations = autoencoder.decode(latent_activations)  # (n_tokens, n_neurons)
    print(f"The original activation is {original_activations}")
    print()
'''

'\nimport torch\nimport blobfile as bf\nimport transformer_lens\nfrom sparse_model import Autoencoder\nfrom tqdm import tqdm\nimport pickle as pkl\n\ndevice = "mps"\n\n# Load the autoencoder\nlayer_index = 0  # in range(12)\nautoencoder_input = ["mlp_post_act", "resid_delta_mlp"][0]\nfilename = f"az://openaipublic/sparse-autoencoder/gpt2-small/{autoencoder_input}/autoencoders/{layer_index}.pt"\n\nwith bf.BlobFile(filename, mode="rb") as f:\n    print("Inside the blobfile")\n    print()\n    state_dict = torch.load(f)\n    with open("state_dict.pkl", "wb") as f:\n        pkl.dump(state_dict, f)\n    autoencoder = Autoencoder.from_state_dict(state_dict)\n\n\n\n# Extract neuron activations with transformer_lens\nmodel = transformer_lens.HookedTransformer.from_pretrained("gpt2", center_writing_weights=False)\nprompt = "This is an example of a prompt that"\nprint(prompt)\nprint()\ntokens = model.to_tokens(prompt)  # (1, n_tokens)\nprint(model.to_str_tokens(tokens))\nwith torch.no_grad():\n 

In [80]:
from imports import *
from transformer_lens import HookedTransformer, utils
import torch
import torch as t
import einops
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
from functools import partial
from datasets import load_dataset
import numpy as np
from jaxtyping import Float
from transformer_lens import ActivationCache
from pathlib import Path
import torch.nn as nn
import pprint
import json 
import torch.nn.functional as F

# DEVICE = "cpu"

DTYPES = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
SAVE_DIR = Path("/workspace/1L-Sparse-Autoencoder/checkpoints")
class AutoEncoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        d_hidden = cfg["dict_size"]
        l1_coeff = cfg["l1_coeff"]
        dtype = DTYPES[cfg["enc_dtype"]]
        torch.manual_seed(cfg["seed"])
        self.W_enc = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(cfg["act_size"], d_hidden, dtype=dtype)))
        self.W_dec = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(d_hidden, cfg["act_size"], dtype=dtype)))
        self.b_enc = nn.Parameter(torch.zeros(d_hidden, dtype=dtype))
        self.b_dec = nn.Parameter(torch.zeros(cfg["act_size"], dtype=dtype))

        self.W_dec.data[:] = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)

        self.d_hidden = d_hidden
        self.l1_coeff = l1_coeff

        self.to(cfg["device"])
    
    def forward(self, x, mask):
        x_cent = x - self.b_dec
        acts = F.relu(x_cent @ self.W_enc + self.b_enc)
        print(acts.shape)
        acts[:,-2,:] = acts[:,-2,:] * mask
        x_reconstruct = acts @ self.W_dec + self.b_dec
        l2_loss = (x_reconstruct.float() - x.float()).pow(2).sum(-1).mean(0)
        l1_loss = self.l1_coeff * (acts.float().abs().sum())
        loss = l2_loss + l1_loss
        return loss, x_reconstruct, acts, l2_loss, l1_loss
    
    @torch.no_grad()
    def make_decoder_weights_and_grad_unit_norm(self):
        W_dec_normed = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
        W_dec_grad_proj = (self.W_dec.grad * W_dec_normed).sum(-1, keepdim=True) * W_dec_normed
        self.W_dec.grad -= W_dec_grad_proj
        # Bugfix(?) for ensuring W_dec retains unit norm, this was not there when I trained my original autoencoders.
        self.W_dec.data = W_dec_normed
    
    def get_version(self):
        version_list = [int(file.name.split(".")[0]) for file in list(SAVE_DIR.iterdir()) if "pt" in str(file)]
        if len(version_list):
            return 1+max(version_list)
        else:
            return 0

    def save(self):
        version = self.get_version()
        torch.save(self.state_dict(), SAVE_DIR/(str(version)+".pt"))
        with open(SAVE_DIR/(str(version)+"_cfg.json"), "w") as f:
            json.dump(cfg, f)
        print("Saved as version", version)
    
    @classmethod
    def load(cls, version):
        cfg = (json.load(open(SAVE_DIR/(str(version)+"_cfg.json"), "r")))
        pprint.pprint(cfg)
        self = cls(cfg=cfg)
        self.load_state_dict(torch.load(SAVE_DIR/(str(version)+".pt")))
        return self

    @classmethod
    def load_from_hf(cls, version, device_override=None):
        """
        Loads the saved autoencoder from HuggingFace. 
        
        Version is expected to be an int, or "run1" or "run2"

        version 25 is the final checkpoint of the first autoencoder run,
        version 47 is the final checkpoint of the second autoencoder run.
        """
        if version=="run1":
            version = 25
        elif version=="run2":
            version = 47
        
        cfg = utils.download_file_from_hf("NeelNanda/sparse_autoencoder", f"{version}_cfg.json")
        if device_override is not None:
            cfg["device"] = device_override

        pprint.pprint(cfg)
        self = cls(cfg=cfg)
        self.load_state_dict(utils.download_file_from_hf("NeelNanda/sparse_autoencoder", f"{version}.pt", force_is_torch=True))
        return self



# load gpt2-small
# model = HookedTransformer.from_pretrained("gpt2-small").to('mps')
model = LanguageModel("openai-community/gpt2", device_map=DEVICE)
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")

point, layer = "resid_pre", 10
dic = utils.download_file_from_hf("jacobcd52/gpt2-small-sparse-autoencoders", f"gpt2-small_6144_{point}_{layer}.pt", force_is_torch=True)


cfg = {
    "dict_size": 6144,
    "act_size": 768,
    "l1_coeff": 0.001,
    "enc_dtype": "fp32",
    "seed": 0,
    "device": DEVICE,
    "model_batch_size": 1028,
}
encoder_resid_pre_10 = AutoEncoder(cfg)
encoder_resid_pre_10.load_state_dict(dic)

point, layer = "resid_pre", 11
dic = utils.download_file_from_hf("jacobcd52/gpt2-small-sparse-autoencoders", f"gpt2-small_6144_{point}_{layer}.pt", force_is_torch=True)


cfg = {
    "dict_size": 6144,
    "act_size": 768,
    "l1_coeff": 0.001,
    "enc_dtype": "fp32",
    "seed": 0,
    "device": DEVICE,
    "model_batch_size": 1028,
}
encoder_resid_pre_11 = AutoEncoder(cfg)
encoder_resid_pre_11.load_state_dict(dic)

data = load_dataset("stas/openwebtext-10k", split="train")
# tokenized_data = utils.tokenize_and_concatenate(data, model.tokenizer, max_length=128)
# tokenized_data = tokenized_data.shuffle(22)

# from transformer_lens import utils

example_prompt = "After Jack and Mary went to the store, Jack gave a bottle of milk to"
example_answer = " Mary"
# utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

In [76]:
dic = torch.load("gpt2-small-sparse-autoencoders/gpt2-small_6144_mlp_out_0.pt", map_location=torch.device(DEVICE))

cfg = {
    "dict_size": 6144,
    "act_size": 768,
    "l1_coeff": 0.001,
    "enc_dtype": "fp32",
    "seed": 0,
    "device": DEVICE,
    "model_batch_size": 1,
}
# encoder_resid_pre_10 = AutoEncoder(cfg)
# encoder_resid_pre_10.load_state_dict(dic)
encoder_mlp_out_0 = AutoEncoder(cfg)
encoder_mlp_out_0.load_state_dict(dic)

<All keys matched successfully>

In [78]:
text = "The Eiffel Tower is in the city of"
a = model.tokenizer.encode(text, return_tensors = 'pt')
a_len = model.tokenizer.tokenize(text)
print(len(a))
print(len(a_len))
mask = torch.ones((1,1, 6144))

with model.trace(a) as tracer:
    acts = model.transformer.h[0].output[0]
    # loss, x_reconstruct, acts, l2_loss, l1_loss = encoder_resid_pre_10(acts, mask)
    loss, x_reconstruct, acts, l2_loss, l1_loss = encoder_mlp_out_0(acts, mask)
    b = model.transformer.h[0].output.save()
    model.transformer.h[0].output[0][:,:,:] = x_reconstruct
    output = model.lm_head.output.argmax(dim = -1).save()

model.tokenizer.decode(output[0][-1])

1
10
torch.Size([1, 10, 6144])


' Paris'

In [62]:
a = torch.tensor([2]*10, dtype = torch.float32)
print(a)
b = torch.ones((10,), dtype = torch.float32)
print(a*b)

tensor([2., 2., 2., 2., 2., 2., 2., 2., 2., 2.])
tensor([2., 2., 2., 2., 2., 2., 2., 2., 2., 2.])
