# ALSI: Phi Projector Training (Phase 2)

This notebook trains a non-linear projector (Phi) to perform Augmented Latent State Injection (ALSI) on Mamba-2.

## 1. Setup Environment


In [1]:
!pip install -q transformers accelerate einops


## 2. Imports and Configuration


In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
import pickle

model_id = "AntonV/mamba2-130m-hf"
LAYER_IDX = 7
PROBE_TEXT = "The password is '"
TARGETS = ["BLUE", "RED", "GREEN", "ORANGE", "YELLOW", "BLACK", "WHITE", "PURPLE", "GOLD", "SILVER"]
DATA_DIR = "./ALSI_data"
os.makedirs(DATA_DIR, exist_ok=True)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")


Using device: cuda


## 3. Dataset Generation (Ground Truth Optimization)

We first find the 'Golden Deltas' for each target via direct optimization.


In [4]:
def optimize_control_delta(model, tokenizer, h_prev_cache, target_str, steps=200, lr=1.0):
    target_id = tokenizer.encode(target_str, add_special_tokens=False)[0]
    print(f"Optimizing for target: {target_str} (ID: {target_id})")
    
    base_state_shape = h_prev_cache.ssm_states[LAYER_IDX].shape
    delta = torch.zeros(base_state_shape, device=model.device, requires_grad=True)
    optimizer = torch.optim.Adam([delta], lr=lr)
    
    last_token_id = tokenizer(PROBE_TEXT, return_tensors="pt").input_ids[:, -1:].to(device)
    cache_pos = torch.tensor([tokenizer(PROBE_TEXT, return_tensors="pt").input_ids.shape[1]-1], device=model.device)
    
    best_rank = 99999
    
    for step in range(steps):
        optimizer.zero_grad()
        
        base_states_tensor = h_prev_cache.ssm_states.detach()
        layers_list = [base_states_tensor[i] for i in range(model.config.num_hidden_layers)]
        layers_list[LAYER_IDX] = layers_list[LAYER_IDX] + delta
        injected_ssm_states = torch.stack(layers_list)
        
        class MockCache:
            def __init__(self, ssm, conv):
                self.ssm_states = ssm
                self.conv_states = conv
                self.config = model.config
                self.conv_kernel_size = model.config.conv_kernel
            def update_ssm_state(self, layer_idx, new_ssm_state, cache_init=False): return
            def update_conv_state(self, layer_idx, new_conv_state, cache_init=False): return
            
        diff_cache = MockCache(injected_ssm_states, h_prev_cache.conv_states)
        outputs = model(last_token_id, cache_params=diff_cache, cache_position=cache_pos)
        logits = outputs.logits[0, -1]
        
        # Cross Entropy Loss
        loss = torch.nn.functional.cross_entropy(logits.view(1, -1), torch.tensor([target_id], device=model.device)) + 1e-5 * delta.norm()
        
        loss.backward()
        optimizer.step()
        
        with torch.no_grad():
            rank = (logits > logits[target_id]).sum().item() + 1
            if rank < best_rank: best_rank = rank
            if rank <= 5: return delta.detach().cpu(), target_id, True
                
    return delta.detach().cpu(), target_id, best_rank <= 10

def generate_dataset(model, tokenizer):
    probe_ids = tokenizer(PROBE_TEXT, return_tensors="pt").input_ids.to(device)
    context_ids = probe_ids[:, :-1]
    with torch.no_grad():
        out_prev = model(context_ids, use_cache=True)
    h_prev_cache = out_prev.cache_params
    
    h_prev_tensor = h_prev_cache.ssm_states[LAYER_IDX].detach().cpu()
    dataset = []
    for target in TARGETS:
        delta, t_id, success = optimize_control_delta(model, tokenizer, h_prev_cache, target)
        if success:
            dataset.append({"h_prev": h_prev_tensor, "target_id": t_id, "delta": delta, "target_str": target})
    return dataset


## 4. Phi Projector Architecture


In [5]:
class PhiProjector(nn.Module):
    def __init__(self, state_dim, embed_dim, hidden_dim=1024):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim + embed_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, state_dim)
        )
        
    def forward(self, h_prev, target_embed):
        x = torch.cat([h_prev, target_embed], dim=-1)
        return self.net(x)


## 5. Main Execution Loop


In [6]:
print(f"Loading model {model_id}...")
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_id)

