In [1]:
import mamba_ssm
from nnsight import LanguageModel, util
from nnsight.tracing.Proxy import Proxy
from nnsight.models.Mamba import MambaInterp
from transformers import AutoTokenizer
import numpy as np
import torch as t
import torch.nn.functional as F
import einops
from tqdm import tqdm
from functools import partial

from rich import print as rprint
from rich.table import Table

from typing import List, Callable, Union

device = t.device("cuda:7" if t.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm
  torch.utils._pytree._register_pytree_node(


In [2]:
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b", padding_side="left")
tokenizer.pad_token_id = tokenizer.eos_token_id
mamba_model = MambaInterp("state-spaces/mamba-2.8b", device=device, tokenizer=tokenizer)
sampling_kwargs = {
    "top_p": 0.3,
    "top_k": 0,
    "repetition_penalty": 1.1,
} # in mamba_ssm/utils/generation.py

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [3]:
def print_steering_table(unsteered_completions, steered_completions, layer, coeff):
    table = Table("Unsteered", "Steered", title=f"Completions for steering at layer {layer}, coefficient {coeff}", show_lines=True)
    for usc, sc in zip(unsteered_completions, steered_completions):
        table.add_row(usc, sc)
    rprint(table)

In [66]:
def apply_steering_vector(
    model: LanguageModel,
    layer: int,
    new_token_length: int,
    steering_prompts: List[str],
    steering_coefficients: List[float],
    prompts: List[str],
    apply_only_at_end=False,
):
    with model.invoke(prompts) as invoker:
        seq_lens = [len(invoker.input["input_ids"][i]) for i in range(len(prompts))]
        prompt_seq_len = max(seq_lens)
    n_steers = len(steering_prompts)
    n_completions = len(prompts)

    
    with model.generate(max_length=prompt_seq_len+new_token_length, **sampling_kwargs) as generator:
        print("Generating vectors")
        with generator.invoke(steering_prompts) as invoker:
            seq_lens = [len(invoker.input["input_ids"][i]) for i in range(n_steers)]
            max_len = max(seq_lens)

            extracted_vectors = []
            extracted_state_vectors=[[] for _ in range(max_len)]
            #hidden_state = model.backbone.layers[layer].mixer.ssm.hx
            hidden_state = model.backbone.layers[layer].mixer
            ssm_state = hidden_state.ssm.hx
            for batch in range(n_steers):
                vec = hidden_state.output[batch,  ...]
                extracted_vectors.append(vec)
            for position in tqdm(range(max_len)):
                for batch in range(n_steers):
                    prompt_len = seq_lens[batch]
                    if position < prompt_len:
                        if apply_only_at_end and position != prompt_len - 1:
                            extracted_state_vectors[position].append(None)
                        else:
                            vec = ssm_state.output[batch, ...]
                            extracted_state_vectors[position].append(vec)
                    ssm_state = ssm_state.next()

        with generator.invoke(prompts) as invoker:
            pass

        print("Steering step")
        with generator.invoke(prompts) as invoker:
            hidden_state = model.backbone.layers[layer].mixer
            ssm_state = hidden_state.ssm.hx
            for steer_batch in range(n_steers):
                coeff = steering_coefficients[steer_batch]  
                for prompt_batch in range(n_completions):
                    vec = extracted_vectors[steer_batch]
                    for i in range(len(vec)):
                        pos = len(invoker.input["input_ids"][prompt_batch]) - len(vec) + i
                        hidden_state.output[prompt_batch,pos, ...] += coeff * vec[i]
                    #hidden_state.output[prompt_batch,:len(vec), ...] += coeff * vec
            for j in range(len(invoker.input["input_ids"][prompt_batch]) - len(vec)):
                ssm_state = ssm_state.next()
            for position in tqdm(range(max_len)):

                for steer_batch in range(n_steers):
                    coeff = steering_coefficients[steer_batch]
                    for prompt_batch in range(n_completions):
                        vec = extracted_state_vectors[position][steer_batch]
                        if vec is None:
                            continue
                        ssm_state.output[prompt_batch, ...] += coeff * vec

                ssm_state = ssm_state.next()
    unsteered_completions = generator.output[n_steers:-n_completions]
    steered_completions = generator.output[-n_completions:]

    return unsteered_completions, steered_completions

In [71]:


num_prompts = 3
new_tokens = 30 # Number of new tokens to generate

for layer, coeff in [
    (16, 10),
    (24, 10),
    (40, 10),
    (60, 10)
  ]:
    steering_info = [
        ('I think I love you', +coeff),
        ('I think I hate you', -coeff)
    ]
    steering_prompts, steering_coefficients = zip(*steering_info)
    unsteered_completions, steered_completions = apply_steering_vector(
        mamba_model,
        layer,
        new_tokens,
        steering_prompts,
        steering_coefficients,
        ["When I think of you I think of" for _ in range(num_prompts)],
        apply_only_at_end=False
    )

    unsteered_sents = tokenizer.batch_decode(unsteered_completions)
    steered_sents = tokenizer.batch_decode(steered_completions)

    print_steering_table(unsteered_sents, steered_sents, layer, coeff)

Generating vectors


100%|██████████| 5/5 [00:00<00:00, 18.20it/s]


Steering step


100%|██████████| 5/5 [00:01<00:00,  2.88it/s]


Generating vectors


100%|██████████| 5/5 [00:00<00:00, 18.15it/s]


Steering step


100%|██████████| 5/5 [00:01<00:00,  2.88it/s]


Generating vectors


100%|██████████| 5/5 [00:00<00:00, 18.41it/s]


Steering step


100%|██████████| 5/5 [00:01<00:00,  2.91it/s]


Generating vectors


100%|██████████| 5/5 [00:00<00:00, 18.47it/s]


Steering step


100%|██████████| 5/5 [00:01<00:00,  2.91it/s]


In [10]:
mamba_model

MambaLMHeadModel(
  (backbone): MixerModel(
    (embedding): Embedding(50280, 2560)
    (layers): ModuleList(
      (0-63): 64 x Block(
        (mixer): MambaModuleInterp(
          (in_proj): Linear(in_features=2560, out_features=10240, bias=False)
          (conv1d): Conv1d(5120, 5120, kernel_size=(4,), stride=(1,), padding=(3,), groups=5120)
          (act): SiLU()
          (x_proj): Linear(in_features=5120, out_features=192, bias=False)
          (dt_proj): Linear(in_features=160, out_features=5120, bias=True)
          (out_proj): Linear(in_features=5120, out_features=2560, bias=False)
          (dt): WrapperModule()
          (B): WrapperModule()
          (C): WrapperModule()
          (ssm): SSM(
            (discA): DiscA()
            (discB): DiscB()
            (hx): Hx(
              (bx): Bx()
              (ah): Ah()
            )
            (yh): Yh()
          )
          (delta_softplus): Softplus(beta=1, threshold=20)
        )
        (norm): RMSNorm()
      )
   