From c6034bf024080efc9835fd0ca0fea759a5edb59b Mon Sep 17 00:00:00 2001 From: nev Date: Mon, 24 Mar 2025 15:21:17 +0000 Subject: [PATCH 1/2] Add functionality to .display() --- delphi/latents/latents.py | 175 ++++++++++++++++++++++++++++++++------ 1 file changed, 147 insertions(+), 28 deletions(-) diff --git a/delphi/latents/latents.py b/delphi/latents/latents.py index cf142f35..0de723e2 100644 --- a/delphi/latents/latents.py +++ b/delphi/latents/latents.py @@ -1,8 +1,9 @@ from dataclasses import dataclass, field -from typing import NamedTuple, Optional +from typing import Literal, NamedTuple, Optional import blobfile as bf import orjson +import torch from jaxtyping import Float from torch import Tensor from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -193,6 +194,8 @@ def display( tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, threshold: float = 0.0, n: int = 10, + do_display: bool = True, + example_source: Literal["examples", "train", "test"] = "examples", ): """ Display the latent record in a formatted string. @@ -206,9 +209,8 @@ def display( Returns: str: The formatted string. """ - from IPython.core.display import HTML, display # type: ignore - def _to_string(tokens: list[str], activations: Float[Tensor, "ctx_len"]) -> str: + def _to_string(toks, activations: Float[Tensor, "ctx_len"]) -> str: """ Convert tokens and activations to a string. @@ -219,28 +221,145 @@ def _to_string(tokens: list[str], activations: Float[Tensor, "ctx_len"]) -> str: Returns: str: The formatted string. """ - result = [] - i = 0 - - max_act = activations.max() - _threshold = max_act * threshold - - while i < len(tokens): - if activations[i] > _threshold: - result.append("") - while i < len(tokens) and activations[i] > _threshold: - result.append(tokens[i]) - i += 1 - result.append("") - else: - result.append(tokens[i]) - i += 1 - return "".join(result) - return "" - - strings = [ - _to_string(tokenizer.batch_decode(example.tokens), example.activations) - for example in self.examples[:n] - ] - - display(HTML("

".join(strings))) + text_spacing = "0.00em" + toks = convert_token_array_to_list(toks) + activations = convert_token_array_to_list(activations) + inverse_vocab = {v: k for k, v in tokenizer.vocab.items()} + toks = [ + [ + inverse_vocab[int(t)] + .replace("Ġ", " ") + .replace("▁", " ") + .replace("\n", "\\n") + for t in tok + ] + for tok in toks + ] + highlighted_text = [] + highlighted_text.append( + """ + + """ + ) + max_value = max([max(activ) for activ in activations]) + min_value = min([min(activ) for activ in activations]) + # Add color bar + highlighted_text.append( + "Token Activations: " + make_colorbar(min_value, max_value) + ) + + highlighted_text.append('
') + for seq_ind, (act, tok) in enumerate(zip(activations, toks)): + for act_ind, (a, t) in enumerate(zip(act, tok)): + text_color, background_color = value_to_color( + a, max_value, min_value + ) + highlighted_text.append( + f'{escape(t)}" + ) # noqa: E501 + highlighted_text.append('
') + highlighted_text = "".join(highlighted_text) + return highlighted_text + + match example_source: + case "examples": + examples = self.examples + case "train": + examples = self.train + case "test": + examples = [x[0] for x in self.test] + case _: + raise ValueError(f"Unknown example source: {example_source}") + examples = examples[:n] + strings = _to_string( + [example.tokens for example in examples], + [example.activations for example in examples], + ) + + if do_display: + from IPython.display import HTML, display + + display(HTML(strings)) + else: + return strings + + +def make_colorbar( + min_value, + max_value, + white=255, + red_blue_ness=250, + positive_threshold=0.01, + negative_threshold=0.01, +): + # Add color bar + colorbar = "" + num_colors = 4 + if min_value < -negative_threshold: + for i in range(num_colors, 0, -1): + ratio = i / (num_colors) + value = round((min_value * ratio), 1) + text_color = "255,255,255" if ratio > 0.5 else "0,0,0" + colorbar += f' {value} ' # noqa: E501 + # Do zero + colorbar += f' 0.0 ' # noqa: E501 + # Do positive + if max_value > positive_threshold: + for i in range(1, num_colors + 1): + ratio = i / (num_colors) + value = round((max_value * ratio), 1) + text_color = "255,255,255" if ratio > 0.5 else "0,0,0" + colorbar += f' {value} ' # noqa: E501 + return colorbar + + +def value_to_color( + activation, + max_value, + min_value, + white=255, + red_blue_ness=250, + positive_threshold=0.01, + negative_threshold=0.01, +): + if activation > positive_threshold: + ratio = activation / max_value + text_color = "0,0,0" if ratio <= 0.5 else "255,255,255" + background_color = f"rgba({int(red_blue_ness-(red_blue_ness*ratio))},{int(red_blue_ness-(red_blue_ness*ratio))},255,1)" # noqa: E501 + elif activation < -negative_threshold: + ratio = activation / min_value + text_color = "0,0,0" if ratio <= 0.5 else "255,255,255" + background_color = f"rgba(255, {int(red_blue_ness-(red_blue_ness*ratio))},{int(red_blue_ness-(red_blue_ness*ratio))},1)" # noqa: E501 + else: + text_color = "0,0,0" + background_color = f"rgba({white},{white},{white},1)" + return text_color, background_color + + +def convert_token_array_to_list(array): + if isinstance(array, torch.Tensor): + if array.dim() == 1: + array = [array.tolist()] + elif array.dim() == 2: + array = array.tolist() + else: + raise NotImplementedError("tokens must be 1 or 2 dimensional") + elif isinstance(array, list): + # ensure it's a list of lists + if isinstance(array[0], int): + array = [array] + if isinstance(array[0], torch.Tensor): + array = [t.tolist() for t in array] + return array + + +def escape(t): + t = ( + t.replace(" ", " ") + .replace("", "BOS") + .replace("<", "<") + .replace(">", ">") + ) + return t From aa0ccf2a008e3e169535e5ce4b45a0ef654ed888 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 23 Jun 2025 19:26:38 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- delphi/latents/latents.py | 1 - 1 file changed, 1 deletion(-) diff --git a/delphi/latents/latents.py b/delphi/latents/latents.py index 7275842c..0f4ff94d 100644 --- a/delphi/latents/latents.py +++ b/delphi/latents/latents.py @@ -4,7 +4,6 @@ import blobfile as bf import orjson import torch -from jaxtyping import Float from jaxtyping import Float, Int from torch import Tensor from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast