In [2]:
import torch
import torch.nn.functional as F

from jamba import JambaLMConfig, JambaLM, from_pretrained, load_balancing_loss

from transformers import AutoModelForCausalLM

In [3]:
model = from_pretrained("TechxGenus/Mini-Jamba")

model_hf = AutoModelForCausalLM.from_pretrained("TechxGenus/Mini-Jamba", torch_dtype=torch.float32, use_mamba_kernels=False, 
                                             device_map="auto", trust_remote_code=True)

The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)` is None. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d. If you want to use the naive implementation, set `use_mamba_kernels=False` in the model config


In [4]:
data = torch.randint(0, 600, (1, 20))
x = data[:, :-1]
y = data[:, 1:]

In [5]:
logits, router_logits = model(x)
aux_loss = load_balancing_loss(router_logits, 8, 2)

In [6]:
torch.allclose(model_hf(x).logits, logits, rtol=0.1)

True

In [7]:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
loss

tensor(55.8649, grad_fn=<NllLossBackward0>)

In [8]:
loss.backward()

In [9]:
loss_hf = F.cross_entropy(model_hf(x).logits.view(-1, model_hf(x).logits.size(-1)), y.view(-1))
loss_hf

tensor(55.8649, grad_fn=<NllLossBackward0>)

In [10]:
from transformers import AutoTokenizer

In [11]:
tokenizer = AutoTokenizer.from_pretrained(
    "TechxGenus/Mini-Jamba",
    trust_remote_code=True,
)
tokenizer.pad_token = tokenizer.eos_token

In [13]:
tokenizer.eos_token_id

2