Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 0 additions & 120 deletions delphi/latents/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,123 +322,3 @@ async def _aprocess_latent(self, buffer_output: BufferOutput) -> LatentRecord:
if self.transform is not None:
self.transform(record)
return record


class LatentLoader:
"""
Loader class for processing latent records from a LatentDataset.
"""

def __init__(
self,
latent_dataset: "LatentDataset",
constructor: Optional[Callable] = None,
sampler: Optional[Callable] = None,
transform: Optional[Callable] = None,
):
"""
Initialize a LatentLoader.

Args:
latent_dataset (LatentDataset): The dataset to load latents from.
constructor (Optional[Callable]): Function to construct latent records.
sampler (Optional[Callable]): Function to sample from latent records.
transform (Optional[Callable]): Function to transform latent records.
"""
self.latent_dataset = latent_dataset
self.constructor = constructor
self.sampler = sampler
self.transform = transform

async def __aiter__(self):
"""
Asynchronous iterator for processing latent records.

Yields:
LatentRecord: Processed latent records.
"""
for buffer in self.latent_dataset.buffers:
async for record in self._aprocess_buffer(buffer):
yield record

async def _aprocess_buffer(self, buffer):
"""
Asynchronously process a buffer.

Args:
buffer (TensorBuffer): Buffer to process.

Yields:
Optional[LatentRecord]: Processed latent record or None.
"""
for data in buffer:
if data is not None:
record = await self._aprocess_latent(data)
if record is not None:
yield record
await asyncio.sleep(0) # Allow other coroutines to run

async def _aprocess_latent(self, buffer_output):
"""
Asynchronously process a single latent.

Args:
buffer_output (BufferOutput): Latent data to process.

Returns:
Optional[LatentRecord]: Processed latent record or None.
"""
record = LatentRecord(buffer_output.latent)
if self.constructor is not None:
self.constructor(record=record, buffer_output=buffer_output)
if self.sampler is not None:
self.sampler(record)
if self.transform is not None:
self.transform(record)
return record

def __iter__(self):
"""
Synchronous iterator for processing latent records.

Yields:
LatentRecord: Processed latent records.
"""
for buffer in self.latent_dataset.buffers:
for record in self._process_buffer(buffer):
yield record

def _process_buffer(self, buffer):
"""
Process a buffer synchronously.

Args:
buffer (TensorBuffer): Buffer to process.

Yields:
Optional[LatentRecord]: Processed latent record or None.
"""
for data in buffer:
if data is not None:
record = self._process_latent(data)
if record is not None:
yield record

def _process_latent(self, buffer_output):
"""
Process a single latent synchronously.

Args:
buffer_output (BufferOutput): Latent data to process.

Returns:
Optional[LatentRecord]: Processed latent record or None.
"""
record = LatentRecord(buffer_output.latent)
if self.constructor is not None:
self.constructor(record=record, buffer_output=buffer_output)
if self.sampler is not None:
self.sampler(record)
if self.transform is not None:
self.transform(record)
return record
22 changes: 11 additions & 11 deletions examples/generate_explanations.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
"metadata": {},
"outputs": [],
"source": [
"import asyncio\n",
"import os\n",
"from functools import partial\n",
"\n",
Expand All @@ -30,7 +29,7 @@
"from delphi.clients import OpenRouter\n",
"from delphi.config import ExperimentConfig, LatentConfig\n",
"from delphi.explainers import DefaultExplainer\n",
"from delphi.latents import LatentDataset, LatentLoader\n",
"from delphi.latents import LatentDataset\n",
"from delphi.latents.constructors import default_constructor\n",
"from delphi.latents.samplers import sample\n",
"from delphi.pipeline import Pipeline, process_wrapper\n",
Expand Down Expand Up @@ -61,12 +60,7 @@
"module = \".model.layers.10\" # The layer to explain\n",
"latent_dict = {module: torch.arange(0,5)} # The what latents to explain\n",
"\n",
"dataset = LatentDataset(\n",
" raw_dir=\"latents\", # The folder where the cache is stored\n",
" cfg=latent_cfg,\n",
" modules=[module],\n",
" latents=latent_dict,\n",
")\n"
"\n"
]
},
{
Expand Down Expand Up @@ -116,8 +110,14 @@
" max_examples=latent_cfg.max_examples\n",
" )\n",
"sampler=partial(sample,cfg=experiment_cfg)\n",
"loader = LatentLoader(dataset, constructor=constructor, sampler=sampler)\n",
" "
"dataset = LatentDataset(\n",
" raw_dir=\"latents\", # The folder where the cache is stored\n",
" cfg=latent_cfg,\n",
" modules=[module],\n",
" latents=latent_dict,\n",
" constructor=constructor,\n",
" sampler=sampler\n",
") "
]
},
{
Expand Down Expand Up @@ -210,7 +210,7 @@
],
"source": [
"pipeline = Pipeline(\n",
" loader,\n",
" dataset,\n",
" explainer_pipe,\n",
")\n",
"number_of_parallel_latents = 10\n",
Expand Down
67 changes: 34 additions & 33 deletions examples/latent_contexts.ipynb

Large diffs are not rendered by default.

22 changes: 11 additions & 11 deletions examples/score_explanations.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,11 @@
"import os \n",
"import torch\n",
"import orjson\n",
"import asyncio\n",
"from delphi.clients import OpenRouter\n",
"from delphi.config import ExperimentConfig, LatentConfig\n",
"from delphi.explainers import explanation_loader\n",
"from delphi.latents import (\n",
" LatentDataset,\n",
" LatentLoader\n",
" LatentDataset\n",
")\n",
"from delphi.latents.constructors import default_constructor\n",
"from delphi.latents.samplers import sample\n",
Expand Down Expand Up @@ -65,12 +63,7 @@
"module = \".model.layers.10\" # The layer to score\n",
"latent_dict = {module: torch.arange(0,3)} # The what latents to score\n",
"\n",
"dataset = LatentDataset(\n",
" raw_dir=\"latents\", # The folder where the cache is stored\n",
" cfg=latent_cfg,\n",
" modules=[module],\n",
" latents=latent_dict,\n",
")\n"
"\n"
]
},
{
Expand Down Expand Up @@ -120,7 +113,14 @@
" max_examples=latent_cfg.max_examples\n",
" )\n",
"sampler=partial(sample,cfg=experiment_cfg)\n",
"loader = LatentLoader(dataset, constructor=constructor, sampler=sampler)\n",
"dataset = LatentDataset(\n",
" raw_dir=\"latents\", # The folder where the cache is stored\n",
" cfg=latent_cfg,\n",
" modules=[module],\n",
" latents=latent_dict,\n",
" constructor=constructor,\n",
" sampler=sampler\n",
")\n",
" "
]
},
Expand Down Expand Up @@ -217,7 +217,7 @@
],
"source": [
"pipeline = Pipeline(\n",
" loader,\n",
" dataset,\n",
" explainer_pipe,\n",
" scorer_pipe,\n",
")\n",
Expand Down