From 5308cbc990bd2d4c4244e60ecbe2c0b7e5b2231b Mon Sep 17 00:00:00 2001 From: CharlesR-W Date: Mon, 18 Aug 2025 20:38:58 +0000 Subject: [PATCH 1/2] Add preliminary best of k explainer --- GP_run_experiment.py | 519 +++++++++++++++++++++++ delphi/explainers/__init__.py | 2 + delphi/explainers/best_of_k_explainer.py | 183 ++++++++ delphi/explainers/default/prompts.py | 15 + delphi/pipeline.py | 19 +- 5 files changed, 737 insertions(+), 1 deletion(-) create mode 100755 GP_run_experiment.py create mode 100644 delphi/explainers/best_of_k_explainer.py diff --git a/GP_run_experiment.py b/GP_run_experiment.py new file mode 100755 index 00000000..1dc316aa --- /dev/null +++ b/GP_run_experiment.py @@ -0,0 +1,519 @@ +import os + +os.environ["CUDA_VISIBLE_DEVICES"] = "6,7" + +import asyncio + +from functools import partial +from pathlib import Path +from typing import Callable + +import orjson +import torch +from torch import Tensor +from transformers import ( + AutoModel, + AutoTokenizer, + BitsAndBytesConfig, + PreTrainedModel, + PreTrainedTokenizer, + PreTrainedTokenizerFast, +) + +from delphi.delphi.clients import Offline, OpenRouter +from delphi.delphi.config import ( + RunConfig, + CacheConfig, + ConstructorConfig, + SamplerConfig, +) +from delphi.delphi.explainers import DefaultExplainer +from delphi.delphi.latents import LatentCache, LatentDataset # , LatentRecord +from delphi.delphi.latents.neighbours import NeighbourCalculator +from delphi.delphi.log.result_analysis import log_results +from delphi.delphi.pipeline import ( + Pipe, + Pipeline, + fan_out_fan_in_wrapper, + process_wrapper, +) +from delphi.delphi.scorers import DetectionScorer, FuzzingScorer +from delphi.delphi.sparse_coders import load_hooks_sparse_coders, load_sparse_coders +from delphi.delphi.utils import assert_type, load_tokenized_data + + +def load_artifacts(run_cfg: RunConfig): + if run_cfg.load_in_8bit: + dtype = torch.float16 + elif torch.cuda.is_bf16_supported(): + dtype = torch.bfloat16 + else: + dtype = "auto" + + model = AutoModel.from_pretrained( + run_cfg.model, + device_map={"": "cuda"}, + quantization_config=( + BitsAndBytesConfig(load_in_8bit=run_cfg.load_in_8bit) + if run_cfg.load_in_8bit + else None + ), + torch_dtype=dtype, + token=run_cfg.hf_token, + ) + + hookpoint_to_sparse_encode, transcode = load_hooks_sparse_coders( + model, + run_cfg, + compile=True, + ) + + return run_cfg.hookpoints, hookpoint_to_sparse_encode, model, transcode + + +def create_neighbours( + run_cfg: RunConfig, + latents_path: Path, + neighbours_path: Path, + hookpoints: list[str], +): + """ + Creates a neighbours file for the given hookpoints. + """ + neighbours_path.mkdir(parents=True, exist_ok=True) + + constructor_cfg = run_cfg.constructor_cfg + saes = ( + load_sparse_coders(run_cfg, device="cpu") + if constructor_cfg.neighbours_type != "co-occurrence" + else {} + ) + + for hookpoint in hookpoints: + if constructor_cfg.neighbours_type == "co-occurrence": + neighbour_calculator = NeighbourCalculator( + cache_dir=latents_path / hookpoint, number_of_neighbours=250 + ) + + elif constructor_cfg.neighbours_type == "decoder_similarity": + neighbour_calculator = NeighbourCalculator( + autoencoder=saes[hookpoint].cuda(), number_of_neighbours=250 + ) + + elif constructor_cfg.neighbours_type == "encoder_similarity": + neighbour_calculator = NeighbourCalculator( + autoencoder=saes[hookpoint].cuda(), number_of_neighbours=250 + ) + else: + raise ValueError( + f"Neighbour type {constructor_cfg.neighbours_type} not supported" + ) + + neighbour_calculator.populate_neighbour_cache(constructor_cfg.neighbours_type) + neighbour_calculator.save_neighbour_cache(f"{neighbours_path}/{hookpoint}") + + +async def process_cache( + run_cfg: RunConfig, + latents_path: Path, + neighbours_path: Path, + explanations_path: Path, + scores_path: Path, + hookpoints: list[str], + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, + latent_range: Tensor | None, + # non_active_to_show: int, +): + """ + Converts SAE latent activations in on-disk cache in the `latents_path` directory + to latent explanations in the `explanations_path` directory and explanation + scores in the `fuzz_scores_path` directory. + """ + explanations_path.mkdir(parents=True, exist_ok=True) + + fuzz_scores_path = scores_path / "fuzz" + detection_scores_path = scores_path / "detection" + fuzz_scores_path.mkdir(parents=True, exist_ok=True) + detection_scores_path.mkdir(parents=True, exist_ok=True) + + if latent_range is None: + latent_dict = None + else: + latent_dict = { + hook: latent_range for hook in hookpoints + } # The latent range to explain + + dataset = LatentDataset( + raw_dir=str(latents_path), + sampler_cfg=run_cfg.sampler_cfg, + constructor_cfg=run_cfg.constructor_cfg, + modules=hookpoints, + latents=latent_dict, + tokenizer=tokenizer, + neighbours_path=str(neighbours_path), + ) + + if run_cfg.explainer_provider == "offline": + client = Offline( + run_cfg.explainer_model, + max_memory=0.9, + # Explainer models context length - must be able to accommodate the longest + # set of examples + max_model_len=run_cfg.explainer_model_max_len, + num_gpus=run_cfg.num_gpus, + statistics=run_cfg.verbose, + ) + elif run_cfg.explainer_provider == "openrouter": + client = OpenRouter( + run_cfg.explainer_model, + api_key="", + ) + else: + raise ValueError( + f"Explainer provider {run_cfg.explainer_provider} not supported" + ) + + def explainer_postprocess(result): + result_dict = {} + result_dict["explanation"] = result.explanation + # result_dict["short_name"]=result.short_name + # result_dict["confidence"]=result.confidence + with open(explanations_path / f"{result.record.latent}.txt", "wb") as f: + f.write(orjson.dumps(result_dict)) + return result + + def explainer_preprocess(result): + if result is None: + return None + record = result + # remove the first non_active examples, save the rest in extra_examples + # record.extra_examples = record.not_active[non_active_to_show:] + # record.not_active = record.not_active[:non_active_to_show] + + return record + + explainer_pipe = process_wrapper( + DefaultExplainer( + client, + threshold=0.3, + verbose=run_cfg.verbose, + ), + preprocess=explainer_preprocess, + postprocess=explainer_postprocess, + ) + + # Builds the record from result returned by the pipeline + def scorer_preprocess(result): + record = result.record + + record.explanation = result.explanation + # record.not_active = record.extra_examples + + return record + + # Saves the score to a file + def scorer_postprocess(result, score_dir): + safe_latent_name = str(result.record.latent).replace("/", "--") + + with open(score_dir / f"{safe_latent_name}.txt", "wb") as f: + f.write(orjson.dumps(result.score)) + + # change the scorer wrapper to handle fan-out if using best of k. + def scorer_wrapper(function: Callable): + if run_cfg.explainer == "best_of_k": + return process_wrapper( + fan_out_fan_in_wrapper(function), + preprocess=scorer_preprocess, + postprocess=partial( + scorer_postprocess, score_dir=detection_scores_path + ), + ) + else: + return process_wrapper( + function, + preprocess=scorer_preprocess, + postprocess=partial( + scorer_postprocess, score_dir=detection_scores_path + ), + ) + + scorer_pipe = Pipe( + scorer_wrapper( + DetectionScorer( + client, + n_examples_shown=run_cfg.num_examples_per_scorer_prompt, + verbose=run_cfg.verbose, + log_prob=True, + ), + ), + FuzzingScorer( + client, + n_examples_shown=run_cfg.num_examples_per_scorer_prompt, + verbose=run_cfg.verbose, + log_prob=True, + ), + preprocess=scorer_preprocess, + postprocess=partial(scorer_postprocess, score_dir=fuzz_scores_path), + ) + + pipeline = Pipeline( + dataset, + explainer_pipe, + scorer_pipe, + ) + + await pipeline.run(10) + + +def populate_cache( + run_cfg: RunConfig, + model: PreTrainedModel, + hookpoint_to_sparse_encode: dict[str, Callable], + latents_path: Path, + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, + transcode: bool, +): + """ + Populates an on-disk cache in `latents_path` with SAE latent activations. + """ + latents_path.mkdir(parents=True, exist_ok=True) + + # Create a log path within the run directory + log_path = latents_path.parent / "log" + log_path.mkdir(parents=True, exist_ok=True) + + cache_cfg = run_cfg.cache_cfg + + tokens = load_tokenized_data( + cache_cfg.cache_ctx_len, + tokenizer, + cache_cfg.dataset_repo, + cache_cfg.dataset_split, + cache_cfg.dataset_name, + cache_cfg.dataset_column, + run_cfg.seed, + ) + + if run_cfg.filter_bos: + if tokenizer.bos_token_id is None: + print("Tokenizer does not have a BOS token, skipping BOS filtering") + else: + flattened_tokens = tokens.flatten() + mask = ~torch.isin(flattened_tokens, torch.tensor([tokenizer.bos_token_id])) + masked_tokens = flattened_tokens[mask] + truncated_tokens = masked_tokens[ + : len(masked_tokens) - (len(masked_tokens) % cache_cfg.cache_ctx_len) + ] + tokens = truncated_tokens.reshape(-1, cache_cfg.cache_ctx_len) + + cache = LatentCache( + model, + hookpoint_to_sparse_encode, + batch_size=cache_cfg.batch_size, + transcode=transcode, + log_path=log_path, + ) + cache.run(cache_cfg.n_tokens, tokens) + + cache.save_splits( + # Split the activation and location indices into different files to make + # loading faster + n_splits=cache_cfg.n_splits, + save_dir=latents_path, + ) + + cache.save_config(save_dir=latents_path, cfg=cache_cfg, model_name=run_cfg.model) + + +def non_redundant_hookpoints( + hookpoint_to_sparse_encode: dict[str, Callable] | list[str], + results_path: Path, + overwrite: bool, +) -> dict[str, Callable] | list[str]: + """ + Returns a list of hookpoints that are not already in the cache. + """ + if overwrite: + print("Overwriting results from", results_path) + return hookpoint_to_sparse_encode + in_results_path = [x.name for x in results_path.glob("*")] + if isinstance(hookpoint_to_sparse_encode, dict): + non_redundant_hookpoints = { + k: v + for k, v in hookpoint_to_sparse_encode.items() + if k not in in_results_path + } + else: + non_redundant_hookpoints = [ + hookpoint + for hookpoint in hookpoint_to_sparse_encode + if hookpoint not in in_results_path + ] + if not non_redundant_hookpoints: + print(f"Files found in {results_path}, skipping...") + return non_redundant_hookpoints + + +async def run( + run_cfg: RunConfig, + # start_latent: int, + # non_active_to_show: int, +): + base_path = Path.cwd() / "results" + + # latents_path = base_path / "latents" + + base_path = base_path / run_cfg.name + + base_path.mkdir(parents=True, exist_ok=True) + + run_cfg.save_json(base_path / "run_config.json", indent=4) + + # All latents will be in the first part of the name + + latents_path = base_path / "latents" + explanations_path = base_path / "explanations" + scores_path = base_path / "scores" + neighbours_path = base_path / "neighbours" + visualize_path = base_path / "visualize" + + latent_range = torch.arange(run_cfg.max_latents) if run_cfg.max_latents else None + + hookpoints, hookpoint_to_sparse_encode, model, transcode = load_artifacts(run_cfg) + tokenizer = AutoTokenizer.from_pretrained(run_cfg.model, token=run_cfg.hf_token) + + nrh = assert_type( + dict, + non_redundant_hookpoints( + hookpoint_to_sparse_encode, latents_path, "cache" in run_cfg.overwrite + ), + ) + + if nrh: + populate_cache( + run_cfg, + model, + hookpoint_to_sparse_encode, + latents_path, + tokenizer, + transcode, + ) + + del model, hookpoint_to_sparse_encode + if run_cfg.constructor_cfg.non_activating_source == "neighbours": + nrh = assert_type( + list, + non_redundant_hookpoints( + hookpoints, neighbours_path, "neighbours" in run_cfg.overwrite + ), + ) + if nrh: + create_neighbours( + run_cfg, + latents_path, + neighbours_path, + nrh, + ) + else: + print("Skipping neighbour creation") + + nrh = assert_type( + list, + non_redundant_hookpoints( + hookpoints, scores_path, "scores" in run_cfg.overwrite + ), + ) + if nrh: + await process_cache( + run_cfg, + latents_path, + neighbours_path, + explanations_path, + scores_path, + nrh, + tokenizer, + latent_range, + # non_active_to_show, + ) + + if run_cfg.verbose: + log_results(scores_path, visualize_path, run_cfg.hookpoints, run_cfg.scorers) + + +if __name__ == "__main__": + print("Creating cache config") + # Create the individual config objects + cache_cfg = CacheConfig( + dataset_repo="EleutherAI/SmolLM2-135M-10B", + # dataset_split="train[:1%]", + # dataset_name="", + # dataset_column="text", # default + # batch_size=32, + cache_ctx_len=32, + n_tokens=10_000_00, + # n_splits=5, + ) + + print("Creating constructor config") + constructor_cfg = ConstructorConfig( + # faiss_embedding_model="sentence-transformers/all-MiniLM-L6-v2", + # faiss_embedding_cache_dir=".embedding_cache", + # faiss_embedding_cache_enabled=True, + example_ctx_len=32, + min_examples=20, # Increased to allow for test examples + n_non_activating=10, # Reduced for smoke test + # center_examples=True, + # non_activating_source="random", + # neighbours_type="co-occurrence", + ) + + print("Creating sampler config") + sampler_cfg = SamplerConfig( + n_examples_train=20, # Reduced for smoke test + n_examples_test=10, # Reduced for smoke test + n_quantiles=5, # Reduced to work with fewer examples + # train_type="quantiles", + # test_type="quantiles", + # ratio_top=0.2, + ) + + print("Creating run config") + # Create RunConfig object with the same parameters as the shell script + run_cfg = RunConfig( + cache_cfg=cache_cfg, + # skip_generate_cache_if_exists=True, + constructor_cfg=constructor_cfg, + sampler_cfg=sampler_cfg, + model="EleutherAI/pythia-70m", + sparse_model="EleutherAI/sae-pythia-70m-32k", + hookpoints=["layers.5"], + # explainer_model="EleutherAI/pythia-70m", + explainer_model="hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + explainer_model_max_len=5120, + explainer_provider="offline", + # explainer="best_of_k", + explainer="default", + # num_explanations=3, + scorers=["fuzz", "detection"], + name="pythia-70m-smoketest", + max_latents=10, + filter_bos=True, + # log_probs=False, + # load_in_8bit=False, + # hf_token=None, + # pipeline_num_proc=4, + num_gpus=2, + # seed=22, + # verbose=True, + # num_examples_per_scorer_prompt=5, + # overwrite=[], + ) + + # NUM_LATENTS_PRINT = 10 + + """parser = ArgumentParser() + parser.add_arguments(RunConfig, dest="run_cfg") + args = parser.parse_args() + + asyncio.run(run(args.run_cfg))""" + asyncio.run(run(run_cfg)) diff --git a/delphi/explainers/__init__.py b/delphi/explainers/__init__.py index 8cbc5579..ac1abc92 100644 --- a/delphi/explainers/__init__.py +++ b/delphi/explainers/__init__.py @@ -1,3 +1,4 @@ +from .best_of_k_explainer import BestOfKExplainer from .contrastive_explainer import ContrastiveExplainer from .default.default import DefaultExplainer from .explainer import Explainer, explanation_loader, random_explanation_loader @@ -5,6 +6,7 @@ from .single_token_explainer import SingleTokenExplainer __all__ = [ + "BestOfKExplainer", "Explainer", "DefaultExplainer", "SingleTokenExplainer", diff --git a/delphi/explainers/best_of_k_explainer.py b/delphi/explainers/best_of_k_explainer.py new file mode 100644 index 00000000..5a114016 --- /dev/null +++ b/delphi/explainers/best_of_k_explainer.py @@ -0,0 +1,183 @@ +import asyncio +import re +from dataclasses import dataclass +from typing import List + +from delphi import logger +from delphi.explainers.default.prompts import SYSTEM_BEST_OF_K_ONESHOT +from delphi.explainers.explainer import Explainer, ExplainerResult, Response +from delphi.latents.latents import ActivatingExample, LatentRecord, NonActivatingExample + + +@dataclass +class BestOfKExplainer(Explainer): + num_explanations: int = 3 + """Number of explanations to generate.""" + is_one_shot: bool = True + """Whether to use a different prompt for each explanation, or just request K""" + + async def __call__(self, record: LatentRecord) -> List[ExplainerResult]: + """ + Override the base __call__ method to implement the best-of-k explainer. + + Args: + record: The latent record containing both activating and + non-activating examples. + + Returns: + ExplainerResultMulti: The explainer result containing the explanations. + """ + messages = self._build_prompt(record.train) + response = await self.client.generate( + messages, temperature=self.temperature, **self.generation_kwargs + ) + + try: + if isinstance(response, Response): + response_text = response.text + else: + response_text = response + assert isinstance(response_text, str) + + explanations = self.parse_explanations(response_text) + + if self.verbose: + logger.info( + f"[BestOfKExplainer::__call__] Explanation(s): {explanations}" + ) + logger.info( + f"[BestOfKExplainer::__call__] Messages: {messages[-1]['content']}" + ) + logger.info(f"[BestOfKExplainer::__call__] Response: {response}") + + return [ + ExplainerResult(record=record, explanation=explanation) + for explanation in explanations + ] + # ExplainerResultMulti(record=record, explanations=explanations) + except Exception as e: + logger.error( + f"[BestOfKExplainer::__call__] Explanation parsing failed: {repr(e)}" + ) + return [ + ExplainerResult( + record=record, + explanation="[BestOfKExplainer::__call__] Explanation\ + could not be parsed.", + ) + ] + + def parse_explanations(self, text: str) -> List[str]: + try: + # Extract as many explanation lines as present, up to K, but return whatever + # we find + pattern = re.compile( + r"^\s*\[\s*EXPLANATION\s*\]\s*:\s*(.+?)\s*$", # \s* is whitespace + re.IGNORECASE | re.MULTILINE, # Multiline makes ^ and $ activate for + # each line, not just whole string + ) + matches = pattern.findall(text) + matches = [m.strip() for m in matches] + + if matches: + # Return up to the requested number; join to keep return + # type consistent (str) + if self.verbose: + logger.info( + f"[BestOfKExplainer::parse_explanation] Found {len(matches)} \ + well-formed explanations. Requested {self.num_explanations}" + ) + if len(matches) > self.num_explanations: + logger.warning( + f"[BestOfKExplainer::parse_explanations] Found {len(matches)} \ + explanations, but requested {self.num_explanations}. \ + Returning {self.num_explanations} explanations." + ) + return matches + else: + logger.error( + "[BestOfKExplainer::parse_explanations] No explanations found." + + "\n[BestOfKExplainer::parse_explanations] Text: {text}" + ) + + return ["[BestOfKExplainer::parse_explanations] No explanations found."] + except Exception as e: + logger.error( + f"[BestOfKExplainer::parse_explanations] Explanation parsing \ + regex failed: {repr(e)}" + ) + raise + + def _build_prompt( # type: ignore + self, examples: list[ActivatingExample | NonActivatingExample] + ) -> list[dict]: + """ + Build a prompt with both activating and non-activating examples clearly labeled. + + Args: + examples: List containing both activating and non-activating examples. + + Returns: + A list of message dictionaries for the prompt. + """ + highlighted_examples = [] + + # First, separate activating and non-activating examples + activating_examples = [ + ex for ex in examples if isinstance(ex, ActivatingExample) + ] + non_activating_examples = [ + ex for ex in examples if not isinstance(ex, ActivatingExample) + ] + + # Process activating examples + if activating_examples: + highlighted_examples.append("ACTIVATING EXAMPLES:") + for i, example in enumerate(activating_examples, 1): + str_toks = example.str_tokens + activations = example.activations.tolist() + ex = self._highlight(str_toks, activations).strip().replace("\n", "") + highlighted_examples.append(f"Example {i}: {ex}") + + if self.activations and example.normalized_activations is not None: + normalized_activations = example.normalized_activations.tolist() + highlighted_examples.append( + self._join_activations( + str_toks, activations, normalized_activations + ) + ) + + # Process non-activating examples + if non_activating_examples: + highlighted_examples.append("\nNON-ACTIVATING EXAMPLES:") + for i, example in enumerate(non_activating_examples, 1): + str_toks = example.str_tokens + activations = example.activations.tolist() + # Note: For non-activating examples, the _highlight method won't + # highlight anything since activation values will be below threshold + ex = self._highlight(str_toks, activations).strip().replace("\n", "") + highlighted_examples.append(f"Example {i}: {ex}") + + # Join all sections into a single string + highlighted_examples_str = "\n".join(highlighted_examples) + num_explanations_str = ( + "\n" + + f"The number of explanations you are asked to \ + generate is: {self.num_explanations}." + ) + system_prompt = SYSTEM_BEST_OF_K_ONESHOT + num_explanations_str + # Create messages array with the system prompt + return [ + { + "role": "system", + "content": system_prompt, + }, + { + "role": "user", + "content": highlighted_examples_str, + }, + ] + + def call_sync(self, record: LatentRecord) -> List[ExplainerResult]: + """Synchronous wrapper for the asynchronous __call__ method.""" + return asyncio.run(self.__call__(record)) diff --git a/delphi/explainers/default/prompts.py b/delphi/explainers/default/prompts.py index ddb7bace..7013948f 100644 --- a/delphi/explainers/default/prompts.py +++ b/delphi/explainers/default/prompts.py @@ -62,6 +62,21 @@ """ +SYSTEM_BEST_OF_K_ONESHOT = """You are a meticulous AI researcher conducting an important investigation into patterns found in language.""" \ ++"""These patterns will be presented to you in the form of fragments of text wherein certain words are marked as activating the pattern in question.""" \ ++"""Your task is to analyze text and conjecture one or more explanations of the pattern which thoroughly the patterns shown. +Guidelines: + +You will be given a list of text examples on which special words are selected and placed between delimiters like <>. """\ ++"""If a sequence of consecutive tokens all are important, the entire sequence of tokens will be contained between delimiters <>. How important each token is for the behavior is listed as a parenthesized number after each example. + +- Below, an integer K will be specified; you are tasked with providing K meaningfully distinct possible explanations for the patterns found in the examples. These will be scored and the best chosen. +- Try to be concise, but complete - the best answers are those which give a concise natural-lnguage condition which is necessary and sufficient for the pattern to hold. +- Counterexamples where no activating words are present are also provided to help you understand the pattern's negative cases. +- If the examples are uninformative, you don't need to mention them. Don't focus on giving examples of important tokens +- Do not mention the marker tokens (<< >>) in your explanation. +- The final K lines of your response must be the K formatted explanations; each explanation should be placed on its own line and preceded by [EXPLANATION]:. This text will be processed programmatically so it is imperative you obey these formatting restrictions without fail. +""" ### EXAMPLE 1 ### diff --git a/delphi/pipeline.py b/delphi/pipeline.py index 428d817e..b3da1e1c 100644 --- a/delphi/pipeline.py +++ b/delphi/pipeline.py @@ -1,7 +1,7 @@ import asyncio from collections.abc import AsyncIterable, Awaitable, Callable from functools import wraps -from typing import Any +from typing import Any, List from tqdm.asyncio import tqdm @@ -40,6 +40,23 @@ async def wrapped(input: Any): return wrapped +def fan_out_fan_in_wrapper( + function: Callable[..., Awaitable], +) -> Callable[..., Awaitable]: + """ + Wraps a function with fan-out and fan-in steps. + Applies the function element-wise to a list of inputs. + """ + + @wraps(function) + async def wrapped(input: List[Any]) -> List[Any]: + tasks = [asyncio.create_task(function(item)) for item in input] + results = await asyncio.gather(*tasks) + return results + + return wrapped + + class Pipe: """ Represents a pipe of functions to be executed with the same input. From 7eb14a93293c8a04f8415002b937723b125a21f5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 18 Aug 2025 20:54:04 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- GP_run_experiment.py | 3 +-- delphi/explainers/default/prompts.py | 12 +++++++----- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/GP_run_experiment.py b/GP_run_experiment.py index 1dc316aa..d4badd53 100755 --- a/GP_run_experiment.py +++ b/GP_run_experiment.py @@ -3,7 +3,6 @@ os.environ["CUDA_VISIBLE_DEVICES"] = "6,7" import asyncio - from functools import partial from pathlib import Path from typing import Callable @@ -22,9 +21,9 @@ from delphi.delphi.clients import Offline, OpenRouter from delphi.delphi.config import ( - RunConfig, CacheConfig, ConstructorConfig, + RunConfig, SamplerConfig, ) from delphi.delphi.explainers import DefaultExplainer diff --git a/delphi/explainers/default/prompts.py b/delphi/explainers/default/prompts.py index 7013948f..58528d0c 100644 --- a/delphi/explainers/default/prompts.py +++ b/delphi/explainers/default/prompts.py @@ -62,13 +62,14 @@ """ -SYSTEM_BEST_OF_K_ONESHOT = """You are a meticulous AI researcher conducting an important investigation into patterns found in language.""" \ -+"""These patterns will be presented to you in the form of fragments of text wherein certain words are marked as activating the pattern in question.""" \ -+"""Your task is to analyze text and conjecture one or more explanations of the pattern which thoroughly the patterns shown. +SYSTEM_BEST_OF_K_ONESHOT = ( + """You are a meticulous AI researcher conducting an important investigation into patterns found in language.""" + + """These patterns will be presented to you in the form of fragments of text wherein certain words are marked as activating the pattern in question.""" + + """Your task is to analyze text and conjecture one or more explanations of the pattern which thoroughly the patterns shown. Guidelines: -You will be given a list of text examples on which special words are selected and placed between delimiters like <>. """\ -+"""If a sequence of consecutive tokens all are important, the entire sequence of tokens will be contained between delimiters <>. How important each token is for the behavior is listed as a parenthesized number after each example. +You will be given a list of text examples on which special words are selected and placed between delimiters like <>. """ + + """If a sequence of consecutive tokens all are important, the entire sequence of tokens will be contained between delimiters <>. How important each token is for the behavior is listed as a parenthesized number after each example. - Below, an integer K will be specified; you are tasked with providing K meaningfully distinct possible explanations for the patterns found in the examples. These will be scored and the best chosen. - Try to be concise, but complete - the best answers are those which give a concise natural-lnguage condition which is necessary and sufficient for the pattern to hold. @@ -77,6 +78,7 @@ - Do not mention the marker tokens (<< >>) in your explanation. - The final K lines of your response must be the K formatted explanations; each explanation should be placed on its own line and preceded by [EXPLANATION]:. This text will be processed programmatically so it is imperative you obey these formatting restrictions without fail. """ +) ### EXAMPLE 1 ###