In [1]:
import torch
import torch.nn.functional as F

from jamba import JambaLMConfig, Jamba, AttentionLayer, MambaLayer

In [2]:
config = JambaLMConfig(d_model=64, mlp_size=12, n_layers=1, num_attention_heads=32, num_key_value_heads=8, attn_layer_offset=0)
layer = Jamba(config)

In [3]:
x = torch.randn(1, 10, 64)

In [4]:
output = layer(x)
output.shape

torch.Size([1, 10, 64])

In [5]:
output.argmax(-1)

tensor([[46, 14, 47, 29, 50, 35, 36, 49, 20, 55]])

In [6]:
cache = [layer.layers[0].get_empty_cache(1)]
ys = []
for i in range(10):
    y, cache = layer.step(x[:, [i]], cache)
    print(y[0].argmax(-1))
    ys.append(y[0])

ys = torch.stack(ys, dim=1).squeeze(2)

tensor([46])
tensor([14])
tensor([47])
tensor([29])
tensor([50])
tensor([35])
tensor([36])
tensor([49])
tensor([20])
tensor([55])


In [7]:
torch.allclose(output, ys, rtol=0.01)

True

In [8]:
caches = [layer.layers[i].get_empty_cache(1) for i in range(config.n_layers)]

for i in range(10):
    next_token_logits, caches = layer.step(x[:, [i]], caches)
    print(next_token_logits.mean())
    probs = F.softmax(next_token_logits, dim=-1)
    next_token = torch.argmax(probs, dim=-1)

tensor(0.0433, grad_fn=<MeanBackward0>)
tensor(-0.1767, grad_fn=<MeanBackward0>)
tensor(0.2905, grad_fn=<MeanBackward0>)
tensor(-0.1766, grad_fn=<MeanBackward0>)
tensor(0.0499, grad_fn=<MeanBackward0>)
tensor(-0.1062, grad_fn=<MeanBackward0>)
tensor(-0.0894, grad_fn=<MeanBackward0>)
tensor(0.0102, grad_fn=<MeanBackward0>)
tensor(-0.1781, grad_fn=<MeanBackward0>)
tensor(-0.1762, grad_fn=<MeanBackward0>)


In [9]:
ys.mean(2)

tensor([[ 0.0433, -0.1767,  0.2905, -0.1766,  0.0499, -0.1062, -0.0894,  0.0102,
         -0.1781, -0.1762]], grad_fn=<MeanBackward1>)

In [10]:
output.mean(2)

tensor([[ 0.0433, -0.1767,  0.2905, -0.1766,  0.0499, -0.1062, -0.0894,  0.0102,
         -0.1781, -0.1762]], grad_fn=<MeanBackward1>)