Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Install this library as a local editable installation. Run the following command

To run the default pipeline from the command line, use the following command:

`python -m delphi meta-llama/Meta-Llama-3-8B EleutherAI/sae-llama-3-8b-32x --explainer_model 'hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4' --dataset_repo 'EleutherAI/fineweb-edu-dedup-10b' --dataset_split 'train[:1%]' --n_tokens 10_000_000 --max_latents 100 --hookpoints layers.5 --filter_bos --name llama-3-8B`
`python -m delphi meta-llama/Meta-Llama-3-8B EleutherAI/sae-llama-3-8b-32x --explainer_model 'hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4' --dataset_repo 'EleutherAI/fineweb-edu-dedup-10b' --dataset_split 'train[:1%]' --n_tokens 10_000_000 --max_latents 100 --hookpoints layers.5 --filter_bos --name llama-3-8B`

This command will:
1. Cache activations for the first 10 million tokens of EleutherAI/rpj-v2-sample.
Expand Down
5 changes: 1 addition & 4 deletions delphi/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from simple_parsing import ArgumentParser
from sparsify.data import chunk_and_tokenize
from torch import Tensor
from torchtyping import TensorType
from transformers import (
AutoModel,
AutoTokenizer,
Expand Down Expand Up @@ -88,7 +87,7 @@ async def process_cache(
latent_dict = {
hook: latent_range for hook in hookpoints
} # The latent range to explain
latent_dict = cast(dict[str, int | Tensor], latent_dict)
latent_dict = cast(dict[str, Tensor], latent_dict)

constructor = partial(
default_constructor,
Expand Down Expand Up @@ -235,8 +234,6 @@ def populate_cache(
]
tokens = truncated_tokens.reshape(-1, cfg.ctx_len)

tokens = cast(TensorType["batch", "seq"], tokens)

cache = LatentCache(
model,
hookpoint_to_sparse_encode,
Expand Down
111 changes: 23 additions & 88 deletions delphi/explainers/default/default.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import asyncio
import re

from ...logger import logger
from ..explainer import Explainer, ExplainerResult
from ..explainer import Example, Explainer
from .prompt_builder import build_prompt


Expand All @@ -20,98 +18,35 @@ def __init__(
temperature: float = 0.0,
**generation_kwargs,
):
self.client = client
self.tokenizer = tokenizer
self.verbose = verbose

self.activations = activations
self.cot = cot
self.threshold = threshold
self.temperature = temperature
self.generation_kwargs = generation_kwargs

async def __call__(self, record):
messages = self._build_prompt(record.train)

response = await self.client.generate(
messages, temperature=self.temperature, **self.generation_kwargs
super().__init__(
client,
tokenizer,
verbose,
activations,
cot,
threshold,
temperature,
**generation_kwargs,
)

try:
explanation = self.parse_explanation(response.text)
if self.verbose:
logger.info(f"Explanation: {explanation}")
logger.info(f"Final message to explainer: {messages[-1]['content']}")
logger.info(f"Response from explainer: {response.text}")

return ExplainerResult(record=record, explanation=explanation)
except Exception as e:
logger.error(f"Explanation parsing failed: {e}")
return ExplainerResult(
record=record, explanation="Explanation could not be parsed."
)

def parse_explanation(self, text: str) -> str:
try:
match = re.search(r"\[EXPLANATION\]:\s*(.*)", text, re.DOTALL)
return (
match.group(1).strip() if match else "Explanation could not be parsed."
)
except Exception as e:
logger.error(f"Explanation parsing regex failed: {e}")
raise

def _highlight(self, index, example):
result = f"Example {index}: "

threshold = example.max_activation * self.threshold
if self.tokenizer is not None:
str_toks = self.tokenizer.batch_decode(example.tokens)
example.str_toks = str_toks
else:
str_toks = example.tokens
example.str_toks = str_toks
activations = example.activations

def check(i):
return activations[i] > threshold

i = 0
while i < len(str_toks):
if check(i):
result += "<<"

while i < len(str_toks) and check(i):
result += str_toks[i]
i += 1
result += ">>"
else:
result += str_toks[i]
i += 1

return "".join(result)

def _join_activations(self, example):
activations = []

for i, activation in enumerate(example.activations):
if activation > example.max_activation * self.threshold:
activations.append(
(example.str_toks[i], int(example.normalized_activations[i]))
)

acts = ", ".join(f'("{item[0]}" : {item[1]})' for item in activations)

return "Activations: " + acts

def _build_prompt(self, examples):
def _build_prompt(self, examples: list[Example]) -> list[dict]:
highlighted_examples = []

for i, example in enumerate(examples):
highlighted_examples.append(self._highlight(i + 1, example))
str_toks = self.tokenizer.batch_decode(example.tokens)
activations = example.activations.tolist()
highlighted_examples.append(self._highlight(str_toks, activations))

if self.activations:
highlighted_examples.append(self._join_activations(example))
assert (
example.normalized_activations is not None
), "Normalized activations are required for activations in explainer"
normalized_activations = example.normalized_activations.tolist()
highlighted_examples.append(
self._join_activations(
str_toks, activations, normalized_activations
)
)

highlighted_examples = "\n".join(highlighted_examples)

Expand Down
5 changes: 2 additions & 3 deletions delphi/explainers/default/prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ def build_examples(


def build_prompt(
examples,
examples: str,
activations: bool = False,
cot: bool = False,
):
) -> list[dict]:
messages = system(
cot=cot,
)
Expand All @@ -49,7 +49,6 @@ def build_prompt(
"content": user_start,
}
)
print(messages)

return messages

Expand Down
116 changes: 110 additions & 6 deletions delphi/explainers/explainer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import json
import os
import random
import re
from abc import ABC, abstractmethod
from typing import NamedTuple

import aiofiles

from ..latents.latents import LatentRecord
from ..latents.latents import Example, LatentRecord
from ..logger import logger


class ExplainerResult(NamedTuple):
Expand All @@ -18,18 +20,120 @@ class ExplainerResult(NamedTuple):


class Explainer(ABC):
"""
Abstract base class for explainers.
"""

def __init__(
self,
client,
tokenizer,
verbose: bool = False,
activations: bool = False,
cot: bool = False,
threshold: float = 0.6,
temperature: float = 0.0,
**generation_kwargs,
):
self.client = client
self.tokenizer = tokenizer
self.verbose = verbose
self.activations = activations
self.cot = cot
self.threshold = threshold
self.temperature = temperature
self.generation_kwargs = generation_kwargs

async def __call__(self, record: LatentRecord) -> ExplainerResult:
messages = self._build_prompt(record.train)

response = await self.client.generate(
messages, temperature=self.temperature, **self.generation_kwargs
)

try:
explanation = self.parse_explanation(response.text)
if self.verbose:
logger.info(f"Explanation: {explanation}")
logger.info(f"Messages: {messages[-1]['content']}")
logger.info(f"Response: {response}")

return ExplainerResult(record=record, explanation=explanation)
except Exception as e:
logger.error(f"Explanation parsing failed: {e}")
return ExplainerResult(
record=record, explanation="Explanation could not be parsed."
)

def parse_explanation(self, text: str) -> str:
try:
match = re.search(r"\[EXPLANATION\]:\s*(.*)", text, re.DOTALL)
if match:
return match.group(1).strip()
else:
return "Explanation could not be parsed."
except Exception as e:
logger.error(f"Explanation parsing regex failed: {e}")
raise

def _highlight(self, str_toks: list[str], activations: list[float]) -> str:
result = ""
threshold = max(activations) * self.threshold

def check(i):
return activations[i] > threshold

i = 0
while i < len(str_toks):
if check(i):
result += "<<"

while i < len(str_toks) and check(i):
result += str_toks[i]
i += 1
result += ">>"
else:
result += str_toks[i]
i += 1

return "".join(result)

def _join_activations(
self,
str_toks: list[str],
token_activations: list[float],
normalized_activations: list[float],
) -> str:
acts = ""
activation_count = 0
for str_tok, token_activation, normalized_activation in zip(
str_toks, token_activations, normalized_activations
):
if token_activation > max(token_activations) * self.threshold:
# TODO: for each example, we only show the first 10 activations
# decide on the best way to do this
if activation_count > 10:
break
acts += f'("{str_tok}" : {int(normalized_activation)}), '
activation_count += 1

return "Activations: " + acts

@abstractmethod
def __call__(self, record: LatentRecord) -> ExplainerResult:
def _build_prompt(self, examples: list[Example]) -> list[dict]:
pass


async def explanation_loader(
record: LatentRecord, explanation_dir: str
) -> ExplainerResult:
async with aiofiles.open(f"{explanation_dir}/{record.latent}.txt", "r") as f:
explanation = json.loads(await f.read())

return ExplainerResult(record=record, explanation=explanation)
try:
async with aiofiles.open(f"{explanation_dir}/{record.latent}.txt", "r") as f:
explanation = json.loads(await f.read())
return ExplainerResult(record=record, explanation=explanation)
except FileNotFoundError:
print(f"No explanation found for {record.latent}")
return ExplainerResult(record=record, explanation="No explanation found")


async def random_explanation_loader(
Expand Down
Loading