In [7]:
import torch
import torch.nn.functional as F

from transformers import AutoTokenizer, AutoModelForCausalLM

from jamba import JambaLMConfig, JambaLM

In [8]:
tokenizer = AutoTokenizer.from_pretrained("TechxGenus/Mini-Jamba", trust_remote_code=True,)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained("TechxGenus/Mini-Jamba", torch_dtype=torch.float16, use_mamba_kernels=False, 
                                             device_map="auto", trust_remote_code=True)

In [9]:
config_jamba = JambaLMConfig(vocab_size=model.config.vocab_size, d_model=model.config.hidden_size, n_layers=model.config.num_hidden_layers, 
                               rms_norm_eps=model.config.rms_norm_eps, mlp_size=model.config.intermediate_size, inner_layernorms=model.config.mamba_inner_layernorms,
                               expand_factor=model.config.mamba_expand, dt_rank=model.config.mamba_dt_rank, d_state=model.config.mamba_d_state,
                               d_conv=model.config.mamba_d_conv, conv_bias=model.config.mamba_conv_bias, initializer_range=model.config.initializer_range,
                               num_experts=model.config.num_experts, num_experts_per_tok=model.config.num_experts_per_tok, 
                               attn_layer_offset=model.config.attn_layer_offset, attn_layer_period=model.config.attn_layer_period, 
                               expert_layer_offset=model.config.expert_layer_offset, expert_layer_period=model.config.expert_layer_period,
                               num_key_value_heads=model.config.num_key_value_heads, num_attention_heads=model.config.num_attention_heads,
                               pad_token_id=model.config.pad_token_id, bias=model.config.mamba_proj_bias, attention_dropout=model.config.attention_dropout,
                               tie_lm_weights=model.config.tie_word_embeddings)

model = JambaLM(config_jamba)

for name, param in model.named_parameters():
    name = name.replace("model.", "jamba.")
    
    if "embed_tokens" in name:
        name = "embedding.weight"
    
    if "final_layernorm" in name:
        name = name.replace("jamba.", "")

    counterpart_param = model.get_parameter(name)
    if counterpart_param is not None:
        counterpart_param.data.copy_(param.data)

model = model.eval()

In [10]:
def gen_no_caching(num_tokens):
    batch_size = 1
    
    input_ids = 2*torch.ones(batch_size, 1, dtype=torch.int64)

    for i in range(num_tokens):
        next_token_logits, _ = model(input_ids[:, 0:i+1])
        probs = F.softmax(next_token_logits[:, [-1]], dim=-1)
        next_token = torch.argmax(probs, dim=-1)

        input_ids = torch.cat([input_ids, next_token], dim=1)

    return input_ids

In [11]:
gen_no_caching(4)

tensor([[    2,  9421, 24881, 24881, 10872]])

In [14]:
def gen(num_tokens):
    batch_size = 1

    caches = [model.jamba.layers[i].get_empty_cache(batch_size) for i in range(config_jamba.n_layers)]
    
    input_ids = 2*torch.ones(batch_size, 1, dtype=torch.int64)

    for i in range(num_tokens):
        next_token_logits, caches = model(input_ids[:, [i]], caches)
        probs = F.softmax(next_token_logits, dim=-1)
        next_token = torch.argmax(probs, dim=-1)

        input_ids = torch.cat([input_ids, next_token], dim=1)

    return input_ids

In [15]:
gen(4)

tensor([[    2,  9421, 33330,  9421, 33330]])