Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions delphi/latents/constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,40 @@ def pool_max_activation_windows(
return token_windows, activation_windows


def is_single_token_feature(
activations: Float[Tensor, "examples ctx_len"],
quantile_threshold: float = 0.5,
activation_ratio_threshold: float = 0.8,
) -> bool:
"""
Determine if a feature is primarily activated by single tokens.

Args:
activations: Activation values across context windows
quantile_threshold: Threshold for considering top activations (0.5 means top 50%)
activation_ratio_threshold: Ratio of single-token activations needed (0.8 means 80%)

Returns:
bool: True if the feature is primarily single-token activated
"""
# For each example, check if activation is concentrated in a single position
max_activations = activations.max(dim=1).values
top_k = int(len(max_activations) * quantile_threshold)
top_indices = max_activations.topk(top_k).indices
Copy link
Contributor

@luciaquirke luciaquirke Feb 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe that the activations are already in top k order (tbh they should probably be renamed to ordered_example_acts or similar to reflect this, starting in _top_k_pools). So we should be able to directly slice the first len(activations) * quantile_threshold activations.


# For top activating examples, check if activation is concentrated in single token
top_examples = activations[top_indices]

# Count positions where activation is significant
threshold = top_examples.max(dim=1).values.unsqueeze(1) * 0.5
significant_activations = (top_examples > threshold).sum(dim=1)

# Calculate ratio of single token activations
single_token_ratio = (significant_activations == 1).float().mean().item()

return single_token_ratio >= activation_ratio_threshold
Copy link
Contributor

@luciaquirke luciaquirke Feb 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it doesn't look like this method actually checks whether the activations are single token 🤔 maybe I'm just confused but I can't see it



def constructor(
record: LatentRecord,
activation_data: ActivationData,
Expand All @@ -125,6 +159,17 @@ def constructor(
max_examples = constructor_cfg.max_examples
min_examples = constructor_cfg.min_examples

token_windows, act_windows = pool_max_activation_windows(
Copy link
Contributor

@luciaquirke luciaquirke Feb 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this call has been duplicated by mistake. I don't believe this code will run. I have added documentation on how to run the tests to the README.md.

Copy link
Contributor

@luciaquirke luciaquirke Feb 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@meghana-0211 try pytest . and python -m delphi.tests.e2e.

activations=activations,
tokens=reshaped_tokens,
ctx_indices=ctx_indices,
index_within_ctx=index_within_ctx,
ctx_len=example_ctx_len,
max_examples=max_examples,
)

record.is_single_token = is_single_token_feature(act_windows)

# Get all positions where the latent is active
flat_indices = (
activation_data.locations[:, 0] * cache_ctx_len
Expand Down
3 changes: 3 additions & 0 deletions delphi/latents/latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ class LatentRecord:
explanation: str = ""
"""Explanation of the latent."""

is_single_token: bool = False
"""Whether this latent primarily activates on single tokens."""

@property
def max_activation(self) -> float:
"""
Expand Down
16 changes: 11 additions & 5 deletions delphi/latents/neighbours.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,25 +177,31 @@ def _compute_cooccurrence_neighbours(self) -> dict[int, list[tuple[int, float]]]
latent_index = latent_index[idx_cantor_sorted_idx]

n_tokens = int(idx_cantor.max().item())

token_batch_size = 20_000
done = False
while not done:
try:
print("Trying with batch size", token_batch_size)
# Find indices where idx_cantor crosses each batch boundary
bounday_values = torch.arange(token_batch_size, n_tokens, token_batch_size)
bounday_values = torch.arange(
token_batch_size, n_tokens, token_batch_size
)

batch_boundaries_tensor = torch.searchsorted(idx_cantor, bounday_values)
batch_boundaries = [0] + batch_boundaries_tensor.tolist()

if batch_boundaries[-1] != len(idx_cantor):
batch_boundaries.append(len(idx_cantor))

co_occurrence_matrix = torch.zeros((n_latents, n_latents), dtype=torch.int32)
#co_occurrence_matrix = co_occurrence_matrix.cuda()
co_occurrence_matrix = torch.zeros(
(n_latents, n_latents), dtype=torch.int32
)
# co_occurrence_matrix = co_occurrence_matrix.cuda()

for start, end in tqdm(zip(batch_boundaries[:-1], batch_boundaries[1:])):
for start, end in tqdm(
zip(batch_boundaries[:-1], batch_boundaries[1:])
):
# get all ind_cantor values between start and start + token_batch_size
selected_idx_cantor = idx_cantor[start:end]
selected_latent_index = latent_index[start:end]
Expand Down
13 changes: 13 additions & 0 deletions delphi/log/result_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def latent_balanced_score_metrics(
"true_negative_rate": np.average(df["true_negative_rate"], weights=weights),
"false_positive_rate": np.average(df["false_positive_rate"], weights=weights),
"false_negative_rate": np.average(df["false_negative_rate"], weights=weights),
"single_token_ratio": (
df["is_single_token"].mean() if "is_single_token" in df.columns else None
),
}

if verbose:
Expand All @@ -54,6 +57,12 @@ def latent_balanced_score_metrics(
{sum(fractions_failed) / len(fractions_failed):.3f}"""
)

if metrics["single_token_ratio"] is not None:
print("\nSingle Token Features:")
print(
f"Ratio of single token features: {metrics['single_token_ratio']:.3f}"
)

print("\nConfusion Matrix:")
print(f"True Positive Rate: {metrics['true_positive_rate']:.3f}")
print(f"True Negative Rate: {metrics['true_negative_rate']:.3f}")
Expand All @@ -77,6 +86,7 @@ def latent_balanced_score_metrics(
def parse_score_file(file_path):
with open(file_path, "rb") as f:
data = orjson.loads(f.read())
is_single_token = data.get("is_single_token", False)
df = pd.DataFrame(
[
{
Expand All @@ -87,6 +97,7 @@ def parse_score_file(file_path):
"probability": example["probability"],
"correct": example["correct"],
"activations": example["activations"],
"is_single_token": is_single_token,
}
for example in data
]
Expand Down Expand Up @@ -158,6 +169,7 @@ def parse_score_file(file_path):
total_negatives / total_examples if total_examples > 0 else 0
),
"failed_count": failed_count,
"is_single_token": is_single_token,
}

for key, value in metrics.items():
Expand Down Expand Up @@ -187,6 +199,7 @@ def build_scores_df(path: Path, target_modules: list[str], range: Tensor | None
"positive_class_ratio",
"negative_class_ratio",
"failed_count",
"is_single_token",
]
df_data = {
col: []
Expand Down
8 changes: 8 additions & 0 deletions delphi/scorers/scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@ class ScorerResult(NamedTuple):
score: Any
"""Generated score for latent."""

def to_dict(self):
"""Convert the scorer result to a dictionary for serialization."""
return {
**asdict(self.record),
"score": self.score,
"is_single_token": self.record.is_single_token,
}


class Scorer(ABC):
@abstractmethod
Expand Down