In [None]:
%pip install huggingface_hub
%pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
#The nightly which allows torch to use RTX 5090s.
%pip install transformers

from huggingface_hub import snapshot_download, login
import os, torch
from transformers import AutoTokenizer, AutoModelForCausalLM

token="..."
#Your Hugging Face token.
login(token=token)

local_dir = snapshot_download("google/gemma-2-9b-it")
print("Files are in:", local_dir)

In [None]:
model_path = "/root/.cache/huggingface/hub/models--google--gemma-2-9b-it/snapshots/11c9b309abf73637e4b6f9a3fa1e92e615547819"
#Or whatever path got printed in the previous cell.

tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16)

rmsnorm_location = "model.model.layers[21].post_feedforward_layernorm"

In [19]:
def extract_latents(texts: list[str], model, rmsnorm_location: str, tokenizer, device: torch.device | str = "cpu") -> list[torch.Tensor]:
    model.to(device).eval()
    captured_latents: list[torch.Tensor] = []

    hook_location = eval(rmsnorm_location)
    #This location should, of course, depend on the model architecture, but in general we want it right after a layernorm so that the latents all lie on a well-defined hyperellipsoid.
    def _hook(module, inp, output):
        captured_latents.append(output.clone().detach().cpu())
    handle = hook_location.register_forward_hook(_hook)

    enc = tokenizer(texts, return_tensors="pt", padding=True).to(device)

    with torch.no_grad():
        _ = model(**enc)

    handle.remove()

    hidden_states = captured_latents[0]
    #Since we run all given texts together as a single batch, captured_latents will only have one element.
    mask = enc.attention_mask.cpu().bool()
    flat = hidden_states[mask]
    #Getting rid of the padding tokens.

    return [flat[i] for i in range(flat.size(0))]

In [20]:
def extract_layernorm_params(model, rmsnorm_location: str): 
    rms = eval(rmsnorm_location)

    weight = rms.weight.detach().cpu()
    
    d = weight.numel()
    radii = weight * (d ** 0.5)
    #In layernorm, the normalized vector prior to multiplication by gamma weights has norm sqrt(d), and so the final latent ellipsoid has axis length 2*gamma*sqrt(d).

    return radii

In [21]:
def modify_latent(x, normal, b, radii, alpha):
    #We take in a latent x lying on an ellipsoid immediately post-layernorm, rescale it to the unit sphere, add some factor of the normal direction, then scale it back to the ellipsoid.
    #If we know normal is in the semantic direction of some concept we're trying to steer the model towards, this should "insert" that concept into the latent.
    x_sph = x / radii
    #x_sph should now lie on the unit sphere.
    s = torch.relu((x_sph * normal).sum(dim=1) + b)
    delta = (alpha - s).unsqueeze(1)
    x_mod = x_sph + delta * normal.unsqueeze(0)

    return x_mod * radii

In [22]:
from torch import nn
import torch.nn.functional as F

class PlaneLearner(nn.Module):
    def __init__(self, dim: int):
        super().__init__()
        self.n_raw = nn.Parameter(torch.randn(dim))
        #The vector normal to the plane.
        self.theta = nn.Parameter(torch.tensor(0.0))
        #If the plane is n.x+b=0, then b=tanh(theta).

    def forward(self):
        n = F.normalize(self.n_raw.float(), dim=0)
        b = torch.tanh(self.theta.float())
        return n, b

In [23]:
def cap_loss(unit_vecs: torch.Tensor, n: torch.Tensor, b: torch.Tensor):
    #This loss is more-or-less the (negative) proportion of unit vectors which lie on the desired side of the plane n.x+b=0; that is, in the desired spherical cap.
    n = n.float()
    b = b.float()
    unit_vecs = unit_vecs.float()
    logits  = unit_vecs @ n + b
    scores  = torch.sigmoid(logits)
    #The sigmoid is a smooth approximation of an indicator function.
    return -scores.mean()

In [24]:
import random
from typing import List

def mle_loss_on_example(model,
    tokenizer, rmsnorm_location: str,
    example_texts: List[str],
    n: torch.Tensor, b: torch.Tensor,
    radii: torch.Tensor,
    alpha: float
):
    device, dtype = n.device, n.dtype

    text = random.choice(example_texts)
    #We choose a random string among the given texts to perform autoregression on.
    ids = tokenizer(text, return_tensors="pt").input_ids[0]

    cut = random.randint(1, ids.numel() - 2)
    #Then we choose a random token in that string to try to predict.
    prefix_ids = ids[:cut].unsqueeze(0).to(device)
    target_id  = ids[cut].to(device)

    rms = eval(rmsnorm_location)
    radii_ = radii.to(device, dtype=dtype)
    n_ = n.to(device, dtype=dtype)
    b_ = b.to(device, dtype=dtype)

    def hook(_m, _inp, out):
        last  = out[:, -1, :]
        out[:, -1, :] = modify_latent(last, n_, b_, radii_, alpha)
        return out

    h = rms.register_forward_hook(hook)
    logits = model(prefix_ids).logits
    h.remove()

    log_probs = F.log_softmax(logits[0, -1], dim=-1)
    loss = -log_probs[target_id]
    return loss

In [25]:
from random import sample

