In [1]:
from dataclasses import dataclass, field
from typing import List
from sae_lens.training.config import LanguageModelSAERunnerConfig

@dataclass
class SAETrainConfig(LanguageModelSAERunnerConfig):
    dataset_path: str = 'imagenet_data'
    num_workers: int = 0
    num_epochs: int = 3

    expansion_factor: int = 24
    context_size: int = 257
    d_in: int = 1024
    model_name: str = "laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K"
    hook_point: str = "blocks.{layer}.hook_mlp_out"
    hook_point_layer: List[int] = field(default_factory=lambda: [22])
    dead_feature_window: int = 2500
    use_ghost_grads: bool = True
    feature_sampling_window: int = 250
    from_pretrained_path: str = None

    b_dec_init_method: str = "geometric_median"
    normalize_sae_decoder: bool = True

    lr: float = 0.0005
    l1_coefficient: int = 0.006
    lr_scheduler_name: str = "cosineannealing"
    train_batch_size_tokens: int = 8
    lr_warm_up_steps: int = 4000

    n_batches_in_buffer: int = 8
    store_batch_size: int = 4

    log_to_wandb: bool = True
    wandb_project: str = "openclip_sae_training"
    wandb_entity: str = "willfulbytes"
    wandb_log_frequency: int = 25
    eval_every_n_wandb_logs: int = 10
    run_name: str = None

    device: str = "cuda"
    seed: int = 42
    n_checkpoints: int = 10
    checkpoint_path: str = "checkpoints24"
    dtype: str = "torch.float32"

In [2]:
import torch
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f0d4a26e0e0>

In [3]:
from torch.utils.data import Dataset
from vit_prisma.models.base_vit import HookedViT
from open_clip import tokenize
import datasets
from typing import Any, Iterator, cast
from torch.utils.data import DataLoader

class HFDataset(Dataset):
    def __init__(self, data_location, transforms, image_col, text_col):
        self.dataset = datasets.load_dataset(data_location, split="train")
        self.image_col = image_col
        self.text_col = text_col
        self.transforms = transforms

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        # Remove the extra dimension by squeezing the tensor
        images = self.transforms(self.dataset[idx][self.image_col], return_tensors="pt")["pixel_values"].squeeze(0)
        texts = tokenize([self.dataset[idx][self.text_col]])[0]
        return images, texts

# Update the collate functions accordingly
def collate_fn(data):
    imgs, _ = zip(*data)
    return torch.stack(imgs, dim=0)

def collate_fn_eval(data):
    imgs, texts = zip(*data)
    return torch.stack(imgs, dim=0), torch.stack(texts, dim=0)