# 1. Generate Dataset
dataset = generate_dataset(model, tokenizer)
print(f"Generated {len(dataset)} samples.")

# 2. Initialize Phi
embedding_layer = model.get_input_embeddings()
sample_h = dataset[0]['h_prev']
state_dim = sample_h.numel()
embed_dim = embedding_layer.weight.shape[1]
phi = PhiProjector(state_dim, embed_dim).to(device)
optimizer = optim.Adam(phi.parameters(), lr=1e-4)

# 3. Training Loop
probe_ids = tokenizer(PROBE_TEXT, return_tensors="pt").input_ids.to(device)
context_ids = probe_ids[:, :-1]
last_token_id = probe_ids[:, -1:]
with torch.no_grad():
    out_prev = model(context_ids, use_cache=True)
base_cache_struct = out_prev.cache_params

for epoch in range(101):
    total_l = 0
    for sample in dataset:
        optimizer.zero_grad()
        t_id = sample['target_id']
        gt_delta = sample['delta'].to(device)
        h_prev = sample['h_prev'].to(device).view(1, -1)
        t_embed = embedding_layer(torch.tensor([[t_id]], device=device)).view(1, -1)
        
        pred_delta_flat = phi(h_prev, t_embed)
        pred_delta = pred_delta_flat.view(sample_h.shape)
        
        # Match Loss
        l_match = nn.functional.mse_loss(pred_delta, gt_delta)
        
        # Control Loss
        base_states = base_cache_struct.ssm_states.detach().clone()
        layers_list = [base_states[i] for i in range(model.config.num_hidden_layers)]
        layers_list[LAYER_IDX] = layers_list[LAYER_IDX] + pred_delta
        
        class MockCache: # Re-defined for training
            def __init__(self, ssm, conv): self.ssm_states, self.conv_states, self.config, self.conv_kernel_size = ssm, conv, model.config, model.config.conv_kernel
            def update_ssm_state(self, *args, **kwargs): pass
            def update_conv_state(self, *args, **kwargs): pass
            
        diff_cache = MockCache(torch.stack(layers_list), base_cache_struct.conv_states)
        outputs = model(last_token_id, cache_params=diff_cache, cache_position=torch.tensor([context_ids.shape[1]], device=device))
        l_control = nn.functional.cross_entropy(outputs.logits[0, -1].view(1, -1), torch.tensor([t_id], device=device))
        
        loss = l_control + 1.0 * l_match
        loss.backward()
        optimizer.step()
        total_l += loss.item()
        
    if epoch % 20 == 0: print(f"Epoch {epoch}: Loss {total_l/len(dataset):.4f}")

# 4. Final Evaluation
phi.eval()
with torch.no_grad():
    for sample in dataset:
        t_embed = embedding_layer(torch.tensor([[sample['target_id']]], device=device)).view(1, -1)
        pred_delta = phi(sample['h_prev'].to(device).view(1, -1), t_embed).view(sample_h.shape)
        base_states = base_cache_struct.ssm_states.detach().clone()
        layers_list = [base_states[i] for i in range(model.config.num_hidden_layers)]
        layers_list[LAYER_IDX] = layers_list[LAYER_IDX] + pred_delta
        diff_cache = MockCache(torch.stack(layers_list), base_cache_struct.conv_states)
        logits = model(last_token_id, cache_params=diff_cache, cache_position=torch.tensor([context_ids.shape[1]], device=device)).logits[0, -1]
        rank = (logits > logits[sample['target_id']]).sum().item() + 1
        print(f"Target: {sample['target_str']} | Final Rank: {rank}")


Loading model AntonV/mamba2-130m-hf...


Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


config.json:   0%|          | 0.00/756 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/516M [00:00<?, ?B/s]

The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)` is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/587 [00:00<?, ?B/s]

Optimizing for target: BLUE (ID: 5993)
Optimizing for target: RED (ID: 34803)
Optimizing for target: GREEN (ID: 41127)
Optimizing for target: ORANGE (ID: 1372)
Optimizing for target: YELLOW (ID: 27413)
Optimizing for target: BLACK (ID: 5993)
Optimizing for target: WHITE (ID: 8835)
Optimizing for target: PURPLE (ID: 49)
Optimizing for target: GOLD (ID: 40)
Optimizing for target: SILVER (ID: 52)
Generated 10 samples.
Epoch 0: Loss 31.1053
Epoch 20: Loss 10.9237
Epoch 40: Loss 10.0749
Epoch 60: Loss 9.8845
Epoch 80: Loss 10.0150
Epoch 100: Loss 9.6845
Target: BLUE | Final Rank: 1
Target: RED | Final Rank: 5
Target: GREEN | Final Rank: 3
Target: ORANGE | Final Rank: 2
Target: YELLOW | Final Rank: 4
Target: BLACK | Final Rank: 1
Target: WHITE | Final Rank: 2
Target: PURPLE | Final Rank: 3
Target: GOLD | Final Rank: 5
Target: SILVER | Final Rank: 5


In [10]:
# --- PHASE 3: ROBUSTNESS & STABILITY (CORRECTED) ---
phi.eval()

# 1. Unseen Tokens (Zero-Shot)
UNSEEN_TARGETS = ["PINK", "CYAN", "BROWN", "NAVY", "EMERALD", "MAROON"]
print("--- 3a: Zero-Shot Generalization (Unseen Tokens) ---")

# Define h_prev globally from the dataset
h_prev_global = dataset[0]['h_prev']

with torch.no_grad():
    for target_str in UNSEEN_TARGETS:
        t_id = tokenizer.encode(target_str, add_special_tokens=False)[0]
        t_embed = embedding_layer(torch.tensor([[t_id]], device=device)).view(1, -1)

        # Predict delta for UNSEEN token
        pred_delta = phi(h_prev_global.to(device).view(1, -1), t_embed).view(sample_h.shape)

        # Inject
        base_states = base_cache_struct.ssm_states.detach().clone()
        layers_list = [base_states[i] for i in range(model.config.num_hidden_layers)]
        layers_list[LAYER_IDX] = layers_list[LAYER_IDX] + pred_delta

        diff_cache = MockCache(torch.stack(layers_list), base_cache_struct.conv_states)

        # Forward pass
        logits = model(last_token_id, cache_params=diff_cache,
                       cache_position=torch.tensor([context_ids.shape[1]], device=device)).logits[0, -1]

        rank = (logits > logits[t_id]).sum().item() + 1
        prob = torch.softmax(logits, dim=-1)[t_id].item()
        print(f"Unseen Target: {target_str:8} | Rank: {rank:2} | Prob: {prob:.4f}")

# 2. Generation Stability (The "Continuation" Test)
print("\n--- 3b: Generation Stability Test ---")
TARGET_TO_GENERATE = "BLUE"

with torch.no_grad():
    t_id = tokenizer.encode(TARGET_TO_GENERATE, add_special_tokens=False)[0]
    t_embed = embedding_layer(torch.tensor([[t_id]], device=device)).view(1, -1)

    # Use h_prev_global
    pred_delta = phi(h_prev_global.to(device).view(1, -1), t_embed).view(sample_h.shape)

    base_states = base_cache_struct.ssm_states.detach().clone()
    layers_list = [base_states[i] for i in range(model.config.num_hidden_layers)]
    layers_list[LAYER_IDX] = layers_list[LAYER_IDX] + pred_delta

    # We must use a real Mamba2Cache for generate() to work correctly
    from transformers.models.mamba2.modeling_mamba2 import Mamba2Cache
    gen_cache = Mamba2Cache(model.config, 1, device=device, dtype=model.dtype)
    gen_cache.ssm_states = torch.stack(layers_list)
    gen_cache.conv_states = base_cache_struct.conv_states.detach().clone()

    # Generate starting from the last token of the probe ("'")
    out = model.generate(input_ids=last_token_id, cache_params=gen_cache, max_new_tokens=20)
    print(f"Injected Target: {TARGET_TO_GENERATE}")
    print(f"Generated Text:  {tokenizer.decode(out[0])}")

--- 3a: Zero-Shot Generalization (Unseen Tokens) ---
Unseen Target: PINK     | Rank:  3 | Prob: 0.0919
Unseen Target: CYAN     | Rank: 285 | Prob: 0.0000
Unseen Target: BROWN    | Rank: 244 | Prob: 0.0000
Unseen Target: NAVY     | Rank: 1148 | Prob: 0.0000
Unseen Target: EMERALD  | Rank: 978 | Prob: 0.0000
Unseen Target: MAROON   | Rank: 3042 | Prob: 0.0000

--- 3b: Generation Stability Test ---
Injected Target: BLUE
Generated Text:   'I'm not sure I can do that.'

'I'm not sure you can either,'
