In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys

sys.path.append("../")

import torch
import transformers
import baukit
from tqdm.auto import tqdm
import json
import os
from src import functional
import src.tokens as tokenization_utils

torch.__version__, transformers.__version__, torch.version.cuda

('2.1.2+cu121', '4.36.2', '12.1')

In [1]:
from src.models import ModelandTokenizer

MODEL_PATH = "state-spaces/mamba-2.8b-slimpj" # state-spaces/mamba-2.8b


mt = ModelandTokenizer(
    model_path=MODEL_PATH, 
    torch_dtype=torch.float32
)

ModuleNotFoundError: No module named 'src'

In [9]:
subject = "The Space Needle"
prompt_template = "{} is located in the city of"

prompt = tokenization_utils.maybe_prefix_eos(mt, prompt_template.format(subject))

functional.predict_next_token(
    mt = mt, 
    prompt = prompt,
)

[[PredictedToken(token=' Seattle', prob=0.9798887372016907),
  PredictedToken(token=' Se', prob=0.0017078507225960493),
  PredictedToken(token=' the', prob=0.0015009533381089568),
  PredictedToken(token=' Sea', prob=0.0008902765694074333),
  PredictedToken(token=' se', prob=0.0006061139283701777)]]

In [10]:
from src.hooking.mamba import selective_scan_with_mask

tokenized = mt.tokenizer(prompt, return_tensors="pt", padding=True, return_offsets_mapping=True).to(mt.device)
offsets = tokenized.pop("offset_mapping")

subj_first, subj_last = functional.find_token_range(
    string = prompt,
    substring=subject,
    offset_mapping=offsets[0],
)

subj_last -= 1
last_idx = tokenized.input_ids.shape[1] - 1

In [26]:
import types

mamba_block_format = "layers.{}.mixer"


def mask_layers(model, layer_indices, mask_config):
    masked_scan = lambda self, u, delta, A, B, C, D: selective_scan_with_mask(
        self=self, u=u, delta=delta, A=A, B=B, C=C, D=D, mask=mask_config
    )
    for layer_idx in layer_indices:
        block = baukit.get_module(model, mamba_block_format.format(layer_idx))
        block.selective_scan = types.MethodType(masked_scan, block)

# -----------------------------------------
mt.reset_forward()
# -----------------------------------------

mask_everything = {
    idx: torch.arange(0, idx + 1).tolist() for idx in range(last_idx + 1)
}

mask_layers(
    model=mt.model,
    # layer_indices = [20],
    layer_indices=torch.arange(0, mt.n_layer).tolist(),
    # mask_config = {last_idx: []}  # no mask
    # mask_config = {last_idx: [subj_last]}   # last token cant see subject_last token
    # mask_config = {last_idx: torch.arange(subj_first, subj_last + 1).tolist()} # last token cant see the entire subject
    # mask_config={
    #     target_idx: [subj_last]
    #     for target_idx in range(subj_last + 1, last_idx + 1)
    # },  # the entire query cant see subject_last
    mask_config={
        target_idx: torch.arange(subj_first, subj_last + 1).tolist()
        for target_idx in range(subj_last + 1, last_idx + 1)
    },  # the entire query cant see the entire subject
    # mask_config = mask_everything
)

In [28]:
# mt.reset_forward()

# ! just ablating the diagonal ssm isn't enough.
# TODO: figure out how to ablate the shift-SSM or the conv as well
# also, the "attention" visualization is wrong, because it doesn't take the Conv into account
# technically, ssm doesn't pay attention on a particular token, it pays attention to the entire receptive field

functional.predict_next_token(
    mt = mt, 
    prompt = prompt,
)

deltaA.shape=torch.Size([1, 11, 5120, 16]) | deltaB_u.shape=torch.Size([1, 11, 5120, 16])
y.norm()=tensor(3.4754, device='cuda:0')
||y_0|| = 2.2530975341796875 | IS IT ZERO: False | max = 0.6196796894073486 | min = -0.9921875596046448
||y_1|| = 0.28607794642448425 | IS IT ZERO: False | max = 0.09646371006965637 | min = -0.08225928992033005
||y_2|| = 0.8664557933807373 | IS IT ZERO: False | max = 0.27339398860931396 | min = -0.2156476527452469
||y_3|| = 1.9508823156356812 | IS IT ZERO: False | max = 0.30057957768440247 | min = -0.4940722584724426
||y_4|| = 0.9365155696868896 | IS IT ZERO: False | max = 0.10947712510824203 | min = -0.2693585455417633
||y_5|| = 0.2292548418045044 | IS IT ZERO: False | max = 0.03318154439330101 | min = -0.14536695182323456
||y_6|| = 0.6427634954452515 | IS IT ZERO: False | max = 0.1837357133626938 | min = -0.1968754678964615
||y_7|| = 0.32628703117370605 | IS IT ZERO: False | max = 0.07218272984027863 | min = -0.06766912341117859
||y_8|| = 0.25064778327941

