In [None]:
from __future__ import annotations
"""
TODO(Adriano) after getting some plots for the blobs post-SAE in GPT2, it's important to check whether these
SAEs are actually any good. Unfortunately, I have really bad FVUs, MSEs, etc... It's also unclear if a error
norm of 30 is normal. People seem not to be reporting this very well and it's deeply annoying.
"""
import os
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Union, Any
import torch
import json
import torch.nn as nn
from datasets import load_dataset
from sae_lens import SAE, HookedSAETransformer
from torch.utils.data import DataLoader
import dotenv
from transformers import AutoTokenizer
from transformer_lens.utils import tokenize_and_concatenate
import tqdm
import torch
import gc
import itertools
import einops
from jaxtyping import Float, Int
import re
import matplotlib.pyplot as plt
import pydantic
from transformer_lens import ActivationCache
from transformer_lens.components import TransformerBlock, LayerNormPre
import shutil

# Load our own imports, etc...
from sans_sae_lib.utils import plot_cosine_kernel, plot_all_nc2_top_pcs, plot_all_nc2_top_pcs_errs
from sans_sae_lib.schemas import ExtractedActivations, FlattenedExtractedActivations

dotenv.load_dotenv()
assert "CUDA_VISIBLE_DEVICES" in os.environ, "CUDA_VISIBLE_DEVICES is not set"
assert len(os.environ["CUDA_VISIBLE_DEVICES"].strip()) > 0, "CUDA_VISIBLE_DEVICES is empty"

print("="*50 + " [Loading Dataset] " + "="*50) # DEBUG
# dataset = load_dataset("openwebtext", split="train", trust_remote_code=True)
dataset = load_dataset("stas/openwebtext-10k", split="train", trust_remote_code=True) # Smaller version
tokenizer = AutoTokenizer.from_pretrained("gpt2")
token_dataset = tokenize_and_concatenate(
    dataset=dataset,  # type: ignore
    tokenizer=tokenizer,  # type: ignore
    streaming=True,
    # NOTE: all these have context 128
    max_length=128, #sae.cfg.context_size,
    add_bos_token=True, #sae.cfg.prepend_bos,
)
print("="*50 + " [Loading Model] " + "="*50) # DEBUG
# TODO(Adriano) do this below...
# for d in extractor.cfg_dics:
#     print(d["context_size"]) # NOTE: you should picke the smallest of these...


# Shorten the dataset for testing more quickly
dataset_size = 300 # XXX make this longer please
token_dataset_short = token_dataset[:dataset_size]['tokens']
dataset_length = token_dataset_short.shape[0]
sequence_length = token_dataset_short.shape[1]

