# Benchmarks

Benchmarks to help with design/architecture decisions of the lib.

## Setup

In [1]:
# Autoreload
# %load_ext autoreload
# %autoreload 2

import gzip
import os
import shutil
import tempfile
import random
import numpy as np

import pandas as pd
import torch
from torch import Tensor
import time
from datasets import load_dataset
from transformer_lens import HookedTransformer

from sparse_autoencoder.autoencoder.model import SparseAutoencoder
from sparse_autoencoder.dataset.dataloader import (
    collate_neel_c4_tokenized,
    create_dataloader,
)
from sparse_autoencoder.activations.ListActivationStore import ListActivationStore
from sparse_autoencoder.train.train import pipeline

## Activation Tensor Sizes

It's useful to know both the size and how much they can be compressed.

In [None]:
# Create a batch of text data
dataset = load_dataset("NeelNanda/c4-code-tokenized-2b", split="train", streaming=True)
first_batch = []
for idx, example in enumerate(dataset):
    if not idx <= 24:
        break
    first_batch.append(example["tokens"])
first_batch = torch.tensor(first_batch)
f"Number of activations to store in this benchmark test: {first_batch.numel()}"

In [None]:
# Create the activations
src_model = HookedTransformer.from_pretrained("NeelNanda/GELU_1L512W_C4_Code")
logits, cache = src_model.run_with_cache(first_batch)
activations = cache["blocks.0.mlp.hook_post"].half()
number_activations = activations.numel()
size_bytes_activations = number_activations * 2  # Assume float 16
size_mb_activations = f"{size_bytes_activations / (10**6):.2f} MB"
f"With {activations.numel()} features at half precision, the features take up {size_mb_activations} of memory"

Next we try compressing on the disk (and find the impact is small so probably not worth it):

In [None]:
# Save to temp dir
temp_dir = tempfile.gettempdir()
temp_file = temp_dir + "/temp.pt"
temp_file_gz = temp_file + ".gz"
torch.save(activations, temp_file)

# Zip it
with open(temp_file, "rb") as f_in:
    with gzip.open(temp_file_gz, "wb") as f_out:
        shutil.copyfileobj(f_in, f_out)

# Get the file size back
fs_bytes = os.path.getsize(temp_file_gz)
f"Compressed file size is {fs_bytes / (10**6):.2f} MB"

Now let's calculate assuming 8 billion activations:

In [None]:
assumed_n_activation_batches = 8 * (10**9)
assumed_n_activations_per_batch = 2048
uncompressed_size_per_activation = 2  # float16
estimated_size = (
    assumed_n_activation_batches
    * assumed_n_activations_per_batch
    * uncompressed_size_per_activation
)
f"With {assumed_n_activation_batches/10**9}B activations with {assumed_n_activations_per_batch} features, \
the estimated size is {estimated_size / (10**12):.2f} TB"

In [None]:
# Calculate the amount of activations you can store with different sizes
sizes_gb = [10, 50, 100, 300, 500, 1000]
activations_per_size = [
    i * (10**9) / uncompressed_size_per_activation / assumed_n_activations_per_batch
    for i in sizes_gb
]

table = pd.DataFrame({"Size (GB)": sizes_gb, "Activations": activations_per_size})
table["Activations"] = table["Activations"].apply(
    lambda x: "{:,.0f}".format(x / 10**6) + "M"
)
table

VastAI systems often have quite a lot of HD space (e.g. 300GB) but available ram is often smaller
(e.g. 50GB and we need a reasonable amount left over for moving tensors around etc). This means that
we can store c. 5-10M activations on a typical instance in CPU RAM (sometimes 25M+), or 50-100M on
disk. Both seem like plenty!

To note that replenishing a buffer of cached activations when half used in training seems like a lot
of pain, considering that the improvement is likely marginal. Particularly if we also randomly sort
the prompts for the forward pass of the source model, we'll have a chance of two tokens coming from
the same/nearby prompts as very small.

The conclusion is therefore that we do a need some sort of buffer, as we can't store 40TB on disk
easily, and this buffer can be disk or ram. It needs to store asynchronously (so it doesn't block
the forward pass), and it needs to be able to handle multiple simultaneous writes from e.g.
distributed GPUs. The best approaches here are probably (a) pre-allocating a cpu ram space with
torch.empty, or (b) writing asynchronously to disk.

