## Load the Target model

In [30]:
CACHE_DIR = "/mnt/raid10/ak-research-01/ak-research-01/codes/.cache"

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from contextlib import contextmanager
import random



# 1) Load Llama-3.1-8B (target/explainer model)
model_name = "meta-llama/Llama-3.1-8B"

tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=CACHE_DIR)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    cache_dir=CACHE_DIR,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

model.eval()
device = model.device
dtype = next(model.parameters()).dtype

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [31]:
# 2) Ensure [s] and [e] are special tokens
special_tokens = {"additional_special_tokens": ["[s]", "[e]"]}

need_add = any(tok not in tokenizer.get_vocab()
               for tok in special_tokens["additional_special_tokens"])
if need_add:
    tokenizer.add_special_tokens(special_tokens)
    model.resize_token_embeddings(len(tokenizer))

ID_S = tokenizer.convert_tokens_to_ids("[s]")
ID_E = tokenizer.convert_tokens_to_ids("[e]")

# 3) Paper templates with [s]v[e]
FEATURE_TEMPLATES = [
    "At layer {layer}, [s]v[e] encodes",
    "[s]v[e] activates at layer {layer} for",
    "We can describe [s]v[e] at layer {layer} as encoding",
    "Generate a description of this feature at layer {layer}: [s]v[e].",
    "What does [s]v[e] mean at layer {layer}?",
    "[s]v[e] activates at layer {layer} for inputs with the following features:",
]

In [32]:
def make_feature_prompt(layer: int, template_id: int | None = None) -> str:
    if template_id is None:
        template_id = random.randrange(len(FEATURE_TEMPLATES))
    return FEATURE_TEMPLATES[template_id].format(layer=layer) + " "

def normalize_feature(v: torch.Tensor) -> torch.Tensor:
    v = v.to(device=device, dtype=dtype)
    return v / (v.norm() + 1e-8)


@contextmanager
def patch_at_v_between_s_e(model, input_ids: torch.Tensor, v: torch.Tensor):
    """
    Find a span [s] v [e] in the tokenized *prompt* and replace the embedding of
    the token 'v' (the one between [s] and [e]) with the continuous vector v.
    This patch is applied only on the first forward pass (full prompt),
    and skipped on later generation steps (seq_len = 1).
    """
    emb = model.get_input_embeddings()
    v = v.to(device=emb.weight.device, dtype=emb.weight.dtype)

    ids = input_ids[0]  # [seq_len] for the *prompt*

    # locate [s]
    s_positions = (ids == ID_S).nonzero(as_tuple=True)[0]
    if len(s_positions) == 0:
        raise ValueError("Prompt does not contain [s] token.")
    s_idx = int(s_positions[0].item())

    # locate [e] AFTER [s]
    e_positions = (ids == ID_E).nonzero(as_tuple=True)[0]
    e_positions = e_positions[e_positions > s_idx]
    if len(e_positions) == 0:
        raise ValueError("Prompt does not contain [e] token after [s].")
    e_idx = int(e_positions[0].item())

    # tokens strictly between [s] and [e] → expect just 'v'
    mid_positions = list(range(s_idx + 1, e_idx))
    if len(mid_positions) != 1:
        raise ValueError(
            f"Expected exactly one token between [s] and [e], got {len(mid_positions)}."
        )
    v_idx = mid_positions[0]

    patched_once = {"done": False}  # mutable flag closed over by hook

    def hook(module, inputs, output):
        out = output.clone()  # [batch, seq_len, d_model]
        seq_len = out.size(1)

        # Only patch on first call with full prompt where v_idx is in range
        if (not patched_once["done"]) and (seq_len > v_idx):
            out[:, v_idx, :] = v
            patched_once["done"] = True

        return out

    handle = emb.register_forward_hook(hook)
    try:
        yield
    finally:
        handle.remove()


