In [1]:
import mamba_ssm
from nnsight import LanguageModel, util
from nnsight.tracing.Proxy import Proxy
from nnsight.models.Mamba import MambaInterp
from transformers import AutoTokenizer
import numpy as np
import torch as t
import torch.nn.functional as F
import einops
from tqdm import tqdm
from functools import partial

from rich import print as rprint
from rich.table import Table

from typing import List, Callable, Union

device = t.device("cuda:2" if t.cuda.is_available() else "cpu")


In [2]:
from datasets import Dataset, DatasetDict, load_dataset

In [3]:
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b", padding_side="left")
tokenizer.pad_token_id = tokenizer.eos_token_id
mamba_model = MambaInterp("state-spaces/mamba-2.8b", device=device, tokenizer=tokenizer)
sampling_kwargs = {
    "top_p": 0.2,
    "top_k": 0,
    "repetition_penalty": 1.1,
} # in mamba_ssm/utils/generation.py

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [7]:
with mamba_model.generate("The Eiffel tower is located",max_length=20,scan=False, validate=False,**sampling_kwargs) as tracer:
    output = mamba_model.generator.output.save()
decoded = tokenizer.decode(output[0],skip_special_tokens=True)
print(decoded)

The Eiffel tower is located in Paris, France. It was built between 1889 and
19


In [15]:
mamba_model.backbone.layers[0].mixer

MambaModuleInterp(
  (in_proj): Linear(in_features=2560, out_features=10240, bias=False)
  (conv1d): Conv1d(5120, 5120, kernel_size=(4,), stride=(1,), padding=(3,), groups=5120)
  (act): SiLU()
  (x_proj): Linear(in_features=5120, out_features=192, bias=False)
  (dt_proj): Linear(in_features=160, out_features=5120, bias=True)
  (out_proj): Linear(in_features=5120, out_features=2560, bias=False)
  (dt): WrapperModule()
  (B): WrapperModule()
  (C): WrapperModule()
  (ssm): SSM(
    (discA): DiscA()
    (discB): DiscB()
    (hx): Hx(
      (bx): Bx()
      (ah): Ah()
    )
    (yh): Yh()
  )
  (delta_softplus): Softplus(beta=1, threshold=20)
)

In [38]:
output.shape

torch.Size([1, 20])

In [60]:
inputs = []
outputs = []
with mamba_model.generate("The Eiffel tower is located",max_length=20,scan=False, validate=False,**sampling_kwargs) as tracer:
    for layer in mamba_model.backbone.layers:
       #inputs.append(layer.mixer.ssm.input[0].save())
       outputs.append(layer.mixer.ssm.discA.save())
       #pass
       #layer.mixer.ssm.discA.output=t.ones_like(layer.mixer.ssm.discA.output)
      
      
    output = mamba_model.generator.output.save()

decoded = tokenizer.decode(output[0],skip_special_tokens=True)
print(decoded)

ValueError: Accessing Proxy value before it's been set.

In [52]:
with mamba_model.invoke(dataset["text"][0][:800],scan=True) as invoker:

    for layer in mamba_model.backbone.layers:
        A_bar = layer.mixer.ssm.discA.output
        dB = layer.mixer.ssm.discB.output
        dA = layer.mixer.ssm.dA.output
        zohDeltaB=dB*1/dA*(A_bar-1) 
        layer.mixer.ssm.discB.output=t.zeros_like(zohDeltaB)
        layer.mixer.ssm.discA.output=t.zeros_like(A_bar)

output=invoker.output
loss = compute_loss(dataset["text"][0][:800],output.logits[0])  
print(loss)

tensor(38.9072, device='cuda:2')


In [53]:
mamba_model

MambaLMHeadModel(
  (backbone): MixerModel(
    (embedding): Embedding(50280, 2560)
    (layers): ModuleList(
      (0-63): 64 x Block(
        (mixer): MambaModuleInterp(
          (in_proj): Linear(in_features=2560, out_features=10240, bias=False)
          (conv1d): Conv1d(5120, 5120, kernel_size=(4,), stride=(1,), padding=(3,), groups=5120)
          (act): SiLU()
          (x_proj): Linear(in_features=5120, out_features=192, bias=False)
          (dt_proj): Linear(in_features=160, out_features=5120, bias=True)
          (out_proj): Linear(in_features=5120, out_features=2560, bias=False)
          (dt): WrapperModule()
          (B): WrapperModule()
          (C): WrapperModule()
          (ssm): SSM(
            (discA): DiscA()
            (discB): DiscB()
            (dA): DA()
            (hx): Hx(
              (bx): Bx()
              (ah): Ah()
            )
            (yh): Yh()
          )
          (delta_softplus): Softplus(beta=1, threshold=20)
        )
        (norm)

In [108]:
with mamba_model.invoke(dataset["text"][0][:50],scan=True) as invoker:

    for layer in mamba_model.backbone.layers:
        out = layer.mixer.conv1d.output[0][:,:]
        layer.mixer.conv1d.output[0][...,:] = t.concat([layer.mixer.conv1d.input[0][0][0][:,:],t.zeros_like(out)[:,:3]],dim=1)  #A_bar = layer.mixer.ssm.discA.output
        
output=invoker.output
loss = compute_loss(dataset["text"][0][:50],output.logits[0])  
print(loss)

tensor(22.7835, device='cuda:2')


In [110]:
with mamba_model.invoke(dataset["text"][0][:50],scan=True) as invoker:

    out = mamba_model.backbone.layers[0].mixer.conv1d.output[0][:,:]
    mamba_model.backbone.layers[0].mixer.conv1d.output[0][...,:] = t.concat([mamba_model.backbone.layers[0].mixer.conv1d.input[0][0][0][:,:],t.zeros_like(out)[:,:3]],dim=1)  #A_bar = layer.mixer.ssm.discA.output
        
output=invoker.output
loss = compute_loss(dataset["text"][0][:50],output.logits[0])  
print(loss)

tensor(42.1457, device='cuda:2')


In [106]:
with mamba_model.invoke(dataset["text"][0][:50],scan=True) as invoker:

    for layer in mamba_model.backbone.layers:
        out = layer.mixer.conv1d.output[0].save()
        layer.mixer.conv1d.output[0]= t.zeros_like(out) 
        
output=invoker.output
loss = compute_loss(dataset["text"][0][:50],output.logits[0])  
print(loss)

tensor(176.0256, device='cuda:2')


In [82]:
invoker.output.logits[0].shape

torch.Size([12, 50280])

In [86]:
conv_input.value[0][0][0,:,0]

tensor([-1.2534, -0.4671, -0.5139,  ...,  0.6724, -0.2481,  0.2129],
       device='cuda:2')

In [85]:
conv_output.value[0][:,0]

tensor([ 0.0650,  0.0308,  0.0203,  ..., -0.0028,  0.2437, -0.0757],
       device='cuda:2')