In [1]:
import math
from dataclasses import dataclass

import torch
import torch.nn as nn
from torch.nn import functional as F

from transformers import AutoModelForCausalLM, AutoTokenizer


In [2]:
# Choose device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cuda


In [3]:
@dataclass
class Config:
    model_name: str = "google/gemma-2-2b-it"

    # Base prompt whose internal state we want to imitate
    base_prompt: str = "Talk about cats."

    # Starting point for the *other* prompt (must be different text)
    seed_prompt: str = "Write a short poem about the ocean."

    # Which layer's activations to match.
    # -1 = final layer; 0 = embedding layer; 1..n_layers-1 = internal layers
    target_layer_index: int = -1

    # How long prompts are (in tokens). We’ll crop/pad both to this length.
    seq_len: int | None = None  # if None, use length of base prompt tokens

    # Optimization hyperparameters
    num_steps: int = 400
    lr: float = 5e-2
    weight_decay: float = 1e-4

    # How often to log during optimization
    log_every: int = 50

    # Loss type: "mse" or "cosine"
    loss_type: str = "mse"

    # L2 regularization towards the initial seed embeddings
    lambda_reg: float = 1e-3

    # Random seed
    seed: int = 0

cfg = Config()


In [4]:
torch.manual_seed(cfg.seed)

tokenizer = AutoTokenizer.from_pretrained(cfg.model_name)
model = AutoModelForCausalLM.from_pretrained(
    cfg.model_name,
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
    device_map="auto" if torch.cuda.is_available() else None,
)
model.eval()
for p in model.parameters():
    p.requires_grad = False

print("Model loaded.")


tokenizer_config.json:   0%|          | 0.00/47.0k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/838 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/241M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/187 [00:00<?, ?B/s]

Model loaded.


In [5]:
def encode_fixed_length(text: str, seq_len: int | None = None):
    """Tokenize a string and return input_ids of fixed length."""
    tokens = tokenizer(
        text,
        add_special_tokens=True,
        return_tensors="pt",
    )["input_ids"][0]  # shape: [L]

    if seq_len is None:
        seq_len = tokens.shape[0]

    # Pad or crop to seq_len
    eos_id = tokenizer.eos_token_id
    if tokens.shape[0] < seq_len:
        pad = torch.full((seq_len - tokens.shape[0],), eos_id, dtype=tokens.dtype)
        tokens = torch.cat([tokens, pad], dim=0)
    elif tokens.shape[0] > seq_len:
        tokens = tokens[:seq_len]

    return tokens, seq_len


In [6]:
def get_hidden_flat(
    input_ids: torch.Tensor,
    layer_index: int,
) -> torch.Tensor:
    """
    Run the model on input_ids and return the hidden states at a given layer,
    flattened across positions.

    input_ids: [seq_len]
    layer_index: like cfg.target_layer_index (can be negative).
    Returns: [hidden_dim * seq_len] float tensor on device=cpu (for stability)
    """
    assert input_ids.ndim == 1
    input_ids = input_ids.unsqueeze(0).to(device)  # [1, seq_len]
    attention_mask = torch.ones_like(input_ids, device=device)

    with torch.no_grad():
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            use_cache=False,
        )

    hidden_states = outputs.hidden_states  # length = n_layers + 1
    h = hidden_states[layer_index]  # [1, seq_len, d_model]
    h = h[0]  # [seq_len, d_model]
    return h.flatten().detach().cpu()  # [seq_len * d_model]


In [7]:
# Decide sequence length
base_ids_raw = tokenizer(cfg.base_prompt, add_special_tokens=True, return_tensors="pt")["input_ids"][0]
if cfg.seq_len is None:
    cfg.seq_len = base_ids_raw.shape[0]

print("Using seq_len =", cfg.seq_len)

base_ids, _ = encode_fixed_length(cfg.base_prompt, cfg.seq_len)
seed_ids, _ = encode_fixed_length(cfg.seed_prompt, cfg.seq_len)