def explain_feature_with_llama(v: torch.Tensor,
                               layer: int,
                               template_id: int | None = None,
                               max_new_tokens: int = 96) -> str:
    """
    v:      SAE feature direction in residual space, shape [hidden_size] (4096).
    layer:  layer index ℓ for the textual prompt.
    """
    prompt = make_feature_prompt(layer, template_id)
    enc = tokenizer(prompt, return_tensors="pt")
    input_ids = enc["input_ids"].to(device)
    attention_mask = enc["attention_mask"].to(device)

    v_norm = normalize_feature(v)

    with patch_at_v_between_s_e(model, input_ids, v_norm):
        out = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id,
        )

    return tokenizer.decode(out[0], skip_special_tokens=True)


In [33]:
from feature_description_dataset import *
from torch.utils.data import Dataset

import os
CACHE_DIR = "/mnt/raid10/ak-research-01/ak-research-01/codes/.cache"

PROCESSED_ROOT = (
    "/mnt/raid10/ak-research-01/ak-research-01/codes/steer-vector/"
    "latentqa/1_feature_description/dataset/processed_dataset"
)

# Layers L01–L13, skipping L03
layers = list(range(1, 14))
if 3 in layers:
    layers.remove(3)

TRAIN_JSON_LIST = [
    os.path.join(PROCESSED_ROOT, f"L{layer:02d}", "train.jsonl")
    for layer in layers
]

TEST_JSON_LIST = [
    os.path.join(PROCESSED_ROOT, f"L{layer:02d}", "test.jsonl")
    for layer in layers
]

MAX_LENGTH = 128   # you can increase if explanations are long

train_dataset = FeatureExplanationDataset(TRAIN_JSON_LIST, tokenizer, MAX_LENGTH)
eval_dataset = FeatureExplanationDataset(TEST_JSON_LIST, tokenizer, MAX_LENGTH)

In [20]:
# fake_v.shape

## Test the pretrained model

In [34]:
explanation = explain_feature_with_llama(torch.tensor(train_dataset.data[0]["vector"]), train_dataset.data[0]['layer'])
print("Prompted explanation:\n", explanation)

print("\nTrue explanation:\n", train_dataset.data[0]['description'])

Prompted explanation:
 We can describe v at layer 1 as encoding 0x0x0x0x0x0x0x0x0x0x0x0x0x0x0x0x0x0x0x0x0x0x0x0x0x0x0x0x0x0x0x0x0x0x0x0x0x0x0x0x0x0x0x0x0x0x0x0x

True explanation:
 modal verbs expressing ability or possibility


## Test the trained model

In [18]:
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

DATA_DIR = "/mnt/raid10/ak-research-01/ak-research-01/codes/steer-vector/latentqa/1_feature_description/model"
CKPT_DIR = os.path.join(DATA_DIR, "explainer_manual_ckpt")

# 1) Load tokenizer and model from the checkpoint directory
tokenizer = AutoTokenizer.from_pretrained(CKPT_DIR)
model = AutoModelForCausalLM.from_pretrained(
    CKPT_DIR,
    torch_dtype=torch.bfloat16,     # match what you used in training
    device_map="auto",              # or device="cuda"
)

# 2) Ensure pad token is set (needed for padding and generation)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id
    model.config.pad_token_id = tokenizer.eos_token_id

model.eval()   # for inference
device = model.device
print("Loaded explainer from:", CKPT_DIR)
print("Device:", device)


The tokenizer you are loading from '/mnt/raid10/ak-research-01/ak-research-01/codes/steer-vector/latentqa/1_feature_description/model/explainer_manual_ckpt' with an incorrect regex pattern: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/discussions/84#69121093e8b480e709447d5e. This will lead to incorrect tokenization. You should set the `fix_mistral_regex=True` flag when loading this tokenizer to fix this issue.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Loaded explainer from: /mnt/raid10/ak-research-01/ak-research-01/codes/steer-vector/latentqa/1_feature_description/model/explainer_manual_ckpt
Device: cuda:0


In [29]:
idx = 72

explanation = explain_feature_with_llama(torch.tensor(eval_dataset.data[idx]["vector"]), eval_dataset.data[idx]['layer'])
print("Prompted explanation:\n", explanation) 

print("\nTrue explanation:\n", eval_dataset.data[idx]['description'])

Prompted explanation:
 At layer 2, v encodes the name "Richard" and variations of it

True explanation:
 references to specific individuals, particularly those named "Chadwick."
