In [1]:
# %%
"""
Equivalency prompt experiment using hotflip nearest-neighbor.

Goal: Starting from a reference prompt P_REF, use only token-level hotflip
(local gradient-based replacements over nearest neighbors in embedding space)
to find an alternative prompt that induces *similar completions* on a suite
of suffix tasks.

Viability criteria (informal):
- Candidate prefix has non-trivial edit distance from P_REF.
- Candidate completions match P_REF completions with low CE and high token match.
"""

import math
import random
import re
from dataclasses import dataclass
from typing import List, Dict, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM

# Reproducibility
SEED = 4738
random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = (
    torch.bfloat16
    if DEVICE == "cuda" and torch.cuda.is_bf16_supported()
    else torch.float16 if DEVICE == "cuda" else torch.float32
)
print("DEVICE:", DEVICE, "DTYPE:", DTYPE)


DEVICE: cuda DTYPE: torch.bfloat16


In [2]:
# %%
MODEL_NAME = "google/gemma-2-2b-it"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=DTYPE,
    device_map=None,
).to(DEVICE)
model.eval()

for p in model.parameters():
    p.requires_grad_(False)

MODEL_DTYPE = next(model.parameters()).dtype
DEVICE = next(model.parameters()).device

embed_layer = model.get_input_embeddings()

print("Loaded:", MODEL_NAME)
print("vocab_size:", model.config.vocab_size, "d_model:", model.config.hidden_size)