print("Base prompt tokens:", base_ids.tolist())
print("Seed prompt tokens:", seed_ids.tolist())
print("Base prompt text:", tokenizer.decode(base_ids))
print("Seed prompt text:", tokenizer.decode(seed_ids))


Using seq_len = 5
Base prompt tokens: [2, 27586, 1105, 19493, 235265]
Seed prompt tokens: [2, 5559, 476, 3309, 19592]
Base prompt text: <bos>Talk about cats.
Seed prompt text: <bos>Write a short poem


In [8]:
target_vec = get_hidden_flat(base_ids, cfg.target_layer_index)
print("Target activation vector shape:", target_vec.shape)


Target activation vector shape: torch.Size([11520])


In [9]:
embed_layer = model.get_input_embeddings()  # nn.Embedding
d_model = embed_layer.weight.shape[1]

def ids_to_embeds(ids: torch.Tensor) -> torch.Tensor:
    """Convert [seq_len] token IDs to [1, seq_len, d_model] embeddings."""
    return embed_layer(ids.to(device)).unsqueeze(0)  # [1, seq_len, d_model]

with torch.no_grad():
    seed_embeds_init = ids_to_embeds(seed_ids)  # [1, seq_len, d_model]

# Trainable embeddings
soft_embeds = nn.Parameter(seed_embeds_init.clone())

optimizer = torch.optim.AdamW(
    [soft_embeds],
    lr=cfg.lr,
    weight_decay=cfg.weight_decay,
)


In [10]:
target_vec_device = target_vec.to(device)

def activation_loss(soft_embeds: torch.Tensor) -> torch.Tensor:
    """
    Compute loss between hidden states of soft_embeds and target_vec.
    soft_embeds: [1, seq_len, d_model], requires_grad=True
    """
    seq_len = soft_embeds.shape[1]
    attention_mask = torch.ones((1, seq_len), dtype=torch.long, device=device)

    outputs = model(
        inputs_embeds=soft_embeds.to(model.dtype),
        attention_mask=attention_mask,
        output_hidden_states=True,
        use_cache=False,
    )

    hidden_states = outputs.hidden_states
    h = hidden_states[cfg.target_layer_index][0]  # [seq_len, d_model]
    h_flat = h.flatten()

    if cfg.loss_type == "mse":
        act_loss = F.mse_loss(h_flat, target_vec_device)
    elif cfg.loss_type == "cosine":
        act_loss = 1.0 - F.cosine_similarity(h_flat, target_vec_device, dim=0)
    else:
        raise ValueError(f"Unknown loss_type: {cfg.loss_type}")

    # Regularization towards initial embeddings
    reg = F.mse_loss(soft_embeds, seed_embeds_init.to(soft_embeds.device))
    total = act_loss + cfg.lambda_reg * reg
    return total, act_loss.detach(), reg.detach(), h_flat.detach()


In [11]:
history = {
    "total": [],
    "act": [],
    "reg": [],
    "cosine_to_target": [],
}

for step in range(1, cfg.num_steps + 1):
    optimizer.zero_grad()
    loss, act_loss, reg_loss, h_flat = activation_loss(soft_embeds)
    loss.backward()
    optimizer.step()

    # Track cosine similarity between current activations and target
    cos_sim = F.cosine_similarity(h_flat, target_vec_device, dim=0).item()

    history["total"].append(loss.item())
    history["act"].append(act_loss.item())
    history["reg"].append(reg_loss.item())
    history["cosine_to_target"].append(cos_sim)

    if step % cfg.log_every == 0 or step == 1 or step == cfg.num_steps:
        print(
            f"Step {step:4d} | total={loss.item():.4f} | "
            f"act={act_loss.item():.4f} | reg={reg_loss.item():.4f} | "
            f"cos={cos_sim:.4f}"
        )


