# Gemma-2-2B-IT: Reference-State Matching Notebook

This notebook:

- Loads **gemma-2-2b-it**.
- Uses a **reference prompt** (default `"Talk about cats."`).
- Optimizes a continuous **soft prompt** so that its hidden states at some layer match those from the reference prompt.
- Projects the soft prompt back to tokens to get a *different* discrete prompt.
- Compares internal state similarity between the reference prompt, the continuous soft prompt, and the discrete alternative prompt.


In [1]:
import math
from typing import List, Tuple

import torch
from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig

In [2]:
def get_device() -> torch.device:
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

device = get_device()
print("Using device:", device)

In [3]:
def load_model_and_tokenizer(
    model_name: str = "google/gemma-2-2b-it",
) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
    """
    Load Gemma-2 2B IT and its tokenizer.
    You need to have accepted the license on Hugging Face.
    """
    print(f"Loading model '{model_name}'...")
    config = AutoConfig.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        config=config,
        torch_dtype=torch.float32,   # keep life simple, avoid bfloat issues
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    model.eval()
    return model, tokenizer


model, tokenizer = load_model_and_tokenizer("google/gemma-2-2b-it")
model.to(device)

In [6]:
banned_words = [
    "Talk",
    " Talk",
    "Talk ",

    "talk",
    " talk",
    "talk ",

    "TALK",
    " TALK",
    "TALK ",

    "about",
    " about",
    "about ",

    "About",
    " About",
    "About ",

    "ABOUT",
    " ABOUT",
    "ABOUT ",

    "Cats",
    " Cats",
    "Cats ",

    "cats",
    " cats",
    "cats ",

    "CATS",
    " CATS",
    "CATS ",

    "Talk.",
    " Talk.",
    "Talk .",

    "talk.",
    " talk.",
    "talk .",

    "TALK.",
    " TALK.",
    "TALK .",

    "about.",
    " about.",
    "about .",

    "About.",
    " About.",
    "About .",

    "ABOUT.",
    " ABOUT.",
    "ABOUT .",

    "Cats.",
    " Cats.",
    "Cats .",

    "cats.",
    " cats.",
    "cats .",

    "CATS.",
    " CATS.",
    "CATS .",
]


banned_words = "Talk about cats."

banned_weight = 1

