Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install ".[dev]"
pip install ".[dev,visualize]"

- name: Run tests
run: pytest

- name: Type Checking
uses: jakebailey/pyright-action@v1
with:
version: 1.1.378
35 changes: 26 additions & 9 deletions delphi/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from delphi.pipeline import Pipe, Pipeline, process_wrapper
from delphi.scorers import DetectionScorer, FuzzingScorer
from delphi.sparse_coders import load_hooks_sparse_coders, load_sparse_coders
from delphi.utils import load_tokenized_data
from delphi.utils import assert_type, load_tokenized_data


def load_artifacts(run_cfg: RunConfig):
Expand Down Expand Up @@ -70,8 +70,11 @@ def create_neighbours(
neighbours_path.mkdir(parents=True, exist_ok=True)

constructor_cfg = run_cfg.constructor_cfg
if constructor_cfg.neighbours_type != "co-occurrence":
saes = load_sparse_coders(run_cfg, device="cpu")
saes = (
load_sparse_coders(run_cfg, device="cpu")
if constructor_cfg.neighbours_type != "co-occurrence"
else {}
)

for hookpoint in hookpoints:

Expand All @@ -90,6 +93,11 @@ def create_neighbours(
neighbour_calculator = NeighbourCalculator(
autoencoder=saes[hookpoint].cuda(), number_of_neighbours=100
)
else:
raise ValueError(
f"Neighbour type {constructor_cfg.neighbours_type} not supported"
)

neighbour_calculator.populate_neighbour_cache(constructor_cfg.neighbours_type)
neighbour_calculator.save_neighbour_cache(f"{neighbours_path}/{hookpoint}")

Expand Down Expand Up @@ -325,8 +333,11 @@ async def run(
hookpoints, hookpoint_to_sparse_encode, model, transcode = load_artifacts(run_cfg)
tokenizer = AutoTokenizer.from_pretrained(run_cfg.model, token=run_cfg.hf_token)

nrh = non_redundant_hookpoints(
hookpoint_to_sparse_encode, latents_path, "cache" in run_cfg.overwrite
nrh = assert_type(
dict,
non_redundant_hookpoints(
hookpoint_to_sparse_encode, latents_path, "cache" in run_cfg.overwrite
),
)
if nrh:
populate_cache(
Expand All @@ -340,8 +351,11 @@ async def run(

del model, hookpoint_to_sparse_encode
if run_cfg.constructor_cfg.non_activating_source == "neighbours":
nrh = non_redundant_hookpoints(
hookpoints, neighbours_path, "neighbours" in run_cfg.overwrite
nrh = assert_type(
list,
non_redundant_hookpoints(
hookpoints, neighbours_path, "neighbours" in run_cfg.overwrite
),
)
if nrh:
create_neighbours(
Expand All @@ -353,8 +367,11 @@ async def run(
else:
print("Skipping neighbour creation")

nrh = non_redundant_hookpoints(
hookpoints, scores_path, "scores" in run_cfg.overwrite
nrh = assert_type(
list,
non_redundant_hookpoints(
hookpoints, scores_path, "scores" in run_cfg.overwrite
),
)
if nrh:
await process_cache(
Expand Down
4 changes: 2 additions & 2 deletions delphi/clients/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
@dataclass
class Response:
text: str
logprobs: list[float] = None
prompt_logprobs: list[float] = None
logprobs: list[float] | None = None
prompt_logprobs: list[float] | None = None


class Client(ABC):
Expand Down
15 changes: 7 additions & 8 deletions delphi/clients/offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ async def process_func(self, batches: Union[str, list[dict[str, str]]], kwargs):
self.sampling_params.temperature = kwarg["temperature"]
loop = asyncio.get_running_loop()
prompts = []
if self.statistics:
statistics = []
statistics = []

for batch in batches:
prompt = self.tokenizer.apply_chat_template(
batch, add_generation_prompt=True, tokenize=True
Expand All @@ -101,7 +101,7 @@ async def process_func(self, batches: Union[str, list[dict[str, str]]], kwargs):
if self.statistics:
non_cached_tokens = len(
self.tokenizer.apply_chat_template(
batch[-1:], add_generation_prompt=True, tokenize=True
batch[-1:], add_generation_prompt=True, tokenize=True # type: ignore
)
)
statistics.append(
Expand All @@ -114,7 +114,7 @@ async def process_func(self, batches: Union[str, list[dict[str, str]]], kwargs):
response = await loop.run_in_executor(
None,
partial(
self.client.generate,
self.client.generate, # type: ignore
prompt_token_ids=prompts,
sampling_params=self.sampling_params,
use_tqdm=False,
Expand All @@ -127,10 +127,10 @@ async def process_func(self, batches: Union[str, list[dict[str, str]]], kwargs):
if self.statistics:
statistics[i].num_generated_tokens = len(r.outputs[0].token_ids)
# save the statistics to a file, name is a hash of the prompt
statistics[i].prompt = batches[i][-1]["content"]
statistics[i].prompt = batches[i][-1]["content"] # type: ignore
statistics[i].response = r.outputs[0].text
with open(
f"statistics/{hash(batches[i][-1]['content'][-100:])}.json", "w"
f"statistics/{hash(batches[i][-1]['content'][-100:])}.json", "w" # type: ignore
) as f:
json.dump(statistics[i].__dict__, f, indent=4)
new_response.append(
Expand All @@ -142,15 +142,14 @@ async def process_func(self, batches: Union[str, list[dict[str, str]]], kwargs):
)
return new_response

async def generate(self, prompt: Union[str, list[dict[str, str]]], **kwargs) -> str:
async def generate(self, prompt: Union[str, list[dict[str, str]]], **kwargs) -> str: # type: ignore
"""
Enqueue a request and wait for the result.
"""
future = asyncio.Future()
if self.task is None:
self.task = asyncio.create_task(self._process_batches())
await self.queue.put((prompt, future, kwargs))
# print(f"Current queue size: {self.queue.qsize()} prompts")
return await future

def _parse_logprobs(self, response):
Expand Down
8 changes: 4 additions & 4 deletions delphi/clients/openrouter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class OpenRouter(Client):
def __init__(
self,
model: str,
api_key: str = None,
api_key: str | None = None,
base_url="https://openrouter.ai/api/v1/chat/completions",
):
super().__init__(model)
Expand All @@ -35,9 +35,9 @@ def postprocess(self, response):
msg = response_json["choices"][0]["message"]["content"]
return Response(msg)

async def generate(
self, prompt: str, raw: bool = False, max_retries: int = 1, **kwargs
) -> Response:
async def generate( # type: ignore
self, prompt: str, raw: bool = False, max_retries: int = 1, **kwargs # type: ignore
) -> Response: # type: ignore
kwargs.pop("schema", None)
max_tokens = kwargs.pop("max_tokens", 500)
temperature = kwargs.pop("temperature", 1.0)
Expand Down
15 changes: 13 additions & 2 deletions delphi/explainers/single_token_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,21 @@ def _build_prompt(self, examples):
highlighted_examples = []

for i, example in enumerate(examples):
highlighted_examples.append(self._highlight(i + 1, example))
highlighted_examples.append(
self._highlight(example.str_tokens, example.activations.tolist())
)

if self.activations:
highlighted_examples.append(self._join_activations(example))
assert (
example.normalized_activations is not None
), "Normalized activations are required for activations in explainer"
highlighted_examples.append(
self._join_activations(
example.str_tokens,
example.activations.tolist(),
example.normalized_activations.tolist(),
)
)

return build_single_token_prompt(
examples=highlighted_examples,
Expand Down
1 change: 0 additions & 1 deletion delphi/latents/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,6 @@ def run(self, n_tokens: int, tokens: token_tensor_shape):

print(f"Total tokens processed: {total_tokens:,}")
self.cache.save()
del sae_latents

def save(self, save_dir: Path, save_tokens: bool = True):
"""
Expand Down
3 changes: 3 additions & 0 deletions delphi/latents/constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,9 @@ def constructor(
seed=seed,
tokenizer=tokenizer,
)
else:
raise ValueError(f"Invalid non-activating source: {source_non_activating}")

record.not_active = non_activating_examples
return record

Expand Down
7 changes: 5 additions & 2 deletions delphi/latents/latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class LatentRecord:
train: list[ActivatingExample] = field(default_factory=list)
"""Training examples."""

test: list[ActivatingExample] = field(default_factory=list)
test: list[ActivatingExample] | list[list[Example]] = field(default_factory=list)
"""Test examples."""

neighbours: list[Neighbour] = field(default_factory=list)
Expand All @@ -143,6 +143,9 @@ class LatentRecord:
explanation: str = ""
"""Explanation of the latent."""

extra_examples: Optional[list[Example]] = None
"""Extra examples to include in the record."""

@property
def max_activation(self) -> float:
"""
Expand Down Expand Up @@ -203,7 +206,7 @@ def display(
Returns:
str: The formatted string.
"""
from IPython.core.display import HTML, display
from IPython.core.display import HTML, display # type: ignore

def _to_string(tokens: list[str], activations: Float[Tensor, "ctx_len"]) -> str:
"""
Expand Down
3 changes: 2 additions & 1 deletion delphi/latents/neighbours.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def _compute_cooccurrence_neighbours(self) -> dict[int, list[tuple[int, float]]]
n_tokens = int(idx_cantor.max().item())

token_batch_size = 20_000
co_occurrence_matrix = None
done = False
while not done:
try:
Expand All @@ -197,7 +198,6 @@ def _compute_cooccurrence_neighbours(self) -> dict[int, list[tuple[int, float]]]
co_occurrence_matrix = torch.zeros(
(n_latents, n_latents), dtype=torch.int32
)
# co_occurrence_matrix = co_occurrence_matrix.cuda()

for start, end in tqdm(
zip(batch_boundaries[:-1], batch_boundaries[1:])
Expand Down Expand Up @@ -239,6 +239,7 @@ def compute_jaccard(cooc_matrix):
return jaccard_matrix

# Compute Jaccard similarity matrix
assert co_occurrence_matrix is not None, "Co-occurrence matrix is not computed"
jaccard_matrix = compute_jaccard(co_occurrence_matrix)

# get the indices of the top k neighbours for each feature
Expand Down
2 changes: 1 addition & 1 deletion delphi/log/result_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,5 +258,5 @@ def log_results(scores_path: Path, visualize_path: Path, target_modules: list[st
plot_line(df, visualize_path)

for score_type in df["score_type"].unique():
score_df = df[df["score_type"] == score_type]
score_df = df.query(f"score_type == '{score_type}'")
latent_balanced_score_metrics(score_df, score_type)
16 changes: 8 additions & 8 deletions delphi/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ def __init__(self, loader: AsyncIterable | Callable, *pipes: Pipe | Callable):
loader (Callable): The loader to be executed first.
*pipes (list[Pipe]): Pipes to be executed in the pipeline.
"""
self.pipes = [loader] + list(pipes)

self.loader = loader
self.pipes = pipes

async def run(self, max_concurrent: int = 10) -> list[Any]:
"""
Expand Down Expand Up @@ -136,13 +138,11 @@ async def generate_items(self) -> AsyncIterable[Any]:
Raises:
TypeError: If the first pipe is neither an async iterable nor a callable.
"""
first_pipe = self.pipes[0]

if isinstance(first_pipe, AsyncIterable):
async for item in first_pipe:
if isinstance(self.loader, AsyncIterable):
async for item in self.loader:
yield item
elif callable(first_pipe):
for item in first_pipe():
elif callable(self.loader):
for item in self.loader():
yield item
await asyncio.sleep(0) # Allow other coroutines to run
else:
Expand All @@ -164,6 +164,6 @@ async def process_item(
"""
async with semaphore:
result = item
for pipe in self.pipes[1:]:
for pipe in self.pipes:
result = await pipe(result)
return result
8 changes: 4 additions & 4 deletions delphi/scorers/classifier/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ def __init__(
self.generation_kwargs = generation_kwargs
self.log_prob = log_prob

async def __call__(
self,
record: LatentRecord,
) -> ScorerResult:
async def __call__( # type: ignore
self, # type: ignore
record: LatentRecord, # type: ignore
) -> ScorerResult: # type: ignore
samples = self._prepare(record)
random.shuffle(samples)

Expand Down
4 changes: 2 additions & 2 deletions delphi/scorers/classifier/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(
def prompt(self, examples: str, explanation: str) -> list[dict]:
return prompt(examples, explanation)

def _prepare(self, record: LatentRecord) -> list[Sample]:
def _prepare(self, record: LatentRecord) -> list[Sample]: # type: ignore
"""
Prepare and shuffle a list of samples for classification.
"""
Expand All @@ -57,7 +57,7 @@ def _prepare(self, record: LatentRecord) -> list[Sample]:

samples.extend(
examples_to_samples(
record.test,
record.test, # type: ignore
)
)

Expand Down
6 changes: 3 additions & 3 deletions delphi/scorers/classifier/fuzz.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,13 @@ def mean_n_activations_ceil(self, examples: list[ActivatingExample]):

return ceil(avg)

def _prepare(self, record: LatentRecord) -> list[Sample]:
def _prepare(self, record: LatentRecord) -> list[Sample]: # type: ignore
"""
Prepare and shuffle a list of samples for classification.
"""
assert len(record.test) > 0, "No test records found"

n_incorrect = self.mean_n_activations_ceil(record.test)
n_incorrect = self.mean_n_activations_ceil(record.test) # type: ignore

if len(record.not_active) > 0:
samples = examples_to_samples(
Expand All @@ -81,7 +81,7 @@ def _prepare(self, record: LatentRecord) -> list[Sample]:

samples.extend(
examples_to_samples(
record.test,
record.test, # type: ignore
n_incorrect=0,
highlighted=True,
)
Expand Down
Loading