In [1]:
from nnsight.models.Mamba import MambaInterp
from nnsight import LanguageModel
import torch
from torch import nn
from transformers import MambaForCausalLM
from transformers import AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-2.8b-hf")
mamba_sight = LanguageModel("state-spaces/mamba-2.8b-hf",tokenizer=tokenizer)
mamba = MambaForCausalLM.from_pretrained("state-spaces/mamba-2.8b-hf")


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards: 100%|██████████| 3/3 [00:01<00:00,  2.51it/s]


## ICL

In [11]:
input = "cat: cat, dog: dog, tree:tree, window:"
input_ids = tokenizer.encode(input, return_tensors="pt")
output = mamba.generate(input_ids, max_length=60, num_return_sequences=1, do_sample=True, top_k=50, top_p=0.95)
print(tokenizer.decode(output[0], skip_special_tokens=True))

cat: cat, dog: dog, tree:tree, window: window,
        door: door, lightbulb: lightbulb, mouse: mouse, elephant: elephant,
        turtle: turtle, catfish: catfish, monkey: monkey, pig: pig,



In [12]:
input = "cat: dog, tree:window, pyramid:sand, cat:"
input_ids = tokenizer.encode(input, return_tensors="pt")
output = mamba.generate(input_ids, max_length=30, num_return_sequences=1, do_sample=True, top_k=50, top_p=0.95)
print(tokenizer.decode(output[0], skip_special_tokens=True))

cat: dog, tree:window, pyramid:sand, cat: dog, and so on.

In this paper, we extend the classic


In [13]:
input = "absfads: absfads, daksfjksadshajsd: daksfjksadshajsd, tffassdwe:tffassdwe, tffassdwe:"
input_ids = tokenizer.encode(input, return_tensors="pt")
output = mamba.generate(input_ids, max_length=60, num_return_sequences=1, do_sample=True, top_k=50, top_p=0.95)
print(tokenizer.decode(output[0], skip_special_tokens=True))

absfads: absfads, daksfjksadshajsd: daksfjksadshajsd, tffassdwe:tffassdwe, tffassdwe:tffassdwe
<kennyloggins> i


In [6]:
input = "1: absfads, 2: daksfjksadshajsd, 3:tffassdwe, 1:"
input_ids = tokenizer.encode(input, return_tensors="pt")
output = mamba.generate(input_ids, max_length=60, num_return_sequences=1, do_sample=True, top_k=50, top_p=0.95)
print(tokenizer.decode(output[0], skip_special_tokens=True))

1: absfads, 2: daksfjksadshajsd, 3:tffassdwe, 1:sdsaafsad, 3:adfkjskd, 1:sadsfadshj, 2:sjkaskdksad


# Ablations

In [39]:
mamba

MambaForCausalLM(
  (backbone): MambaModel(
    (embeddings): Embedding(104, 512)
    (layers): ModuleList(
      (0-15): 16 x MambaBlock(
        (norm): MambaRMSNorm()
        (mixer): MambaMixer(
          (conv1d): Conv1d(1024, 1024, kernel_size=(4,), stride=(1,), padding=(3,), groups=1024)
          (act): SiLU()
          (in_proj): Linear(in_features=512, out_features=2048, bias=False)
          (x_proj): Linear(in_features=1024, out_features=112, bias=False)
          (dt_proj): Linear(in_features=48, out_features=1024, bias=True)
          (out_proj): Linear(in_features=1024, out_features=512, bias=False)
        )
      )
    )
    (norm_f): MambaRMSNorm()
  )
  (lm_head): Linear(in_features=512, out_features=104, bias=False)
  (generator): WrapperModule()
)