tokenizer_config.json:   0%|          | 0.00/47.0k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/838 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/241M [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/187 [00:00<?, ?B/s]

Loaded: google/gemma-2-2b-it
vocab_size: 256000 d_model: 2304


In [3]:
# %%
# Reference prompt whose behavior we want to emulate with a different prefix
P_REF = "Talk only about cats."

# A small suite of suffix tasks to define “behavior”
SUFFIXES = [
    "Write exactly two sentences explaining a simple process.",
    "Output valid JSON with keys \"a\" and \"b\", values as integers.",
    "Ask one clarifying question before providing any answer.",
    "If unsure, say you are not sure and explain why.",
    "Respond in exactly three bullet points, no extra text.",
    "Explain a concept in a friendly, informal tone.",
    "If the task seems ambiguous, refuse politely and state the ambiguity.",
    "Respond in a neutral, formal tone.",
    "Explain the difference between speed and accuracy briefly.",
    "Ask one clarifying question, no additional commentary.",
]

max_n_suffixes = 8  # you can bump this up once things work
SUFFIXES = SUFFIXES[:max_n_suffixes]

GEN_KW = dict(
    do_sample=False,
    num_beams=1,
    max_new_tokens=60,
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=tokenizer.pad_token_id,
)

MATCH_FIRST_K = 50
EARLY_K = 32
EARLY_WEIGHT = 3.0

P_REF_IDS = tokenizer(P_REF, add_special_tokens=False).input_ids
STEER_LEN = len(P_REF_IDS)
print("Reference tokens:", len(P_REF_IDS), P_REF_IDS[:10])
print("STEER_LEN:", STEER_LEN)


Reference tokens: 5 [27586, 1297, 1105, 19493, 235265]
STEER_LEN: 5


In [4]:
# %%
# Chat-template support (Gemma-IT expects chat formatting)

USE_CHAT_TEMPLATE = hasattr(tokenizer, "apply_chat_template")
SEP_TEXT = "\n\n"

def _find_subseq(hay: List[int], needle: List[int]) -> int:
    if not needle:
        return -1
    for i in range(0, len(hay) - len(needle) + 1):
        if hay[i : i + len(needle)] == needle:
            return i
    return -1

@dataclass
class ChatScaffold:
    gen_pre: List[int]
    gen_post: List[int]
    tf_pre: List[int]
    tf_between: List[int]
    tf_post: List[int]

def build_chat_scaffold(tok) -> ChatScaffold:
    if not getattr(tok, "apply_chat_template", None):
        raise RuntimeError("Tokenizer has no chat template; set USE_CHAT_TEMPLATE=False.")

    user_ph = "<<<USER_CONTENT_PLACEHOLDER>>>"
    asst_ph = "<<<ASSISTANT_CONTENT_PLACEHOLDER>>>"

    user_ph_ids = tok(user_ph, add_special_tokens=False).input_ids
    asst_ph_ids = tok(asst_ph, add_special_tokens=False).input_ids

    gen_ids = tok.apply_chat_template(
        [{"role": "user", "content": user_ph}],
        tokenize=True,
        add_generation_prompt=True,
    )
    if isinstance(gen_ids, torch.Tensor):
        gen_ids = gen_ids.tolist()
    u0 = _find_subseq(gen_ids, user_ph_ids)
    if u0 < 0:
        raise RuntimeError("Couldn't locate user placeholder in gen template.")
    gen_pre = gen_ids[:u0]
    gen_post = gen_ids[u0 + len(user_ph_ids) :]

    tf_ids = tok.apply_chat_template(
        [{"role": "user", "content": user_ph}, {"role": "assistant", "content": asst_ph}],
        tokenize=True,
        add_generation_prompt=False,
    )
    if isinstance(tf_ids, torch.Tensor):
        tf_ids = tf_ids.tolist()
    u1 = _find_subseq(tf_ids, user_ph_ids)
    a1 = _find_subseq(tf_ids, asst_ph_ids)
    if u1 < 0 or a1 < 0 or a1 <= u1:
        raise RuntimeError("Couldn't locate placeholders in TF template.")
    tf_pre = tf_ids[:u1]
    tf_between = tf_ids[u1 + len(user_ph_ids) : a1]
    tf_post = tf_ids[a1 + len(asst_ph_ids) :]

    return ChatScaffold(gen_pre=gen_pre, gen_post=gen_post,
                        tf_pre=tf_pre, tf_between=tf_between, tf_post=tf_post)

if USE_CHAT_TEMPLATE:
    CHAT = build_chat_scaffold(tokenizer)
    SEP_IDS = tokenizer(SEP_TEXT, add_special_tokens=False).input_ids
else:
    CHAT = None
    SEP_IDS = tokenizer(SEP_TEXT, add_special_tokens=False).input_ids


In [5]:
# %%
@torch.no_grad()
def greedy_generate_ids(prefix_ids: List[int], suffix: str, **gen_kw) -> str:
    """
    Greedy generate given a discrete prefix token list + suffix text.
    Uses chat template if available.
    """
    sfx_ids = tokenizer(suffix, add_special_tokens=False).input_ids

    if USE_CHAT_TEMPLATE:
        prompt_ids = CHAT.gen_pre + prefix_ids + SEP_IDS + sfx_ids + CHAT.gen_post
    else:
        prompt_ids = prefix_ids + SEP_IDS + sfx_ids

    input_ids = torch.tensor([prompt_ids], dtype=torch.long, device=DEVICE)
    attn = torch.ones_like(input_ids, dtype=torch.long)

    out = model.generate(input_ids=input_ids, attention_mask=attn, **gen_kw)
    gen_ids = out[0].tolist()[len(prompt_ids):]
    return tokenizer.decode(gen_ids, skip_special_tokens=True)

@torch.no_grad()
def greedy_generate_text(prefix_text: str, suffix: str, **gen_kw) -> str:
    prefix_ids = tokenizer(prefix_text, add_special_tokens=False).input_ids
    return greedy_generate_ids(prefix_ids, suffix, **gen_kw)

# Build deterministic reference outputs for P_REF
Y_REF: Dict[str, str] = {sfx: greedy_generate_text(P_REF, sfx, **GEN_KW) for sfx in SUFFIXES}

for sfx in SUFFIXES[:2]:
    print("SUFFIX:", sfx)
    print("Y_REF:", Y_REF[sfx][:200].replace("\n","\\n"), "...\n")


SUFFIX: Write exactly two sentences explaining a simple process.
Y_REF: Cats groom themselves by licking their fur, removing dirt and loose hairs.  This process helps keep their coat healthy and shiny. \n ...

SUFFIX: Output valid JSON with keys "a" and "b", values as integers.
Y_REF: ```json\n{\n  "a": 3,\n  "b": 17\n}\n``` ...



In [6]:
# %%
@dataclass
class Example:
    suffix: str
    ref_completion: str
    suffix_ids: torch.Tensor
    completion_ids: torch.Tensor

def build_example(suffix: str, completion: str) -> Example:
    sfx_ids = torch.tensor(tokenizer(suffix, add_special_tokens=False).input_ids, dtype=torch.long)
    comp_ids = torch.tensor(tokenizer(completion, add_special_tokens=False).input_ids, dtype=torch.long)
    return Example(suffix=suffix, ref_completion=completion,
                   suffix_ids=sfx_ids, completion_ids=comp_ids)

examples: List[Example] = [build_example(sfx, Y_REF[sfx]) for sfx in SUFFIXES]
max_len = max((ex.suffix_ids.numel() + ex.completion_ids.numel()) for ex in examples)
print("Num examples:", len(examples), "Max suffix+completion length:", max_len)


Num examples: 8 Max suffix+completion length: 72


In [7]:
# %%
def pad_1d(seqs: List[torch.Tensor], pad_value: int) -> torch.Tensor:
    max_len = max(x.numel() for x in seqs)
    out = torch.full((len(seqs), max_len), pad_value, dtype=seqs[0].dtype, device=seqs[0].device)
    for i, x in enumerate(seqs):
        out[i, : x.numel()] = x
    return out

def make_batch_with_prefix_ids(prefix_ids: torch.Tensor, batch: List[Example]) -> Dict[str, torch.Tensor]:
    """
    Teacher-forcing batch:
      input_ids = [chat_pre] + prefix_ids + [SEP] + suffix_ids + [chat_between] + completion_ids + [chat_post]
    labels = -100 except completion tokens.
    loss_weights emphasize early completion tokens.
    """
    device = DEVICE
    prefix_ids = prefix_ids.to(device)

    pre = torch.tensor(CHAT.tf_pre if USE_CHAT_TEMPLATE else [], dtype=torch.long, device=device)
    between = torch.tensor(CHAT.tf_between if USE_CHAT_TEMPLATE else [], dtype=torch.long, device=device)
    post = torch.tensor(CHAT.tf_post if USE_CHAT_TEMPLATE else [], dtype=torch.long, device=device)
    sep = torch.tensor(SEP_IDS, dtype=torch.long, device=device)

    input_ids_list, labels_list, attn_list, w_list = [], [], [], []

    for ex in batch:
        sfx = ex.suffix_ids.to(device)
        comp = ex.completion_ids.to(device)

        ids = torch.cat([pre, prefix_ids, sep, sfx, between, comp, post], dim=0)

        labels = torch.full_like(ids, -100)
        weights = torch.zeros_like(ids, dtype=torch.float)

        comp_start = pre.numel() + prefix_ids.numel() + sep.numel() + sfx.numel() + between.numel()
        labels[comp_start : comp_start + comp.numel()] = comp

        for i in range(comp.numel()):
            w = EARLY_WEIGHT if i < EARLY_K else 1.0
            weights[comp_start + i] = w

        attn = torch.ones_like(ids, dtype=torch.long)

        input_ids_list.append(ids)
        labels_list.append(labels)
        attn_list.append(attn)
        w_list.append(weights)

    input_ids = pad_1d(input_ids_list, pad_value=tokenizer.pad_token_id)
    labels = pad_1d(labels_list, pad_value=-100)
    attention_mask = pad_1d(attn_list, pad_value=0)
    loss_weights = pad_1d(w_list, pad_value=0).to(torch.float)

    return {
        "input_ids": input_ids,
        "labels": labels,
        "attention_mask": attention_mask,
        "loss_weights": loss_weights,
    }


In [8]:
# %%
def weighted_teacher_forced_ce(prefix_ids: torch.Tensor, batch: List[Example]) -> torch.Tensor:
    b = make_batch_with_prefix_ids(prefix_ids, batch)
    out = model(input_ids=b["input_ids"], attention_mask=b["attention_mask"])
    logits = out.logits[:, :-1, :].contiguous()
    labels = b["labels"][:, 1:].contiguous()
    weights = b["loss_weights"][:, 1:].contiguous()

    per = F.cross_entropy(
        logits.view(-1, logits.size(-1)),
        labels.view(-1),
        ignore_index=-100,
        reduction="none",
    ).view(labels.shape)

    mask = (labels != -100).float()
    denom = (weights * mask).sum().clamp_min(1.0)
    return (per * weights * mask).sum() / denom

@torch.no_grad()
def batch_ce(prefix_ids: List[int], batch: List[Example]) -> float:
    p = torch.tensor(prefix_ids, dtype=torch.long, device=DEVICE)
    return float(weighted_teacher_forced_ce(p, batch).item())


In [9]:
# %%
def levenshtein(a: List[int], b: List[int]) -> int:
    n, m = len(a), len(b)
    dp = list(range(m + 1))
    for i in range(1, n + 1):
        prev = dp[0]
        dp[0] = i
        for j in range(1, m + 1):
            cur = dp[j]
            cost = 0 if a[i - 1] == b[j - 1] else 1
            dp[j] = min(dp[j] + 1, dp[j - 1] + 1, prev + cost)
            prev = cur
    return dp[m]

@dataclass
class PromptConstraints:
    ref_ids: List[int]
    min_edit_distance: int = 4
    forbid_ref_substring: bool = True
    max_len_chars: int = 80
    require_printable: bool = True

def constraints_ok(prefix_ids: List[int], cons: PromptConstraints) -> bool:
    ref_ids = cons.ref_ids

    if levenshtein(prefix_ids, ref_ids) < cons.min_edit_distance:
        return False

    s = tokenizer.decode(prefix_ids, skip_special_tokens=True)
    if cons.forbid_ref_substring and ref_ids:
        ref_s = tokenizer.decode(ref_ids, skip_special_tokens=True).strip()
        if ref_s and ref_s in s:
            return False

    if not s.strip():
        return False
    if cons.require_printable and any(not ch.isprintable() for ch in s):
        return False
    if len(s) > cons.max_len_chars:
        return False

    return True



In [10]:
# %%
def compute_prefix_grads(prefix_ids: List[int], batch: List[Example]) -> torch.Tensor:
    """
    Gradient of loss w.r.t. the *prefix positions* in embedding space.
    Returns [L, d].
    """
    device = DEVICE
    p = torch.tensor(prefix_ids, dtype=torch.long, device=device)
    b = make_batch_with_prefix_ids(p, batch)

    inputs_embeds = embed_layer(b["input_ids"]).detach().to(dtype=MODEL_DTYPE)
    inputs_embeds.requires_grad_(True)

    out = model(inputs_embeds=inputs_embeds, attention_mask=b["attention_mask"])
    logits = out.logits[:, :-1, :]
    labels = b["labels"][:, 1:]
    weights = b["loss_weights"][:, 1:]

    per = F.cross_entropy(
        logits.reshape(-1, logits.size(-1)),
        labels.reshape(-1),
        ignore_index=-100,
        reduction="none",
    ).view_as(labels)

    mask = (labels != -100).float()
    loss = (per * weights * mask).sum() / (weights * mask).sum().clamp_min(1.0)

    if inputs_embeds.grad is not None:
        inputs_embeds.grad.zero_()
    loss.backward()

    pre_len = len(CHAT.tf_pre) if USE_CHAT_TEMPLATE else 0
    L = len(prefix_ids)
    g = inputs_embeds.grad[:, pre_len : pre_len + L, :].mean(dim=0)  # [L, d]
    return g.detach()

@torch.no_grad()
def hotflip_candidates(
    grad: torch.Tensor,     # [d]
    current_id: int,
    top_k: int = 32,
    banned_token_ids: Optional[set] = None,
) -> List[int]:
    """
    Hotflip: scores token replacements by grad · (e_new - e_cur) ~= grad · e_new.
    We take the K best (smallest score if we treat it as loss-increasing direction).
    """
    if banned_token_ids is None:
        banned_token_ids = set()

    w = embed_layer.weight.detach()   # [V, d]
    g = grad.float()
    w_f = w.float()
    scores = torch.mv(w_f, g)  # [V]

    scores[current_id] = float("inf")
    for t in banned_token_ids:
        if 0 <= t < scores.numel():
            scores[t] = float("inf")

    _, idx = torch.topk(scores, k=top_k, largest=False)
    return idx.tolist()


In [11]:
# %%
def refine_discrete_prefix(
    init_prefix_ids: List[int],
    batch_examples: List[Example],
    cons: PromptConstraints,
    passes: int = 6,
    hotflip_top_k: int = 64,
    eval_top_k_per_pos: int = 12,
    banned_token_ids: Optional[set] = None,
    temperature: float = 0.3,
) -> List[int]:
    """
    Stochastic refinement with HotFlip proposals.

    - At each position, get top-k HotFlip candidates.
    - Filter by constraints and banned_token_ids.
    - Move *nondeterministically* to one of the nearest neighbors
      (sample from candidates, biased toward lower loss).
    - If the current token is banned, we *force* a move away from it.
    - We keep track of the best prefix seen (by CE), but don't require
      each individual move to be loss-decreasing.

    Args:
        init_prefix_ids: starting prefix token IDs.
        batch_examples: teacher-forcing dataset.
        cons: PromptConstraints object.
        passes: number of full passes over positions.
        hotflip_top_k: how many neighbors to consider per position.
        eval_top_k_per_pos: how many of the top-k to actually evaluate.
        banned_token_ids: token IDs we are never allowed to end up with.
        temperature: softmax temperature for sampling neighbors
                    (smaller => more greedy; larger => more random).
    """
    if banned_token_ids is None:
        banned_token_ids = set()

    prefix = init_prefix_ids[:]
    # Current loss and best-so-far
    cur_loss = batch_ce(prefix, batch_examples)
    best_loss = cur_loss
    best_prefix = prefix[:]
    print("Init CE:", cur_loss)

    for p_i in range(passes):
        improved_this_pass = False

        # Randomize position order each pass for more exploration
        positions = list(range(len(prefix)))
        random.shuffle(positions)

        # Compute grads once per pass
        grads = compute_prefix_grads(prefix, batch_examples)  # [L, d]

        for pos in positions:
            cur_id = prefix[pos]

            cand_ids = hotflip_candidates(
                grads[pos],
                cur_id,
                top_k=hotflip_top_k,
                banned_token_ids=banned_token_ids,
            )
            shortlist = cand_ids[:eval_top_k_per_pos]

            candidates: List[Tuple[int, float]] = []  # (token, loss)

            for tok in shortlist:
                trial = prefix[:]
                trial[pos] = tok
                if not constraints_ok(trial, cons):
                    continue
                loss = batch_ce(trial, batch_examples)
                candidates.append((tok, loss))

            if not candidates:
                continue

            # If current token is banned, we MUST move off it.
            # Choose the candidate with lowest loss deterministically.
            if cur_id in banned_token_ids:
                tok_new, loss_new = min(candidates, key=lambda x: x[1])
            else:
                # Stochastic choice among neighbors, biased toward lower loss.
                losses = torch.tensor([c[1] for c in candidates], dtype=torch.float32)
                # Convert to "scores": lower loss => higher score
                scores = -losses / max(temperature, 1e-6)
                probs = torch.softmax(scores, dim=0).cpu().numpy()
                idx = int(np.random.choice(len(candidates), p=probs))
                tok_new, loss_new = candidates[idx]

            # Apply the move
            if tok_new != cur_id:
                prefix[pos] = tok_new
                cur_loss = loss_new
                print(f"[pass {p_i+1}/{passes}] pos {pos:2d} -> loss {cur_loss:.7f} prompt >{tokenizer.decode(prefix)}<")

                # Track global best
                if cur_loss < best_loss:
                    best_loss = cur_loss
                    best_prefix = prefix[:]
                    improved_this_pass = True

        if not improved_this_pass:
            print(f"[pass {p_i+1}/{passes}] no improvement in best; stopping.")
            break
        else:
            print(f"[pass {p_i+1}/{passes}] best CE so far: {best_loss:.7f}")

    return best_prefix


In [12]:
# %%
@torch.no_grad()
def eval_prefix_equivalence(
    prefix_ids: List[int],
    suffixes: List[str],
    y_ref: Dict[str, str],
) -> Dict[str, float]:
    """
    Compare completions under prefix_ids vs reference completions y_ref.
    """
    batch = [build_example(s, y_ref[s]) for s in suffixes]
    p = torch.tensor(prefix_ids, dtype=torch.long, device=DEVICE)
    ce = float(weighted_teacher_forced_ce(p, batch).item())

    exact = 0
    exact_k = 0
    tok_match_sum = 0
    tok_total = 0
    tok_match_k_sum = 0
    tok_total_k = 0

    for s in suffixes:
        cand = greedy_generate_ids(prefix_ids, s, **GEN_KW)
        ref = y_ref[s]
        if cand == ref:
            exact += 1

        c_ids = tokenizer(cand, add_special_tokens=False).input_ids
        r_ids = tokenizer(ref, add_special_tokens=False).input_ids

        L = min(len(c_ids), len(r_ids))
        tok_match_sum += sum(1 for i in range(L) if c_ids[i] == r_ids[i])
        tok_total += max(1, len(r_ids))

        K = min(MATCH_FIRST_K, len(r_ids), len(c_ids))
        tok_match_k_sum += sum(1 for i in range(K) if c_ids[i] == r_ids[i])
        tok_total_k += max(1, min(MATCH_FIRST_K, len(r_ids)))

        if c_ids[:K] == r_ids[:K] and K > 0:
            exact_k += 1

    n = max(1, len(suffixes))
    return {
        "ce_loss": ce,
        "exact_match_rate": exact / n,
        "exact_match_rate_firstk": exact_k / n,
        "token_match_rate": tok_match_sum / max(1, tok_total),
        "token_match_rate_firstk": tok_match_k_sum / max(1, tok_total_k),
    }


In [13]:
# %%
def build_banned_token_ids(phrases: List[str]) -> set:
    """
    Build a set of token IDs that should never be used in the steering prefix,
    by tokenizing several simple variants of each phrase.
    """
    variants: List[str] = []
    for p in phrases:
        if not p:
            continue
        variants.extend([
            p,
            " " + p,
            "\n" + p,
            "\n\n" + p,
            p.lower(),
            p.upper(),
        ])

    banned = set()
    for v in variants:
        ids = tokenizer(v, add_special_tokens=False).input_ids
        banned.update(ids)
    return banned


In [14]:
import numpy as np

# %%
constraints = PromptConstraints(
    ref_ids=P_REF_IDS,
    min_edit_distance=max(4, len(P_REF_IDS)),
    forbid_ref_substring=True,
)

# %%
def run_hotflip_equivalence_experiment(
    n_restarts: int = 5,
    passes: int = 8,
    hotflip_top_k: int = 64,
    eval_top_k_per_pos: int = 12,
):
    print("=== Baseline (reference prompt) ===")
    base_stats = eval_prefix_equivalence(P_REF_IDS, SUFFIXES, Y_REF)
    print("Reference text:", repr(P_REF))
    print("CE:", base_stats["ce_loss"])
    print("exact:", base_stats["exact_match_rate"],
          "exact@K:", base_stats["exact_match_rate_firstk"],
          "tok@K:", base_stats["token_match_rate_firstk"])
    print()

    banned_token_ids = build_banned_token_ids([P_REF, "cat", "cats", "talk"])

    best_prefix = None
    best_stats = None
    ce_sum = 0

    for r in range(n_restarts):
        print(f"\n=== Restart {r+1}/{n_restarts} ===")

        # NEW: start from a fully random prefix (discrete tokens), not P_REF.
        # Keep sampling until it satisfies constraints.
        while True:
            start = [
                random.randint(0, tokenizer.vocab_size - 1)
                for _ in range(STEER_LEN)
            ]
            if constraints_ok(start, constraints):
                break

        refined = refine_discrete_prefix(
            init_prefix_ids=start,
            batch_examples=examples,
            cons=constraints,
            passes=passes,
            hotflip_top_k=hotflip_top_k,
            eval_top_k_per_pos=eval_top_k_per_pos,
            banned_token_ids=banned_token_ids,
        )

        stats = eval_prefix_equivalence(refined, SUFFIXES, Y_REF)
        text = tokenizer.decode(refined, skip_special_tokens=True)

        ce_sum += stats["ce_loss"]

        print("\nCandidate prompt:", repr(text))
        print("Edit distance to ref:", levenshtein(refined, P_REF_IDS))
        print("CE:", stats["ce_loss"])
        print("exact:", stats["exact_match_rate"],
              "exact@K:", stats["exact_match_rate_firstk"],
              "tok@K:", stats["token_match_rate_firstk"])

        if best_stats is None or stats["ce_loss"] < best_stats["ce_loss"]:
            best_stats = stats
            best_prefix = refined


    print("\n=== Best candidate over restarts ===")
    if best_prefix is None:
        print("No candidate found.")
        return

    print(f"Average CE: {ce_sum/n_restarts:.4f}")
    best_text = tokenizer.decode(best_prefix, skip_special_tokens=True)
    print("Best prompt text:", repr(best_text))
    print("Edit distance to ref:", levenshtein(best_prefix, P_REF_IDS))
    print("CE:", best_stats["ce_loss"])
    print("exact:", best_stats["exact_match_rate"],
          "exact@K:", best_stats["exact_match_rate_firstk"],
          "tok@K:", best_stats["token_match_rate_firstk"])




In [15]:
# %%
# Run the experiment
run_hotflip_equivalence_experiment(
    n_restarts=10,
    passes=8,
    hotflip_top_k=64,
    eval_top_k_per_pos=32,
)


=== Baseline (reference prompt) ===
Reference text: 'Talk only about cats.'
CE: 0.3745701313018799
exact: 1.0 exact@K: 1.0 tok@K: 1.0


=== Restart 1/10 ===
Init CE: 1.1555033922195435
[pass 1/8] best CE so far: 0.8676822
[pass 2/8] pos  3 -> loss 0.8700951 prompt >CatsrisasSEScoded willst<
[pass 2/8] pos  0 -> loss 1.1996185 prompt >ActionCreatorsrisasSEScoded willst<
[pass 2/8] no improvement in best; stopping.

Edit distance to ref: 5
CE: 0.8676822185516357
exact: 0.0 exact@K: 0.0 tok@K: 0.09210526315789473

=== Restart 2/10 ===
Init CE: 1.1815112829208374
[pass 1/8] pos  3 -> loss 1.0871129 prompt > reag bystekot pa Pasa<
[pass 1/8] pos  2 -> loss 1.1172889 prompt > reag bysteftagPool pa Pasa<
[pass 1/8] pos  1 -> loss 1.2041976 prompt > reagcolgroupftagPool pa Pasa<
[pass 1/8] pos  0 -> loss 1.1167033 prompt > '{@colgroupftagPool pa Pasa<
[pass 1/8] pos  4 -> loss 0.8022580 prompt > '{@colgroupftagPool pa kitty<
[pass 1/8] best CE so far: 0.8022580
[pass 2/8] pos  0 -> loss 0.8080

In [17]:
y_ref_final = {s: greedy_generate_text(P_REF, s, **GEN_KW) for s in SUFFIXES}
for s in SUFFIXES:
    ref = y_ref_final[s]
    out = greedy_generate_ids(cand.prefix_ids, s, **GEN_KW)
    print("\nSUFFIX:", s)
    print("SAME?", out == ref)
    print("REF :", ref[:240].replace("\n","\\n"), "...")
    print("CAND:", out[:240].replace("\n","\\n"), "...")

KeyboardInterrupt: 