From 52a05d9f549c9eb13ad6a4c21b61febe0c575110 Mon Sep 17 00:00:00 2001 From: Alan Cooney <41682961+alan-cooney@users.noreply.github.com> Date: Wed, 8 Nov 2023 11:54:17 +0800 Subject: [PATCH 1/3] Add stateful iterator to the pipeline --- .../source_data/abstract_dataset.py | 20 +++++-- .../source_data/c4_pre_tokenized.py | 4 +- .../source_data/pile_uncopyrighted.py | 4 +- sparse_autoencoder/source_data/random_int.py | 9 +-- .../tests/test_abstract_dataset.py | 4 +- .../train/generate_activations.py | 21 +++---- sparse_autoencoder/train/pipeline.py | 58 +++++++++++++++++-- 7 files changed, 89 insertions(+), 31 deletions(-) diff --git a/sparse_autoencoder/source_data/abstract_dataset.py b/sparse_autoencoder/source_data/abstract_dataset.py index a8c040bd..d32fbb4b 100644 --- a/sparse_autoencoder/source_data/abstract_dataset.py +++ b/sparse_autoencoder/source_data/abstract_dataset.py @@ -3,6 +3,8 @@ from typing import Any, Generic, TypedDict, TypeVar, final from datasets import IterableDataset, load_dataset +from jaxtyping import Int +from torch import Tensor from torch.utils.data import DataLoader from torch.utils.data import Dataset as TorchDataset @@ -11,12 +13,18 @@ """A tokenized prompt.""" -class PreprocessTokenizedPrompts(TypedDict): - """Preprocess tokenized prompts return type.""" +class TokenizedPrompts(TypedDict): + """Tokenized prompts.""" input_ids: list[TokenizedPrompt] +class TorchTokenizedPrompts(TypedDict): + """Tokenized prompts prepared for PyTorch.""" + + input_ids: Int[Tensor, "batch pos"] + + HuggingFaceDatasetItem = TypeVar("HuggingFaceDatasetItem", bound=Any) """Hugging face dataset item typed dict. @@ -65,7 +73,7 @@ def preprocess( source_batch: HuggingFaceDatasetItem, *, context_size: int, - ) -> PreprocessTokenizedPrompts: + ) -> TokenizedPrompts: """Preprocess function. Takes a `preprocess_batch_size` ($m$) batch of source data (which may e.g. include string @@ -154,7 +162,7 @@ def __next__(self) -> Any: # noqa: ANN401 return next(iter(self)) @final - def get_dataloader(self, batch_size: int) -> DataLoader: + def get_dataloader(self, batch_size: int) -> DataLoader[TorchTokenizedPrompts]: """Get a PyTorch DataLoader. Args: @@ -163,9 +171,9 @@ def get_dataloader(self, batch_size: int) -> DataLoader: Returns: PyTorch DataLoader. """ - torch_dataset: TorchDataset = self.dataset.with_format("torch") # type: ignore + torch_dataset: TorchDataset[TorchTokenizedPrompts] = self.dataset.with_format("torch") # type: ignore - return DataLoader( + return DataLoader[TorchTokenizedPrompts]( torch_dataset, batch_size=batch_size, # Shuffle is most efficiently done with the `shuffle` method on the dataset itself, not diff --git a/sparse_autoencoder/source_data/c4_pre_tokenized.py b/sparse_autoencoder/source_data/c4_pre_tokenized.py index 85a9fc7f..58a41dc4 100644 --- a/sparse_autoencoder/source_data/c4_pre_tokenized.py +++ b/sparse_autoencoder/source_data/c4_pre_tokenized.py @@ -7,8 +7,8 @@ from typing import TypedDict, final from sparse_autoencoder.source_data.abstract_dataset import ( - PreprocessTokenizedPrompts, SourceDataset, + TokenizedPrompts, ) @@ -33,7 +33,7 @@ def preprocess( source_batch: NeelC4SourceDataBatch, *, context_size: int, - ) -> PreprocessTokenizedPrompts: + ) -> TokenizedPrompts: """Preprocess a batch of prompts. As this dataset is already tokenized, all this does is split up each item based on the diff --git a/sparse_autoencoder/source_data/pile_uncopyrighted.py b/sparse_autoencoder/source_data/pile_uncopyrighted.py index f0928d3e..1500aabd 100644 --- a/sparse_autoencoder/source_data/pile_uncopyrighted.py +++ b/sparse_autoencoder/source_data/pile_uncopyrighted.py @@ -4,8 +4,8 @@ from transformers import PreTrainedTokenizerBase from sparse_autoencoder.source_data.abstract_dataset import ( - PreprocessTokenizedPrompts, SourceDataset, + TokenizedPrompts, ) @@ -33,7 +33,7 @@ def preprocess( source_batch: PileUncopyrightedSourceDataBatch, *, context_size: int, - ) -> PreprocessTokenizedPrompts: + ) -> TokenizedPrompts: """Preprocess a batch of prompts. For each prompt's `text`, tokenize it and chunk into a list of tokenized prompts of length diff --git a/sparse_autoencoder/source_data/random_int.py b/sparse_autoencoder/source_data/random_int.py index d9c021b7..0adb93b6 100644 --- a/sparse_autoencoder/source_data/random_int.py +++ b/sparse_autoencoder/source_data/random_int.py @@ -9,8 +9,9 @@ from transformers import PreTrainedTokenizerFast from sparse_autoencoder.source_data.abstract_dataset import ( - PreprocessTokenizedPrompts, SourceDataset, + TokenizedPrompts, + TorchTokenizedPrompts, ) @@ -65,7 +66,7 @@ def preprocess( source_batch: RandomIntSourceData, *, context_size: int, - ) -> PreprocessTokenizedPrompts: + ) -> TokenizedPrompts: """Preprocess a batch of prompts. Not implemented for this dummy dataset. @@ -107,6 +108,6 @@ def __init__( """ self.dataset = RandomIntHuggingFaceDataset(50000, context_size=context_size) # type: ignore - def get_dataloader(self, batch_size: int) -> DataLoader: # type: ignore + def get_dataloader(self, batch_size: int) -> DataLoader[TorchTokenizedPrompts]: # type: ignore """Get Dataloader.""" - return DataLoader(self.dataset, batch_size=batch_size) # type: ignore + return DataLoader[TorchTokenizedPrompts](self.dataset, batch_size=batch_size) # type: ignore diff --git a/sparse_autoencoder/source_data/tests/test_abstract_dataset.py b/sparse_autoencoder/source_data/tests/test_abstract_dataset.py index bd24acb5..593063d7 100644 --- a/sparse_autoencoder/source_data/tests/test_abstract_dataset.py +++ b/sparse_autoencoder/source_data/tests/test_abstract_dataset.py @@ -7,8 +7,8 @@ import torch from sparse_autoencoder.source_data.abstract_dataset import ( - PreprocessTokenizedPrompts, SourceDataset, + TokenizedPrompts, ) @@ -30,7 +30,7 @@ def preprocess( source_batch: MockHuggingFaceDatasetItem, # noqa: ARG002 *, context_size: int, # noqa: ARG002 - ) -> PreprocessTokenizedPrompts: + ) -> TokenizedPrompts: """Preprocess a batch of prompts.""" preprocess_batch = 100 tokenized_texts = torch.randint( diff --git a/sparse_autoencoder/train/generate_activations.py b/sparse_autoencoder/train/generate_activations.py index f1473a96..9b162328 100644 --- a/sparse_autoencoder/train/generate_activations.py +++ b/sparse_autoencoder/train/generate_activations.py @@ -1,10 +1,10 @@ """Generate activations for training a Sparse Autoencoder.""" +from collections.abc import Iterable from functools import partial from jaxtyping import Int import torch from torch import Tensor -from torch.utils.data import DataLoader from tqdm.auto import tqdm from transformer_lens import HookedTransformer @@ -12,6 +12,7 @@ ActivationStore, StoreFullError, ) +from sparse_autoencoder.source_data.abstract_dataset import TorchTokenizedPrompts from sparse_autoencoder.src_model.store_activations_hook import store_activations_hook @@ -20,7 +21,9 @@ def generate_activations( layer: int, cache_name: str, store: ActivationStore, - dataloader: DataLoader[Int[Tensor, " pos"]], + source_data: Iterable[TorchTokenizedPrompts], + context_size: int, + batch_size: int, num_items: int, device: torch.device | None = None, ) -> None: @@ -48,7 +51,9 @@ def generate_activations( `blocks.0.ln2.hook_scale`, `blocks.0.ln2.hook_normalized`, `blocks.0.mlp.hook_pre`, `blocks.0.mlp.hook_post`, `blocks.0.hook_mlp_out`, `blocks.0.hook_resid_post`]. store: The activation store to use. - dataloader: Dataloader containing source model input tokens. + source_data: Stateful iterator that yields batches of data to generate activations. + context_size: Number of tokens in each prompt. + batch_size: Size of each batch. num_items: Number of activation vectors to generate. This is an approximate rather than strict limit. device: Device to run the model on. @@ -62,15 +67,11 @@ def generate_activations( model.add_hook(cache_name, hook) # Get the input dimensions for logging - first_item: Int[Tensor, "batch pos"] = next(iter(dataloader))["input_ids"] - batch_size: int = first_item.shape[0] - context_size: int = first_item.shape[1] activations_per_batch: int = context_size * batch_size total: int = num_items - num_items % activations_per_batch # Loop through the dataloader until the store reaches the desired size with torch.no_grad(), tqdm( - dataloader, desc="Generate Activations", total=total, colour="green", @@ -78,10 +79,10 @@ def generate_activations( leave=False, dynamic_ncols=True, ) as progress_bar: - for batch in dataloader: + for batch in source_data: try: - input_ids = batch["input_ids"].to(device) - model.forward(input_ids, stop_at_layer=layer + 1) # type: ignore + input_ids: Int[Tensor, "batch pos"] = batch["input_ids"].to(device) + model.forward(input_ids, stop_at_layer=layer + 1) # type: ignore (TLens is typed incorrectly) progress_bar.update(activations_per_batch) # Break the loop if the store is full diff --git a/sparse_autoencoder/train/pipeline.py b/sparse_autoencoder/train/pipeline.py index 74191c3c..a80ee590 100644 --- a/sparse_autoencoder/train/pipeline.py +++ b/sparse_autoencoder/train/pipeline.py @@ -1,7 +1,8 @@ """Training Pipeline.""" -from jaxtyping import Int +from collections.abc import Iterable + +from jaxtyping import Float, Int import torch -from torch import Tensor from torch.optim import Adam from torch.utils.data import DataLoader from tqdm.auto import tqdm @@ -10,16 +11,58 @@ from sparse_autoencoder.activation_store.base_store import ActivationStore from sparse_autoencoder.autoencoder.model import SparseAutoencoder +from sparse_autoencoder.source_data.abstract_dataset import ( + SourceDataset, + TokenizedPrompts, + TorchTokenizedPrompts, +) from sparse_autoencoder.train.generate_activations import generate_activations from sparse_autoencoder.train.sweep_config import SweepParametersRuntime from sparse_autoencoder.train.train_autoencoder import train_autoencoder +def stateful_dataloader_iterable( + dataloader: DataLoader[TorchTokenizedPrompts] +) -> Iterable[TorchTokenizedPrompts]: + """Create a stateful dataloader iterable. + + Create an iterable that maintains it's position in the dataloader between loops. + + Examples: + Without this, when iterating over a DataLoader with 2 loops, each loop get the same data + (assuming shuffle is turned off). That is to say, the second loop won't maintain the + position from where the first loop left off. + + >>> from datasets import Dataset + >>> from torch.utils.data import DataLoader + >>> def gen(): + ... yield {"int": 0} + ... yield {"int": 1} + >>> data = DataLoader(Dataset.from_generator(gen)) + >>> next(iter(data))["int"], next(iter(data))["int"] + (tensor([0]), tensor([0])) + + By contrast if you create a stateful iterable from the dataloader, each loop will get + different data. + + >>> iterator = stateful_dataloader_iterable(data) + >>> next(iter(data))["int"], next(iter(data))["int"] + (tensor([0]), tensor([1])) + + Args: + dataloader: PyTorch DataLoader. + + Returns: + Stateful iterable over the data in the dataloader. + """ + yield from dataloader + + def pipeline( src_model: HookedTransformer, src_model_activation_hook_point: str, src_model_activation_layer: int, - src_dataloader: DataLoader[Int[Tensor, " pos"]], + source_dataset: SourceDataset[TokenizedPrompts], activation_store: ActivationStore, num_activations_before_training: int, autoencoder: SparseAutoencoder, @@ -35,7 +78,7 @@ def pipeline( src_model_activation_hook_point: The hook point to get activations from. src_model_activation_layer: The layer to get activations from. This is used to stop the model after this layer, as we don't need the final logits. - src_dataloader: DataLoader containing source model inputs (typically batches of prompts) + source_dataset: Source dataset containing source model inputs (typically batches of prompts) that are used to generate the activations data. activation_store: The store to buffer activations in once generated, before training the autoencoder. @@ -56,6 +99,9 @@ def pipeline( weight_decay=sweep_parameters.adam_weight_decay, ) + source_dataloader = source_dataset.get_dataloader(sweep_parameters.batch_size) + source_data_iterator = stateful_dataloader_iterable(source_dataloader) + # Run loop until source data is exhausted: with logging_redirect_tqdm(), tqdm( desc="Generate/Train Cycles", @@ -69,9 +115,11 @@ def pipeline( src_model_activation_layer, src_model_activation_hook_point, activation_store, - src_dataloader, + source_data_iterator, device=device, + context_size=source_dataset.context_size, num_items=num_activations_before_training, + batch_size=sweep_parameters.batch_size, ) if len(activation_store) == 0: break From c64c5b9825be033640440543de04795702fbd988 Mon Sep 17 00:00:00 2001 From: Alan Cooney <41682961+alan-cooney@users.noreply.github.com> Date: Wed, 8 Nov 2023 12:03:42 +0800 Subject: [PATCH 2/3] Fix checks --- sparse_autoencoder/train/pipeline.py | 3 +-- sparse_autoencoder/train/tests/test_generate_activations.py | 4 +++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/sparse_autoencoder/train/pipeline.py b/sparse_autoencoder/train/pipeline.py index a80ee590..7dc4c562 100644 --- a/sparse_autoencoder/train/pipeline.py +++ b/sparse_autoencoder/train/pipeline.py @@ -1,7 +1,6 @@ """Training Pipeline.""" from collections.abc import Iterable -from jaxtyping import Float, Int import torch from torch.optim import Adam from torch.utils.data import DataLoader @@ -46,7 +45,7 @@ def stateful_dataloader_iterable( different data. >>> iterator = stateful_dataloader_iterable(data) - >>> next(iter(data))["int"], next(iter(data))["int"] + >>> next(iterator)["int"], next(iterator)["int"] (tensor([0]), tensor([1])) Args: diff --git a/sparse_autoencoder/train/tests/test_generate_activations.py b/sparse_autoencoder/train/tests/test_generate_activations.py index a569be00..a159c126 100644 --- a/sparse_autoencoder/train/tests/test_generate_activations.py +++ b/sparse_autoencoder/train/tests/test_generate_activations.py @@ -26,8 +26,10 @@ def test_activations_generated() -> None: layer=1, cache_name="blocks.1.mlp.hook_post", store=store, - dataloader=dataloader, + source_data=iter(dataloader), num_items=num_items, + context_size=dataset.context_size, + batch_size=2, ) assert len(store) >= num_items From fa9adacc32a70abc133b6b6ce87a8d0d955d8037 Mon Sep 17 00:00:00 2001 From: Alan Cooney <41682961+alan-cooney@users.noreply.github.com> Date: Wed, 8 Nov 2023 13:34:37 +0800 Subject: [PATCH 3/3] Update demo --- .vscode/cspell.json | 1 + demo.ipynb | 188 +++--------------- .../source_data/abstract_dataset.py | 2 + sparse_autoencoder/train/pipeline.py | 9 +- 4 files changed, 34 insertions(+), 166 deletions(-) diff --git a/.vscode/cspell.json b/.vscode/cspell.json index a48a99f1..38bfeec8 100644 --- a/.vscode/cspell.json +++ b/.vscode/cspell.json @@ -69,6 +69,7 @@ "runcap", "sharded", "snapshottest", + "solu", "tqdm", "transformer_lens", "typecheck", diff --git a/demo.ipynb b/demo.ipynb index 4c3f94fd..0684a590 100644 --- a/demo.ipynb +++ b/demo.ipynb @@ -16,9 +16,18 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 5, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + } + ], "source": [ "# Autoreload\n", "%load_ext autoreload\n", @@ -27,7 +36,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -36,13 +45,12 @@ "from transformer_lens import HookedTransformer\n", "from transformer_lens.utils import get_device\n", "from transformers import PreTrainedTokenizerBase\n", - "import torch\n", - "import wandb" + "import torch" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -58,27 +66,9 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 9, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Loaded pretrained model solu-1l into HookedTransformer\n" - ] - }, - { - "data": { - "text/plain": [ - "2048" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "src_model = HookedTransformer.from_pretrained(\"solu-1l\", dtype=\"float32\")\n", "src_d_mlp: int = src_model.cfg.d_mlp # type: ignore\n", @@ -94,38 +84,12 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", - "To disable this warning, you can either:\n", - "\t- Avoid using `tokenizers` before the fork if possible\n", - "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "a1ce590449484e1788109c4f13a2e8bf", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Resolving data files: 0%| | 0/30 [00:00 None: @@ -85,6 +85,7 @@ def pipeline( autoencoder. As a guide, 1 million activations, each of size 1024, will take up about 2GB of memory (assuming float16/bfloat16). autoencoder: The autoencoder to train. + source_dataset_batch_size: Batch size of tokenized prompts for generating the source data. sweep_parameters: Parameter config to use. device: Device to run pipeline on. """ @@ -98,7 +99,7 @@ def pipeline( weight_decay=sweep_parameters.adam_weight_decay, ) - source_dataloader = source_dataset.get_dataloader(sweep_parameters.batch_size) + source_dataloader = source_dataset.get_dataloader(source_dataset_batch_size) source_data_iterator = stateful_dataloader_iterable(source_dataloader) # Run loop until source data is exhausted: @@ -118,7 +119,7 @@ def pipeline( device=device, context_size=source_dataset.context_size, num_items=num_activations_before_training, - batch_size=sweep_parameters.batch_size, + batch_size=source_dataset_batch_size, ) if len(activation_store) == 0: break