In [2]:
class ResidAndLn2Comparer:
    """
    A class that basically automates the task of taking in a dataset, tokenizing it, running it through
    GPT2 using HookedSAETransformer with JBloom's SAEs.

    It is meant for measuring/plotting the impacts on the activations on the model.
    The outputs of this module are/should be:
    1. Plots of histograms for PCA applied on
        - The residual stream activations
        - The ln2 activations
        - The SAE processed residual stream activations
        - The SAE processed ln2 activations
        (there are 2D histograms for any pair of PCs in the top 10 PCs)
    2. Plots of the MSE, FVU, etc... for the SAEs at each location (including also at ln2) as 1D histograms
    3. Plots of the cosine similarity between the PCs pre and post-sae intervention (a 2D heatmap)
        (this also will include a )
    4. Plots of the cosine similarity between input datapoints and output datapoints pre and post-SAE
        intervention in the form of a 1D histogram.
    5. SAE error norm/mse/fvu (or in the future any general function)
    6. Plots for each K out of a set of top ks of the distribution of reconstructed activations
        (this is critical). To do this with non-top-k-trained SAEs we just force the latents to be top-K'ed
        but ideally we should also support some top-k SAEs.
    
    TODO(Adriano) later we want to see how the distributions change as we keep applying SAEs.
    """
    def __init__(self):
        # Load the model and set some basic settins
        self.model_name = "gpt2"
        self.sae_release = "gpt2-small-res-jb"
        self.device = "cuda" # NOTE: you should use CUDA_VISIBLE_DEVICES to select the GPU
        self.model = HookedSAETransformer.from_pretrained(self.model_name, device=self.device)
        self.model.eval()

        # Load the SAEs and set helper variables
        self.d_model = self.model.cfg.d_model
        self.n_layers = self.model.cfg.n_layers
        self.load_jbloom_gpt2_saes(sae_release=self.sae_release)

    def load_jbloom_gpt2_saes(self, sae_release: str = "gpt2-small-res-jb"):
        self.saes = []
        self.sae_cfg_dicts = []
        self.sae_sparsities = []
        for layer in range(self.model.cfg.n_layers):
            sae, cfg_dict, sparsity = SAE.from_pretrained(
                release = sae_release,
                sae_id = f"blocks.{layer}.hook_resid_pre",
                device = self.device
            )
            sae.eval()
            self.saes.append(sae)
            self.sae_cfg_dicts.append(cfg_dict)
            self.sae_sparsities.append(sparsity)

    def get_post_ln2_hookpoint_after_hookpoint(self, hookpoint: str) -> torch.Tensor:
        """
        Look here to see where mlp_in happens:
        ------------------------------------------------------------
        pre:
        https://github.com/TransformerLensOrg/TransformerLens/blob/e65fafb4791c66076bc54ec9731920de1e8c676f/transformer_lens/components/transformer_block.py#L191

        also
        ------------------------------------------------------------
        post:
        https://github.com/TransformerLensOrg/TransformerLens/blob/e65fafb4791c66076bc54ec9731920de1e8c676f/transformer_lens/components/layer_norm_pre.py#L52
        """
        assert re.match(r"^blocks.[0-9]+.hook_resid_pre$", hookpoint)
        layer = int(re.match(r"^blocks.([0-9]+)\.hook_resid_pre$", hookpoint).group(1))
        return f"blocks.{layer}.ln2.hook_normalized"

    def get_hookpoint_layer(self, hookpoint: str) -> int:
        assert re.match(r"^blocks.[0-9]+\..*$", hookpoint)
        return int(re.match(r"^blocks.([0-9]+)\..*$", hookpoint).group(1))

    def get_post_ln2_value_after_hookpoint_from_cache(
            self,
            cache: ActivationCache,
            hookpoint: str
        ) -> Float[torch.Tensor, "layer batch seq d_model"]:
        """
        Quick helper meethod to get the post-ln2 value after a hookpoint in the model (vanilla)
        This is used for the express purpose of being able to calculate the SAE's MSE/FVU impact on a
        cerain later location (right after ln2 right before MLP).
        """
        # 1. Get the ln2 that we will have to apply
        layer_num: int = self.get_hookpoint_layer(hookpoint)
        block: TransformerBlock = self.model.blocks[layer_num]
        assert isinstance(block, TransformerBlock), f"block is {type(block)}"
        ln2 = block.ln2
        assert isinstance(ln2, LayerNormPre), f"ln2 is {type(ln2)}"
        # 2. Get the activations
        post_ln2_hp: str = self.get_post_ln2_hookpoint_after_hookpoint(hookpoint)
        assert post_ln2_hp in cache.keys(), f"post_ln2_hp={post_ln2_hp} not in cache.keys()={cache.keys()}"
        post_ln2_act: torch.Tensor = cache[post_ln2_hp]
        return post_ln2_act
    
    def get_post_ln2_value_after_hookpoint_from_activations(
            self,
            activations: Float[torch.Tensor, "batch seq d_model"],
            hookpoint: str
    ) -> Float[torch.Tensor, "layer batch seq d_model"]:
        """
        Like `get_post_ln2_value_after_hookpoint` but this will instead take in activations
        instead of a cache object. The point is that we will propagate forwards the SAE-processed
        activations instead of the vanilla activations.
        """
        # 1. Get the ln2 that we will have to apply
        layer_num: int = self.get_hookpoint_layer(hookpoint)
        block: TransformerBlock = self.model.blocks[layer_num]
        assert isinstance(block, TransformerBlock), f"block is {type(block)}"
        ln2 = block.ln2
        assert isinstance(ln2, LayerNormPre), f"ln2 is {type(ln2)}"
        # 2. run the model for exactly one layer and hook out the spot right before ln2
        debug_number = torch.randn(1, device=self.device) * 999999
        buffer_post_ln2 = torch.ones_like(activations) * debug_number
        def write_hook(normalized_resid, hook):
            buffer_post_ln2[:] = normalized_resid
            return normalized_resid
        self.model.run_with_hooks(
            activations,
            fwd_hooks=[
                (
                    f"blocks.{layer_num}.ln2.hook_normalized",
                    write_hook,
                ),
            ],
            start_at_layer=layer_num,
            stop_at_layer=layer_num+1,
        )
        # Sanity check that we actually wrote out
        assert not torch.any(buffer_post_ln2 == debug_number), f"buffer_post_ln2 is still {debug_number}"
        return buffer_post_ln2
        
        
    def extract_activations(self, tokens: Int[torch.Tensor, "batch seq"], batch_size: int = 30) -> ExtractedActivations:
        """
        Extracts the activations of the model and the SAEs and returns them as tensors.
        """
        # 1. Define output buffers
        debug_numbers = torch.randn(4, device="cpu") * 999999
        sae_ins = torch.ones((len(self.saes), dataset_length, sequence_length, self.d_model), device="cpu") * debug_numbers[0] # fmt: skip
        sae_outs = torch.ones((len(self.saes), dataset_length, sequence_length, self.d_model), device="cpu") * debug_numbers[1] # fmt: skip
        ln2s = torch.ones((len(self.saes), dataset_length, sequence_length, self.d_model), device="cpu") * debug_numbers[2] # fmt: skip
        ln2s_saed = torch.ones((len(self.saes), dataset_length, sequence_length, self.d_model), device="cpu") * debug_numbers[3] # fmt: skip

        while True:
            try:
                pbar = tqdm.trange(0, len(tokens), batch_size, desc=f"Batch Size = {batch_size}")
                with torch.no_grad():
                    for i in pbar:
                        j = min(i + batch_size, len(tokens))
                        # activation store can give us tokens.
                        # 1. Get the activations for our current batch of tokens
                        # TODO(Adriano) if you do this with run_with_hooks (etc...) you could
                        # get better performance for sure.
                        batch_tokens = tokens[i:j]
                        _, cache = self.model.run_with_cache(batch_tokens, prepend_bos=True)

                        # 2. Extract the desired activations from the cache
                        # print(cache.keys())
                        # Use the SAE
                        # print(len(extractor.block2sae))
                        # print(f"hook_name={extractor.block2sae[8].cfg.hook_name}") # Nope
                        sae_in = torch.stack([cache[self.saes[i].cfg.hook_name].detach() for i in range(len(self.saes))])
                        ln2_ = torch.stack([self.get_post_ln2_value_after_hookpoint_from_cache(cache, self.saes[i].cfg.hook_name) for i in range(len(self.saes))]) # fmt: skip
                        del cache
                        gc.collect()
                        torch.cuda.empty_cache()
                        assert sae_in.shape == ln2_.shape, f"sae_in.shape={sae_in.shape}, ln2_.shape={ln2_.shape}"
                        # feature_acts = [extractor.block2sae[i].encode(sae_in[i])
                        # TODO(Adriano) this should be possible to parallelize
                        sae_out = torch.stack([self.saes[i](sae_in[i]).detach()for i in range(len(self.saes))])
                        assert sae_in.shape == sae_out.shape, f"sae_in.shape={sae_in.shape}, sae_out.shape={sae_out.shape}"
                        ln2_saed = torch.stack([self.get_post_ln2_value_after_hookpoint_from_activations(sae_out[i], self.saes[i].cfg.hook_name) for i in range(len(self.saes))]) # fmt: skip

                        # 2. Sanity check the sizes
                        assert sae_in.shape == sae_out.shape
                        assert sae_in.shape == ln2_.shape # NOTE: this will not scale to different layers but eh
                        assert sae_in.shape == ln2_saed.shape
                        assert sae_in.shape[0] == len(self.saes)
                        assert sae_in.shape[1] == j - i
                        assert sae_in.shape[2] == sequence_length
                        assert sae_in.shape[3] == self.d_model, f"sae_in.shape={sae_in.shape}, need [2] = {self.d_model}" # fmt: skip
                        assert sae_in.ndim == 4

                        # 3. Store the activations to the appropriate buffers
                        sae_ins[:, i:j, :] = sae_in.cpu()
                        sae_outs[:, i:j, :] = sae_out.cpu()
                        ln2s[:, i:j, :] = ln2_.cpu()
                        ln2s_saed[:, i:j, :] = ln2_saed.cpu()
                
                # Sanity check that we actually wrote out
                assert not torch.any(sae_ins == debug_numbers[0]), f"sae_ins is still {debug_numbers[0]}"
                assert not torch.any(sae_outs == debug_numbers[1]), f"sae_outs is still {debug_numbers[1]}"
                assert not torch.any(ln2s == debug_numbers[2]), f"ln2s is still {debug_numbers[2]}"
                assert not torch.any(ln2s_saed == debug_numbers[3]), f"ln2s_saed is still {debug_numbers[3]}"
                
                # Done
                return ExtractedActivations(
                    sae_ins=sae_ins,
                    sae_outs=sae_outs,
                    ln2s=ln2s,
                    ln2s_saed=ln2s_saed
                )
            except torch.OutOfMemoryError as e:
                print(type(e), e)
                if batch_size == 1:
                    raise e
                batch_size = min(1, batch_size // 2)
                pbar.update(desc=f"Batch Size = {batch_size}")

print("="*50 + " [Loading Model] " + "="*50) # DEBUG
comparer = ResidAndLn2Comparer()
# print(comparer.model) # DEBUG

print("="*50 + " [Extracting Activations] " + "="*50) # DEBUG
extracted_activations: ExtractedActivations = comparer.extract_activations(
    token_dataset_short,
    batch_size=150
)



huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Loaded pretrained model gpt2 into HookedTransformer


This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)