Step    1 | total=3.2812 | act=3.2812 | reg=0.0000 | cos=0.7031
Step   50 | total=0.7969 | act=0.7969 | reg=0.0728 | cos=0.9336
Step  100 | total=0.3262 | act=0.3262 | reg=0.0698 | cos=0.9727
Step  150 | total=0.1816 | act=0.1816 | reg=0.0664 | cos=0.9844
Step  200 | total=0.1216 | act=0.1216 | reg=0.0645 | cos=0.9844
Step  250 | total=0.0913 | act=0.0913 | reg=0.0630 | cos=0.9922
Step  300 | total=0.0732 | act=0.0732 | reg=0.0620 | cos=0.9922
Step  350 | total=0.0618 | act=0.0618 | reg=0.0613 | cos=0.9922
Step  400 | total=0.0566 | act=0.0566 | reg=0.0608 | cos=0.9883


In [12]:
@torch.no_grad()
def project_embeds_to_tokens(embeds: torch.Tensor) -> torch.Tensor:
    """
    embeds: [1, seq_len, d_model] (float)
    Returns: token_ids [seq_len] via nearest-neighbor in embedding space.
    """
    embeds = embeds[0]  # [seq_len, d_model]
    # embedding weight: [vocab_size, d_model]
    vocab_embeds = embed_layer.weight  # [V, d_model]

    # cosine similarity is often more stable than L2
    # [seq_len, V]
    sims = F.linear(F.normalize(embeds.float(), dim=-1), F.normalize(vocab_embeds.float(), dim=-1))

    token_ids = sims.argmax(dim=-1)  # [seq_len]
    return token_ids.cpu()

soft_embeds_opt = soft_embeds.detach().to(device)
opt_ids = project_embeds_to_tokens(soft_embeds_opt)
opt_text = tokenizer.decode(opt_ids, skip_special_tokens=False)

print("=== Original base prompt ===")
print(cfg.base_prompt)
print("\n=== Seed prompt (initial) ===")
print(cfg.seed_prompt)
print("\n=== Optimized discrete prompt (decoded) ===")
print(opt_text)


=== Original base prompt ===
Talk about cats.

=== Seed prompt (initial) ===
Write a short poem about the ocean.

=== Optimized discrete prompt (decoded) ===
<bos>var a short讲话


In [13]:
opt_vec = get_hidden_flat(opt_ids, cfg.target_layer_index)

with torch.no_grad():
    mse = F.mse_loss(opt_vec.to(device), target_vec_device).item()
    cos = F.cosine_similarity(opt_vec.to(device), target_vec_device, dim=0).item()

print("Activation distance (MSE) between base and optimized discrete prompt:", mse)
print("Activation cosine similarity between base and optimized discrete prompt:", cos)


Activation distance (MSE) between base and optimized discrete prompt: 15.75
Activation cosine similarity between base and optimized discrete prompt: 0.2333984375


In [14]:
def generate_from_prompt(text: str, max_new_tokens: int = 64):
    inputs = tokenizer(text, return_tensors="pt").to(device)
    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
        )
    return tokenizer.decode(out[0], skip_special_tokens=True)

print("\n=== Model output on base prompt ===")
print(generate_from_prompt(cfg.base_prompt))

print("\n=== Model output on optimized discrete prompt ===")
print(generate_from_prompt(opt_text))



=== Model output on base prompt ===
Talk about cats.

Cats are fascinating creatures. They are known for their independence, grace, and playful nature. 

Here are some interesting facts about cats:

* **Cats have excellent night vision.** Their eyes have a special reflective layer that allows them to see in low light conditions.
* **Cats are natural hunters.** They

=== Model output on optimized discrete prompt ===
var a short讲话稿，关于如何利用人工智能技术，提升企业竞争力。

##  AI: The New Engine of Business Growth

**Introduction:**

Good morning everyone. Today, I'm here to talk about a topic that's shaping the future of business: Artificial Intelligence.  

**The Power of AI:**


In [15]:
print(f">{opt_text}<")

><bos>var a short讲话<
