## Sparse Information Encoders (SIEs)

### A. Data and Model Setup

In [1]:
import transformer_lens

In [2]:
from datasets import load_dataset

In [3]:
data = load_dataset('monology/pile-uncopyrighted', split='train')

Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

Downloading data:   0%|          | 0/30 [00:00<?, ?files/s]

Generating train split: 0 examples [00:00, ? examples/s]

KeyboardInterrupt: 

In [None]:
data.save_to_disk('data/pile.hf')
data.set_format(type='torch', columns=['tokens'])

In [None]:
all_tokens = data["tokens"]
all_tokens.shape

In [None]:
import torch
import torch.nn as nn
import einops

In [None]:
all_tokens_reshaped = einops.rearrange(all_tokens, "batch (x seq_len) -> (batch x) seq_len", x=8, seq_len=128)
all_tokens_reshaped[:, 0] = model.tokenizer.bos_token_id
all_tokens_reshaped = all_tokens_reshaped[torch.randperm(all_tokens_reshaped.shape[0])]
torch.save(all_tokens_reshaped, 'data/c4_code_2b_tokens_reshaped.pt')

In [None]:
def shuffle_data(all_tokens):
    print("Shuffled data")
    return all_tokens[torch.randperm(all_tokens.shape[0])]

In [None]:
class Buffer():
    def __init__(self, cfg):
        self.buffer = torch.zeros((cfg["buffer_size"], cfg["d_mlp"]), dtype=torch.bfloat16, requires_grad=False).cuda()
        self.cfg = cfg
        self.token_pointer = 0
        self.first = True
        self.refresh()
    
    @torch.no_grad()
    def refresh(self):
        self.pointer = 0
        with torch.autocast("cuda", torch.bfloat16):
            if self.first:
                num_batches = self.cfg["buffer_batches"]
            else:
                num_batches = self.cfg["buffer_batches"]//2
            self.first = False
            for _ in range(0, num_batches, self.cfg["model_batch_size"]):
                tokens = all_tokens[self.token_pointer:self.token_pointer+self.cfg["model_batch_size"]]
                _, cache = model.run_with_cache(tokens, stop_at_layer=1, names_filter=utils.get_act_name("post", 0))
                mlp_acts = cache[utils.get_act_name("post", 0)].reshape(-1, self.cfg["d_mlp"])
                self.buffer[self.pointer: self.pointer+mlp_acts.shape[0]] = mlp_acts
                self.pointer += mlp_acts.shape[0]
                self.token_pointer += self.cfg["model_batch_size"]

        self.pointer = 0
        self.buffer = self.buffer[torch.randperm(self.buffer.shape[0]).cuda()]

    @torch.no_grad()
    def next(self):
        out = self.buffer[self.pointer:self.pointer+self.cfg["batch_size"]]
        self.pointer += self.cfg["batch_size"]
        if self.pointer > self.buffer.shape[0]//2 - self.cfg["batch_size"]:
            # print("Refreshing the buffer!")
            self.refresh()
        return out