[[PredictedToken(token=' Seattle', prob=0.9744139909744263),
  PredictedToken(token=' the', prob=0.0027675111778080463),
  PredictedToken(token=' Se', prob=0.0017259021988138556),
  PredictedToken(token=' Pier', prob=0.0008594307000748813),
  PredictedToken(token=' se', prob=0.0006244106334634125)]]

In [185]:
from src.utils import experiment_utils
experiment_utils.set_seed(123456)

u = torch.randn(1, 4, 5120)
delta = torch.randn(1, 4, 5120)
A = torch.randn(5120, 16)
B = torch.randn(1, 4, 16)
C = torch.randn(1, 4, 16)
D = torch.randn(5120)

output = selective_scan_with_mask(
    self=None,
    u=u,
    delta=delta,
    A=A,
    B=B,
    C=C,
    D=D,
    # mask = {0:[]}
    mask={0: [0], 1: [0, 1], 2: [0, 1, 2], 3: [0, 1, 2, 3]},
)

output.shape

deltaA.shape=torch.Size([1, 4, 5120, 16]) | deltaB_u.shape=torch.Size([1, 4, 5120, 16])
delta_A_src_to_target.shape=torch.Size([1, 5120, 16])
delta_B_src.shape=torch.Size([1, 5120, 16])
subtracting src_idx=0 from target_idx=0 >> 289.8774719238281
delta_A_src_to_target.shape=torch.Size([1, 5120, 16])
delta_B_src.shape=torch.Size([1, 5120, 16])
subtracting src_idx=0 from target_idx=1 >> 3233.42333984375
delta_A_src_to_target.shape=torch.Size([1, 5120, 16])
delta_B_src.shape=torch.Size([1, 5120, 16])
subtracting src_idx=1 from target_idx=1 >> 182.55360412597656
delta_A_src_to_target.shape=torch.Size([1, 5120, 16])
delta_B_src.shape=torch.Size([1, 5120, 16])
subtracting src_idx=0 from target_idx=2 >> 127030.8671875
delta_A_src_to_target.shape=torch.Size([1, 5120, 16])
delta_B_src.shape=torch.Size([1, 5120, 16])
subtracting src_idx=1 from target_idx=2 >> 23893.58203125
delta_A_src_to_target.shape=torch.Size([1, 5120, 16])
delta_B_src.shape=torch.Size([1, 5120, 16])
subtracting src_idx=2 fro

torch.Size([1, 4, 5120])

In [202]:
delta_A = torch.randn(1, 4, 5120, 16)
src_idx = 2
target_idx = 2

delta_A_src_to_target = torch.prod(delta_A[:, src_idx + 1 : target_idx + 1], dim=1)
delta_A_src_to_target.shape

torch.Size([1, 5120, 16])

In [203]:
delta_B_src = torch.randn(1, 5120, 16)

In [204]:
delta_AB_src = delta_A_src_to_target * delta_B_src

torch.allclose(delta_B_src, delta_AB_src)

True

In [5]:
mt.model

Mamba(
  (embedding): Embedding(50280, 2560)
  (layers): ModuleList(
    (0-63): 64 x ResidualBlock(
      (mixer): MambaBlock(
        (in_proj): Linear(in_features=2560, out_features=10240, bias=False)
        (conv1d): Conv1d(5120, 5120, kernel_size=(4,), stride=(1,), padding=(3,), groups=5120)
        (x_proj): Linear(in_features=5120, out_features=192, bias=False)
        (dt_proj): Linear(in_features=160, out_features=5120, bias=True)
        (out_proj): Linear(in_features=5120, out_features=2560, bias=False)
      )
      (norm): RMSNorm()
    )
  )
  (norm_f): RMSNorm()
  (lm_head): Linear(in_features=2560, out_features=50280, bias=False)
)

In [15]:
block = baukit.get_module(
    mt.model,
    "layers.10.mixer",
)

In [17]:
block.A_log.shape

torch.Size([5120, 16])