Batch Size = 150: 100%|██████████| 2/2 [00:06<00:00,  3.22s/it]


In [3]:
print("="*50 + " [Flattening + Calculating PCA & Errors] " + "="*50) # DEBUG
extracted_activations_flattened: FlattenedExtractedActivations = extracted_activations.flatten()
print("="*50 + " [Done, now ready to plot (next cell)] " + "="*50) # DEBUG



In [4]:
"""Create the global plots folder for us to be able to plot everything"""
global_plot_folder_path = Path("sae_sans_plots")
if global_plot_folder_path.exists() and len(list(global_plot_folder_path.glob("*"))) == 0:
    shutil.rmtree(global_plot_folder_path)
global_plot_folder_path.mkdir(parents=True, exist_ok=False)

In [None]:
"""
Plot the errors from the SAEs, etc...
"""
# res folders
res_sae_err_norms_output_folder = global_plot_folder_path / "res_sae_err_norms" # fmt: skip
res_sae_variance_explained_output_folder = global_plot_folder_path / "res_sae_variance_explained" # fmt: skip
res_sae_mse_output_folder = global_plot_folder_path / "res_sae_mse" # fmt: skip

# ln2 folders
ln2_sae_err_norms_output_folder = global_plot_folder_path / "ln2_sae_err_norms" # fmt: skip
ln2_sae_variance_explained_output_folder = global_plot_folder_path / "ln2_sae_variance_explained" # fmt: skip
ln2_sae_mse_output_folder = global_plot_folder_path / "ln2_sae_mse" # fmt: skip

