In [None]:
from transformers import AutoTokenizer
import torch

import html
from IPython.display import display, HTML

In [None]:
qe, seqs, activation_details = torch.load("quantile_examples/crosscoder/examples.pt",weights_only=False)

print(type(qe))   # dict: quantile_idx → feature_id → examples
print(len(seqs))  # list of all token sequences
print(len(activation_details))

In [None]:
tokenizer = AutoTokenizer.from_pretrained("Qwen3-1.7B")

In [None]:
def visualize_activations(feature_idx, sequence_idx):
    """Highlight activations in the detokenized sentence, aligned to token positions."""

    token_ids = seqs[sequence_idx]

    if sequence_idx not in activation_details[feature_idx]:
        print(f"No activations for feature {feature_idx} in sequence {sequence_idx}")
        return

    positions, values = activation_details[feature_idx][sequence_idx]


    if len(values) > 0:
        max_act_in_seq = float(values.max())
    else:
        max_act_in_seq = 1.0

    act_map = {int(pos): float(val) for pos, val in zip(positions, values)}

    # --- Decode once for clean text ---
    decoded_text = tokenizer.decode(token_ids, clean_up_tokenization_spaces=False)

    # --- Re-tokenize decoded text to get offsets ---
    encoding = tokenizer(
        decoded_text,
        return_offsets_mapping=True,
        add_special_tokens=False
    )

    html_output = ""
    last_end = 0
    offsets = encoding["offset_mapping"]

    for i, (start, end) in enumerate(offsets):
        # plain text between tokens
        html_output += html.escape(decoded_text[last_end:start])

        token_text = decoded_text[start:end]
        escaped_token = html.escape(token_text)

        if i in act_map:
            intensity = act_map[i] / max_act_in_seq
            color = f"rgba(255, 0, 0, {intensity:.2f})"
            html_output += (
                f'<span style="background-color: {color}" '
                f'title="Activation: {act_map[i]:.2f}">{escaped_token}</span>'
            )
        else:
            html_output += escaped_token

        last_end = end

    # trailing text
    html_output += html.escape(decoded_text[last_end:])

    display(
        HTML(
            f"<div style='font-family: monospace; white-space: pre-wrap;'>{html_output}</div>"
        )
    )

In [None]:
# Decode and display the top examples for your chosen feature ---

feature_idx_to_analyze = 2
quantile_to_analyze = 3 # 4 is the strongest, 0 is the weakest

# Get the list of examples for this feature and quantile
examples_bin = qe[quantile_to_analyze][feature_idx_to_analyze]

# Sort the examples by activation strength (descending) to see the best ones first
examples_bin.sort(key=lambda x: x[0], reverse=True)

In [None]:
act, top_sequence_idx = examples_bin[8]
print(f"Visualizing Feature {feature_idx_to_analyze} on Sequence {top_sequence_idx}")
visualize_activations(feature_idx_to_analyze, top_sequence_idx)

# Some plotting

In [None]:
from datasets import load_from_disk

In [None]:
true_val_data = load_from_disk("MATS_true_processed/")
false_val_data = load_from_disk("MATS_false_processed/")
true_set = set(true_val_data['test']['text'])
false_set = set(false_val_data['test']['text'])

In [None]:
import os
from tqdm import tqdm

# Map of true feature indices
true_feature_idxs = [
    1774, 3609, 6425, 15943, 18232, 24833, 30340, 
    36090, 43592, 44647, 51653, 51802, 53929, 56428, 
    57787, 58237
]


# Iterate over features and quantiles
feats_false_counts = {}
for feat_idx, true_idx in enumerate(true_feature_idxs):
    counter = 0
    total = 0
    for quantile in range(1,5):

        examples_bin = qe[quantile][feat_idx]
        print(f"Feature {true_idx}, quantile {quantile}: {len(examples_bin)} examples")

        # Use tqdm for progress bar
        for _, seq_idx in tqdm(examples_bin, desc=f"Feature {true_idx}, Q{quantile}"):
            if tokenizer.decode(seqs[seq_idx]) in false_set:
                counter+=1
            total+=1
    feats_false_counts[true_idx] = (counter,total)
                

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Your data
data = feats_false_counts

latent_idx = list(data.keys())
false_docs = [v[0] for v in data.values()]
total_docs = [v[1] for v in data.values()]
ratios = [f / t * 100 for f, t in data.values()]  # percentage

x = np.arange(len(latent_idx))
width = 0.65

# Choose one clean, modern color
bar_color = "#4C9F70"  # a soft teal green-blue

plt.figure(figsize=(14, 7))
bars = plt.bar(x, ratios, width, color=bar_color, edgecolor="black", alpha=0.85)

# Annotate each bar with total samples
for bar, total in zip(bars, total_docs):
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2, height + 2, 
             f"n={total}", ha="center", va="bottom", fontsize=9, fontweight="bold")

# Formatting
plt.xticks(x, latent_idx, rotation=45, ha="right", fontsize=10)
plt.yticks(np.arange(0, 110, 10), fontsize=10)
plt.ylabel("Percentage of Activations on False Docs (%)", fontsize=12)
plt.xlabel("Latent ID", fontsize=12)
plt.title("Latent Activations on False Documents (%)", fontsize=14, fontweight="bold")

plt.ylim(0, 110)
plt.grid(axis="y", linestyle="--", alpha=0.5)

plt.tight_layout()
plt.savefig("latent_activations_falsepercentage.png", dpi=300, bbox_inches="tight")
plt.show()