In [8]:
import torch
import torch.nn.functional as F

from transformers import AutoTokenizer, AutoModelForCausalLM

from jamba import JambaLMConfig as myJambaLMConfig, JambaLM as myJambaLM

In [2]:
tokenizer = AutoTokenizer.from_pretrained("TechxGenus/Mini-Jamba", trust_remote_code=True,)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained("TechxGenus/Mini-Jamba", torch_dtype=torch.float16, use_mamba_kernels=False, 
                                             device_map="auto", trust_remote_code=True)

The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)` is None. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d. If you want to use the naive implementation, set `use_mamba_kernels=False` in the model config


In [4]:
prompt = "A Python function is"

inputs = tokenizer.encode(prompt, return_tensors="pt")
outputs = model.generate(input_ids=inputs, max_new_tokens=64, do_sample=False)
tokenizer.decode(outputs[0], skip_special_tokens=True)

'A Python function isipv chambre Inn Iterate\n        \n        \n        \n        \n        \n        \n        \n        \n        \n        \n        \n        \n        \n        \n        \n        \n        \n        \n        \n        \n        \n        \n        \n        \n        \n        \n        \n        \n        \n'

In [7]:
config_jamba = myJambaLMConfig(vocab_size=model.config.vocab_size, d_model=model.config.hidden_size, n_layers=model.config.num_hidden_layers, 
                               rms_norm_eps=model.config.rms_norm_eps, mlp_size=model.config.intermediate_size, inner_layernorms=model.config.mamba_inner_layernorms,
                               expand_factor=model.config.mamba_expand, dt_rank=model.config.mamba_dt_rank, d_state=model.config.mamba_d_state,
                               d_conv=model.config.mamba_d_conv, conv_bias=model.config.mamba_conv_bias, initializer_range=model.config.initializer_range,
                               num_experts=model.config.num_experts, num_experts_per_tok=model.config.num_experts_per_tok, 
                               attn_layer_offset=model.config.attn_layer_offset, attn_layer_period=model.config.attn_layer_period, 
                               expert_layer_offset=model.config.expert_layer_offset, expert_layer_period=model.config.expert_layer_period,
                               num_key_value_heads=model.config.num_key_value_heads, num_attention_heads=model.config.num_attention_heads,
                               pad_token_id=model.config.pad_token_id, bias=model.config.mamba_proj_bias, attention_dropout=model.config.attention_dropout,
                               tie_lm_weights=model.config.tie_word_embeddings)

my_model = myJambaLM(config_jamba)

for name, param in model.named_parameters():
    name = name.replace("model.", "jamba.")
    
    if "embed_tokens" in name:
        name = "embedding.weight"
    
    if "final_layernorm" in name:
        name = name.replace("jamba.", "")

    counterpart_param = my_model.get_parameter(name)
    if counterpart_param is not None:
        counterpart_param.data.copy_(param.data)

In [75]:
def gen_no_caching(num_tokens):
    batch_size = 1
    
    input_ids = 1*torch.ones(batch_size, 1, dtype=torch.int64)

    for i in range(num_tokens):
        next_token_logits, _ = my_model(input_ids[:, 0:i+1])
        probs = F.softmax(next_token_logits[:, [-1]], dim=-1)
        next_token = torch.argmax(probs, dim=-1)

        input_ids = torch.cat([input_ids, next_token], dim=1)

    return input_ids[:, 1:]

def generate(self, tokenizer, prompt: str, num_tokens: int = 50, batch_size: int = 1, sample: bool = True, top_k: int = 40, temperature: float = 1.0):
        self.eval()

        input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(next(self.parameters()).device) # (1, num_tokens)
        input_ids = input_ids.repeat(batch_size, 1)

        # caches is a list of cache, one per layer
        # todo

        for i in range(input_ids.size(1) + num_tokens - 1):
            with torch.no_grad():
                # forward the new output, get new cache
                next_token_logits, _ = self(input_ids[:, 0:i+1]) # (batch_size, vocab_size), caches

            # sample (no sampling when the prompt is being processed)
            if i+1 >= input_ids.size(1):
                probs = F.softmax(next_token_logits[:, [-1]] / temperature, dim=-1) # (batch_size, vocab_size)

                if top_k is not None:
                    values, _ = torch.topk(probs, k=top_k) # (batch_size, k) ordered from lowest to biggest
                    probs[probs < values[:, -1, None]] = 0
                    probs = probs / probs.sum(axis=1, keepdims=True)

                if sample:
                    next_token = torch.multinomial(probs, num_samples=1).squeeze(1) # (batch_size)
                else:
                    next_token = torch.argmax(probs, dim=-1).squeeze(1) # (batch_size)

                input_ids = torch.cat([input_ids, next_token.unsqueeze(1)], dim=1)
                
        outputs = [tokenizer.decode(output.tolist()) for output in input_ids[:, 1:]]

        self.train()

        if batch_size==1:
            return outputs[0]
        else:
            return outputs

In [76]:
tokenizer.decode(gen_no_caching(10)[0])

"uther chambre horseback chambre oreganosem')[\n    return"

In [80]:
outputs = model.generate(max_new_tokens=10, do_sample=False)
tokenizer.decode(outputs[0], skip_special_tokens=True)

"uther chambre horseback chambre oreganosem')[\n    return"

In [90]:
generate(my_model, tokenizer, "salut", num_tokens=10, top_k=None, sample=False)

'salut horseback oreganoishable chop vegetable vegetable vegetable\nreturn<|endoftext|>'

In [89]:
inputs = tokenizer.encode("salut", return_tensors="pt")
outputs = model.generate(input_ids=inputs, max_new_tokens=10, do_sample=False)
tokenizer.decode(outputs[0], skip_special_tokens=True)

'salut horseback oreganoishable chop vegetable vegetable vegetable\nreturn'

"uther chambre horseback chambre oreganosem')[\n    return"

"uther chambre horseback chambre oreganosem')[\n    return"