for (name, folder, arr) in tqdm.tqdm([
    # res
    ("res_sae_err_norms", res_sae_err_norms_output_folder, extracted_activations_flattened.res_sae_error_norms), # fmt: skip
    ("res_sae_variance_explained", res_sae_variance_explained_output_folder, extracted_activations_flattened.res_sae_var_explained), # fmt: skip
    ("res_sae_mse", res_sae_mse_output_folder, extracted_activations_flattened.res_sae_mse), # fmt: skip
    # ln2
    ("ln2_sae_err_norms", ln2_sae_err_norms_output_folder, extracted_activations_flattened.ln2_sae_error_norms), # fmt: skip
    ("ln2_sae_variance_explained", ln2_sae_variance_explained_output_folder, extracted_activations_flattened.ln2_sae_var_explained), # fmt: skip
    ("ln2_sae_mse", ln2_sae_mse_output_folder, extracted_activations_flattened.ln2_sae_mse), # fmt: skip
    
]):
    if folder.exists() and len(list(folder.glob("*"))) == 0:
        shutil.rmtree(folder)
    folder.mkdir(parents=True, exist_ok=False)
    for layer in tqdm.trange(len(comparer.saes)):
        filepath = folder / f"layer_{layer}.png"
        plt.hist(arr.flatten().cpu().log10().numpy(), bins=100)
        plt.title(f"log10({name}) (layer {layer})")
        plt.savefig(filepath)
        plt.close()
        filepath_meta = folder / f"layer_{layer}.json"
        with open(filepath_meta, "w") as f:
            json.dump({
                "layer": layer,
                "name": name,
                "shape": str(arr.shape),
                "min": arr.min().item(),
                "max": arr.max().item(),
                "mean": arr.mean().item(),
                "std": arr.std().item(),
                "median": arr.median().item(),
                "q1": arr.quantile(0.25).item(),
                "q3": arr.quantile(0.75).item(),
            }, f, indent=4)
        


