Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .vscode/cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
"runcap",
"sharded",
"snapshottest",
"solu",
"tqdm",
"transformer_lens",
"typecheck",
Expand Down
188 changes: 26 additions & 162 deletions demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,18 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 5,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"source": [
"# Autoreload\n",
"%load_ext autoreload\n",
Expand All @@ -27,7 +36,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -36,13 +45,12 @@
"from transformer_lens import HookedTransformer\n",
"from transformer_lens.utils import get_device\n",
"from transformers import PreTrainedTokenizerBase\n",
"import torch\n",
"import wandb"
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -58,27 +66,9 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loaded pretrained model solu-1l into HookedTransformer\n"
]
},
{
"data": {
"text/plain": [
"2048"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"src_model = HookedTransformer.from_pretrained(\"solu-1l\", dtype=\"float32\")\n",
"src_d_mlp: int = src_model.cfg.d_mlp # type: ignore\n",
Expand All @@ -94,38 +84,12 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a1ce590449484e1788109c4f13a2e8bf",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Resolving data files: 0%| | 0/30 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"outputs": [],
"source": [
"tokenizer: PreTrainedTokenizerBase = src_model.tokenizer # type: ignore\n",
"source_data = PileUncopyrightedDataset(tokenizer=tokenizer)\n",
"src_dataloader = source_data.get_dataloader(batch_size=8)"
"source_data = PileUncopyrightedDataset(tokenizer=tokenizer)"
]
},
{
Expand All @@ -137,7 +101,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -154,30 +118,9 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"SparseAutoencoder(\n",
" (encoder): Sequential(\n",
" (0): TiedBias(position=TiedBiasPosition.PRE_ENCODER)\n",
" (1): ConstrainedUnitNormLinear(in_features=2048, out_features=16384, bias=True)\n",
" (2): ReLU()\n",
" )\n",
" (decoder): Sequential(\n",
" (0): ConstrainedUnitNormLinear(in_features=16384, out_features=2048, bias=False)\n",
" (1): TiedBias(position=TiedBiasPosition.POST_DECODER)\n",
" )\n",
")"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"autoencoder = SparseAutoencoder(src_d_mlp, src_d_mlp * 8, torch.zeros(src_d_mlp))\n",
"autoencoder"
Expand All @@ -199,7 +142,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -208,94 +151,15 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "309fbf4a29a147ada581ba09b0cff34d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generate/Train Cycles: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a26f99ac95d44bf196f1d5fe70bafbe9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generate Activations: 0%| | 0/1000000 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5cd07ef70a1f4b4c97cd2828f4cfd745",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Train Autoencoder: 0%| | 0/1000000 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/alan/Documents/Repos/sparse_autoencoder/.venv/lib/python3.11/site-packages/torch/autograd/__init__.py:251: UserWarning: The operator 'aten::sgn.out' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/mps/MPSFallback.mm:13.)\n",
" Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "242a91de8f694f64a7d04e93f25b95dc",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generate Activations: 0%| | 0/1000000 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5337eac728eb4ced9590c001e20e53ed",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Train Autoencoder: 0%| | 0/1000000 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"outputs": [],
"source": [
"pipeline(\n",
" src_model=src_model,\n",
" src_model_activation_hook_point=\"blocks.0.mlp.hook_post\",\n",
" src_model_activation_layer=0,\n",
" src_dataloader=src_dataloader,\n",
" source_dataset=source_data,\n",
" activation_store=store,\n",
" num_activations_before_training=max_items,\n",
" autoencoder=autoencoder,\n",
Expand Down
22 changes: 16 additions & 6 deletions sparse_autoencoder/source_data/abstract_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from typing import Any, Generic, TypedDict, TypeVar, final

from datasets import IterableDataset, load_dataset
from jaxtyping import Int
from torch import Tensor
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as TorchDataset

Expand All @@ -11,12 +13,18 @@
"""A tokenized prompt."""


class PreprocessTokenizedPrompts(TypedDict):
"""Preprocess tokenized prompts return type."""
class TokenizedPrompts(TypedDict):
"""Tokenized prompts."""

input_ids: list[TokenizedPrompt]


class TorchTokenizedPrompts(TypedDict):
"""Tokenized prompts prepared for PyTorch."""

input_ids: Int[Tensor, "batch pos"]


HuggingFaceDatasetItem = TypeVar("HuggingFaceDatasetItem", bound=Any)
"""Hugging face dataset item typed dict.

Expand Down Expand Up @@ -65,7 +73,7 @@ def preprocess(
source_batch: HuggingFaceDatasetItem,
*,
context_size: int,
) -> PreprocessTokenizedPrompts:
) -> TokenizedPrompts:
"""Preprocess function.

Takes a `preprocess_batch_size` ($m$) batch of source data (which may e.g. include string
Expand Down Expand Up @@ -119,6 +127,8 @@ def __init__(
preprocess_batch_size: The batch size to use just for preprocessing the dataset (e.g.
tokenizing prompts).
"""
self.context_size = context_size

# Load the dataset
dataset: IterableDataset = load_dataset(dataset_path, streaming=True, split=dataset_split) # type: ignore

Expand Down Expand Up @@ -154,7 +164,7 @@ def __next__(self) -> Any: # noqa: ANN401
return next(iter(self))

@final
def get_dataloader(self, batch_size: int) -> DataLoader:
def get_dataloader(self, batch_size: int) -> DataLoader[TorchTokenizedPrompts]:
"""Get a PyTorch DataLoader.

Args:
Expand All @@ -163,9 +173,9 @@ def get_dataloader(self, batch_size: int) -> DataLoader:
Returns:
PyTorch DataLoader.
"""
torch_dataset: TorchDataset = self.dataset.with_format("torch") # type: ignore
torch_dataset: TorchDataset[TorchTokenizedPrompts] = self.dataset.with_format("torch") # type: ignore

return DataLoader(
return DataLoader[TorchTokenizedPrompts](
torch_dataset,
batch_size=batch_size,
# Shuffle is most efficiently done with the `shuffle` method on the dataset itself, not
Expand Down
4 changes: 2 additions & 2 deletions sparse_autoencoder/source_data/c4_pre_tokenized.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from typing import TypedDict, final

from sparse_autoencoder.source_data.abstract_dataset import (
PreprocessTokenizedPrompts,
SourceDataset,
TokenizedPrompts,
)


Expand All @@ -33,7 +33,7 @@ def preprocess(
source_batch: NeelC4SourceDataBatch,
*,
context_size: int,
) -> PreprocessTokenizedPrompts:
) -> TokenizedPrompts:
"""Preprocess a batch of prompts.

As this dataset is already tokenized, all this does is split up each item based on the
Expand Down
4 changes: 2 additions & 2 deletions sparse_autoencoder/source_data/pile_uncopyrighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from transformers import PreTrainedTokenizerBase

from sparse_autoencoder.source_data.abstract_dataset import (
PreprocessTokenizedPrompts,
SourceDataset,
TokenizedPrompts,
)


Expand Down Expand Up @@ -33,7 +33,7 @@ def preprocess(
source_batch: PileUncopyrightedSourceDataBatch,
*,
context_size: int,
) -> PreprocessTokenizedPrompts:
) -> TokenizedPrompts:
"""Preprocess a batch of prompts.

For each prompt's `text`, tokenize it and chunk into a list of tokenized prompts of length
Expand Down
Loading