Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 147 additions & 28 deletions delphi/latents/latents.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from dataclasses import dataclass, field
from typing import NamedTuple, Optional
from typing import Literal, NamedTuple, Optional

import blobfile as bf
import orjson
import torch
from jaxtyping import Float, Int
from torch import Tensor
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
Expand Down Expand Up @@ -203,6 +204,8 @@ def display(
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
threshold: float = 0.0,
n: int = 10,
do_display: bool = True,
example_source: Literal["examples", "train", "test"] = "examples",
):
"""
Display the latent record in a formatted string.
Expand All @@ -216,9 +219,8 @@ def display(
Returns:
str: The formatted string.
"""
from IPython.core.display import HTML, display # type: ignore

def _to_string(tokens: list[str], activations: Float[Tensor, "ctx_len"]) -> str:
def _to_string(toks, activations: Float[Tensor, "ctx_len"]) -> str:
"""
Convert tokens and activations to a string.

Expand All @@ -229,28 +231,145 @@ def _to_string(tokens: list[str], activations: Float[Tensor, "ctx_len"]) -> str:
Returns:
str: The formatted string.
"""
result = []
i = 0

max_act = activations.max()
_threshold = max_act * threshold

while i < len(tokens):
if activations[i] > _threshold:
result.append("<mark>")
while i < len(tokens) and activations[i] > _threshold:
result.append(tokens[i])
i += 1
result.append("</mark>")
else:
result.append(tokens[i])
i += 1
return "".join(result)
return ""

strings = [
_to_string(tokenizer.batch_decode(example.tokens), example.activations)
for example in self.examples[:n]
]

display(HTML("<br><br>".join(strings)))
text_spacing = "0.00em"
toks = convert_token_array_to_list(toks)
activations = convert_token_array_to_list(activations)
inverse_vocab = {v: k for k, v in tokenizer.vocab.items()}
toks = [
[
inverse_vocab[int(t)]
.replace("Ġ", " ")
.replace("▁", " ")
.replace("\n", "\\n")
for t in tok
]
for tok in toks
]
highlighted_text = []
highlighted_text.append(
"""
<body style="background-color: black; color: white;">
"""
)
max_value = max([max(activ) for activ in activations])
min_value = min([min(activ) for activ in activations])
# Add color bar
highlighted_text.append(
"Token Activations: " + make_colorbar(min_value, max_value)
)

highlighted_text.append('<div style="margin-top: 0.5em;"></div>')
for seq_ind, (act, tok) in enumerate(zip(activations, toks)):
for act_ind, (a, t) in enumerate(zip(act, tok)):
text_color, background_color = value_to_color(
a, max_value, min_value
)
highlighted_text.append(
f'<span style="background-color:{background_color};'
f'margin-right: {text_spacing}; color:rgb({text_color})"'
f">{escape(t)}</span>"
) # noqa: E501
highlighted_text.append('<div style="margin-top: 0.2em;"></div>')
highlighted_text = "".join(highlighted_text)
return highlighted_text

match example_source:
case "examples":
examples = self.examples
case "train":
examples = self.train
case "test":
examples = [x[0] for x in self.test]
case _:
raise ValueError(f"Unknown example source: {example_source}")
examples = examples[:n]
strings = _to_string(
[example.tokens for example in examples],
[example.activations for example in examples],
)

if do_display:
from IPython.display import HTML, display

display(HTML(strings))
else:
return strings


def make_colorbar(
min_value,
max_value,
white=255,
red_blue_ness=250,
positive_threshold=0.01,
negative_threshold=0.01,
):
# Add color bar
colorbar = ""
num_colors = 4
if min_value < -negative_threshold:
for i in range(num_colors, 0, -1):
ratio = i / (num_colors)
value = round((min_value * ratio), 1)
text_color = "255,255,255" if ratio > 0.5 else "0,0,0"
colorbar += f'<span style="background-color:rgba(255, {int(red_blue_ness-(red_blue_ness*ratio))},{int(red_blue_ness-(red_blue_ness*ratio))},1); color:rgb({text_color})">&nbsp{value}&nbsp</span>' # noqa: E501
# Do zero
colorbar += f'<span style="background-color:rgba({white},{white},{white},1);color:rgb(0,0,0)">&nbsp0.0&nbsp</span>' # noqa: E501
# Do positive
if max_value > positive_threshold:
for i in range(1, num_colors + 1):
ratio = i / (num_colors)
value = round((max_value * ratio), 1)
text_color = "255,255,255" if ratio > 0.5 else "0,0,0"
colorbar += f'<span style="background-color:rgba({int(red_blue_ness-(red_blue_ness*ratio))},{int(red_blue_ness-(red_blue_ness*ratio))},255,1);color:rgb({text_color})">&nbsp{value}&nbsp</span>' # noqa: E501
return colorbar


def value_to_color(
activation,
max_value,
min_value,
white=255,
red_blue_ness=250,
positive_threshold=0.01,
negative_threshold=0.01,
):
if activation > positive_threshold:
ratio = activation / max_value
text_color = "0,0,0" if ratio <= 0.5 else "255,255,255"
background_color = f"rgba({int(red_blue_ness-(red_blue_ness*ratio))},{int(red_blue_ness-(red_blue_ness*ratio))},255,1)" # noqa: E501
elif activation < -negative_threshold:
ratio = activation / min_value
text_color = "0,0,0" if ratio <= 0.5 else "255,255,255"
background_color = f"rgba(255, {int(red_blue_ness-(red_blue_ness*ratio))},{int(red_blue_ness-(red_blue_ness*ratio))},1)" # noqa: E501
else:
text_color = "0,0,0"
background_color = f"rgba({white},{white},{white},1)"
return text_color, background_color


def convert_token_array_to_list(array):
if isinstance(array, torch.Tensor):
if array.dim() == 1:
array = [array.tolist()]
elif array.dim() == 2:
array = array.tolist()
else:
raise NotImplementedError("tokens must be 1 or 2 dimensional")
elif isinstance(array, list):
# ensure it's a list of lists
if isinstance(array[0], int):
array = [array]
if isinstance(array[0], torch.Tensor):
array = [t.tolist() for t in array]
return array


def escape(t):
t = (
t.replace(" ", "&nbsp;")
.replace("<bos>", "BOS")
.replace("<", "&lt;")
.replace(">", "&gt;")
)
return t