In [1]:
import torch

from transformers import AutoTokenizer, AutoModelForCausalLM

from jamba import JambaLMConfig as myJambaLMConfig, JambaLM as myJambaLM

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tokenizer = AutoTokenizer.from_pretrained("TechxGenus/Mini-Jamba", trust_remote_code=True,)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained("TechxGenus/Mini-Jamba", torch_dtype=torch.float16, use_mamba_kernels=False, 
                                             device_map="auto", trust_remote_code=True)

The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)` is None. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d. If you want to use the naive implementation, set `use_mamba_kernels=False` in the model config


In [3]:
prompt = "A Python function is"

inputs = tokenizer.encode(prompt, return_tensors="pt")
outputs = model.generate(input_ids=inputs, max_new_tokens=64, do_sample=False)
tokenizer.decode(outputs[0], skip_special_tokens=True)

The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` model input instead.


'A Python function isipv chambre Inn Iterate\n        \n        \n        \n        \n        \n        \n        \n        \n        \n        \n        \n        \n        \n        \n        \n        \n        \n        \n        \n        \n        \n        \n        \n        \n        \n        \n        \n        \n        \n'

In [4]:
config_jamba = myJambaLMConfig(vocab_size=model.config.vocab_size, d_model=model.config.hidden_size, n_layers=model.config.num_hidden_layers, 
                               rms_norm_eps=model.config.rms_norm_eps, mlp_size=model.config.intermediate_size, inner_layernorms=model.config.mamba_inner_layernorms,
                               expand_factor=model.config.mamba_expand, dt_rank=model.config.mamba_dt_rank, d_state=model.config.mamba_d_state,
                               d_conv=model.config.mamba_d_conv, conv_bias=model.config.mamba_conv_bias, initializer_range=model.config.initializer_range,
                               num_experts=model.config.num_experts, num_experts_per_tok=model.config.num_experts_per_tok, 
                               attn_layer_offset=model.config.attn_layer_offset, attn_layer_period=model.config.attn_layer_period, 
                               expert_layer_offset=model.config.expert_layer_offset, expert_layer_period=model.config.expert_layer_period,
                               num_key_value_heads=model.config.num_key_value_heads, num_attention_heads=model.config.num_attention_heads,
                               pad_token_id=model.config.pad_token_id, bias=model.config.mamba_proj_bias, attention_dropout=model.config.attention_dropout,
                               tie_lm_weights=model.config.tie_word_embeddings)

my_model = myJambaLM(config_jamba)

for name, param in model.named_parameters():
    name = name.replace("model.", "jamba.")
    
    if "embed_tokens" in name:
        name = "embedding.weight"
    
    if "final_layernorm" in name:
        name = name.replace("jamba.", "")

    counterpart_param = my_model.get_parameter(name)
    if counterpart_param is not None:
        counterpart_param.data.copy_(param.data)

In [5]:
x = torch.randint(low=0, high=60, size=(10, 12))
torch.allclose(model(x).logits, my_model(x), atol=0.1)

False

In [6]:
model(x).logits

tensor([[[-19.3125, -19.3750,  26.2031,  ..., -19.3594, -19.4219, -19.2188],
         [-22.5312, -22.5156,  27.9219,  ..., -22.6094, -22.5781, -22.3594],
         [-21.9219, -21.8594,  28.4531,  ..., -21.9531, -21.9531, -21.7500],
         ...,
         [-26.5000, -26.4688,  30.5312,  ..., -26.6406, -26.5156, -26.4219],
         [-28.5625, -28.4844,  31.1250,  ..., -28.6875, -28.5312, -28.4844],
         [-28.9688, -28.8594,  31.4375,  ..., -29.1094, -28.9531, -28.8438]],

        [[-25.2656, -25.1719,  28.7344,  ..., -25.3125, -25.2656, -25.2031],
         [-27.7969, -27.6719,  30.7969,  ..., -27.9219, -27.8125, -27.7656],
         [-26.9375, -26.8438,  30.3750,  ..., -27.0938, -26.9531, -26.9062],
         ...,
         [-27.8125, -27.7031,  32.2812,  ..., -27.9219, -27.7969, -27.7188],
         [-28.0000, -27.9062,  32.3125,  ..., -28.0625, -27.9688, -27.9062],
         [-27.6562, -27.5625,  32.2500,  ..., -27.7031, -27.6250, -27.5781]],

        [[-22.2656, -22.2500,  26.8125,  ...

In [7]:
my_model(x)

tensor([[[-19.3141, -19.3836,  26.2029,  ..., -19.3653, -19.4294, -19.2251],
         [-22.5247, -22.5133,  27.9237,  ..., -22.6124, -22.5718, -22.3550],
         [-21.9077, -21.8545,  28.4431,  ..., -21.9459, -21.9527, -21.7437],
         ...,
         [-26.4991, -26.4622,  30.5308,  ..., -26.6301, -26.5195, -26.4149],
         [-28.5515, -28.4696,  31.1276,  ..., -28.6707, -28.5148, -28.4668],
         [-28.9560, -28.8569,  31.4315,  ..., -29.0955, -28.9422, -28.8281]],

        [[-25.2732, -25.1907,  28.7347,  ..., -25.3240, -25.2770, -25.2212],
         [-27.7995, -27.6675,  30.7923,  ..., -27.9272, -27.8126, -27.7696],
         [-26.9448, -26.8476,  30.3769,  ..., -27.0928, -26.9667, -26.9075],
         ...,
         [-27.8120, -27.7041,  32.2676,  ..., -27.9230, -27.7895, -27.7203],
         [-28.0004, -27.9071,  32.3101,  ..., -28.0653, -27.9692, -27.9188],
         [-27.6509, -27.5559,  32.2521,  ..., -27.7057, -27.6208, -27.5710]],

        [[-22.2703, -22.2582,  26.8164,  ...