100%|██████████| 12/12 [00:02<00:00,  4.43it/s]
100%|██████████| 12/12 [00:02<00:00,  4.02it/s]
100%|██████████| 12/12 [00:02<00:00,  4.35it/s]
100%|██████████| 12/12 [00:03<00:00,  3.90it/s]
100%|██████████| 12/12 [00:02<00:00,  4.24it/s]
100%|██████████| 12/12 [00:03<00:00,  3.78it/s]
100%|██████████| 6/6 [00:17<00:00,  2.93s/it]


In [6]:
"""
Calculate the similarity via cosine between each pair of input and output PCs (after
applying an SAE). If the SAE is good, we should expect an exact match (i.e. a diagonal-
like stripe).
"""
plots_output_folder = global_plot_folder_path / "plots_cosine_sim_pca_post_sae"
if plots_output_folder.exists() and len(list(plots_output_folder.glob("*"))) == 0:
    shutil.rmtree(plots_output_folder)
plots_output_folder.mkdir(parents=True, exist_ok=False)

# You can also analyze multiple layers
print("\nAnalyzing multiple layers...")
for layer in tqdm.trange(len(comparer.saes)):
    # Acquire all the principle components we want to compare
    eigenvectors_sae_ins = extracted_activations_flattened.sae_ins_pca_eigenvectors[layer]
    eigenvectors_sae_outs = extracted_activations_flattened.sae_outs_pca_eigenvectors[layer]
    eigenvectors_ln2s = extracted_activations_flattened.ln2s_pca_eigenvectors[layer]
    eigenvectors_ln2s_saed = extracted_activations_flattened.ln2s_saed_pca_eigenvectors[layer]
    # Save them to the plot folder
    cosine_sim = plot_cosine_kernel(eigenvectors_sae_ins, eigenvectors_sae_outs, force_positive=True, save_to_file=plots_output_folder / f"layer_{layer}.png")
    cosine_sim = plot_cosine_kernel(eigenvectors_sae_ins, eigenvectors_ln2s, force_positive=True, save_to_file=plots_output_folder / f"layer_{layer}_sae_ins_ln2.png")
    
    # TODO(Adriano) this is some tidbit code written by Claude, not sure if I honestly want it :P
    # Find the top aligned eigenvectors
    # max_values, max_indices = torch.max(cosine_sim.abs(), dim=1)
    # top_k = 5
    # top_indices = torch.argsort(max_values, descending=True)[:top_k]
    
    # print(f"Top {top_k} aligned eigenvector pairs (SAE in → SAE out):")
    # for i, idx in enumerate(top_indices):
    #     out_idx = max_indices[idx]
    #     sim_value = cosine_sim[idx, out_idx].item()
    #     print(f"  {i+1}. In eigenvector {idx} aligns with out eigenvector {out_idx} (sim={sim_value:.4f})")
    pass




Analyzing multiple layers...


100%|██████████| 12/12 [00:17<00:00,  1.50s/it]


In [7]:
"""
Plot for each pair of PCs the histogram of their values on that projection.
"""
res_sae_in_pca_histograms_folder = global_plot_folder_path / "res_sae_in_pca_histograms"
res_sae_out_pca_histograms_folder = global_plot_folder_path / "res_sae_out_pca_histograms"
ln2_pca_histograms_folder = global_plot_folder_path / "ln2_pca_histograms"
ln2_sae_effect_pca_histograms_folder = global_plot_folder_path / "ln2_sae_effect_pca_histograms"
n_pcs = 2 # ehh

for output_folder, (activations, mean, eigenvectors) in tqdm.tqdm([
    (
        res_sae_in_pca_histograms_folder,
        (
            extracted_activations_flattened.sae_ins,
            extracted_activations_flattened.sae_ins_means,
            extracted_activations_flattened.sae_ins_pca_eigenvectors
        )
    ),
    (
        res_sae_out_pca_histograms_folder, 
        (
            extracted_activations_flattened.sae_outs,
            extracted_activations_flattened.sae_outs_means,
            extracted_activations_flattened.sae_outs_pca_eigenvectors
        )
    ),
    (
        ln2_pca_histograms_folder,
        (
            extracted_activations_flattened.ln2s,
            extracted_activations_flattened.ln2s_means,
            extracted_activations_flattened.ln2s_pca_eigenvectors
        )
    ),
    (
        ln2_sae_effect_pca_histograms_folder,
        (
            extracted_activations_flattened.ln2s_saed,
            extracted_activations_flattened.ln2s_saed_means,
            extracted_activations_flattened.ln2s_saed_pca_eigenvectors
        )
    )
]):
    for layer in tqdm.trange(len(comparer.saes)):
        output_folder_layer = output_folder / f"layer_{layer}"
        output_folder_layer.mkdir(parents=True, exist_ok=False)
        # NOTE: use default plotting kwargs
        plot_all_nc2_top_pcs(n_pcs, activations[layer], mean[layer], eigenvectors[layer], output_folder_layer)

  0%|          | 0/4 [00:00<?, ?it/s]
