In [1]:
import torch
import torch.nn.functional as F

from transformers import AutoTokenizer, AutoModelForCausalLM

from jamba import JambaLMConfig, JambaLM, MambaLayer, AttentionLayer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tokenizer = AutoTokenizer.from_pretrained("TechxGenus/Mini-Jamba", trust_remote_code=True,)
tokenizer.pad_token = tokenizer.eos_token

hf_model = AutoModelForCausalLM.from_pretrained("TechxGenus/Mini-Jamba", torch_dtype=torch.float16, use_mamba_kernels=False, 
                                             device_map="auto", trust_remote_code=True)

The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)` is None. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d. If you want to use the naive implementation, set `use_mamba_kernels=False` in the model config


In [3]:
config_jamba = JambaLMConfig(vocab_size=hf_model.config.vocab_size, d_model=hf_model.config.hidden_size, n_layers=hf_model.config.num_hidden_layers, 
                               rms_norm_eps=hf_model.config.rms_norm_eps, mlp_size=hf_model.config.intermediate_size, inner_layernorms=hf_model.config.mamba_inner_layernorms,
                               expand_factor=hf_model.config.mamba_expand, dt_rank=hf_model.config.mamba_dt_rank, d_state=hf_model.config.mamba_d_state,
                               d_conv=hf_model.config.mamba_d_conv, conv_bias=hf_model.config.mamba_conv_bias, initializer_range=hf_model.config.initializer_range,
                               num_experts=hf_model.config.num_experts, num_experts_per_tok=hf_model.config.num_experts_per_tok, 
                               attn_layer_offset=hf_model.config.attn_layer_offset, attn_layer_period=hf_model.config.attn_layer_period, 
                               expert_layer_offset=hf_model.config.expert_layer_offset, expert_layer_period=hf_model.config.expert_layer_period,
                               num_key_value_heads=hf_model.config.num_key_value_heads, num_attention_heads=hf_model.config.num_attention_heads,
                               pad_token_id=hf_model.config.pad_token_id, bias=hf_model.config.mamba_proj_bias, attention_dropout=hf_model.config.attention_dropout,
                               tie_lm_weights=hf_model.config.tie_word_embeddings)

model = JambaLM(config_jamba)

for name, param in hf_model.named_parameters():
    name = name.replace("model.", "jamba.")
    
    if "embed_tokens" in name:
        name = "embedding.weight"
    
    if "final_layernorm" in name:
        name = name.replace("jamba.", "")

    counterpart_param = model.get_parameter(name)
    if counterpart_param is not None:
        counterpart_param.data.copy_(param.data)

"""
config_jamba = JambaLMConfig(d_model=64, n_layers=32, mlp_size=128, vocab_size=60, attn_layer_period=2, attn_layer_offset=1)
# 100, 4 = full mamba; 1, 0 = full attention
model = JambaLM(config_jamba)
model = model.eval()
"""

'\nconfig_jamba = JambaLMConfig(d_model=64, n_layers=32, mlp_size=128, vocab_size=60, attn_layer_period=2, attn_layer_offset=1)\n# 100, 4 = full mamba; 1, 0 = full attention\nmodel = JambaLM(config_jamba)\nmodel = model.eval()\n'

In [4]:
print(sum([isinstance(layer, MambaLayer) for layer in model.jamba.layers]))
print(sum([isinstance(layer, AttentionLayer) for layer in model.jamba.layers]))

8
8


In [5]:
def gen_no_caching(num_tokens):
    batch_size = 1
    
    input_ids = 1*torch.ones(batch_size, 1, dtype=torch.int64)

    for i in range(num_tokens):
        next_token_logits, _ = model(input_ids[:, 0:i+1])
        print(next_token_logits[:, [-1]].mean())
        probs = F.softmax(next_token_logits[:, [-1]], dim=-1)
        next_token = torch.argmax(probs, dim=-1)

        input_ids = torch.cat([input_ids, next_token], dim=1)

    return input_ids

In [6]:
gen_no_caching(10)

tensor(4.8229, grad_fn=<MeanBackward0>)
tensor(36.0137, grad_fn=<MeanBackward0>)
tensor(40.9549, grad_fn=<MeanBackward0>)
tensor(42.1674, grad_fn=<MeanBackward0>)
tensor(33.2820, grad_fn=<MeanBackward0>)
tensor(4.1003, grad_fn=<MeanBackward0>)
tensor(-0.9146, grad_fn=<MeanBackward0>)
tensor(-3.4191, grad_fn=<MeanBackward0>)
tensor(-4.0525, grad_fn=<MeanBackward0>)
tensor(-6.5642, grad_fn=<MeanBackward0>)


tensor([[    1, 16249, 35307, 56791, 35307, 60128,  7007, 52545,  1554,  2534,
          3068]])

In [7]:
def gen(num_tokens):
    batch_size = 1

    caches = [model.jamba.layers[i].get_empty_cache(batch_size) for i in range(config_jamba.n_layers)]
    
    input_ids = 1*torch.ones(batch_size, 1, dtype=torch.int64)

    for i in range(num_tokens):
        next_token_logits, caches = model(input_ids[:, [i]], caches)
        print(next_token_logits.mean())
        probs = F.softmax(next_token_logits, dim=-1)
        next_token = torch.argmax(probs, dim=-1)

        input_ids = torch.cat([input_ids, next_token], dim=1)

    return input_ids

In [8]:
gen(10)

tensor(4.8229, grad_fn=<MeanBackward0>)
tensor(36.0137, grad_fn=<MeanBackward0>)
tensor(40.9549, grad_fn=<MeanBackward0>)
tensor(42.1674, grad_fn=<MeanBackward0>)
tensor(33.2820, grad_fn=<MeanBackward0>)
tensor(4.1003, grad_fn=<MeanBackward0>)
tensor(-0.9146, grad_fn=<MeanBackward0>)
tensor(-3.4191, grad_fn=<MeanBackward0>)
tensor(-4.0525, grad_fn=<MeanBackward0>)
tensor(-6.5642, grad_fn=<MeanBackward0>)


tensor([[    1, 16249, 35307, 56791, 35307, 60128,  7007, 52545,  1554,  2534,
          3068]])