In [1]:
!pip install torch transformers scikit-learn ripser scipy umap-learn

Collecting ripser
  Downloading ripser-0.6.12-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.7 kB)
Collecting persim (from ripser)
  Downloading persim-0.3.8-py3-none-any.whl.metadata (3.8 kB)
Collecting deprecated (from persim->ripser)
  Downloading Deprecated-1.2.18-py2.py3-none-any.whl.metadata (5.7 kB)
Collecting hopcroftkarp (from persim->ripser)
  Downloading hopcroftkarp-1.2.5.tar.gz (16 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Downloading ripser-0.6.12-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (827 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m827.3/827.3 kB[0m [31m47.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading persim-0.3.8-py3-none-any.whl (48 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m48.6/48.6 kB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading Deprecated-1.2.18-py2.py3-none-any.whl (10.0 kB)
Building wheels for collected packages: hopcroftkarp
  Building w

In [17]:
# ARM_transformer_scaffold.py
# Requires: torch, transformers, numpy, scikit-learn, ripser, scipy, umap-learn (install via pip)
# pip install torch transformers scikit-learn ripser scipy umap-learn

import torch
import torch.nn.functional as F
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
from sklearn.decomposition import PCA
from sklearn.neighbors import kneighbors_graph
from sklearn.manifold import spectral_embedding
from ripser import ripser
from sklearn.metrics import pairwise_distances
from typing import List, Tuple, Dict, Any
import math

# -----------------------
# Configuration / defaults
# -----------------------
MODEL_NAME = "distilgpt2"   # small, efficient; switch to "gpt2" if you prefer
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ARM hyperparams (safe defaults)
N_SEEDS = 200
PROBES_PER_SEED = 16
STEPS_PER_PROBE = 9
EPS = 0.03                 # perturbation magnitude (relative to hidden vector norm)
LAYER_TO_PROBE = 6         # index of transformer block to inject perturbations (0-based)
NEIGHBOR_PCA_SAMPLES = 128 # for local PCA when available
MANIFOLD_MODES = 8         # Number of principal components to use for manifold

# -----------------------
# Utilities: load model
# -----------------------
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, output_hidden_states=True).to(DEVICE)
model.eval()

# Helper: get token ids and attention mask
def encode_prompt(prompt: str):
    toks = tokenizer(prompt, return_tensors="pt")
    return toks["input_ids"].to(DEVICE), toks["attention_mask"].to(DEVICE)

# -----------------------
# Core: run forward from a chosen layer (block-wise)
# -----------------------
# We'll use the model.transformer.* components directly so we can inject altered hidden states.
# For distilgpt2/gpt2 HF models, the transformer body is model.transformer consisting of:
# - wte (token embeddings), wpe (position embeddings), drop, and h = list of blocks, ln_f.
#
# Strategy:
# 1) Build initial hidden states (token embeddings + positions) up to the layer to probe.
# 2) Optionally modify the residual stream at that layer (add delta).
# 3) Run remaining transformer blocks from that layer onward to get final logits/hidden states.

def build_initial_hidden(input_ids: torch.LongTensor):
    # returns hidden states BEFORE block 0 (embedding+pos), shape (batch, seq_len, d_model)
    wte = model.transformer.wte(input_ids)        # token embeddings
    seq_len = input_ids.shape[1]
    position_ids = torch.arange(seq_len, dtype=torch.long, device=DEVICE).unsqueeze(0)
    wpe = model.transformer.wpe(position_ids)
    hidden = wte + wpe  # shape batch x seq x d_model
    hidden = model.transformer.drop(hidden)
    return hidden

def forward_from_layer(hidden: torch.Tensor, start_layer: int, attention_mask: torch.Tensor=None):
    """
    hidden: (batch, seq, d_model) hidden state to feed to block start_layer
    returns: final logits, final hidden, and list of intermediate hidden states (per layer)
    """
    h = hidden
    intermediates = []
    # blocks are modules in model.transformer.h (list-like)
    for i, block in enumerate(model.transformer.h):
        if i < start_layer:
            continue
        h = block(h)[0] if isinstance(block(h), tuple) else block(h)
        intermediates.append(h)
    # final layer norm
    h = model.transformer.ln_f(h)
    # lm head (tie weights with wte)
    # reshape for lm head: (batch*seq, d_model)
    logits = F.linear(h, model.transformer.wte.weight)  # tied weights
    return logits, h, intermediates

# -----------------------
# Seed / probe generation
# -----------------------
def get_seed_hidden(prompt: str, layer_idx: int) -> torch.Tensor:
    """
    Returns hidden state at layer_idx just BEFORE running block layer_idx.
    shape: (seq_len, d_model) - batch dim removed for simplicity
    """
    input_ids, attn_mask = encode_prompt(prompt)
    hidden = build_initial_hidden(input_ids)  # batch x seq x d
    # run blocks up to layer_idx-1 to get hidden state to modify
    h = hidden
    for i, block in enumerate(model.transformer.h):
        if i >= layer_idx:
            break
        h = block(h)[0] if isinstance(block(h), tuple) else block(h)
    # h is batch x seq x d; return squeeze(0)
    return h.squeeze(0).detach().cpu()  # move to CPU numpy-friendly

def sample_probes_for_hidden(hidden_vec: np.ndarray, k: int = PROBES_PER_SEED, eps: float = EPS, manifold_basis: np.ndarray = None):
    """
    hidden_vec: (seq_len, d) array (we'll flatten sequence dimension to treat as a single vector or pool)
    Return: probe_deltas shape (k, d) or (k, seq_len, d)
    Approach: get global direction sampling in hidden-space.
    - If manifold_basis is provided, sample directions within the manifold subspace.
    - Otherwise, start with isotropic Gaussian directions normalized,
      then scale to magnitude eps * ||hidden_vec|| (per token or pooled).
    """
    # pool hidden to a single vector per seed (mean over tokens) for direction construction,
    # but we will expand deltas per token when injecting.
    pooled = hidden_vec.mean(axis=0)   # (d,)
    d = pooled.shape[0]
    rng = np.random.default_rng()

    if manifold_basis is not None and manifold_basis.size > 0:
        # Sample random coefficients for the manifold basis
        coeffs = rng.normal(size=(k, manifold_basis.shape[0])) # Use shape[0] as it's n_components
        # Construct directions as linear combinations of basis vectors
        dirs = coeffs @ manifold_basis  # (k, d) - This should now produce k vectors of dimension d
    else:
        dirs = rng.normal(size=(k, d))

    dirs = dirs / (np.linalg.norm(dirs, axis=1, keepdims=True) + 1e-12)
    hidden_norm = np.linalg.norm(pooled) + 1e-12
    scale = eps * hidden_norm
    dirs = dirs * scale
    return dirs  # (k, d)

def expand_delta_to_sequence(delta_vec: np.ndarray, seq_len: int):
    # replicate delta_vec for each token position (simple approach)
    return np.tile(delta_vec[None, :], (seq_len, 1))  # (seq_len, d)

# -----------------------
# Probe path: generate small path along a direction
# -----------------------
def build_probe_path(hidden_base: np.ndarray, dir_vec: np.ndarray, steps: int = STEPS_PER_PROBE, tau: float = 1.0):
    """
    hidden_base: (seq_len, d)
    dir_vec: (d,) pooled direction; will be expanded across seq positions
    Returns: list of perturbed hidden tensors (steps long)
    """
    seq_len = hidden_base.shape[0]
    # Ensure dir_vec is (d,) before expanding
    if dir_vec.ndim > 1:
        # If dir_vec is (1, d), squeeze it
        dir_vec = dir_vec.squeeze(0)
    dir_seq = expand_delta_to_sequence(dir_vec, seq_len)  # (seq_len, d)
    ts = np.linspace(-tau, tau, steps)
    path = [hidden_base + (t * dir_seq) for t in ts]
    return path, ts

# -----------------------
# Activation / response collection
# -----------------------
def activation_matrix_for_seed(prompt: str, layer_idx: int, k: int = PROBES_PER_SEED, m: int = STEPS_PER_PROBE, eps: float = EPS, manifold_basis: np.ndarray = None):
    """
    For one seed prompt, sample k probes, each with m steps; forward from layer_idx
    Collect features for each sample (e.g., final logits pooled, or final hidden pooled)
    Return: A matrix of shape (k*m, f) for downstream analysis.
    """
    hidden_base = get_seed_hidden(prompt, layer_idx).numpy()  # (seq_len, d)
    seq_len, d = hidden_base.shape
    deltas = sample_probes_for_hidden(hidden_base, k=k, eps=eps, manifold_basis=manifold_basis) # deltas shape (k, d)
    rows = []
    for j in range(k):
        # Pass each individual delta vector (shape (d,)) to build_probe_path
        path, ts = build_probe_path(hidden_base, deltas[j, :], steps=m)
        for hidden_pert in path:
            # run from layer_idx with this perturbed hidden
            # convert to tensor with batch dim
            h_t = torch.tensor(hidden_pert[None, :, :], dtype=torch.float32, device=DEVICE)
            logits, final_h, intermediates = forward_from_layer(h_t, start_layer=layer_idx, attention_mask=None)
            # choose feature vector to represent response:
            # Option A: pooled logits over last token
            # last_token_logits = logits[0, -1, :].detach().cpu().numpy()  # (vocab,)
            # Option B (more compact): mean-pooled final hidden representation
            feat = final_h.squeeze(0).mean(dim=0).detach().cpu().numpy()  # (d,)
            rows.append(feat)
    A = np.stack(rows, axis=0)  # (k*m, f) where f == d in this choice
    return A

# -----------------------
# Resonance signature (SVD-based)
# -----------------------
def resonance_signature(A: np.ndarray, n_modes: int = 8) -> Dict[str, Any]:
    """
    Compute SVD stats and compact resonance signature for activation matrix A (n_samples x f).
    Returns dict with normalized singular values, entropy, participation ratio, top modes.
    """
    # center
    A0 = A - A.mean(axis=0, keepdims=True)
    # SVD (economy)
    U, s, Vt = np.linalg.svd(A0, full_matrices=False)
    s = np.maximum(s, 1e-12)
    s_norm = s / s.sum()
    entropy = -np.sum(s_norm * np.log(s_norm + 1e-12))
    # participation ratio (measure of mode concentration)
    pr = (s**2).sum()**2 / (np.sum(s**4) + 1e-12)
    sig = {
        "singular_values": s[:n_modes],
        "s_norm": s_norm[:n_modes],
        "entropy": float(entropy),
        "participation": float(pr),
        # optionally return top singular vectors (Vt[:n_modes,:]) if needed
    }
    return sig

# -----------------------
# Local topology via persistent homology
# -----------------------
def local_persistence_diagram(A: np.ndarray, maxdim: int = 1) -> Dict[str, Any]:
    """
    Compute persistence diagrams from the sample points A (n_points x f).
    Use pairwise distances -> ripser with distance matrix True.
    Returns ripser output (dgms).
    """
    # compute pairwise distances to reduce memory in ripser call
    D = pairwise_distances(A)
    r = ripser(D, distance_matrix=True, maxdim=maxdim)
    dgms = r["dgms"]  # list of arrays for dimensions [0], [1], ...
    return {"diagrams": dgms}

# -----------------------
# Descriptor assembly for one seed
# -----------------------
def descriptor_for_prompt(prompt: str, layer_idx: int, manifold_basis: np.ndarray = None):
    """
    Run probes, compute A, then compute resonance signature + persistence.
    Return a compact descriptor dict and flattened vector for graph building.
    """
    A = activation_matrix_for_seed(prompt, layer_idx, manifold_basis=manifold_basis)
    R = resonance_signature(A)
    PD = local_persistence_diagram(A)
    # flatten descriptor to a vector: use top-n singular values + entropy + participation + persistence stats
    top_sv = R["s_norm"][:6]
    entropy = R["entropy"]
    part = R["participation"]
    # summary persistence features: count of significant 1D features (persistence > threshold)
    d1 = PD["diagrams"][1] if len(PD["diagrams"]) > 1 else np.zeros((0,2))
    pers_threshold = 0.05 * np.max(pairwise_distances(A)) if A.shape[0] > 1 else 0  # heuristic, handle single point case
    n_1d_significant = np.sum((d1[:,1] - d1[:,0]) > pers_threshold) if d1.size else 0
    vec = np.concatenate([top_sv, [entropy, part, n_1d_significant]])
    return {"A": A, "R": R, "PD": PD, "vec": vec, "prompt": prompt}


# -----------------------
# Manifold representation
# -----------------------
def compute_manifold_basis(example_prompts: List[str], layer_idx: int, n_components: int = MANIFOLD_MODES):
    """
    Compute hidden states for example prompts and find the top PCA components.
    Returns the principal components (basis vectors).
    """
    hidden_states = []
    for prompt in example_prompts:
        hidden = get_seed_hidden(prompt, layer_idx).numpy()  # (seq_len, d)
        # Pool across sequence length for simplicity
        pooled_hidden = hidden.mean(axis=0)  # (d,)
        hidden_states.append(pooled_hidden)
    H = np.stack(hidden_states, axis=0)  # (n_examples, d)
    # Ensure n_components does not exceed the number of samples
    n_components_actual = min(n_components, H.shape[0])
    if n_components_actual == 0:
        print("Warning: No manifold prompts provided, cannot compute manifold basis.")
        return np.array([]) # Return empty array if no prompts
    pca = PCA(n_components=n_components_actual)
    pca.fit(H)
    return pca.components_  # (n_components_actual, d)


# -----------------------
# Build global atlas from many seeds
# -----------------------
def build_atlas(prompts: List[str], layer_idx: int, n_neighbors: int = 8, manifold_basis: np.ndarray = None):
    """
    Build a global atlas of prompts by computing descriptors and their spectral embedding.
    Can accept a larger list of prompts.
    """
    descriptors = []
    vecs = []
    for i, p in enumerate(prompts):
        # Add a progress indicator for potentially large numbers of prompts
        print(f"Processing prompt {i+1}/{len(prompts)}...")
        d = descriptor_for_prompt(p, layer_idx, manifold_basis=manifold_basis)
        descriptors.append(d)
        vecs.append(d["vec"])
    X = np.stack(vecs, axis=0)  # n_seeds x dim
    # kNN graph adjacency (distance)
    W = kneighbors_graph(X, n_neighbors=n_neighbors, mode="distance", include_self=False).toarray()
    # spectral embedding for visualization
    emb = spectral_embedding(W + W.T, n_components=3)
    return {"descriptors": descriptors, "X": X, "W": W, "emb": emb}

# -----------------------
# Simple iterative (greedy) proximal steering operator
# -----------------------
def steer_toward_manifold_resonance(seed_prompt: str, target_signature: np.ndarray, layer_idx: int, manifold_basis: np.ndarray, iters: int = 6, candidates: int = 12, max_new_tokens: int = 100):
    """
    Iterative greedy search to steer generation toward a target manifold resonance.
    Generates `max_new_tokens` iteratively.
    """
    input_ids = tokenizer.encode(seed_prompt, return_tensors="pt").to(DEVICE)
    generated_tokens = input_ids.tolist()[0]

    for _ in range(max_new_tokens):
        current_text = tokenizer.decode(generated_tokens)
        # get hidden state at the layer before probing for the *current* sequence
        # need to re-run the initial layers for the current sequence to get the updated hidden state
        input_ids_current = tokenizer.encode(current_text, return_tensors="pt").to(DEVICE)
        hidden_base = build_initial_hidden(input_ids_current) # batch x seq x d
        # run blocks up to layer_idx-1
        h = hidden_base
        for i, block in enumerate(model.transformer.h):
            if i >= layer_idx:
                break
            h = block(h)[0] if isinstance(block(h), tuple) else block(h)
        current_hidden = h.squeeze(0).detach().cpu().numpy() # (seq_len, d)


        seq_len, d = current_hidden.shape

        best_score = float("inf")
        best_delta = None

        # Propose candidate deltas restricted to manifold directions
        rng = np.random.default_rng()
        coeffs = rng.normal(size=(candidates, manifold_basis.shape[0]))
        cand_dirs = coeffs @ manifold_basis  # (candidates, d)

        # Normalize and scale candidates
        cand_dirs = cand_dirs / (np.linalg.norm(cand_dirs, axis=1, keepdims=True) + 1e-12)
        scales = np.linspace(-EPS, EPS, 5)

        for cd in cand_dirs:
            for s in scales:
                delta = cd * s * (np.linalg.norm(current_hidden.mean(axis=0)) + 1e-12)
                perturbed_hidden = current_hidden + expand_delta_to_sequence(delta, seq_len)

                # forward and compute final pooled hidden (cheap shortcut)
                h_t = torch.tensor(perturbed_hidden[None, :, :], dtype=torch.float32, device=DEVICE)
                _, final_h, _ = forward_from_layer(h_t, start_layer=layer_idx)
                feat = final_h.squeeze(0).mean(dim=0).detach().cpu().numpy()

                # compute simple proxy signature: projection on top eigenvector (cheap)
                # Here we create a tiny matrix with just this feat to plug into resonance_signature (works but trivial)
                sig = resonance_signature(np.stack([feat], axis=0))
                # distance: compare sig["s_norm"] to target_signature (assumed same length)
                cand_vec = sig["s_norm"][:len(target_signature)]
                score = np.linalg.norm(cand_vec - target_signature)

                if score < best_score:
                    best_score = score
                    best_delta = delta # Store the delta that gave the best score

        # Apply the best delta to the hidden state for the next token prediction
        if best_delta is not None:
             current_hidden_perturbed = current_hidden + expand_delta_to_sequence(best_delta, seq_len)
        else:
             # If no good delta found, use the original hidden state (or handle as needed)
             current_hidden_perturbed = current_hidden


        # forward with the potentially perturbed hidden state to get logits for next token
        h_t_perturbed = torch.tensor(current_hidden_perturbed[None, :, :], dtype=torch.float32, device=DEVICE)
        logits, new_hidden, _ = forward_from_layer(h_t_perturbed, start_layer=layer_idx) # Get the hidden state for the next token

        # predict next token (greedy)
        next_token_id = torch.argmax(logits[0, -1, :]).item()

        # append to generated sequence
        generated_tokens.append(next_token_id)

        # Optional: add a condition to stop generation (e.g., a stop token)
        if next_token_id == tokenizer.eos_token_id: # Example: stop on EOS token
            break


    return tokenizer.decode(generated_tokens), best_score # Return the full generated sequence and the final score


# -----------------------
# Example usage
# -----------------------
if __name__ == "__main__":
    # quick test prompts for atlas
    atlas_prompts = [
        "The capital of France is",
        "The capital of Germany is",
        "I love reading about physics because",
        "The chef seasoned the soup with",
        "Quantum entanglement is best described as"
    ]

    # Example prompts for defining a manifold (e.g., "Paris-style factual completions")
    manifold_prompts = [
        "The capital of Italy is",
        "The largest city in Spain is",
        "Mount Everest is located in",
        "The currency of Japan is"
    ]

    # Compute manifold basis
    manifold_basis = compute_manifold_basis(manifold_prompts, layer_idx=LAYER_TO_PROBE, n_components=MANIFOLD_MODES)
    print(f"Computed manifold basis with shape: {manifold_basis.shape}")

    # compute atlas (descriptors may be somewhat slow; reduce N_SEEDS for testing)
    # Pass the manifold basis to build_atlas to potentially see how descriptors within the manifold space cluster
    atlas = build_atlas(atlas_prompts, layer_idx=LAYER_TO_PROBE, n_neighbors=3, manifold_basis=manifold_basis)
    print("Spectral embedding shape:", atlas["emb"].shape)

    # pick a seed and compute its descriptor (using the manifold basis)
    d = descriptor_for_prompt("The capital of France is", layer_idx=LAYER_TO_PROBE, manifold_basis=manifold_basis)
    print("Descriptor vector (manifold-aware):", d["vec"])

    # Example target signature for steering (pick a seed's signature within the manifold)
    # For demonstration, let's use the signature of "The capital of Italy is" as the target
    target_descriptor = descriptor_for_prompt("The capital of Italy is", layer_idx=LAYER_TO_PROBE, manifold_basis=manifold_basis)
    target_sig = target_descriptor["R"]["s_norm"][:6]
    print("Target signature (from 'The capital of Italy is'):", target_sig)


    out_token, score = steer_toward_manifold_resonance("The capital of Ger", target_sig, layer_idx=LAYER_TO_PROBE, manifold_basis=manifold_basis)
    print("Steered next-token:", out_token, "score:", score)

Computed manifold basis with shape: (4, 768)
Processing prompt 1/5...
Processing prompt 2/5...
Processing prompt 3/5...
Processing prompt 4/5...
Processing prompt 5/5...
Spectral embedding shape: (5, 3)
Descriptor vector (manifold-aware): [4.17819798e-01 2.91416734e-01 2.00587690e-01 8.86625499e-02
 1.48228987e-03 1.80561765e-05 1.27104044e+00 2.40331578e+00
 0.00000000e+00]
Target signature (from 'The capital of Italy is'): [5.2048963e-01 2.6041636e-01 1.4006966e-01 7.7772535e-02 1.2239203e-03
 1.3758596e-05]
Steered next-token: The capital of Gerhard-Belsen, Germany, on Friday said it would not comment on the decision.

















































































 score: 2.0891204


In [10]:
!pip install -q gradio

In [18]:
import gradio as gr
import re # Import regular expression module for splitting sentences

def split_sentences(text):
    # Split text into sentences using common punctuation marks
    sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|\!)\s', text)
    return [s.strip() for s in sentences if s.strip()]

def run_steering_gradio(prompt, manifold_file):
    # The 'prompt' input is not used in this function, but kept for the function signature compatibility with Gradio
    if manifold_file is None:
        return "Please upload a text file with manifold prompts."

    current_manifold_prompts = []
    try:
        # Read manifold prompts from the uploaded file using its path
        with open(manifold_file.name, 'r', encoding='utf-8') as f:
             file_content = f.read()
        current_manifold_prompts = split_sentences(file_content)
    except Exception as e:
        return f"Error reading or processing manifold file: {e}"

    if not current_manifold_prompts:
        return "No manifold prompts found in the uploaded file."

    output_text = f"Using manifold prompts ({len(current_manifold_prompts)}): {current_manifold_prompts[:5]}...\n" # Show only first 5 for brevity

    try:
        # Compute manifold basis based on current input
        current_manifold_basis = compute_manifold_basis(current_manifold_prompts, layer_idx=LAYER_TO_PROBE, n_components=MANIFOLD_MODES)
        output_text += f"Computed manifold basis with shape: {current_manifold_basis.shape}\n"

        # For demonstration, use the signature of the first manifold prompt as the target
        if current_manifold_prompts:
             target_descriptor = descriptor_for_prompt(current_manifold_prompts[0], layer_idx=LAYER_TO_PROBE, manifold_basis=current_manifold_basis)
             target_sig = target_descriptor["R"]["s_norm"][:6]
             output_text += f"Target signature (from first manifold prompt): {target_sig}\n"
        else:
             # This case should theoretically not be reached due to the check above, but as a safeguard
             output_text += "No manifold prompts available to determine target signature.\n"
             return output_text

        # Store computed basis and target signature for testing phase
        global current_manifold_basis_global, target_sig_global
        current_manifold_basis_global = current_manifold_basis
        target_sig_global = target_sig

        output_text += "\nManifold constructed successfully. You can now test steering below.\n"


    except Exception as e:
        output_text += f"An error occurred during manifold construction: {e}\n"
        import traceback
        output_text += traceback.format_exc()

    return output_text

def test_steering_gradio(test_prompt, max_tokens):
    if not test_prompt:
        return "Please enter a prompt to test steering."

    global current_manifold_basis_global, target_sig_global
    if current_manifold_basis_global is None or target_sig_global is None:
        return "Please construct the manifold first by clicking the 'GO!' button in the 'Manifold Construction' tab."

    output_text = f"Testing steering for prompt: '{test_prompt}' with max tokens: {max_tokens}\n"

    try:
        # steer_toward_manifold_resonance now returns the full generated sequence
        generated_sequence, score = steer_toward_manifold_resonance(test_prompt, target_sig_global, layer_idx=LAYER_TO_PROBE, manifold_basis=current_manifold_basis_global, max_new_tokens=max_tokens)
        output_text += f"Generated Sequence: {generated_sequence}\n"
        output_text += f"Final Score: {score}\n" # Display the final score if needed

    except Exception as e:
        output_text += f"An error occurred during steering: {e}\n"
        import traceback
        output_text += traceback.format_exc()

    return output_text


# Initialize global variables
current_manifold_basis_global = None
target_sig_global = None


# Create the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("## ARM Manifold Steering")

    with gr.Tab("Manifold Construction"):
        gr.Markdown("Upload a text file containing prompts to define the manifold. Each sentence will be treated as a separate prompt.")
        manifold_file_input = gr.File(label="Upload Manifold Prompts File")
        go_button = gr.Button("GO! Construct Manifold")
        construction_output = gr.Textbox(label="Construction Output", lines=10, interactive=False)

        go_button.click(
            run_steering_gradio,
            inputs=[gr.Textbox(visible=False), manifold_file_input], # Dummy prompt input for function signature
            outputs=construction_output
        )

    with gr.Tab("Test Steering"):
        gr.Markdown("Enter a prompt prefix to see the next token steered towards the constructed manifold.")
        test_prompt_input = gr.Textbox(label="Enter Prompt Prefix", lines=3)
        max_tokens_input = gr.Number(label="Max Tokens to Generate", value=400, precision=0)
        test_button = gr.Button("Test Steering")
        test_output = gr.Textbox(label="Test Output", lines=10, interactive=False)

        test_button.click(
            test_steering_gradio,
            inputs=[test_prompt_input, max_tokens_input],
            outputs=test_output
        )


demo.launch()

It looks like you are running Gradio on a hosted Jupyter notebook, which requires `share=True`. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://b95d738b0031bcedd6.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)




In [None]:
import ipywidgets as widgets
from IPython.display import display

prompt_input = widgets.Textarea(
    value='',
    placeholder='Enter your prompt here',
    description='Prompt:',
    disabled=False,
    layout={'width': '500px', 'height': '100px'}
)

manifold_prompts_input = widgets.Textarea(
    value='\n'.join(manifold_prompts), # Use existing manifold_prompts as default
    placeholder='Enter manifold prompts (one per line)',
    description='Manifold Prompts:',
    disabled=False,
    layout={'width': '500px', 'height': '150px'}
)


output_area = widgets.Output()

run_button = widgets.Button(description="Run Steering")

def run_steering(b):
    with output_area:
        output_area.clear_output()
        prompt = prompt_input.value
        current_manifold_prompts = manifold_prompts_input.value.splitlines()

        if not prompt:
            print("Please enter a prompt.")
            return
        if not current_manifold_prompts:
            print("Please enter manifold prompts.")
            return

        print(f"Running steering for prompt: '{prompt}'")
        print(f"Using manifold prompts: {current_manifold_prompts}")

        try:
            # Recompute manifold basis based on current input
            current_manifold_basis = compute_manifold_basis(current_manifold_prompts, layer_idx=LAYER_TO_PROBE, n_components=MANIFOLD_MODES)
            print(f"Computed manifold basis with shape: {current_manifold_basis.shape}")

            # For demonstration, use the signature of the first manifold prompt as the target
            if current_manifold_prompts:
                 target_descriptor = descriptor_for_prompt(current_manifold_prompts[0], layer_idx=LAYER_TO_PROBE, manifold_basis=current_manifold_basis)
                 target_sig = target_descriptor["R"]["s_norm"][:6]
                 print("Target signature (from first manifold prompt):", target_sig)
            else:
                 print("No manifold prompts provided to determine target signature.")
                 return


            out_token, score = steer_toward_manifold_resonance(prompt, target_sig, layer_idx=LAYER_TO_PROBE, manifold_basis=current_manifold_basis)
            print("Steered next-token:", out_token, "score:", score)

        except Exception as e:
            print(f"An error occurred: {e}")


run_button.on_click(run_steering)

display(prompt_input, manifold_prompts_input, run_button, output_area)