In [1]:
!pip install torch transformers scikit-learn ripser scipy umap-learn

Collecting ripser
  Downloading ripser-0.6.12-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.7 kB)
Collecting persim (from ripser)
  Downloading persim-0.3.8-py3-none-any.whl.metadata (3.8 kB)
Collecting deprecated (from persim->ripser)
  Downloading Deprecated-1.2.18-py2.py3-none-any.whl.metadata (5.7 kB)
Collecting hopcroftkarp (from persim->ripser)
  Downloading hopcroftkarp-1.2.5.tar.gz (16 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Downloading ripser-0.6.12-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (827 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m827.3/827.3 kB[0m [31m47.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading persim-0.3.8-py3-none-any.whl (48 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m48.6/48.6 kB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading Deprecated-1.2.18-py2.py3-none-any.whl (10.0 kB)
Building wheels for collected packages: hopcroftkarp
  Building w

In [2]:
# ARM_transformer_scaffold.py
# Requires: torch, transformers, numpy, scikit-learn, ripser, scipy (install via pip)
# pip install torch transformers scikit-learn ripser scipy umap-learn

import torch
import torch.nn.functional as F
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
from sklearn.decomposition import PCA
from sklearn.neighbors import kneighbors_graph
from sklearn.manifold import spectral_embedding
from ripser import ripser
from sklearn.metrics import pairwise_distances
from typing import List, Tuple, Dict, Any
import math

# -----------------------
# Configuration / defaults
# -----------------------
MODEL_NAME = "distilgpt2"   # small, efficient; switch to "gpt2" if you prefer
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ARM hyperparams (safe defaults)
N_SEEDS = 200
PROBES_PER_SEED = 16
STEPS_PER_PROBE = 9
EPS = 0.03                 # perturbation magnitude (relative to hidden vector norm)
LAYER_TO_PROBE = 6         # index of transformer block to inject perturbations (0-based)
NEIGHBOR_PCA_SAMPLES = 128 # for local PCA when available

# -----------------------
# Utilities: load model
# -----------------------
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, output_hidden_states=True).to(DEVICE)
model.eval()

# Helper: get token ids and attention mask
def encode_prompt(prompt: str):
    toks = tokenizer(prompt, return_tensors="pt")
    return toks["input_ids"].to(DEVICE), toks["attention_mask"].to(DEVICE)

# -----------------------
# Core: run forward from a chosen layer (block-wise)
# -----------------------
# We'll use the model.transformer.* components directly so we can inject altered hidden states.
# For distilgpt2/gpt2 HF models, the transformer body is model.transformer consisting of:
# - wte (token embeddings), wpe (position embeddings), drop, and h = list of blocks, ln_f.
#
# Strategy:
# 1) Build initial hidden states (token embeddings + positions) up to the layer to probe.
# 2) Optionally modify the residual stream at that layer (add delta).
# 3) Run remaining transformer blocks from that layer onward to get final logits/hidden states.

def build_initial_hidden(input_ids: torch.LongTensor):
    # returns hidden states BEFORE block 0 (embedding+pos), shape (batch, seq_len, d_model)
    wte = model.transformer.wte(input_ids)        # token embeddings
    seq_len = input_ids.shape[1]
    position_ids = torch.arange(seq_len, dtype=torch.long, device=DEVICE).unsqueeze(0)
    wpe = model.transformer.wpe(position_ids)
    hidden = wte + wpe  # shape batch x seq x d_model
    hidden = model.transformer.drop(hidden)
    return hidden

def forward_from_layer(hidden: torch.Tensor, start_layer: int, attention_mask: torch.Tensor=None):
    """
    hidden: (batch, seq, d_model) hidden state to feed to block start_layer
    returns: final logits, final hidden, and list of intermediate hidden states (per layer)
    """
    h = hidden
    intermediates = []
    # blocks are modules in model.transformer.h (list-like)
    for i, block in enumerate(model.transformer.h):
        if i < start_layer:
            continue
        h = block(h)[0] if isinstance(block(h), tuple) else block(h)
        intermediates.append(h)
    # final layer norm
    h = model.transformer.ln_f(h)
    # lm head (tie weights with wte)
    # reshape for lm head: (batch*seq, d_model)
    logits = F.linear(h, model.transformer.wte.weight)  # tied weights
    return logits, h, intermediates

# -----------------------
# Seed / probe generation
# -----------------------
def get_seed_hidden(prompt: str, layer_idx: int) -> torch.Tensor:
    """
    Returns hidden state at layer_idx just BEFORE running block layer_idx.
    shape: (seq_len, d_model) - batch dim removed for simplicity
    """
    input_ids, attn_mask = encode_prompt(prompt)
    hidden = build_initial_hidden(input_ids)  # batch x seq x d
    # run blocks up to layer_idx-1 to get hidden state to modify
    h = hidden
    for i, block in enumerate(model.transformer.h):
        if i >= layer_idx:
            break
        h = block(h)[0] if isinstance(block(h), tuple) else block(h)
    # h is batch x seq x d; return squeeze(0)
    return h.squeeze(0).detach().cpu()  # move to CPU numpy-friendly

def sample_probes_for_hidden(hidden_vec: np.ndarray, k: int = PROBES_PER_SEED, eps: float = EPS):
    """
    hidden_vec: (seq_len, d) array (we'll flatten sequence dimension to treat as a single vector or pool)
    Return: probe_deltas shape (k, d) or (k, seq_len, d)
    Approach: get global direction sampling in hidden-space.
    - For simplicity start with isotropic Gaussian directions normalized,
      then scale to magnitude eps * ||hidden_vec|| (per token or pooled).
    """
    # pool hidden to a single vector per seed (mean over tokens) for direction construction,
    # but we will expand deltas per token when injecting.
    pooled = hidden_vec.mean(axis=0)   # (d,)
    d = pooled.shape[0]
    rng = np.random.default_rng()
    dirs = rng.normal(size=(k, d))
    dirs = dirs / (np.linalg.norm(dirs, axis=1, keepdims=True) + 1e-12)
    hidden_norm = np.linalg.norm(pooled) + 1e-12
    scale = eps * hidden_norm
    dirs = dirs * scale
    return dirs  # (k, d)

def expand_delta_to_sequence(delta_vec: np.ndarray, seq_len: int):
    # replicate delta_vec for each token position (simple approach)
    return np.tile(delta_vec[None, :], (seq_len, 1))  # (seq_len, d)

# -----------------------
# Probe path: generate small path along a direction
# -----------------------
def build_probe_path(hidden_base: np.ndarray, dir_vec: np.ndarray, steps: int = STEPS_PER_PROBE, tau: float = 1.0):
    """
    hidden_base: (seq_len, d)
    dir_vec: (d,) pooled direction; will be expanded across seq positions
    Returns: list of perturbed hidden tensors (steps long)
    """
    seq_len = hidden_base.shape[0]
    dir_seq = expand_delta_to_sequence(dir_vec, seq_len)  # (seq_len, d)
    ts = np.linspace(-tau, tau, steps)
    path = [hidden_base + (t * dir_seq) for t in ts]
    return path, ts

# -----------------------
# Activation / response collection
# -----------------------
def activation_matrix_for_seed(prompt: str, layer_idx: int, k: int = PROBES_PER_SEED, m: int = STEPS_PER_PROBE, eps: float = EPS):
    """
    For one seed prompt, sample k probes, each with m steps; forward from layer_idx
    Collect features for each sample (e.g., final logits pooled, or final hidden pooled)
    Return: A matrix of shape (k*m, f) for downstream analysis.
    """
    hidden_base = get_seed_hidden(prompt, layer_idx).numpy()  # (seq_len, d)
    seq_len, d = hidden_base.shape
    deltas = sample_probes_for_hidden(hidden_base, k=k, eps=eps)
    rows = []
    for j in range(k):
        path, ts = build_probe_path(hidden_base, deltas[j], steps=m)
        for hidden_pert in path:
            # run from layer_idx with this perturbed hidden
            # convert to tensor with batch dim
            h_t = torch.tensor(hidden_pert[None, :, :], dtype=torch.float32, device=DEVICE)
            logits, final_h, intermediates = forward_from_layer(h_t, start_layer=layer_idx, attention_mask=None)
            # choose feature vector to represent response:
            # Option A: pooled logits over last token
            # last_token_logits = logits[0, -1, :].detach().cpu().numpy()  # (vocab,)
            # Option B (more compact): mean-pooled final hidden representation
            feat = final_h.squeeze(0).mean(dim=0).detach().cpu().numpy()  # (d,)
            rows.append(feat)
    A = np.stack(rows, axis=0)  # (k*m, f) where f == d in this choice
    return A

# -----------------------
# Resonance signature (SVD-based)
# -----------------------
def resonance_signature(A: np.ndarray, n_modes: int = 8) -> Dict[str, Any]:
    """
    Compute SVD stats and compact resonance signature for activation matrix A (n_samples x f).
    Returns dict with normalized singular values, entropy, participation ratio, top modes.
    """
    # center
    A0 = A - A.mean(axis=0, keepdims=True)
    # SVD (economy)
    U, s, Vt = np.linalg.svd(A0, full_matrices=False)
    s = np.maximum(s, 1e-12)
    s_norm = s / s.sum()
    entropy = -np.sum(s_norm * np.log(s_norm + 1e-12))
    # participation ratio (measure of mode concentration)
    pr = (s**2).sum()**2 / (np.sum(s**4) + 1e-12)
    sig = {
        "singular_values": s[:n_modes],
        "s_norm": s_norm[:n_modes],
        "entropy": float(entropy),
        "participation": float(pr),
        # optionally return top singular vectors (Vt[:n_modes,:]) if needed
    }
    return sig

# -----------------------
# Local topology via persistent homology
# -----------------------
def local_persistence_diagram(A: np.ndarray, maxdim: int = 1) -> Dict[str, Any]:
    """
    Compute persistence diagrams from the sample points A (n_points x f).
    Use pairwise distances -> ripser with distance matrix True.
    Returns ripser output (dgms).
    """
    # compute pairwise distances to reduce memory in ripser call
    D = pairwise_distances(A)
    r = ripser(D, distance_matrix=True, maxdim=maxdim)
    dgms = r["dgms"]  # list of arrays for dimensions [0], [1], ...
    return {"diagrams": dgms}

# -----------------------
# Descriptor assembly for one seed
# -----------------------
def descriptor_for_prompt(prompt: str, layer_idx: int):
    """
    Run probes, compute A, then compute resonance signature + persistence.
    Return a compact descriptor dict and flattened vector for graph building.
    """
    A = activation_matrix_for_seed(prompt, layer_idx)
    R = resonance_signature(A)
    PD = local_persistence_diagram(A)
    # flatten descriptor to a vector: use top-n singular values + entropy + participation + persistence stats
    top_sv = R["s_norm"][:6]
    entropy = R["entropy"]
    part = R["participation"]
    # summary persistence features: count of significant 1D features (persistence > threshold)
    d1 = PD["diagrams"][1] if len(PD["diagrams"]) > 1 else np.zeros((0,2))
    pers_threshold = 0.05 * np.max(pairwise_distances(A))  # heuristic
    n_1d_significant = np.sum((d1[:,1] - d1[:,0]) > pers_threshold) if d1.size else 0
    vec = np.concatenate([top_sv, [entropy, part, n_1d_significant]])
    return {"A": A, "R": R, "PD": PD, "vec": vec, "prompt": prompt}

# -----------------------
# Build global atlas from many seeds
# -----------------------
def build_atlas(prompts: List[str], layer_idx: int, n_neighbors: int = 8):
    descriptors = []
    vecs = []
    for p in prompts:
        d = descriptor_for_prompt(p, layer_idx)
        descriptors.append(d)
        vecs.append(d["vec"])
    X = np.stack(vecs, axis=0)  # n_seeds x dim
    # kNN graph adjacency (distance)
    W = kneighbors_graph(X, n_neighbors=n_neighbors, mode="distance", include_self=False).toarray()
    # spectral embedding for visualization
    emb = spectral_embedding(W + W.T, n_components=3)
    return {"descriptors": descriptors, "X": X, "W": W, "emb": emb}

# -----------------------
# Simple iterative (greedy) proximal steering operator
# -----------------------
def steer_toward_resonance(seed_prompt: str, target_signature: np.ndarray, layer_idx: int, iters: int = 6, candidates: int = 12):
    """
    Greedy search: at each step propose candidate deltas (random + PCA directions), evaluate resulting resonance distance to target,
    choose the best, and update the hidden state.
    This is a simple gradient-free proximal operator demo.
    """
    # start hidden
    hidden_base = get_seed_hidden(seed_prompt, layer_idx).numpy()  # (seq_len, d)
    seq_len, d = hidden_base.shape
    current_hidden = hidden_base.copy()
    for it in range(iters):
        # propose candidates
        rng = np.random.default_rng()
        cand_dirs = rng.normal(size=(candidates, d))
        cand_dirs = cand_dirs / (np.linalg.norm(cand_dirs, axis=1, keepdims=True) + 1e-12)
        scales = np.linspace(-EPS, EPS, 5)
        best_score = float("inf")
        best_hidden = None
        for cd in cand_dirs:
            for s in scales:
                delta = cd * s * (np.linalg.norm(current_hidden.mean(axis=0)) + 1e-12)
                path, _ = build_probe_path(current_hidden, delta, steps=STEPS_PER_PROBE)
                # evaluate signature for small immediate perturbation: use midpoint
                test_hidden = path[len(path)//2]
                # forward and compute final pooled hidden (cheap shortcut)
                h_t = torch.tensor(test_hidden[None, :, :], dtype=torch.float32, device=DEVICE)
                _, final_h, _ = forward_from_layer(h_t, start_layer=layer_idx)
                feat = final_h.squeeze(0).mean(dim=0).detach().cpu().numpy()
                # compute simple proxy signature: projection on top eigenvector (cheap)
                # Here we create a tiny matrix with just this feat to plug into resonance_signature (works but trivial)
                sig = resonance_signature(np.stack([feat], axis=0))
                # distance: compare sig["s_norm"] to target_signature (assumed same length)
                cand_vec = sig["s_norm"][:len(target_signature)]
                score = np.linalg.norm(cand_vec - target_signature)
                if score < best_score:
                    best_score = score
                    best_hidden = test_hidden
        # apply best_hidden as new current_hidden (proximal step)
        if best_hidden is None:
            break
        current_hidden = best_hidden
    # produce final text by forwarding from layer with current_hidden
    h_t = torch.tensor(current_hidden[None, :, :], dtype=torch.float32, device=DEVICE)
    logits, final_h, _ = forward_from_layer(h_t, start_layer=layer_idx)
    # decode greedy token for next token
    next_token = torch.argmax(logits[0, -1, :]).item()
    return tokenizer.decode([next_token]), best_score

# -----------------------
# Example usage
# -----------------------
if __name__ == "__main__":
    # quick test prompts
    prompts = [
        "The capital of France is",
        "The capital of Germany is",
        "I love reading about physics because",
        "The chef seasoned the soup with",
        "Quantum entanglement is best described as"
    ]
    # compute atlas (descriptors may be somewhat slow; reduce N_SEEDS for testing)
    atlas = build_atlas(prompts, layer_idx=LAYER_TO_PROBE, n_neighbors=3)
    print("Spectral embedding shape:", atlas["emb"].shape)
    # pick a seed and compute its descriptor
    d = descriptor_for_prompt("The capital of France is", layer_idx=LAYER_TO_PROBE)
    print("Descriptor vector", d["vec"])
    # Example target signature for steering (pick seed of "Paris-like" vector)
    target_sig = d["R"]["s_norm"][:6]
    out_token, score = steer_toward_resonance("The capital of Ger", target_sig, layer_idx=LAYER_TO_PROBE)
    print("Steered next-token:", out_token, "score:", score)


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/762 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/353M [00:00<?, ?B/s]

The following generation flags are not valid and may be ignored: ['output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Spectral embedding shape: (5, 3)
Descriptor vector [0.13699271 0.09073847 0.08264589 0.07265289 0.06472936 0.06118379
 2.72268438 8.63142681 0.        ]
Steered next-token: hard score: 2.242578
