Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions delphi/latents/constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def prepare_non_activating_examples(
NonActivatingExample(
tokens=toks,
activations=acts,
normalized_activations=None,
distance=distance,
str_tokens=tokenizer.batch_decode(toks),
)
Expand Down Expand Up @@ -281,7 +280,6 @@ def constructor(
ActivatingExample(
tokens=toks,
activations=acts,
normalized_activations=None,
)
for toks, acts in zip(token_windows, act_windows)
]
Expand Down
17 changes: 10 additions & 7 deletions delphi/latents/latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,6 @@ class Example:
activations: Float[Tensor, "ctx_len"]
"""Activation values for the input sequence."""

str_tokens: list[str] | None = None
"""Tokenized input sequence as strings."""

normalized_activations: Optional[Float[Tensor, "ctx_len"]] = None
"""Activations quantized to integers in [0, 10]."""

@property
def max_activation(self) -> float:
"""
Expand All @@ -98,6 +92,12 @@ class ActivatingExample(Example):
An example of a latent that activates a model.
"""

normalized_activations: Optional[Float[Tensor, "ctx_len"]] = None
"""Activations quantized to integers in [0, 10]."""

str_tokens: Optional[list[str]] = None
"""Tokenized input sequence as strings."""

quantile: int = 0
"""The quantile of the activating example."""

Expand All @@ -108,6 +108,9 @@ class NonActivatingExample(Example):
An example of a latent that does not activate a model.
"""

str_tokens: list[str]
"""Tokenized input sequence as strings."""

distance: float = 0.0
"""
The distance from the neighbouring latent.
Expand All @@ -125,7 +128,7 @@ class LatentRecord:
"""The latent associated with the record."""

examples: list[ActivatingExample] = field(default_factory=list)
"""Example sequences where the latent activations, assumed to be sorted in
"""Example sequences where the latent activates, assumed to be sorted in
descending order by max activation."""

not_active: list[NonActivatingExample] = field(default_factory=list)
Expand Down
78 changes: 38 additions & 40 deletions delphi/scorers/embedding/embedding.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import asyncio
import random
from dataclasses import dataclass
from typing import NamedTuple
from typing import NamedTuple, Sequence

from transformers import PreTrainedTokenizer
from delphi.latents.latents import ActivatingExample, NonActivatingExample

from ...latents import Example, LatentRecord
from ..scorer import Scorer, ScorerResult
Expand Down Expand Up @@ -33,56 +33,53 @@ class EmbeddingScorer(Scorer):
def __init__(
self,
model,
tokenizer: PreTrainedTokenizer | None = None,
verbose: bool = False,
**generation_kwargs,
):
self.model = model
self.verbose = verbose
self.tokenizer = tokenizer
self.generation_kwargs = generation_kwargs

async def __call__( # type: ignore
self, # type: ignore
record: LatentRecord, # type: ignore
) -> ScorerResult: # type: ignore
async def __call__(
self,
record: LatentRecord,
) -> ScorerResult:
samples = self._prepare(record)

random.shuffle(samples)
results = self._query(
record.explanation,
samples, # type: ignore
samples,
)

return ScorerResult(record=record, score=results)

def call_sync(self, record: LatentRecord) -> list[EmbeddingOutput]:
return asyncio.run(self.__call__(record)) # type: ignore
def call_sync(self, record: LatentRecord) -> ScorerResult:
return asyncio.run(self.__call__(record))

def _prepare(self, record: LatentRecord) -> list[list[Sample]]:
def _prepare(self, record: LatentRecord) -> list[Sample]:
"""
Prepare and shuffle a list of samples for classification.
"""
samples = []

assert (
record.extra_examples is not None
), "Extra (non-activating) examples need to be provided"

defaults = {
"tokenizer": self.tokenizer,
}
samples = examples_to_samples(
record.extra_examples, # type: ignore
distance=-1,
**defaults, # type: ignore
samples.extend(
examples_to_samples(
record.extra_examples,
)
)

for i, examples in enumerate(record.test):
samples.extend(
examples_to_samples(
examples, # type: ignore
distance=i + 1,
**defaults, # type: ignore
)
samples.extend(
examples_to_samples(
record.test,
)
)

return samples # type: ignore
return samples

def _query(self, explanation: str, samples: list[Sample]) -> list[EmbeddingOutput]:
explanation_string = (
Expand All @@ -93,38 +90,39 @@ def _query(self, explanation: str, samples: list[Sample]) -> list[EmbeddingOutpu
query_embeding = self.model.encode(explanation_prompt)
samples_text = [sample.text for sample in samples]

# # Temporary batching
# sample_embedings = []
# for i in range(0, len(samples_text), 10):
# sample_embedings.extend(self.model.encode(samples_text[i:i+10]))
sample_embedings = self.model.encode(samples_text)
similarity = self.model.similarity(query_embeding, sample_embedings)[0]

results = []
for i in range(len(samples)):
# print(i)
samples[i].data.similarity = similarity[i].item()
results.append(samples[i].data)
return results


def examples_to_samples(
examples: list[Example],
tokenizer: PreTrainedTokenizer,
**sample_kwargs,
examples: Sequence[Example],
) -> list[Sample]:
samples = []
for example in examples:
if tokenizer is not None:
text = "".join(tokenizer.batch_decode(example.tokens))
else:
text = "".join(example.tokens)
assert isinstance(example, ActivatingExample) or isinstance(
example, NonActivatingExample
)
assert example.str_tokens is not None
text = "".join(str(token) for token in example.str_tokens)
activations = example.activations.tolist()
samples.append(
Sample(
text=text,
activations=activations,
data=EmbeddingOutput(text=text, **sample_kwargs),
data=EmbeddingOutput(
text=text,
distance=(
example.quantile
if isinstance(example, ActivatingExample)
else example.distance
),
),
)
)

Expand Down
60 changes: 29 additions & 31 deletions delphi/scorers/surprisal/surprisal.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import random
from dataclasses import dataclass
from typing import NamedTuple
from typing import NamedTuple, Sequence

import torch
from simple_parsing import field
from torch.nn.functional import cross_entropy
from transformers import PreTrainedTokenizer

from delphi.utils import assert_type

from ...latents import ActivatingExample, Example, LatentRecord
from ...latents import (
ActivatingExample,
Example,
LatentRecord,
NonActivatingExample,
)
from ..scorer import Scorer, ScorerResult
from .prompts import BASEPROMPT as base_prompt

Expand Down Expand Up @@ -44,21 +46,19 @@ class SurprisalScorer(Scorer):
def __init__(
self,
model,
tokenizer,
verbose: bool,
batch_size: int,
**generation_kwargs,
):
self.model = model
self.verbose = verbose
self.tokenizer = tokenizer
self.batch_size = batch_size
self.generation_kwargs = generation_kwargs

async def __call__( # type: ignore
self, # type: ignore
record: LatentRecord, # type: ignore
) -> ScorerResult: # type: ignore
async def __call__(
self,
record: LatentRecord,
) -> ScorerResult:
samples = self._prepare(record)

random.shuffle(samples)
Expand All @@ -74,35 +74,25 @@ def _prepare(self, record: LatentRecord) -> list[Sample]:
Prepare and shuffle a list of samples for classification.
"""

defaults = {
"tokenizer": self.tokenizer,
}

assert record.extra_examples is not None, "No extra examples provided"
samples = examples_to_samples(
record.extra_examples,
distance=-1,
**defaults,
)

for i, examples in enumerate(record.test):
examples = assert_type(list, examples)
samples.extend(
examples_to_samples(
examples,
distance=i + 1,
**defaults,
)
samples.extend(
examples_to_samples(
record.test,
)
)

return samples

def compute_loss_with_kv_cache(
self, explanation: str, samples: list[Sample], batch_size=2
):
# print(explanation_prompt)
model = self.model
tokenizer = self.model.tokenizer
assert tokenizer is not None, "Tokenizer is not set in model.tokenizer"
# Tokenize explanation
tokenizer.padding_side = "right"
tokenizer.pad_token = tokenizer.eos_token
Expand Down Expand Up @@ -187,20 +177,28 @@ def _query(self, explanation: str, samples: list[Sample]) -> list[SurprisalOutpu


def examples_to_samples(
examples: list[Example] | list[ActivatingExample],
tokenizer: PreTrainedTokenizer,
**sample_kwargs,
examples: Sequence[Example],
) -> list[Sample]:
samples = []
for example in examples:
text = "".join(tokenizer.batch_decode(example.tokens))
assert isinstance(example, ActivatingExample) or isinstance(
example, NonActivatingExample
)
assert example.str_tokens is not None
text = "".join(str(token) for token in example.str_tokens)
activations = example.activations.tolist()
samples.append(
Sample(
text=text,
activations=activations,
data=SurprisalOutput(
activations=activations, text=text, **sample_kwargs
activations=activations,
text=text,
distance=(
example.quantile
if isinstance(example, ActivatingExample)
else example.distance
),
),
)
)
Expand Down