diff --git a/delphi/latents/latents.py b/delphi/latents/latents.py
index 91a4b176..0f4ff94d 100644
--- a/delphi/latents/latents.py
+++ b/delphi/latents/latents.py
@@ -1,8 +1,9 @@
from dataclasses import dataclass, field
-from typing import NamedTuple, Optional
+from typing import Literal, NamedTuple, Optional
import blobfile as bf
import orjson
+import torch
from jaxtyping import Float, Int
from torch import Tensor
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
@@ -203,6 +204,8 @@ def display(
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
threshold: float = 0.0,
n: int = 10,
+ do_display: bool = True,
+ example_source: Literal["examples", "train", "test"] = "examples",
):
"""
Display the latent record in a formatted string.
@@ -216,9 +219,8 @@ def display(
Returns:
str: The formatted string.
"""
- from IPython.core.display import HTML, display # type: ignore
- def _to_string(tokens: list[str], activations: Float[Tensor, "ctx_len"]) -> str:
+ def _to_string(toks, activations: Float[Tensor, "ctx_len"]) -> str:
"""
Convert tokens and activations to a string.
@@ -229,28 +231,145 @@ def _to_string(tokens: list[str], activations: Float[Tensor, "ctx_len"]) -> str:
Returns:
str: The formatted string.
"""
- result = []
- i = 0
-
- max_act = activations.max()
- _threshold = max_act * threshold
-
- while i < len(tokens):
- if activations[i] > _threshold:
- result.append("")
- while i < len(tokens) and activations[i] > _threshold:
- result.append(tokens[i])
- i += 1
- result.append("")
- else:
- result.append(tokens[i])
- i += 1
- return "".join(result)
- return ""
-
- strings = [
- _to_string(tokenizer.batch_decode(example.tokens), example.activations)
- for example in self.examples[:n]
- ]
-
- display(HTML("
".join(strings)))
+ text_spacing = "0.00em"
+ toks = convert_token_array_to_list(toks)
+ activations = convert_token_array_to_list(activations)
+ inverse_vocab = {v: k for k, v in tokenizer.vocab.items()}
+ toks = [
+ [
+ inverse_vocab[int(t)]
+ .replace("Ġ", " ")
+ .replace("▁", " ")
+ .replace("\n", "\\n")
+ for t in tok
+ ]
+ for tok in toks
+ ]
+ highlighted_text = []
+ highlighted_text.append(
+ """
+