[A
Plotting PCA histograms: 100%|██████████| 1/1 [00:00<00:00,  1.36it/s]

[A
Plotting PCA histograms: 100%|██████████| 1/1 [00:00<00:00,  1.41it/s]

[A
Plotting PCA histograms: 100%|██████████| 1/1 [00:00<00:00,  1.43it/s]

[A
Plotting PCA histograms: 100%|██████████| 1/1 [00:00<00:00,  1.44it/s]

[A
Plotting PCA histograms: 100%|██████████| 1/1 [00:00<00:00,  1.46it/s]

[A
Plotting PCA histograms: 100%|██████████| 1/1 [00:00<00:00,  1.47it/s]

[A
Plotting PCA histograms: 100%|██████████| 1/1 [00:00<00:00,  1.46it/s]

[A
Plotting PCA histograms: 100%|██████████| 1/1 [00:01<00:00,  1.44s/it]

[A
Plotting PCA histograms: 100%|██████████| 1/1 [00:00<00:00,  1.53it/s]

[A
Plotting PCA histograms: 100%|██████████| 1/1 [00:00<00:00,  1.43it/s]

[A
Plotting PCA histograms: 100%|██████████| 1/1 [00:00<00:00,  1.53it/s]

[A
Plotting PCA histograms: 100%|██████████| 1/1 [00:00<00:00,  1.52it/s]
100%|██████████| 12/12 [00:09<00:00,  1.33it/s]
 25

In [8]:
"""
Plot the error projected onto each pair of PCs as above. Specifically,
we accumulate in each bin (same bins as above) the error in the SAE-modified
representation in that bin. Then we store 2 versions:
1. Divide by the total number of samples (so basically this is proportional to
    expected error)
2. Divide by the number of samples in that bin (so basically here we are looking to
    see if there is an additional contribution of the SAE to that bin above just
    "there were more samples here")

And for each of these two we do it for each of:
1. Error norm (on that datapoint)
2. Variance explained (on that datapoint)
3. MSE (on that datapoint)
"""
res_sae_in_pca_histograms_folder = global_plot_folder_path / "res_sae_in_err_pca_histograms"
res_sae_out_pca_histograms_folder = global_plot_folder_path / "res_sae_out_err_pca_histograms"
ln2_pca_histograms_folder = global_plot_folder_path / "ln2_err_pca_histograms"
ln2_sae_effect_pca_histograms_folder = global_plot_folder_path / "ln2_sae_effect_err_pca_histograms"
n_pcs = 2 # ehh, copy from above lmao
for output_folder, (activations, mean, eigenvectors, err_norm, err_var_explained, err_mse) in tqdm.tqdm([
    (
        res_sae_in_pca_histograms_folder,
        (
            # Projection (binning) data
            extracted_activations_flattened.sae_ins,
            extracted_activations_flattened.sae_ins_means,
            extracted_activations_flattened.sae_ins_pca_eigenvectors,
            # Errors (coloring data)
            extracted_activations_flattened.res_sae_error_norms,
            extracted_activations_flattened.res_sae_var_explained,
            extracted_activations_flattened.res_sae_mse
        )
    ),
    (
        res_sae_out_pca_histograms_folder, 
        (
            # Projection (binning) data
            extracted_activations_flattened.sae_outs,
            extracted_activations_flattened.sae_outs_means,
            extracted_activations_flattened.sae_outs_pca_eigenvectors,
            # Errors (coloring data)
            extracted_activations_flattened.res_sae_error_norms,
            extracted_activations_flattened.res_sae_var_explained,
            extracted_activations_flattened.res_sae_mse
        )
    ),
    (
        ln2_pca_histograms_folder,
        (
            # Projection (binning) data
            extracted_activations_flattened.ln2s,
            extracted_activations_flattened.ln2s_means,
            extracted_activations_flattened.ln2s_pca_eigenvectors,
            # Errors (coloring data)
            extracted_activations_flattened.ln2_sae_error_norms,
            extracted_activations_flattened.ln2_sae_var_explained,
            extracted_activations_flattened.ln2_sae_mse
        )
    ),
    (
        ln2_sae_effect_pca_histograms_folder,
        (
            # Projection (binning) data
            extracted_activations_flattened.ln2s_saed,
            extracted_activations_flattened.ln2s_saed_means,
            extracted_activations_flattened.ln2s_saed_pca_eigenvectors,
            # Errors (coloring data)
            extracted_activations_flattened.ln2_sae_error_norms,
            extracted_activations_flattened.ln2_sae_var_explained,
            extracted_activations_flattened.ln2_sae_mse
        )
    )
]):
    for layer in tqdm.trange(len(comparer.saes)):
        output_folder_layer = output_folder / f"layer_{layer}"
        output_folder_layer.mkdir(parents=True, exist_ok=False)
        err_type_names = ["error_norm", "variance_explained", "mse"] # fmt: skip
        err_arrays = [err_norm[layer], err_var_explained[layer], err_mse[layer]] # fmt: skip
        for err_type_name, err_array in zip(err_type_names, err_arrays):
            for normalize_by_n_in_bin, normalize_by_n_in_bin_name in zip([False, True], ["unnormalized", "normalized"]):
                sub_output_folder = output_folder_layer / f"{normalize_by_n_in_bin_name}_{err_type_name}"
                plot_all_nc2_top_pcs_errs(
                    n_pcs,
                    # Projection parameters
                    activations[layer],
                    mean[layer],
                    eigenvectors[layer],
                    # Error parameters
                    err_array, # already layered
                    normalize_by_n_in_bin, # already layered
                    # Storage parameters
                    sub_output_folder,
                    # Settings
                    store_accumulation_values=True,
                    normalize_accumulation_values_by_n_in_bin=normalize_by_n_in_bin, # normalize by in bin but not total
                    normalize_accumulation_values_by_n_total=not normalize_by_n_in_bin # expected value if not normalize by in bin; fmt: skip
                )

  0%|          | 0/12 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]


NotImplementedError: Not implemented

In [12]:
"""
Generally the output will look like this:

XXX put the filetree and later improve it please!
"""

SyntaxError: incomplete input (969763094.py, line 1)