In [None]:
from sae import Sae

# load two saes, which is in the first two layers.
saes = {
    "layers.0": Sae.load_from_hub(
        "EleutherAI/sae-pythia-70m-deduped-32k", hookpoint="layers.0"
    ),
    "layers.1": Sae.load_from_hub(
        "EleutherAI/sae-pythia-70m-deduped-32k", hookpoint="layers.1"
    ),
}

In [None]:
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load dataset
dataset = load_dataset(
    "EleutherAI/the_pile_deduplicated", split="train", streaming=True
)

# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-70m-deduped")
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-70m-deduped")

In [None]:
import torch
from tqdm import tqdm
from collections import Counter

# Process parameters
batch_size = 4
max_batches = 5
max_samples = batch_size * max_batches

# Initialize counters and tracking variables
common_indices_layer_1 = Counter()  # Track activations for layer 1
batch_texts = []
total_samples_processed = 0

# A list to store latent activations for further processing
all_layer_0_latent_acts = []  # Store all data from layer.0
all_layer_1_latent_acts = []  # Store all data from layer.1

# Set model to evaluation mode
model.eval()

# Process dataset with progress bar
for sample in tqdm(dataset, total=max_samples, desc="Processing samples"):
    batch_texts.append(sample["text"])
    total_samples_processed += 1

    # Process in batches
    if len(batch_texts) == batch_size:
        inputs = tokenizer(
            batch_texts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=512,
        )

        with torch.no_grad():
            outputs = model(**inputs, output_hidden_states=True)

        latent_acts_batch = []

        # Loop through layers, update for layer 1 and collect data for layer 0
        for layer_name, (sae, hidden_state) in zip(
            saes.keys(), zip(saes.values(), outputs.hidden_states)
        ):
            latent_acts = sae.encode(hidden_state)  # Encode the hidden state using SAE

            if layer_name == "layers.1":  # Focus on activations for layer 1
                common_indices_layer_1.update(
                    latent_acts.top_indices.flatten().tolist()
                )
                all_layer_1_latent_acts.append(
                    latent_acts
                )  # Collect data for future reference

            if layer_name == "layers.0":  # Collect layer.0 activations
                all_layer_0_latent_acts.append(latent_acts)

            latent_acts_batch.append(latent_acts)

        # Clear batch after processing
        batch_texts = []

    if total_samples_processed >= max_samples:
        break

In [4]:
def extract_flattened_top_acts_indices(layer_acts):
    all_acts, all_indices = [], []

    for enc_out in layer_acts:
        acts_list = enc_out.top_acts.tolist()
        indices_list = enc_out.top_indices.tolist()

        for acts, indices in zip(acts_list, indices_list):
            for act, ind in zip(acts, indices):
                all_acts.append(act)
                all_indices.append(ind)

    return all_acts, all_indices


def filter_by_neuron(
    front_top_acts, front_top_indices, back_top_acts, back_top_indices, target_neuron
):
    filtered_front_acts = []  # To store the filtered top_acts for the front layer
    filtered_front_indices = []  # To store the filtered top_indices for the front layer
    filtered_back_acts = []  # To store the filtered top_acts for the back layer
    filtered_back_indices = []  # To store the filtered top_indices for the back layer

    # Iterate over each token's activations and indices in the back layer
    for i, (token_top_acts_back, token_top_indices_back) in enumerate(
        zip(back_top_acts, back_top_indices)
    ):
        # Check if the target neuron is in the top indices for the token in the back layer
        if target_neuron in token_top_indices_back:
            # If the neuron is activated, keep the corresponding token data from both layers
            filtered_back_acts.append(token_top_acts_back)
            filtered_back_indices.append(token_top_indices_back)
            filtered_front_acts.append(
                front_top_acts[i]
            )  # Corresponding data from the front layer
            filtered_front_indices.append(
                front_top_indices[i]
            )  # Corresponding data from the front layer

    return (
        filtered_front_acts,
        filtered_front_indices,
        filtered_back_acts,
        filtered_back_indices,
    )


def update_acts(front_acts, front_indices):
    unified_indices = sorted({ind for indices in front_indices for ind in indices})

    updated_acts = []
    for acts, indices in zip(front_acts, front_indices):
        idx_to_act = dict(zip(indices, acts))
        updated_acts.append([idx_to_act.get(neuron, 0) for neuron in unified_indices])

    return updated_acts, unified_indices


def extract_Y(back_acts, back_indices, target_neuron):
    return [
        acts[indices.index(target_neuron)]
        for acts, indices in zip(back_acts, back_indices)
        if target_neuron in indices
    ]

In [None]:
print(common_indices_layer_1)
target_neuron = common_indices_layer_1.most_common(1)[0][
    0
]  # Get the most frequent neuron

In [None]:
layer_0_top_acts_list, layer_0_top_indices_list = extract_flattened_top_acts_indices(
    all_layer_0_latent_acts
)
layer_1_top_acts_list, layer_1_top_indices_list = extract_flattened_top_acts_indices(
    all_layer_1_latent_acts
)

# Use the updated function names with the correct variable names from earlier in the code
filtered_0_acts, filtered_0_indices, filtered_1_acts, filtered_1_indices = (
    filter_by_neuron(
        layer_0_top_acts_list,
        layer_0_top_indices_list,
        layer_1_top_acts_list,
        layer_1_top_indices_list,
        target_neuron,
    )
)

# Update activations for the front layer (layer 0)
updated_0_acts, unified_indices_list = update_acts(filtered_0_acts, filtered_0_indices)

# X represents updated activations for the front layer (layer 0), Y represents activations from the back layer (layer 1)
X = updated_0_acts
Y = extract_Y(filtered_1_acts, filtered_1_indices, target_neuron)

print(len(X), len(Y))

In [None]:
from pysr import PySRRegressor

# Initialize the symbolic regression model
model = PySRRegressor(
    niterations=100,  # Increase if necessary
    binary_operators=["+", "-", "*", "/"],  # Operations to use
    unary_operators=["cos", "sin", "exp", "log"],
)

# Train the model on the top activations and their corresponding top indices
model.fit(X, Y)

model