Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions delphi/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,12 @@ class RunConfig:
overwrite: list[str] = list_field()
"""Whether to overwrite existing parts of the run. Options are 'cache', 'scores', and 'visualize'."""

num_examples_per_scorer_prompt: int = field(
default=1,
)
"""Number of examples to use for each scorer prompt. Using more than 1 improves scoring speed but can
leak information to the fuzzing scorer, and increases the scorer LLM task difficulty."""


def load_gemma_autoencoders(model, ae_layers: list[int],average_l0s: Dict[int,int],size:str,type:str, hookpoints):
submodules = {}
Expand All @@ -131,6 +137,8 @@ def _forward(sae, x):
submodule = model.model.layers[layer]
elif type == "mlp":
submodule = model.model.layers[layer].post_feedforward_layernorm
else:
raise ValueError(f"Invalid autoencoder type: {type}")
submodule.ae = AutoencoderLatents(
sae, partial(_forward, sae), width=sae.W_enc.shape[1]
)
Expand Down Expand Up @@ -264,7 +272,7 @@ async def process_cache(
constructor = partial(
default_constructor,
token_loader=None,
n_random=experiment_cfg.n_random,
n_not_active=experiment_cfg.n_non_activating,
ctx_len=experiment_cfg.example_ctx_len,
max_examples=latent_cfg.max_examples,
)
Expand Down Expand Up @@ -304,7 +312,7 @@ def scorer_postprocess(result, score_dir):
DetectionScorer(
client,
tokenizer=dataset.tokenizer, # type: ignore
batch_size=10,
batch_size=run_cfg.num_examples_per_scorer_prompt,
verbose=False,
log_prob=False,
),
Expand All @@ -315,7 +323,7 @@ def scorer_postprocess(result, score_dir):
FuzzingScorer(
client,
tokenizer=dataset.tokenizer, # type: ignore
batch_size=10,
batch_size=run_cfg.num_examples_per_scorer_prompt,
verbose=False,
log_prob=False,
),
Expand Down Expand Up @@ -363,10 +371,10 @@ def populate_cache(
flattened_tokens = tokens.flatten()
mask = ~torch.isin(flattened_tokens, torch.tensor([tokenizer.bos_token_id]))
masked_tokens = flattened_tokens[mask]
truncated_tokens = masked_tokens[
: len(masked_tokens) - (len(masked_tokens) % cfg.ctx_len)
]
tokens = truncated_tokens.reshape(-1, cfg.ctx_len)
truncated_tokens = masked_tokens[
: len(masked_tokens) - (len(masked_tokens) % cfg.ctx_len)
]
tokens = truncated_tokens.reshape(-1, cfg.ctx_len)

tokens = cast(TensorType["batch", "seq"], tokens)

Expand Down Expand Up @@ -399,6 +407,7 @@ async def run(experiment_cfg: ExperimentConfig, latent_cfg: LatentConfig, cache_
latents_path = base_path / "latents"
explanations_path = base_path / "explanations"
scores_path = base_path / "scores"
visualize_path = base_path / "visualize"

latent_range = (
torch.arange(run_cfg.max_latents) if run_cfg.max_latents else None
Expand Down Expand Up @@ -445,7 +454,7 @@ async def run(experiment_cfg: ExperimentConfig, latent_cfg: LatentConfig, cache_
print(f"Files found in {scores_path}, skipping...")

if run_cfg.log:
log_results(scores_path, run_cfg.hookpoints)
log_results(scores_path, visualize_path, run_cfg.hookpoints)


if __name__ == "__main__":
Expand Down
2 changes: 2 additions & 0 deletions delphi/autoencoders/DeepMind/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def _forward(sae, x):
submodule = model.model.layers[layer]
elif type == "mlp":
submodule = model.model.layers[layer].post_feedforward_layernorm
else:
raise ValueError(f"Invalid autoencoder type: {type}")
submodule.ae = AutoencoderLatents(
sae, partial(_forward, sae), width=sae.W_enc.shape[1]
)
Expand Down
2 changes: 1 addition & 1 deletion delphi/autoencoders/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def from_pretrained(
autoencoder_type = config.autoencoder_type
model_name_or_path = config.model_name_or_path
if autoencoder_type == "SAE":
from sae import Sae
from sparsify import Sae
local = kwargs.get("local",None)
assert local is not None, "local must be specified for SAE"
if local:
Expand Down
91 changes: 59 additions & 32 deletions delphi/log/result_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@
from torch import Tensor
from pathlib import Path
import numpy as np
import plotly.express as px
import plotly.io as pio

pio.kaleido.scope.mathjax = None # https://github.com/plotly/plotly.py/issues/3469

def feature_balanced_score_metrics(df: pd.DataFrame, score_type: str):

def latent_balanced_score_metrics(df: pd.DataFrame, score_type: str, log: bool = True):
# Calculate weights based on non-errored examples
valid_examples = df['total_examples']
weights = valid_examples / valid_examples.sum()
Expand All @@ -29,25 +33,26 @@ def feature_balanced_score_metrics(df: pd.DataFrame, score_type: str):
'false_negative_rate': np.average(df['false_negative_rate'], weights=weights),
}

print(f"\n--- {score_type.title()} Metrics ---")
print(f"Accuracy: {weighted_mean_metrics['accuracy']:.3f}")
print(f"F1 Score: {weighted_mean_metrics['f1_score']:.3f}")
print(f"Precision: {weighted_mean_metrics['precision']:.3f}")
print(f"Recall: {weighted_mean_metrics['recall']:.3f}")

fractions_failed = [failed_count / total_examples for failed_count, total_examples in zip(df['failed_count'], df['total_examples'])]
print(f"Average fraction of failed examples: {sum(fractions_failed) / len(fractions_failed):.3f}")

print("\nConfusion Matrix:")
print(f"True Positive Rate: {weighted_mean_metrics['true_positive_rate']:.3f}")
print(f"True Negative Rate: {weighted_mean_metrics['true_negative_rate']:.3f}")
print(f"False Positive Rate: {weighted_mean_metrics['false_positive_rate']:.3f}")
print(f"False Negative Rate: {weighted_mean_metrics['false_negative_rate']:.3f}")

print(f"\nClass Distribution:")
print(f"Positives: {df['total_positives'].sum():.0f} ({weighted_mean_metrics['positive_class_ratio']:.1%})")
print(f"Negatives: {df['total_negatives'].sum():.0f} ({weighted_mean_metrics['negative_class_ratio']:.1%})")
print(f"Total: {df['total_examples'].sum():.0f}")
if log:
print(f"\n--- {score_type.title()} Metrics ---")
print(f"Accuracy: {weighted_mean_metrics['accuracy']:.3f}")
print(f"F1 Score: {weighted_mean_metrics['f1_score']:.3f}")
print(f"Precision: {weighted_mean_metrics['precision']:.3f}")
print(f"Recall: {weighted_mean_metrics['recall']:.3f}")

fractions_failed = [failed_count / (total_examples + failed_count) for failed_count, total_examples in zip(df['failed_count'], df['total_examples'])]
print(f"Average fraction of failed examples: {sum(fractions_failed) / len(fractions_failed):.3f}")

print("\nConfusion Matrix:")
print(f"True Positive Rate: {weighted_mean_metrics['true_positive_rate']:.3f}")
print(f"True Negative Rate: {weighted_mean_metrics['true_negative_rate']:.3f}")
print(f"False Positive Rate: {weighted_mean_metrics['false_positive_rate']:.3f}")
print(f"False Negative Rate: {weighted_mean_metrics['false_negative_rate']:.3f}")

print(f"\nClass Distribution:")
print(f"Positives: {df['total_positives'].sum():.0f} ({weighted_mean_metrics['positive_class_ratio']:.1%})")
print(f"Negatives: {df['total_negatives'].sum():.0f} ({weighted_mean_metrics['negative_class_ratio']:.1%})")
print(f"Total: {df['total_examples'].sum():.0f}")

return weighted_mean_metrics

Expand All @@ -68,8 +73,8 @@ def parse_score_file(file_path):
} for example in data])

# Calculate basic counts
failed_count = (df['prediction'] == -1).sum()
df = df[df['prediction'] != -1]
failed_count = (df['prediction'].isna()).sum()
df = df[df['prediction'].notna()]
df.reset_index(drop=True, inplace=True)
total_examples = len(df)
total_positives = (df["ground_truth"]).sum()
Expand All @@ -93,7 +98,7 @@ def parse_score_file(file_path):
f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

# Calculate accuracy
accuracy = (true_positives + true_negatives) / total_examples
accuracy = (true_positives + true_negatives) / total_examples if total_examples > 0 else 0

# Add metrics to first row
metrics = {
Expand All @@ -112,8 +117,8 @@ def parse_score_file(file_path):
"total_examples": total_examples,
"total_positives": total_positives,
"total_negatives": total_negatives,
"positive_class_ratio": total_positives / total_examples,
"negative_class_ratio": total_negatives / total_examples,
"positive_class_ratio": total_positives / total_examples if total_examples > 0 else 0,
"negative_class_ratio": total_negatives / total_examples if total_examples > 0 else 0,
"failed_count": failed_count,
}

Expand All @@ -133,7 +138,7 @@ def build_scores_df(path: Path, target_modules: list[str], range: Tensor | None
]
df_data = {
col: []
for col in ["file_name", "score_type", "feature_idx", "module"] + metrics_cols
for col in ["file_name", "score_type", "latent_idx", "module"] + metrics_cols
}

# Get subdirectories in the scores path
Expand All @@ -146,16 +151,19 @@ def build_scores_df(path: Path, target_modules: list[str], range: Tensor | None
for score_file in list(score_type_path.glob(f"*{module}*")) + list(
score_type_path.glob(f".*{module}*")
):
feature_idx = int(score_file.stem.split("feature")[-1])
if range is not None and feature_idx not in range:
if "latent" in score_file.stem:
latent_idx = int(score_file.stem.split("latent")[-1])
else:
latent_idx = int(score_file.stem.split("feature")[-1])
if range is not None and latent_idx not in range:
continue

df = parse_score_file(score_file)

# Calculate the accuracy and cross entropy loss for this feature
# Calculate the accuracy and cross entropy loss for this latent
df_data["file_name"].append(score_file.stem)
df_data["score_type"].append(score_type)
df_data["feature_idx"].append(feature_idx)
df_data["latent_idx"].append(latent_idx)
df_data["module"].append(module)
for col in metrics_cols: df_data[col].append(df.loc[0, col])

Expand All @@ -165,8 +173,27 @@ def build_scores_df(path: Path, target_modules: list[str], range: Tensor | None
return df


def log_results(scores_path: Path, target_modules: list[str]):
def plot_line(df: pd.DataFrame, visualize_path: Path):
visualize_path.mkdir(parents=True, exist_ok=True)

for score_type in df["score_type"].unique():
mask = (df["score_type"] == score_type)

fig = px.histogram(
df[mask],
x="accuracy",
title=f"Latent explanation accuracies for {score_type} scorer",
nbins=100
)

fig.write_image(visualize_path / f"{score_type}_accuracies.pdf", format="pdf")


def log_results(scores_path: Path, visualize_path: Path, target_modules: list[str]):
df = build_scores_df(scores_path, target_modules)
plot_line(df, visualize_path)

for score_type in df["score_type"].unique():
score_df = df[df['score_type'] == score_type]
feature_balanced_score_metrics(score_df, score_type)
latent_balanced_score_metrics(score_df, score_type)

2 changes: 1 addition & 1 deletion delphi/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from tqdm.asyncio import tqdm


def process_wrapper(function: Callable, preprocess: Callable = None, postprocess: Callable = None) -> Callable:
def process_wrapper(function: Callable, preprocess: Callable | None = None, postprocess: Callable | None = None) -> Callable:
"""
Wraps a function with optional preprocessing and postprocessing steps.

Expand Down
13 changes: 7 additions & 6 deletions delphi/tests/e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@

from delphi.config import ExperimentConfig, LatentConfig, CacheConfig
from delphi.__main__ import run, RunConfig
from delphi.log.result_analysis import build_scores_df, feature_balanced_score_metrics
from delphi.log.result_analysis import build_scores_df, latent_balanced_score_metrics


async def test():
cache_cfg = CacheConfig(
dataset_repo="EleutherAI/rpj-v2-sample",
dataset_repo="EleutherAI/fineweb-edu-dedup-10b",
dataset_split="train[:1%]",
dataset_row="text",
batch_size=8,
ctx_len=256,
n_splits=5,
Expand All @@ -30,17 +31,17 @@ async def test():
max_examples=10_000, # The maximum number of examples a latent may activate on before being excluded from explanation
)
run_cfg = RunConfig(
name='test',
name="test",
overwrite=["cache", "scores"],
model="EleutherAI/pythia-160m",
sparse_model="EleutherAI/sae-pythia-160m-32k",
explainer_model="hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4",
hookpoints=["layers.3"],
explainer_model="ibnzterrell/Meta-Llama-3.3-70B-Instruct-AWQ-INT4",
explainer_model_max_len=4208,
max_latents=100,
seed=22,
num_gpus=torch.cuda.device_count(),
filter_bos=True
filter_bos=True,
)

start_time = time.time()
Expand All @@ -53,7 +54,7 @@ async def test():
df = build_scores_df(scores_path, run_cfg.hookpoints)
for score_type in df["score_type"].unique():
score_df = df[df['score_type'] == score_type]
weighted_mean_metrics = feature_balanced_score_metrics(score_df, score_type)
weighted_mean_metrics = latent_balanced_score_metrics(score_df, score_type, log=False)

assert weighted_mean_metrics['accuracy'] > 0.55, f"Score type {score_type} has an accuracy of {weighted_mean_metrics['accuracy']}"

Expand Down