class OpenCLIPActivationsStore:
    """
    Class for streaming tokens and generating and storing activations
    while training SAEs.
    """

    def __init__(
        self,
        config: SAETrainConfig,
        model: HookedViT,
        dataset: torch.utils.data.Dataset,
        eval_dataset: torch.utils.data.Dataset = None,
        num_workers: int = 0,
    ):
        self.config = config
        assert (
            not self.config.normalize_activations
        ), "Normalize activations is currently not implemented for vision, sorry!"
        self.normalize_activations = self.config.normalize_activations
        self.model = model
        self.dataset = dataset
        self.eval_dataset = eval_dataset

        self.image_dataloader = torch.utils.data.DataLoader(
            self.dataset,
            shuffle=True,
            num_workers=num_workers,
            batch_size=self.config.store_batch_size,
            collate_fn=collate_fn,
            drop_last=True,
        )
        self.image_dataloader_eval = torch.utils.data.DataLoader(
            self.eval_dataset,
            shuffle=True,
            num_workers=num_workers,
            batch_size=self.config.store_batch_size,
            collate_fn=collate_fn_eval,
            drop_last=True,
        )

        self.image_dataloader_iter = self.get_batch_tokens_internal()
        self.image_dataloader_eval_iter = self.get_val_batch_tokens_internal()

        self.storage_buffer = self.get_buffer(self.config.n_batches_in_buffer // 2)
        self.dataloader = self.get_data_loader()


    def get_batch_tokens_internal(self):
        """
        Streams a batch of tokens from a dataset.
        """
        device = self.config.device
        while True:
            for data in self.image_dataloader:
                data.requires_grad_(False)
                yield data.to(device)  # 5

    def get_batch_tokens(self):
        return next(self.image_dataloader_iter)

    # returns the ground truth class as well.
    def get_val_batch_tokens_internal(self):
        """
        Streams a batch of tokens from a dataset.
        """
        device = self.config.device
        while True:
            for image_data, labels in self.image_dataloader_eval:
                image_data.requires_grad_(False)
                labels.requires_grad_(False)
                yield image_data.to(device), labels.to(device)

    def get_val_batch_tokens(self):
        return next(self.image_dataloader_eval_iter)

    def get_activations(self, batch_tokens: torch.Tensor, get_loss: bool = False):
        """
        Returns activations of shape (batches, context, num_layers, d_in)
        """
        layers = (
            self.config.hook_point_layer
            if isinstance(self.config.hook_point_layer, list)
            else [self.config.hook_point_layer]
        )
        act_names = [self.config.hook_point.format(layer=layer) for layer in layers]
        hook_point_max_layer = max(layers)

        if self.config.hook_point_head_index is not None:
            layerwise_activations = self.model.run_with_cache(
                batch_tokens,
                names_filter=act_names,
                stop_at_layer=hook_point_max_layer + 1,
            )[1]
            activations_list = [
                layerwise_activations[act_name][:, :, self.config.hook_point_head_index]
                for act_name in act_names
            ]
        else:
            layerwise_activations = self.model.run_with_cache(  ####
                batch_tokens,
                names_filter=act_names,
                stop_at_layer=hook_point_max_layer + 1,
            )[1]
            activations_list = [
                layerwise_activations[act_name] for act_name in act_names
            ]

        # Stack along a new dimension to keep separate layers distinct
        stacked_activations = torch.stack(activations_list, dim=2)

        return stacked_activations

    def get_buffer(self, n_batches_in_buffer: int):
        context_size = self.config.context_size
        batch_size = self.config.store_batch_size
        d_in = self.config.d_in
        total_size = batch_size * n_batches_in_buffer
        num_layers = (
            len(self.config.hook_point_layer)
            if isinstance(self.config.hook_point_layer, list)
            else 1
        )  # Number of hook points or layers

        refill_iterator = range(0, batch_size * n_batches_in_buffer, batch_size)
        # Initialize empty tensor buffer of the maximum required size with an additional dimension for layers
        new_buffer = torch.zeros(
            (total_size, context_size, num_layers, d_in),
            dtype=self.config.dtype,
            device=self.config.device,
        )

        for refill_batch_idx_start in refill_iterator:
            refill_batch_tokens = self.get_batch_tokens()  ######
            refill_activations = self.get_activations(refill_batch_tokens)

            new_buffer[
                refill_batch_idx_start : refill_batch_idx_start + batch_size, ...
            ] = refill_activations

            # pbar.update(1)

        new_buffer = new_buffer.reshape(-1, num_layers, d_in)
        new_buffer = new_buffer[torch.randperm(new_buffer.shape[0])]

        return new_buffer

    def get_data_loader(
        self,
    ) -> Iterator[Any]:
        """
        Return a torch.utils.dataloader which you can get batches from.

        Should automatically refill the buffer when it gets to n % full.
        (better mixing if you refill and shuffle regularly).

        """

        batch_size = self.config.train_batch_size_tokens

        # 1. # create new buffer by mixing stored and new buffer
        mixing_buffer = torch.cat(
            [self.get_buffer(self.config.n_batches_in_buffer // 2), self.storage_buffer], ####
            dim=0,
        )

        mixing_buffer = mixing_buffer[torch.randperm(mixing_buffer.shape[0])]

        # 2.  put 50 % in storage
        self.storage_buffer = mixing_buffer[: mixing_buffer.shape[0] // 2]

        # 3. put other 50 % in a dataloader
        dataloader = iter(
            DataLoader(
                # TODO: seems like a typing bug?
                cast(Any, mixing_buffer[mixing_buffer.shape[0] // 2 :]),
                batch_size=batch_size,
                shuffle=True,
            )
        )

        return dataloader

    def next_batch(self):
        """
        Get the next batch from the current DataLoader.
        If the DataLoader is exhausted, refill the buffer and create a new DataLoader.
        """
        try:
            # Try to get the next batch
            return next(self.dataloader)
        except StopIteration:
            # If the DataLoader is exhausted, create a new one
            self.dataloader = self.get_data_loader() #### 97
            return next(self.dataloader)


In [4]:
from sae_lens.training.sae_group import SparseAutoencoderDictionary
from transformers import CLIPProcessor

config = SAETrainConfig()
processor = CLIPProcessor.from_pretrained(config.model_name)
dataset = HFDataset("awilliamson/fashion-train", processor.image_processor, "image", "text") # load_dataset("awilliamson/fashion-train", split="train")
eval_dataset = HFDataset("awilliamson/fashion-eval", processor.image_processor, "image", "text") # load_dataset("awilliamson/fashion-validation", split="train")
# cfg.training_tokens = int(1_300_000*setup_args['num_epochs']) * cfg.context_size
config.training_tokens = len(dataset) * config.num_epochs
sae_group = SparseAutoencoderDictionary(config)
model = HookedViT.from_pretrained(config.model_name, is_timm=False, is_clip=True)
model.to(config.device)

activation_store = OpenCLIPActivationsStore(
    config = config,
    model = model,
    dataset = dataset,
    eval_dataset = eval_dataset,
    num_workers = 0,
)

for i, (name, sae) in enumerate(sae_group):
    hyp = sae.cfg
    print(
        f"{i}: Name: {name} Layer {hyp.hook_point_layer}, p_norm {hyp.lp_norm}, alpha {hyp.l1_coefficient}"
    )




Run name: 24576-L1-0.006-LR-0.0005-Tokens-2.000e+06
n_tokens_per_buffer (millions): 0.065792
Lower bound: n_contexts_per_buffer (millions): 0.000256
Total training steps: 250000
Total wandb updates: 10000
n_tokens_per_feature_sampling_window (millions): 0.514
n_tokens_per_dead_feature_window (millions): 5.14
We will reset the sparsity calculation 1000 times.
Number tokens in sparsity calculation window: 2.00e+03
Using Ghost Grads.




Run name: 24576-L1-0.006-LR-0.0005-Tokens-4.024e+04
n_tokens_per_buffer (millions): 0.065792
Lower bound: n_contexts_per_buffer (millions): 0.000256
Total training steps: 5029
Total wandb updates: 201
n_tokens_per_feature_sampling_window (millions): 0.514
n_tokens_per_dead_feature_window (millions): 5.14
We will reset the sparsity calculation 20 times.
Number tokens in sparsity calculation window: 2.00e+03
Using Ghost Grads.
Run name: 24576-L1-0.006-LR-0.0005-Tokens-4.024e+04
n_tokens_per_buffer (millions): 0.065792
Lower bound: n_contexts_per_buffer (millions): 0.000256
Total training steps: 5029
Total wandb updates: 201
n_tokens_per_feature_sampling_window (millions): 0.514
n_tokens_per_dead_feature_window (millions): 5.14
We will reset the sparsity calculation 20 times.
Number tokens in sparsity calculation window: 2.00e+03
Using Ghost Grads.
{'n_layers': 24, 'd_model': 1024, 'd_head': 64, 'model_name': '', 'n_heads': 16, 'd_mlp': 4096, 'activation_name': 'gelu', 'eps': 1e-05, 'orig

In [5]:
import wandb
torch.set_grad_enabled(True)
from sae.train import train_sae_group_on_vision_model


if config.log_to_wandb:
    wandb.init(project=config.wandb_project, config=cast(Any, config), name=config.run_name)

train_sae_group_on_vision_model(
    model,
    sae_group,
    activation_store,
    train_contexts=None, #TODO load checkpoints correctly to match saelens v2.1.3 lm_runner!
    training_run_state=None,  #TODO load checkpoints correctly to match saelens v2.1.3 lm_runner!
    n_checkpoints=config.n_checkpoints,
    batch_size=config.train_batch_size_tokens,
    feature_sampling_window=config.feature_sampling_window,
    use_wandb=config.log_to_wandb,
    wandb_log_frequency=config.wandb_log_frequency,
    eval_every_n_wandb_logs=config.eval_every_n_wandb_logs,
    autocast=config.autocast,
)
wandb.finish()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mwillfulbytes[0m. Use [1m`wandb login --relogin`[0m to force relogin


Objective value: 34442.7500:   5%|▌         | 5/100 [00:00<00:00, 180.81it/s]
  out = torch.tensor(origin, dtype=self.dtype, device=self.device)
  scaler = torch.cuda.amp.GradScaler(enabled=autocast)
  lambda data: self._console_raw_callback("stderr", data),


VBox(children=(Label(value='2115.303 MB of 2115.313 MB uploaded (0.022 MB deduped)\r'), FloatProgress(value=0.…

0,1
details/current_l1_coefficient_layer22,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
details/current_learning_rate_layer22,▁▁▁▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▇▇▇▇█████▆▅▄▃▂▂
details/n_training_tokens,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
losses/ghost_grad_loss_layer22,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁██
losses/l1_loss_layer22,█████▇▆▆▆▄▃▃▂▃▃▂▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂
losses/mse_loss_layer22,█▄▃▃▃▂▂▂▂▂▁▂▂▃▅▂▃▂▂▁▂▁▁▂▂▂▃▂▁▂▂▂▃▁▂▁▁▁▁▁
losses/overall_loss_layer22,█▄▃▃▃▂▂▂▂▂▁▂▂▂▄▂▂▂▂▁▂▁▁▂▂▂▃▂▁▁▂▂▂▁▂▁▁▁▁▁
metrics/explained_variance_layer22,▁▃▅▆▆▆▇▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇██████
metrics/explained_variance_std_layer22,█▆▄▄▃▃▂▂▃▂▁▂▂▂▃▁▁▂▃▁▂▂▁▂▁▂▅▃▂▂▂▃▃▁▁▂▁▂▁▁
metrics/l0_layer22,█████▇▇▇▇▆▄▄▃▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
details/current_l1_coefficient_layer22,0.006
details/current_learning_rate_layer22,5e-05
details/n_training_tokens,40200.0
losses/ghost_grad_loss_layer22,0.00621
losses/l1_loss_layer22,129.13976
losses/mse_loss_layer22,6.37519
losses/overall_loss_layer22,7.15624
metrics/explained_variance_layer22,0.92948
metrics/explained_variance_std_layer22,0.01954
metrics/l0_layer22,782.25
