In [1]:
from model import Mamba
import numpy as np
import jax
import jax.numpy as jnp

from transformers import AutoTokenizer
from utils import generate


model, params = Mamba.from_pretrained('state-spaces/mamba-370m')
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')


  from .autonotebook import tqdm as notebook_tqdm


In [4]:
out = generate(model, jax.random.PRNGKey(42), params, tokenizer, 'Mamba is the', n_tokens_to_gen=40, pad=False, pad_token_id=0, sample=True, top_k=40, do_jit=False, deterministic=False)
out

'Mamba is the only way you\'re gonna survive.  I ain\'t gonna get in the way of you gonna get away from me." "You ain\'t got no say." "You\'re on your own." "'

In [10]:
x = jnp.array([[1,2,3]])
x = model.apply(params, x)
x[:, :, :10]

Array([[[  2.4653006 ,  -8.554799  ,   2.6238873 ,   5.3640375 ,
          -1.8180041 ,   0.78379506,  -0.48186472,   0.54381496,
           5.0406475 ,   3.7135432 ],
        [  7.6108055 ,  -8.149187  ,   4.8119187 ,   5.1954722 ,
           3.9434814 ,   2.215487  ,   1.5003986 ,   3.3085713 ,
           3.557276  ,   2.6064222 ],
        [-12.345375  , -29.163307  , -15.7517395 , -16.207508  ,
         -15.9281025 , -19.773325  , -19.072887  , -17.463512  ,
         -17.213531  , -17.006496  ]]], dtype=float32)

In [1]:
import torch
from transformers import AutoTokenizer
from model_torch import Mamba as MambaTorch
model_torch = MambaTorch.from_pretrained('state-spaces/mamba-370m').cuda()
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
model_torch(torch.tensor([[1,2,3]]).cuda())[:, :, :10]

tensor([[[  2.4702,  -8.5650,   2.6273,   5.3573,  -1.8029,   0.7887,  -0.4877,
            0.5402,   5.0373,   3.7238],
         [  7.6153,  -8.1547,   4.8065,   5.1982,   3.9556,   2.2247,   1.5022,
            3.3063,   3.5594,   2.6049],
         [-12.3332, -29.1413, -15.7434, -16.1949, -15.9096, -19.7548, -19.0556,
          -17.4474, -17.2015, -17.0009]]], device='cuda:0',
       grad_fn=<SliceBackward0>)

In [4]:
import torch
import torch.nn.functional as F


def generate(model,
             tokenizer,
             prompt: str,
             n_tokens_to_gen: int = 50,
             sample: bool = True,
             top_k: int = 40):
    model.eval()
    
    input_ids = tokenizer(prompt, return_tensors='pt').input_ids.cuda()
    
    for token_n in range(n_tokens_to_gen):
        with torch.no_grad():
            indices_to_input = input_ids
            next_token_logits = model(indices_to_input)[:, -1]
        
        probs = F.softmax(next_token_logits, dim=-1)
        (batch, vocab_size) = probs.shape
        
        if top_k is not None:
            (values, indices) = torch.topk(probs, k=top_k)
            probs[probs < values[:, -1, None]] = 0
            probs = probs / probs.sum(axis=1, keepdims=True)
        
        if sample:
            next_indices = torch.multinomial(probs, num_samples=1)
        else:
            next_indices = torch.argmax(probs, dim=-1)[:, None]
        
        input_ids = torch.cat([input_ids, next_indices], dim=1)

    output_completions = [tokenizer.decode(output.tolist()) for output in input_ids][0]
    
    return output_completions

print(generate(model_torch, tokenizer, 'John: Hi!\nSally:', sample=False, n_tokens_to_gen=40, top_k=None))

John: Hi!
Sally: Hi!
John: I'm John.
Sally: I'm Sally.
John: I'm John.
Sally: I'm Sally.
John: I'm John.



In [16]:
def step(x1, x2):
    return x1 + x2, x1 + x2

carry, ys = jax.lax.scan(step, init=0, xs=jnp.array([1, 2, 3, 4]))
# carry: 6
# ys: [2, 6, 12]

carry, ys 

(Array(10, dtype=int32), Array([ 1,  3,  6, 10], dtype=int32))