## Dataset Fetching

## Getting Activations (Forward Pass)

## Activations Buffer

### Storage methods

Use the ListActivationStore:

In [3]:
# Benchmark storing activations
def benchmark_list_activation_store(
    n_items: int = 1_000_000,
    n_features: int = 2056,
    batch_size: int = 100,
    model_device: torch.device = torch.device("cpu"),
    storage_device: torch.device = torch.device("cpu"),
    multiprocessing_enabled: bool = False,
):
    # Create the data
    n_batches = int(n_items / batch_size)
    data = [
        torch.rand((batch_size, n_features), device=model_device)
        for _ in range(n_batches)
    ]

    # Create the data store
    dataset = ListActivationStore(
        device=storage_device,
        multiprocessing_enabled=multiprocessing_enabled,
    )

    start = time.time()
    for batch in data:
        dataset.extend(batch)

    dataset.finalize()
    print(len(dataset))

    end = time.time()
    duration = end - start
    print(
        f"Storing {n_items} activations with {n_features} features took {duration:.2f} seconds."
    )

    equivalent_time_10b_activations = duration * (10**10) / n_items / (60 * 60)
    print(
        f"Equivalent time for 10B activations: {equivalent_time_10b_activations:.2f} hours."
    )


print("Benchmark without multiprocessing:")
benchmark_list_activation_store()

print("\nBenchmark with multiprocessing:")
benchmark_list_activation_store(multiprocessing_enabled=True)

print("\nBenchmark with GPU & multiprocessing:")
benchmark_list_activation_store(
    model_device=torch.device("mps"),
    storage_device=torch.device("mps"),
    multiprocessing_enabled=True,
)

Benchmark without multiprocessing:
Storing 1000000 activations with 2056 features took 0.11 seconds.
Equivalent time for 10B activations: 0.30 hours.

Benchmark with multiprocessing:


Exception ignored in: <function ListActivationStore.__del__ at 0x2a8e40720>
Traceback (most recent call last):
  File "/Users/alan/Documents/Repos/sparse_autoencoder/sparse_autoencoder/activations/ListActivationStore.py", line 236, in __del__
    self.finalise()
  File "/Users/alan/Documents/Repos/sparse_autoencoder/sparse_autoencoder/activations/ListActivationStore.py", line 232, in finalise
    self._thread_pool.shutdown(wait=True)
  File "/opt/homebrew/Cellar/python@3.11/3.11.6/Frameworks/Python.framework/Versions/3.11/lib/python3.11/concurrent/futures/thread.py", line 235, in shutdown
    t.join()
  File "/opt/homebrew/Cellar/python@3.11/3.11.6/Frameworks/Python.framework/Versions/3.11/lib/python3.11/threading.py", line 1116, in join
    raise RuntimeError("cannot join current thread")
RuntimeError: cannot join current thread
Exception ignored in: <function ListActivationStore.__del__ at 0x2a8e40720>
Traceback (most recent call last):
  File "/Users/alan/Documents/Repos/sparse_auto

Storing 1000000 activations with 2056 features took 0.03 seconds.
Equivalent time for 10B activations: 0.10 hours.

Benchmark with GPU & multiprocessing:
Storing 1000000 activations with 2056 features took 0.11 seconds.
Equivalent time for 10B activations: 0.32 hours.


Here we compare pre-creating a tensor in memory and then 

In [None]:
def method_1(
    batches_activations: list[Tensor],
    n_items: int,
    n_features: int,
    batch_size: int,
    storage_device: torch.device = torch.device("cpu"),
):
    """Method 1: Using torch.empty to pre-allocate"""
    with torch.no_grad():
        # Setup storage
        tensor_storage = torch.empty((n_items, n_features), device=storage_device)

        # Append
        start_time = time.time()
        for idx, batch in enumerate(batches_activations):
            start = idx * batch_size
            end = start + batch_size
            tensor_storage[start:end] = batch.to(storage_device)
        end_time = time.time()
        del tensor_storage

        print(f"Method 1 (torch.empty) time: {end_time - start_time:.5f} seconds")