In [7]:
class SoftPromptMatcher(nn.Module):
    """
    Learn a soft prompt whose hidden states at a given layer
    match those of a reference prompt as closely as possible,
    while discouraging similarity to some banned words.
    """

    def __init__(
        self,
        model: AutoModelForCausalLM,
        tokenizer: AutoTokenizer,
        d_model: int,
        prompt_length: int,
        layer_idx: int,
        H_ref: torch.Tensor,
        device: torch.device,
        banned_words=None,
        banned_weight: float = 0.0,
    ):
        super().__init__()
        self.model = model
        self.tokenizer = tokenizer
        self.layer_idx = layer_idx
        self.device = device

        # Reference states: [T_ref, d_model]
        self.register_buffer("H_ref", H_ref.to(device))

        self.prompt_length = prompt_length
        self.banned_weight = banned_weight

        # Initialize soft prompt with small random noise
        self.soft_prompt = nn.Parameter(
            torch.randn(prompt_length, d_model, device=device) * 0.02
        )

        # Precompute phrase-level embeddings for banned words (if any)
        token_embed = self.model.get_input_embeddings()
        phrase_vecs = []
        if banned_words:
            for w in banned_words:
                ids = tokenizer(w, add_special_tokens=False).input_ids
                if not ids:
                    continue
                # average subtoken embeddings for this word/phrase
                emb = token_embed.weight[ids].detach().to(device).mean(dim=0)
                phrase_vecs.append(emb)

        if phrase_vecs:
            banned_emb = torch.stack(phrase_vecs, dim=0)   # [n_banned, d]
            banned_emb = banned_emb / (
                banned_emb.norm(dim=-1, keepdim=True) + 1e-8
            )
            self.register_buffer("banned_emb", banned_emb)
        else:
            self.banned_emb = None

    def forward(self) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Run the model with the soft prompt and compute:
        - H_soft : [T, d] hidden states at the target layer
        - loss   : MSE between normalized H_soft and H_ref (tokenwise)
                   + banned-word penalty
        """
        # [1, T, d_model]
        inputs_embeds = self.soft_prompt.unsqueeze(0)

        outputs = self.model(
            inputs_embeds=inputs_embeds,
            output_hidden_states=True,
        )
        hidden_states = outputs.hidden_states
        hs_layer = hidden_states[self.layer_idx]  # [1, T, d_model]
        H_soft = hs_layer[0]                      # [T, d_model]

        # If lengths differ for some reason, crop to min length
        T_soft = H_soft.shape[0]
        T_ref = self.H_ref.shape[0]
        T = min(T_soft, T_ref)

        H_soft_sel = H_soft[:T]       # [T, d]
        H_ref_sel = self.H_ref[:T]    # [T, d]

        # Normalize each token vector before comparing (cosine-ish)
        H_soft_norm = H_soft_sel / (H_soft_sel.norm(dim=-1, keepdim=True) + 1e-8)
        H_ref_norm = H_ref_sel / (H_ref_sel.norm(dim=-1, keepdim=True) + 1e-8)

        # Base loss: MSE between normalized states
        mse_loss = ((H_soft_norm - H_ref_norm) ** 2).mean()

        # Banned-word penalty: avoid average hidden state aligning with banned embeddings
        penalty = 0.0
        if self.banned_emb is not None and self.banned_weight > 0.0:
            # average over tokens
            h_mean = H_soft_norm.mean(dim=0)                 # [d]
            sim_to_banned = torch.matmul(self.banned_emb, h_mean)  # [n_banned]
            # penalize large |similarity|^2
            penalty = self.banned_weight * (sim_to_banned.pow(2).mean())

        loss = mse_loss + penalty

        return H_soft, loss

In [8]:
T_ref

In [9]:
prompt_length = T_ref  # try same length as reference prompt
matcher = SoftPromptMatcher(
    model=model,
    tokenizer=tokenizer,
    d_model=d_model,
    prompt_length=prompt_length,
    layer_idx=layer_idx,
    H_ref=H_ref,
    device=device,
    banned_words=banned_words,
    banned_weight=banned_weight,
).to(device)

steps = 1
lr = 1e-2

optimizer = torch.optim.Adam([matcher.soft_prompt], lr=lr)

history = []

for step in range(steps):
    optimizer.zero_grad()
    _, loss = matcher()
    loss.backward()
    optimizer.step()

    history.append(loss.item())

    if step % max(1, steps // 10) == 0 or step == steps - 1:
        print(f"[step {step:4d}] loss={loss.item():.6f}")

In [10]:
@torch.no_grad()
def project_soft_prompt_to_tokens(
    soft_prompt: torch.Tensor,
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    top_k: int = 1,
) -> List[int]:
    """
    For each soft embedding vector, pick the nearest token embedding(s) by cosine similarity.

    soft_prompt: [T, d]
    Returns: list of T token IDs.
    """
    token_embed = model.get_input_embeddings().weight  # [V, d]
    device = token_embed.device

    soft_prompt = soft_prompt.to(device)

    V, d = token_embed.shape
    T = soft_prompt.shape[0]

    # Normalize
    token_embed_norm = token_embed / (token_embed.norm(dim=-1, keepdim=True) + 1e-8)
    soft_norm = soft_prompt / (soft_prompt.norm(dim=-1, keepdim=True) + 1e-8)

    token_ids = []

    for i in range(T):
        e = soft_norm[i]  # [d]
        sims = torch.matmul(token_embed_norm, e)  # [V]
        if top_k == 1:
            best_id = torch.argmax(sims).item()
            token_ids.append(best_id)
        else:
            topk = torch.topk(sims, k=top_k)
            best_id = topk.indices[0].item()
            token_ids.append(best_id)

    return token_ids


soft_prompt_learned = matcher.soft_prompt.detach().cpu()
alt_token_ids = project_soft_prompt_to_tokens(
    soft_prompt_learned, model, tokenizer, top_k=1
)
alt_prompt_str = tokenizer.decode(alt_token_ids, skip_special_tokens=True)

print("Reference prompt:", repr(reference_prompt))
print("Alternative prompt:", repr(alt_prompt_str))
print("Are they exactly equal?", reference_prompt == alt_prompt_str)

In [11]:
@torch.no_grad()
def hidden_sequence_for_prompt(
    prompt: str,
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    layer_idx: int,
    device: torch.device,
) -> torch.Tensor:
    """
    Get hidden states [T, d] at layer_idx for a text prompt.
    """
    enc = tokenizer(
        prompt,
        return_tensors="pt",
        padding=False,
        truncation=True,
    )
    input_ids = enc["input_ids"].to(device)
    outputs = model(
        input_ids=input_ids,
        output_hidden_states=True,
    )
    hs_layer = outputs.hidden_states[layer_idx][0]  # [T, d]
    return hs_layer.detach()


def sequence_cosine_similarity(
    H_a: torch.Tensor,
    H_b: torch.Tensor,
) -> float:
    """
    Average cosine similarity over token positions, after aligning lengths.
    """
    T = min(H_a.shape[0], H_b.shape[0])
    A = H_a[:T]
    B = H_b[:T]

    A_norm = A / (A.norm(dim=-1, keepdim=True) + 1e-8)
    B_norm = B / (B.norm(dim=-1, keepdim=True) + 1e-8)

    cos_t = (A_norm * B_norm).sum(dim=-1)  # [T]
    return cos_t.mean().item()


# 1) Reference hidden states (already have H_ref)
H_ref_eval = H_ref.to(device)

# 2) Continuous soft prompt hidden states
with torch.no_grad():
    H_soft_continuous, _ = matcher()     # [T_soft, d]
    H_soft_continuous = H_soft_continuous.detach()

# 3) Discrete alternative prompt hidden states
H_alt_discrete = hidden_sequence_for_prompt(
    alt_prompt_str,
    model,
    tokenizer,
    layer_idx=layer_idx,
    device=device,
)

sim_ref_soft = sequence_cosine_similarity(H_ref_eval, H_soft_continuous)
sim_ref_alt = sequence_cosine_similarity(H_ref_eval, H_alt_discrete)

print("Average cosine similarity with reference states:")
print(f"  Continuous soft prompt: {sim_ref_soft:.4f}")
print(f"  Discrete alternative prompt: {sim_ref_alt:.4f}")

In [12]:
random_prompt = "This is a totally random unrelated sentence about weather."

H_random = hidden_sequence_for_prompt(
    random_prompt,
    model,
    tokenizer,
    layer_idx=layer_idx,
    device=device,
)

sim_ref_random = sequence_cosine_similarity(H_ref_eval, H_random)

print("Random baseline prompt:", repr(random_prompt))
print(f"  Similarity(ref, random)   = {sim_ref_random:.4f}")
print(f"  Similarity(ref, alt)      = {sim_ref_alt:.4f}")
print(f"  Similarity(ref, soft)     = {sim_ref_soft:.4f}")

In [16]:
# If you already defined generate_from_prompt earlier, you can delete this block.
@torch.no_grad()
def generate_from_prompt_local(
    prompt: str,
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    device: torch.device,
    max_new_tokens: int = 80,
    temperature: float = 0.8,
    repetition_penalty: float = 1.1,
    no_repeat_ngram_size: int = 3,
) -> str:
    enc = tokenizer(prompt, return_tensors="pt").to(device)
    out_ids = model.generate(
        **enc,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        temperature=temperature,
        pad_token_id=tokenizer.eos_token_id,
        repetition_penalty=repetition_penalty,
        no_repeat_ngram_size=no_repeat_ngram_size,
    )
    return tokenizer.decode(out_ids[0], skip_special_tokens=True)

def project_state_to_tokens():
    with torch.no_grad():   
        soft_prompt_learned = matcher.soft_prompt.detach().cpu()
        alt_token_ids = project_soft_prompt_to_tokens(
            soft_prompt_learned,
            model,
            tokenizer,
            top_k=1,
        )
        alt_prompt_str = tokenizer.decode(alt_token_ids)

    return alt_prompt_str

# print("=== Reference prompt ===")
# print("Prompt:", repr(reference_prompt))
# ref_gen = generate_from_prompt_local(
#     reference_prompt,
#     model,
#     tokenizer,
#     device=device,
# )
# print("Generation:\n", ref_gen)

# print("\n=== Alternative (matched) prompt ===")
# print("Prompt:", repr(alt_prompt_str))
# alt_gen = generate_from_prompt_local(
#     alt_prompt_str,
#     model,
#     tokenizer,
#     device=device,
# )
# print("Generation:\n", alt_gen)
# print(alt_gen == ref_gen)


In [None]:
# === Layer sweep with fixed seeds, early stopping, and generation ===

import random
import numpy as np
import torch

target_layers = range(0, 26, 2)

steps = 1000
lr = 1e-2

n_tokens_to_generate = 80      # how many *new* tokens to generate
lower_loss_threshold = 1e-4    # early stopping threshold

layer_prompts = {}
layer_histories = {}

# Global base seed – controls everything
base_seed = 4738

for layer_idx in target_layers:
    # ----- set seeds so this run is fully reproducible -----
    # If you want the *same* seed for every layer, just use base_seed
    # instead of base_seed + layer_idx.
    layer_seed = base_seed + layer_idx
    reached_original_prompt = False

    random.seed(layer_seed)
    np.random.seed(layer_seed)
    torch.manual_seed(layer_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(layer_seed)

    print("\n" + "=" * 30)
    print(f" Optimizing soft prompt for layer {layer_idx} (seed={layer_seed})")
    print("=" * 30)

    # 1) Reference hidden states for this layer
    H_ref_layer = get_hidden_states_for_prompt(
        reference_prompt,
        model,
        tokenizer,
        layer_idx=layer_idx,
        device=device,
    )
    T_ref, d_model = H_ref_layer.shape
    print(f"Layer {layer_idx}: H_ref shape = {H_ref_layer.shape} (T={T_ref}, d={d_model})")

    # 2) Build matcher for this layer
    matcher = SoftPromptMatcher(
        model=model,
        tokenizer=tokenizer,
        d_model=d_model,
        prompt_length=prompt_length,
        layer_idx=layer_idx,
        H_ref=H_ref_layer,
        device=device,
        banned_words=banned_words,
        banned_weight=banned_weight,
    ).to(device)

    optimizer = torch.optim.Adam([matcher.soft_prompt], lr=lr)
    history = []

    # 3) Optimize with early stopping
    for step in range(steps):
        optimizer.zero_grad()
        _, loss = matcher()
        loss.backward()
        optimizer.step()

        loss_value = loss.item()
        history.append(loss_value)

        current_prompt = project_state_to_tokens()

        if step % max(1, steps // 20) == 0 or step == steps - 1:
            optimizer.zero_grad()
            print(f"[layer {layer_idx:2d} | step {step:4d}] loss={loss_value:.6f} | {current_prompt}")

        if loss_value < lower_loss_threshold:
            print(
                f"--Early stopping on step {step}--\n"
                f"Loss of {loss_value:.6f} < {lower_loss_threshold}"
            )
            break

        elif current_prompt == "<bos>Talk about cats.":
            print(
                f"--Early stopping on step {step}--\n"
                f"Model reached original prompt"
            )
            reached_original_prompt = True
            break

    if reached_original_prompt:
        continue

    layer_histories[layer_idx] = history

    # 4) Project learned soft prompt to discrete tokens and decode
    with torch.no_grad():
        soft_prompt_learned = matcher.soft_prompt.detach().cpu()
        alt_token_ids = project_soft_prompt_to_tokens(
            soft_prompt_learned,
            model,
            tokenizer,
            top_k=1,
        )
        alt_prompt_str = tokenizer.decode(alt_token_ids)

    layer_prompts[layer_idx] = alt_prompt_str

    print(f"\nLayer {layer_idx} alternative prompt:")
    print(repr(alt_prompt_str))

    # 5) Generate n tokens after this alternative starter prompt
    # Re-seed before generation to keep sampling reproducible as well
    random.seed(layer_seed)
    np.random.seed(layer_seed)
    torch.manual_seed(layer_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(layer_seed)

    alt_gen = generate_from_prompt_local(
        alt_prompt_str,
        model,
        tokenizer,
        device=device,
        max_new_tokens=n_tokens_to_generate,
    )
    print(f"\nSample generation from layer {layer_idx} prompt "
          f"(first {n_tokens_to_generate} new tokens):")
    print(alt_gen)