In [14]:
def combine(x1, x2):
    return x1 + x2

result = jax.lax.associative_scan(combine, jnp.array([1, 2, 3, 4]))
# result: [1, 3, 6, 10]
result

Array([ 1,  3,  6, 10], dtype=int32)

In [11]:
import jax
import jax.numpy as jnp

# Sequential implementation with jax.lax.scan
def cumulative_sum_scan(xs):
    def step(carry, x):
        carry = carry + x
        return carry, carry  # Update carry and output it
    _, ys = jax.lax.scan(step, 0, xs)  # Initial carry = 0
    return ys

# Parallel implementation with jax.lax.associative_scan
def cumulative_sum_associative_scan(xs):
    def combine(x1, x2):
        return x1 + x2
    return jax.lax.associative_scan(combine, xs)

# Example input
xs = jnp.array([1, 2, 3, 4, 5])

# Compute with both methods
result_scan = cumulative_sum_scan(xs)
result_associative_scan = cumulative_sum_associative_scan(xs)

# Print results
print("Input:", xs)
print("Cumulative sum with scan:", result_scan)
print("Cumulative sum with associative scan:", result_associative_scan)

# Verify equality
assert jnp.allclose(result_scan, result_associative_scan), "Results do not match!"


Input: [1 2 3 4 5]
Cumulative sum with scan: [ 1  3  6 10 15]
Cumulative sum with associative scan: [ 1  3  6 10 15]


In [7]:
import jax
import jax.numpy as jnp
from einops import einsum

# Standard scan implementation


def run_standard_scan(Ab, Bb_u, Cb):
    def step(x_k_1, inputs):
        Ab_k, Bb_u_k, Cb_k = inputs
        x_k = Ab_k * x_k_1 + Bb_u_k
        y_k = einsum(x_k, Cb_k, 'b d_in n, b n -> b d_in')
        return x_k, y_k

    x0 = jnp.zeros((B, D_in, n))
    _, ys = jax.lax.scan(step, x0, (Ab, Bb_u, Cb))
    return ys


def run_parallel_scan(Ab, Bb_u, Cb):
    # Associative operation
    def combine_parallel(state1, state2):
        # (a_1​,b_1​)⊕(a_2​,b_2​)=(a_1 * ​a_2​,a_2 *​ b_1​ + b_2​)
        fl, xl = state1
        fr, xr = state2
        f = fr * fl
        x = fr * xl + xr
        return f, x

    # Perform associative scan
    results = jax.lax.associative_scan(combine_parallel, (Ab, Bb_u))

    return einsum(results[1], Cb, 'l b d_in n, l b n -> l b d_in')

# Example inputs
B, L, D_in, n = 2, 5, 4, 3  # Batch, sequence length, input dim, model dim
Ab = jnp.arange(0, L* B* D_in* n).reshape((L, B, D_in, n)) / 100
Bb_u = jnp.arange(0, L* B* D_in* n).reshape((L, B, D_in, n))/ 100
Cb = jnp.arange(0, L* B* n).reshape((L, B, n)) / 100

# Compare outputs
standard_output = run_standard_scan(Ab, Bb_u, Cb)
parallel_output = run_parallel_scan(Ab, Bb_u, Cb)

# Verify if the outputs are close
is_close = jnp.allclose(standard_output, parallel_output, atol=1e-6)
print(f"Outputs are close: {is_close}")


Outputs are close: True


In [23]:
# Example inputs
B, L, D_in, n = 2, 1000000, 4, 3  # Batch, sequence length, input dim, model dim
Ab = jnp.arange(0, L* B* D_in* n).reshape((L, B, D_in, n)) / 100
Bb_u = jnp.arange(0, L* B* D_in* n).reshape((L, B, D_in, n))/ 100
Cb = jnp.arange(0, L* B* n).reshape((L, B, n)) / 100

# Compare outputs


In [24]:
%%timeit
standard_output = run_standard_scan(Ab, Bb_u, Cb)

5.6 s ± 61.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [25]:
%%timeit

parallel_output = run_parallel_scan(Ab, Bb_u, Cb)

53.6 ms ± 7.23 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
