diff --git a/delphi/__main__.py b/delphi/__main__.py index 5bef1bb7..f7b36f98 100644 --- a/delphi/__main__.py +++ b/delphi/__main__.py @@ -114,6 +114,12 @@ class RunConfig: overwrite: list[str] = list_field() """Whether to overwrite existing parts of the run. Options are 'cache', 'scores', and 'visualize'.""" + num_examples_per_scorer_prompt: int = field( + default=1, + ) + """Number of examples to use for each scorer prompt. Using more than 1 improves scoring speed but can + leak information to the fuzzing scorer, and increases the scorer LLM task difficulty.""" + def load_gemma_autoencoders(model, ae_layers: list[int],average_l0s: Dict[int,int],size:str,type:str, hookpoints): submodules = {} @@ -131,6 +137,8 @@ def _forward(sae, x): submodule = model.model.layers[layer] elif type == "mlp": submodule = model.model.layers[layer].post_feedforward_layernorm + else: + raise ValueError(f"Invalid autoencoder type: {type}") submodule.ae = AutoencoderLatents( sae, partial(_forward, sae), width=sae.W_enc.shape[1] ) @@ -264,7 +272,7 @@ async def process_cache( constructor = partial( default_constructor, token_loader=None, - n_random=experiment_cfg.n_random, + n_not_active=experiment_cfg.n_non_activating, ctx_len=experiment_cfg.example_ctx_len, max_examples=latent_cfg.max_examples, ) @@ -304,7 +312,7 @@ def scorer_postprocess(result, score_dir): DetectionScorer( client, tokenizer=dataset.tokenizer, # type: ignore - batch_size=10, + batch_size=run_cfg.num_examples_per_scorer_prompt, verbose=False, log_prob=False, ), @@ -315,7 +323,7 @@ def scorer_postprocess(result, score_dir): FuzzingScorer( client, tokenizer=dataset.tokenizer, # type: ignore - batch_size=10, + batch_size=run_cfg.num_examples_per_scorer_prompt, verbose=False, log_prob=False, ), @@ -363,10 +371,10 @@ def populate_cache( flattened_tokens = tokens.flatten() mask = ~torch.isin(flattened_tokens, torch.tensor([tokenizer.bos_token_id])) masked_tokens = flattened_tokens[mask] - truncated_tokens = masked_tokens[ - : len(masked_tokens) - (len(masked_tokens) % cfg.ctx_len) - ] - tokens = truncated_tokens.reshape(-1, cfg.ctx_len) + truncated_tokens = masked_tokens[ + : len(masked_tokens) - (len(masked_tokens) % cfg.ctx_len) + ] + tokens = truncated_tokens.reshape(-1, cfg.ctx_len) tokens = cast(TensorType["batch", "seq"], tokens) @@ -399,6 +407,7 @@ async def run(experiment_cfg: ExperimentConfig, latent_cfg: LatentConfig, cache_ latents_path = base_path / "latents" explanations_path = base_path / "explanations" scores_path = base_path / "scores" + visualize_path = base_path / "visualize" latent_range = ( torch.arange(run_cfg.max_latents) if run_cfg.max_latents else None @@ -445,7 +454,7 @@ async def run(experiment_cfg: ExperimentConfig, latent_cfg: LatentConfig, cache_ print(f"Files found in {scores_path}, skipping...") if run_cfg.log: - log_results(scores_path, run_cfg.hookpoints) + log_results(scores_path, visualize_path, run_cfg.hookpoints) if __name__ == "__main__": diff --git a/delphi/autoencoders/DeepMind/__init__.py b/delphi/autoencoders/DeepMind/__init__.py index 536354d5..52e668db 100644 --- a/delphi/autoencoders/DeepMind/__init__.py +++ b/delphi/autoencoders/DeepMind/__init__.py @@ -25,6 +25,8 @@ def _forward(sae, x): submodule = model.model.layers[layer] elif type == "mlp": submodule = model.model.layers[layer].post_feedforward_layernorm + else: + raise ValueError(f"Invalid autoencoder type: {type}") submodule.ae = AutoencoderLatents( sae, partial(_forward, sae), width=sae.W_enc.shape[1] ) diff --git a/delphi/autoencoders/wrapper.py b/delphi/autoencoders/wrapper.py index 299a1d77..765eb54d 100644 --- a/delphi/autoencoders/wrapper.py +++ b/delphi/autoencoders/wrapper.py @@ -43,7 +43,7 @@ def from_pretrained( autoencoder_type = config.autoencoder_type model_name_or_path = config.model_name_or_path if autoencoder_type == "SAE": - from sae import Sae + from sparsify import Sae local = kwargs.get("local",None) assert local is not None, "local must be specified for SAE" if local: diff --git a/delphi/log/result_analysis.py b/delphi/log/result_analysis.py index 8261bb2f..6b833345 100644 --- a/delphi/log/result_analysis.py +++ b/delphi/log/result_analysis.py @@ -3,9 +3,13 @@ from torch import Tensor from pathlib import Path import numpy as np +import plotly.express as px +import plotly.io as pio +pio.kaleido.scope.mathjax = None # https://github.com/plotly/plotly.py/issues/3469 -def feature_balanced_score_metrics(df: pd.DataFrame, score_type: str): + +def latent_balanced_score_metrics(df: pd.DataFrame, score_type: str, log: bool = True): # Calculate weights based on non-errored examples valid_examples = df['total_examples'] weights = valid_examples / valid_examples.sum() @@ -29,25 +33,26 @@ def feature_balanced_score_metrics(df: pd.DataFrame, score_type: str): 'false_negative_rate': np.average(df['false_negative_rate'], weights=weights), } - print(f"\n--- {score_type.title()} Metrics ---") - print(f"Accuracy: {weighted_mean_metrics['accuracy']:.3f}") - print(f"F1 Score: {weighted_mean_metrics['f1_score']:.3f}") - print(f"Precision: {weighted_mean_metrics['precision']:.3f}") - print(f"Recall: {weighted_mean_metrics['recall']:.3f}") - - fractions_failed = [failed_count / total_examples for failed_count, total_examples in zip(df['failed_count'], df['total_examples'])] - print(f"Average fraction of failed examples: {sum(fractions_failed) / len(fractions_failed):.3f}") - - print("\nConfusion Matrix:") - print(f"True Positive Rate: {weighted_mean_metrics['true_positive_rate']:.3f}") - print(f"True Negative Rate: {weighted_mean_metrics['true_negative_rate']:.3f}") - print(f"False Positive Rate: {weighted_mean_metrics['false_positive_rate']:.3f}") - print(f"False Negative Rate: {weighted_mean_metrics['false_negative_rate']:.3f}") - - print(f"\nClass Distribution:") - print(f"Positives: {df['total_positives'].sum():.0f} ({weighted_mean_metrics['positive_class_ratio']:.1%})") - print(f"Negatives: {df['total_negatives'].sum():.0f} ({weighted_mean_metrics['negative_class_ratio']:.1%})") - print(f"Total: {df['total_examples'].sum():.0f}") + if log: + print(f"\n--- {score_type.title()} Metrics ---") + print(f"Accuracy: {weighted_mean_metrics['accuracy']:.3f}") + print(f"F1 Score: {weighted_mean_metrics['f1_score']:.3f}") + print(f"Precision: {weighted_mean_metrics['precision']:.3f}") + print(f"Recall: {weighted_mean_metrics['recall']:.3f}") + + fractions_failed = [failed_count / (total_examples + failed_count) for failed_count, total_examples in zip(df['failed_count'], df['total_examples'])] + print(f"Average fraction of failed examples: {sum(fractions_failed) / len(fractions_failed):.3f}") + + print("\nConfusion Matrix:") + print(f"True Positive Rate: {weighted_mean_metrics['true_positive_rate']:.3f}") + print(f"True Negative Rate: {weighted_mean_metrics['true_negative_rate']:.3f}") + print(f"False Positive Rate: {weighted_mean_metrics['false_positive_rate']:.3f}") + print(f"False Negative Rate: {weighted_mean_metrics['false_negative_rate']:.3f}") + + print(f"\nClass Distribution:") + print(f"Positives: {df['total_positives'].sum():.0f} ({weighted_mean_metrics['positive_class_ratio']:.1%})") + print(f"Negatives: {df['total_negatives'].sum():.0f} ({weighted_mean_metrics['negative_class_ratio']:.1%})") + print(f"Total: {df['total_examples'].sum():.0f}") return weighted_mean_metrics @@ -68,8 +73,8 @@ def parse_score_file(file_path): } for example in data]) # Calculate basic counts - failed_count = (df['prediction'] == -1).sum() - df = df[df['prediction'] != -1] + failed_count = (df['prediction'].isna()).sum() + df = df[df['prediction'].notna()] df.reset_index(drop=True, inplace=True) total_examples = len(df) total_positives = (df["ground_truth"]).sum() @@ -93,7 +98,7 @@ def parse_score_file(file_path): f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 # Calculate accuracy - accuracy = (true_positives + true_negatives) / total_examples + accuracy = (true_positives + true_negatives) / total_examples if total_examples > 0 else 0 # Add metrics to first row metrics = { @@ -112,8 +117,8 @@ def parse_score_file(file_path): "total_examples": total_examples, "total_positives": total_positives, "total_negatives": total_negatives, - "positive_class_ratio": total_positives / total_examples, - "negative_class_ratio": total_negatives / total_examples, + "positive_class_ratio": total_positives / total_examples if total_examples > 0 else 0, + "negative_class_ratio": total_negatives / total_examples if total_examples > 0 else 0, "failed_count": failed_count, } @@ -133,7 +138,7 @@ def build_scores_df(path: Path, target_modules: list[str], range: Tensor | None ] df_data = { col: [] - for col in ["file_name", "score_type", "feature_idx", "module"] + metrics_cols + for col in ["file_name", "score_type", "latent_idx", "module"] + metrics_cols } # Get subdirectories in the scores path @@ -146,16 +151,19 @@ def build_scores_df(path: Path, target_modules: list[str], range: Tensor | None for score_file in list(score_type_path.glob(f"*{module}*")) + list( score_type_path.glob(f".*{module}*") ): - feature_idx = int(score_file.stem.split("feature")[-1]) - if range is not None and feature_idx not in range: + if "latent" in score_file.stem: + latent_idx = int(score_file.stem.split("latent")[-1]) + else: + latent_idx = int(score_file.stem.split("feature")[-1]) + if range is not None and latent_idx not in range: continue df = parse_score_file(score_file) - # Calculate the accuracy and cross entropy loss for this feature + # Calculate the accuracy and cross entropy loss for this latent df_data["file_name"].append(score_file.stem) df_data["score_type"].append(score_type) - df_data["feature_idx"].append(feature_idx) + df_data["latent_idx"].append(latent_idx) df_data["module"].append(module) for col in metrics_cols: df_data[col].append(df.loc[0, col]) @@ -165,8 +173,27 @@ def build_scores_df(path: Path, target_modules: list[str], range: Tensor | None return df -def log_results(scores_path: Path, target_modules: list[str]): +def plot_line(df: pd.DataFrame, visualize_path: Path): + visualize_path.mkdir(parents=True, exist_ok=True) + + for score_type in df["score_type"].unique(): + mask = (df["score_type"] == score_type) + + fig = px.histogram( + df[mask], + x="accuracy", + title=f"Latent explanation accuracies for {score_type} scorer", + nbins=100 + ) + + fig.write_image(visualize_path / f"{score_type}_accuracies.pdf", format="pdf") + + +def log_results(scores_path: Path, visualize_path: Path, target_modules: list[str]): df = build_scores_df(scores_path, target_modules) + plot_line(df, visualize_path) + for score_type in df["score_type"].unique(): score_df = df[df['score_type'] == score_type] - feature_balanced_score_metrics(score_df, score_type) + latent_balanced_score_metrics(score_df, score_type) + diff --git a/delphi/pipeline.py b/delphi/pipeline.py index 386503c9..dc8c5d21 100644 --- a/delphi/pipeline.py +++ b/delphi/pipeline.py @@ -5,7 +5,7 @@ from tqdm.asyncio import tqdm -def process_wrapper(function: Callable, preprocess: Callable = None, postprocess: Callable = None) -> Callable: +def process_wrapper(function: Callable, preprocess: Callable | None = None, postprocess: Callable | None = None) -> Callable: """ Wraps a function with optional preprocessing and postprocessing steps. diff --git a/delphi/tests/e2e.py b/delphi/tests/e2e.py index ba6b87a6..4491c132 100644 --- a/delphi/tests/e2e.py +++ b/delphi/tests/e2e.py @@ -6,13 +6,14 @@ from delphi.config import ExperimentConfig, LatentConfig, CacheConfig from delphi.__main__ import run, RunConfig -from delphi.log.result_analysis import build_scores_df, feature_balanced_score_metrics +from delphi.log.result_analysis import build_scores_df, latent_balanced_score_metrics async def test(): cache_cfg = CacheConfig( - dataset_repo="EleutherAI/rpj-v2-sample", + dataset_repo="EleutherAI/fineweb-edu-dedup-10b", dataset_split="train[:1%]", + dataset_row="text", batch_size=8, ctx_len=256, n_splits=5, @@ -30,17 +31,17 @@ async def test(): max_examples=10_000, # The maximum number of examples a latent may activate on before being excluded from explanation ) run_cfg = RunConfig( - name='test', + name="test", overwrite=["cache", "scores"], model="EleutherAI/pythia-160m", sparse_model="EleutherAI/sae-pythia-160m-32k", - explainer_model="hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4", hookpoints=["layers.3"], + explainer_model="ibnzterrell/Meta-Llama-3.3-70B-Instruct-AWQ-INT4", explainer_model_max_len=4208, max_latents=100, seed=22, num_gpus=torch.cuda.device_count(), - filter_bos=True + filter_bos=True, ) start_time = time.time() @@ -53,7 +54,7 @@ async def test(): df = build_scores_df(scores_path, run_cfg.hookpoints) for score_type in df["score_type"].unique(): score_df = df[df['score_type'] == score_type] - weighted_mean_metrics = feature_balanced_score_metrics(score_df, score_type) + weighted_mean_metrics = latent_balanced_score_metrics(score_df, score_type, log=False) assert weighted_mean_metrics['accuracy'] > 0.55, f"Score type {score_type} has an accuracy of {weighted_mean_metrics['accuracy']}"