In [42]:
def SSM(discrete_A, deltaB_u, C, seq_len):
    hidden_dimension = 2560
    state_size = 16
    scan_outputs = []
    ssm_state = torch.zeros((1, hidden_dimension*2, state_size))
    ssm_states = []
    for i in range(seq_len):
        ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :]      # [batch, intermediade_size, ssm_state]
        scan_output = torch.matmul(ssm_state, C[:, i, :].unsqueeze(-1))  # [batch, intermediade_size, 1]
        scan_outputs.append(scan_output[:, :, 0])
        ssm_states.append(ssm_state.save())
    return torch.stack(scan_outputs, dim=-1),ssm_states   

## No Ablating

In [117]:
state_size = 16
dt_size = 160
seq_len = 2
with mamba_sight.generate("Once upon",max_new_tokens=20, do_sample=True, top_k=50, top_p=0.95) as tracer:
    seq_len = 2
    ssm_parameters = mamba_sight.backbone.layers[0].mixer.x_proj.output
    projected_states = mamba_sight.backbone.layers[0].mixer.in_proj.output.transpose(1, 2)
    projected_states.save()
    a , gate = projected_states.chunk(2, dim=1)
    
    timestep,B,C = torch.split(ssm_parameters,[dt_size,state_size,state_size],dim=-1)
    discrete_time_step = mamba_sight.backbone.layers[0].mixer.dt_proj.output                            
    discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1, 2) 

    hidden_states = mamba_sight.backbone.layers[0].mixer.act(mamba_sight.backbone.layers[0].mixer.conv1d.output[..., :seq_len])
    A = -torch.exp(mamba_sight.backbone.layers[0].mixer.A_log)
    discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None]) # [batch, intermediate_size, seq_len, ssm_state_size]
    discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float()       # [batch, intermediade_size, seq_len, ssm_state_size]
    deltaB_u = discrete_B * hidden_states[:, :, :, None].float()
    ssm,ssm_state = SSM(discrete_A, deltaB_u, C, seq_len)
    ssm = ssm+hidden_states*mamba_sight.backbone.layers[0].mixer.D[None, :, None]
    ssm.save()
    gated_ssm = ssm*mamba_sight.backbone.layers[0].mixer.act(gate)
    gated_ssm.save()
    input = mamba_sight.backbone.layers[0].mixer.out_proj.input[0][0].save()
    output = mamba_sight.generator.output.save()    

print(tokenizer.decode(output[0], skip_special_tokens=True))


#print(more.shape)

A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.


Once upon a time, the name of any given person could be changed and nobody knew that the name had been


In [81]:
original_input = input
original_gated = gated_ssm
original_ssm = ssm

## Ablation 1: Unitary SSM

In [115]:
state_size = 16
dt_size = 160
with torch.no_grad():
    with mamba_sight.generate("Once upon",max_new_tokens=20) as tracer:
        seq_len = mamba_sight.backbone.layers[0].mixer.in_proj.output.shape[1]
        
        
        #ssm=1
        for i in range(29,30):
            ssm_parameters = mamba_sight.backbone.layers[0].mixer.x_proj.output
            projected_states = mamba_sight.backbone.layers[0].mixer.in_proj.output.transpose(1, 2)
            projected_states.save()
            a , gate = projected_states.chunk(2, dim=1)
            
            timestep,B,C = torch.split(ssm_parameters,[dt_size,state_size,state_size],dim=-1)
            discrete_time_step = mamba_sight.backbone.layers[0].mixer.dt_proj.output                            
            discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1, 2) 

            hidden_states = mamba_sight.backbone.layers[0].mixer.act(mamba_sight.backbone.layers[0].mixer.conv1d.output[..., :seq_len])
            A = -torch.exp(mamba_sight.backbone.layers[0].mixer.A_log)
            discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None]) # [batch, intermediate_size, seq_len, ssm_state_size]
            discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float()       # [batch, intermediade_size, seq_len, ssm_state_size]
            deltaB_u = discrete_B * hidden_states[:, :, :, None].float()
            ssm,ssm_state = SSM(discrete_A, deltaB_u, C, seq_len)
            ssm = ssm+hidden_states*mamba_sight.backbone.layers[0].mixer.D[None, :, None]
            
            gated_ssm = ssm*mamba_sight.backbone.layers[i].mixer.act(gate)
            gated_ssm.save()
            gated_ssm = gated_ssm.transpose(1, 2)
            mamba_sight.backbone.layers[i].mixer.out_proj.input[0][0][:] = gated_ssm[0]
            input = mamba_sight.backbone.layers[i].mixer.out_proj.input[0][0]
        input.save()
        output = mamba_sight.generator.output.save()