def train_plane(
    model, tokenizer, rmsnorm_location: str,              
    example_texts,
    radii,
    alpha,
    use_averaging: bool,
    num_steps=500, batch_size=8, lr=5e-4,
    lambda_cap=0.1, lambda_theta=0.1,
    lambda_mle=1e4,
    device="cuda"
):
    model = model.to(device).eval()
    dtype = torch.float32
    d = model.get_input_embeddings().embedding_dim

    if use_averaging:
        latents = extract_latents(example_texts, model, rmsnorm_location, tokenizer, device)
        unit_vecs_list = [latent.to(device) / radii.to(device) for latent in latents]
        unit_vecs = torch.stack(unit_vecs_list)
    
    plane = PlaneLearner(d).to(device)
    optim = torch.optim.Adam(plane.parameters(), lr=lr)

    for step in range(1, num_steps + 1):
        n, b = plane()
        theta = torch.acos((-b).clamp(-0.999, 0.999))

        loss_theta = lambda_theta * theta if use_averaging else 0
        #Encourages a smaller spherical cap.
        loss_cap = cap_loss(unit_vecs, n, b) if use_averaging else 0
        #Encourages the spherical cap to contain more of the example latents.
        loss_mle = mle_loss_on_example(model, tokenizer, rmsnorm_location, example_texts, n, b, radii, alpha)
        #Encourages the spherical cap to move in a semantically meaningful direction.

        loss = lambda_cap * loss_cap + loss_theta + lambda_mle * loss_mle

        optim.zero_grad()
        loss.backward()
        optim.step()

        with torch.no_grad():
            plane.n_raw[:] = F.normalize(plane.n_raw, dim=0)

        if step % 10 == 0 or step == 1:
            if use_averaging:
                print(
                    f"{step}  L_cap={loss_cap.item():.4f}  "
                    f"b={b.item():.6f}  MLE={loss_mle.item():.4f}"
                )
            else:
                print(f"{step}  MLE={loss_mle.item():.4f}")

    n_final, b_final = plane()
    return n_final.detach(), b_final.detach()

In [26]:
def generate_modified_model(
    model, tokenizer,
    rmsnorm_location: str,
    input_text: str,
    alpha: float,
    radii,
    normal, b,
    max_new_tokens: int = 200,
    device: str = "cpu"
):
    model = model.to(device).eval()
    norm_mod = eval(rmsnorm_location)
    radii = radii.to(device)
    normal = normal.to(device)
    b = b.to(device)

    def hook_fn(module, inp, output):
        out  = output.clone()
        last = out[:, -1, :]
        with torch.no_grad():
            out[:, -1, :] = modify_latent(last, normal, b, radii, alpha)
        return out

    handle = norm_mod.register_forward_hook(hook_fn)

    input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
    gen_ids = model.generate(
        input_ids,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        temperature=1.0,
        use_cache=True
    )

    handle.remove()
    return tokenizer.decode(gen_ids[0], skip_special_tokens=True)

In [None]:
"""
example_texts = [
    "Thou art kind.",
    "Wherefore cam’st thou hither?",
    "He doth protest too much.",
    "Get thee to a nunnery, and quickly, for the day grows short.",
    "Fain would I go, yet duty binds me here.",
    "’Tis but thy name that is my enemy; thou art thyself, though not a Montague.",
    "Methinks it is the east, and Juliet is the sun.",
    "Come hither, good sir, and lend thine ear unto my counsel.",
    "Thy will be done.",
    "Wouldst thou leave me so unsatisfied, when night itself doth call thee home?",
    "Hast thou no pity left in thy breast, no drop of mercy?",
    "Speak’st thou in jest, or is thy meaning earnest?",
    "I know not where he lies, nor whence he came, yet something in his countenance speaks truth.",
    "Let us hence; this place grows cold with treachery.",
    "So foul and fair a day I have not seen.",
    "Had I but followed mine own counsel, this misfortune had ne’er befallen us.",
    "I prithee, stay a while, for the moon is yet high and I have more to tell.",
    "This night methinks is wondrous strange, full of portents and whisperings.",
    "Be it known unto all that here I stand, resolute in word and deed."
]
"""
#example_texts=["Oh, yes, there's no place I'd rather be than here with you, answering your brilliant questions that never seem to end."]
example_texts = ["I love weddings!", "Let's get married...", "Who will be your best man?", "You may kiss the bride."]
radii         = extract_layernorm_params(model, rmsnorm_location)
latent_vecs   = extract_latents(example_texts, model, rmsnorm_location, tokenizer, "cuda")

n, b = train_plane(
    model, tokenizer,
    rmsnorm_location = rmsnorm_location,
    example_texts    = example_texts,
    radii            = radii,
    alpha            = 1.0,
    use_averaging    = False,
    num_steps        = 500,
    batch_size       = 8,
    lr               = 5e-4,
    lambda_cap       = 1e-1,
    lambda_theta     = 1e-1,
    lambda_mle       = 1e4,
    device           = "cuda",
)

In [None]:
output = generate_modified_model(
        model, tokenizer,
        rmsnorm_location = rmsnorm_location,
        input_text       = "Let me tell you a story:",
        alpha            = 3.0,
        radii            = radii,
        normal           = n,
        b                = b,
        max_new_tokens   = 200,
        device           = "cuda"
)

print(output)