def method_2(
    batches_activations: list[Tensor],
    storage_device: torch.device = torch.device("cpu"),
):
    """Method 2: Appending to list"""
    # Setup storage
    result_list = []

    # Append
    start_time = time.time()
    for idx, batch in enumerate(batches_activations):
        for item in batch:
            result_list.append(item.to(storage_device))
    end_time = time.time()

    print(f"Method 2 (appending to list) time: {end_time - start_time:.5f} seconds")


def run_test():
    # Config
    features = 512
    batch_size = 10
    batches = int(1_000_000 / batch_size)
    items = int(batches * batch_size)

    # Create the activations data
    model_device = torch.device("cpu")
    activations = [
        torch.randn((batch_size, features), device=model_device) for _ in range(batches)
    ]

    # Run the methods
    print("CPU -> CPU:")
    method_1(activations, items, features, batch_size)
    method_2(activations)

    # Move to another device and test again
    activations = [i.to(torch.device("mps")) for i in activations]
    print("MPS -> CPU:")
    method_1(activations, items, features, batch_size)
    method_2(activations)


run_test()

Interestingly if they're already on the CPU, both methods are quite fast. But if they are both on
the GPU, moving them across one at a time vs pre-allocating and then filling is very slow.

This suggests pre-allocating is better from a write perspective.

In [None]:
def method1(
    n_items: int,
    n_features: int,
    batch_size: int,
    random_reads: int,
    should_delete: bool = False,
) -> None:
    with torch.no_grad():
        data = torch.randn((n_items, n_features))
        has_data = torch.randn((n_items)) > 0.5

        start_time = time.time()
        batches = []
        for i in range(random_reads):
            has_data_indices = torch.where(has_data)[0]
            random_indices = torch.randint(0, len(has_data_indices), (batch_size,))
            items = data[has_data_indices[random_indices]]
            batches.append(items)

            if should_delete:
                has_data[has_data_indices[random_indices]] = False

        end_time = time.time()
        print(
            f"Method 1 (torch.empty) random read time: {end_time - start_time:.5f} seconds"
        )


def method2(
    n_items: int,
    n_features: int,
    batch_size: int,
    random_reads: int,
    should_delete: bool = False,
) -> None:
    data = [torch.randn((n_features,)) for _ in range(n_items)]

    start_time = time.time()
    batches = []
    for i in range(random_reads):
        if should_delete:
            data_np = np.array(data, dtype=object)
            sampled_indices = np.random.choice(
                data_np.shape[0], size=batch_size, replace=False
            )
            items = data_np[sampled_indices]
            data = np.delete(data, sampled_indices).tolist()

        else:
            items = random.sample(data, batch_size)
            batches.append(items)

    end_time = time.time()
    print(f"Method 2 (list) random read time: {end_time - start_time:.5f} seconds")


def run():
    n_items = 1_000_000
    n_features = 512
    batch_size = 10
    random_reads = 1000

    print("Without deleting:")
    method1(n_items, n_features, batch_size, random_reads)
    method2(n_items, n_features, batch_size, random_reads)

    print("With deleting:")
    method1(n_items, n_features, batch_size, random_reads, should_delete=True)
    # method2(n_items, n_features, batch_size, random_reads, should_delete=True)


run()

Random reads are substantially faster with list (torch.empty is slow), but we're still talking
8h/TB (roughly) so this is still too slow.

It seems like we need a better setup for reading e.g. `TensorDataset` or just a DataSet reading the
list. Let's try that:

In [None]:
def tensor_dataset_test(
    n_items: int,
    n_features: int,
    batch_size: int,
    n_reads: int,
):
    with torch.no_grad():
        data = torch.randn((n_items, n_features))
        tensor_dataset = torch.utils.data.TensorDataset(data)
        dataloader = torch.utils.data.DataLoader(
            tensor_dataset, batch_size=batch_size, shuffle=True, num_workers=2
        )

        start_time = time.time()
        for i in range(n_reads):
            batch = next(iter(dataloader))
        end_time = time.time()
        print(
            f"With TensorDataset, reading {n_reads} took {end_time - start_time:.5f} seconds"
        )


tensor_dataset_test(1_000_000, 512, 10, 1_000)

This is v. slow (dataloader shuffle has a big overhead compared to our optimum approach. Therefore 

## Learning