print(tokenizer.decode(output[0], skip_special_tokens=True))


#print(more.shape)

A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.


Once upon a time, there was a little girl who was very, very, very, very, very,


In [83]:
input

tensor([[-0.2715, -0.2718,  0.1379,  ..., -0.1511, -0.0687, -0.2392],
        [-0.1506, -0.1783, -0.2216,  ...,  0.1278, -0.2739,  0.1495],
        [ 0.2125,  1.5953,  0.1440,  ..., -0.2545,  0.5127,  0.2982],
        ...,
        [-0.2784, -0.2781, -0.1417,  ..., -0.2575, -0.2784, -0.2480],
        [ 0.2984, -0.1127, -0.2026,  ..., -0.0838,  0.0210,  0.1993],
        [-0.2772, -0.2221, -0.2215,  ..., -0.2477, -0.2764, -0.2683]])

In [86]:
original_input

tensor([[[ 3.8533e-03, -6.0776e-04, -2.9002e-04,  ..., -2.7613e-04,
          -9.6215e-04, -6.7582e-05],
         [ 2.8024e-03,  5.5461e-04, -4.3667e-03,  ...,  2.7560e-04,
          -2.1598e-03,  1.6535e-04],
         [-4.1623e-03, -5.8760e-03, -7.8286e-03,  ..., -4.2756e-02,
          -4.6123e-03,  1.5350e-03],
         ...,
         [ 2.7521e-03, -2.8168e-03,  1.1099e-03,  ...,  2.8431e-04,
          -1.6862e-03,  1.8183e-04],
         [-8.8093e-03,  3.3298e-04, -5.7871e-03,  ...,  4.7501e-04,
           6.1060e-04, -3.2732e-04],
         [ 3.1671e-03, -2.1554e-03, -4.9180e-04,  ...,  2.3122e-03,
          -5.3262e-03,  1.7630e-04]]])

## Ablate all

In [98]:
state_size = 16
dt_size = 160
with mamba_sight.generate("The Eiffel Tower is in the city of",max_new_tokens=10) as tracer:
    seq_len = mamba_sight.backbone.layers[0].mixer.in_proj.output.shape[1]
    
    ssm_parameters = mamba_sight.backbone.layers[0].mixer.x_proj.output
    projected_states = mamba_sight.backbone.layers[0].mixer.in_proj.output.transpose(1, 2)
    projected_states.save()
    a , gate = projected_states.chunk(2, dim=1)
    
    ssm=0
    gated_ssm = ssm*mamba_sight.backbone.layers[0].mixer.act(gate)
    gated_ssm.save()
    #print(gated_ssm.shape)
    gated_ssm = gated_ssm.transpose(1, 2)
    for i in range(64):
        mamba_sight.backbone.layers[i].mixer.out_proj.input[0][0][:] = gated_ssm[0]
    input = mamba_sight.backbone.layers[63].mixer.out_proj.input[0][0].save()
    output = mamba_sight.generator.output.save()

print(tokenizer.decode(output[0], skip_special_tokens=True))


#print(more.shape)

A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.


The Eiffel Tower is in the city of of the city.

The city is also


In [97]:
input

tensor([[[-0., -0., 0.,  ..., -0., -0., -0.],
         [-0., -0., -0.,  ..., 0., -0., 0.],
         [0., 0., 0.,  ..., -0., 0., 0.],
         ...,
         [-0., -0., -0.,  ..., -0., -0., -0.],
         [0., -0., -0.,  ..., -0., 0., 0.],
         [-0., -0., -0.,  ..., -0., -0., -0.]]])