In [1]:
import torch
import torch.nn.functional as F

from transformers import AutoTokenizer, AutoModelForCausalLM

from jamba import JambaLMConfig as myJambaLMConfig, JambaLM as myJambaLM

  from .autonotebook import tqdm as notebook_tqdm


In [74]:
tokenizer = AutoTokenizer.from_pretrained("TechxGenus/Mini-Jamba", trust_remote_code=True,)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained("TechxGenus/Mini-Jamba", torch_dtype=torch.float16, use_mamba_kernels=False, 
                                             device_map="auto", trust_remote_code=True)

In [70]:
prompt = '''def min(arr):
    """
    Returns the minimum value from the list `arr`.
    
    Parameters:
    - arr (list): A list of numerical values.
    
    Returns:
    - The minimum value in `arr`.
    """
'''

inputs = tokenizer.encode(prompt, return_tensors="pt")
outputs = model.generate(input_ids=inputs, max_new_tokens=64, do_sample=False)
tokenizer.decode(outputs[0], skip_special_tokens=True)

'def min(arr):\n    """\n    Returns the minimum value from the list `arr`.\n    \n    Parameters:\n    - arr (list): A list of numerical values.\n    \n    Returns:\n    - The minimum value in `arr`.\n    """\n    return min(arr)\n\n\n'

In [4]:
config_jamba = myJambaLMConfig(vocab_size=model.config.vocab_size, d_model=model.config.hidden_size, n_layers=model.config.num_hidden_layers, 
                               rms_norm_eps=model.config.rms_norm_eps, mlp_size=model.config.intermediate_size, inner_layernorms=model.config.mamba_inner_layernorms,
                               expand_factor=model.config.mamba_expand, dt_rank=model.config.mamba_dt_rank, d_state=model.config.mamba_d_state,
                               d_conv=model.config.mamba_d_conv, conv_bias=model.config.mamba_conv_bias, initializer_range=model.config.initializer_range,
                               num_experts=model.config.num_experts, num_experts_per_tok=model.config.num_experts_per_tok, 
                               attn_layer_offset=model.config.attn_layer_offset, attn_layer_period=model.config.attn_layer_period, 
                               expert_layer_offset=model.config.expert_layer_offset, expert_layer_period=model.config.expert_layer_period,
                               num_key_value_heads=model.config.num_key_value_heads, num_attention_heads=model.config.num_attention_heads,
                               pad_token_id=model.config.pad_token_id, bias=model.config.mamba_proj_bias, attention_dropout=model.config.attention_dropout,
                               tie_lm_weights=model.config.tie_word_embeddings)

my_model = myJambaLM(config_jamba)

for name, param in model.named_parameters():
    name = name.replace("model.", "jamba.")
    
    if "embed_tokens" in name:
        name = "embedding.weight"
    
    if "final_layernorm" in name:
        name = name.replace("jamba.", "")

    counterpart_param = my_model.get_parameter(name)
    if counterpart_param is not None:
        counterpart_param.data.copy_(param.data)

In [58]:
def generate(self, tokenizer, prompt: str, num_tokens: int = 50, batch_size: int = 1, sample: bool = True, top_k: int = 40, temperature: float = 1.0):
    self.eval()

    input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(next(self.parameters()).device) # (1, num_tokens)
    input_ids = input_ids.repeat(batch_size, 1)

    for i in range(input_ids.size(1) + num_tokens - 1):
        with torch.no_grad():
            # forward the new output, get new cache
            next_token_logits, _ = self(input_ids[:, 0:i+1]) # (batch_size, vocab_size), caches

        # sample (no sampling when the prompt is being processed)
        if i+1 >= input_ids.size(1):
            probs = F.softmax(next_token_logits[:, [-1]] / temperature, dim=-1) # (batch_size, vocab_size)
            probs = probs.squeeze(1)

            if top_k is not None:
                values, _ = torch.topk(probs, k=top_k) # (batch_size, k) ordered from lowest to biggest
                probs[probs < values[:, -1, None]] = 0
                probs = probs / probs.sum(axis=1, keepdims=True)

            if sample:
                next_token = torch.multinomial(probs, num_samples=1).squeeze(1) # (batch_size)
            else:
                next_token = torch.argmax(probs, dim=-1) # (batch_size)

            input_ids = torch.cat([input_ids, next_token.unsqueeze(1)], dim=1)
                
    outputs = [tokenizer.decode(output.tolist()) for output in input_ids[:, 1:]]

    self.train()

    if batch_size==1:
        return outputs[0]
    else:
        return outputs
        
#todo : tester avec sample=True
#todo : comments (caching)

def generate_caching(self, tokenizer, prompt: str, num_tokens: int = 50, batch_size: int = 1, sample: bool = True, top_k: int = 40, temperature: float = 1.0):
    self.eval()

    input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(next(self.parameters()).device) # (1, num_tokens)
    input_ids = input_ids.repeat(batch_size, 1)

    # caches is a list of cache, one per layer
    caches = [self.jamba.layers[i].get_empty_cache(batch_size) for i in range(self.config.n_layers)]

    for i in range(input_ids.size(1) + num_tokens - 1):
        with torch.no_grad():
            # forward the new output, get new cache
            next_token_logits, caches = self(input_ids[:, [i]], caches) # (batch_size, 1, vocab_size), caches
            next_token_logits = next_token_logits.squeeze(1)

        # sample (no sampling when the prompt is being processed)
        if i+1 >= input_ids.size(1):
            probs = F.softmax(next_token_logits / temperature, dim=-1) # (batch_size, vocab_size)

            if top_k is not None:
                values, _ = torch.topk(probs, k=top_k) # (batch_size, k) ordered from lowest to biggest
                probs[probs < values[:, -1, None]] = 0
                probs = probs / probs.sum(axis=1, keepdims=True)

            if sample:
                next_token = torch.multinomial(probs, num_samples=1).squeeze(1) # (batch_size)
            else:
                next_token = torch.argmax(probs, dim=-1) # (batch_size)

            input_ids = torch.cat([input_ids, next_token.unsqueeze(1)], dim=1)
                
    outputs = [tokenizer.decode(output.tolist()) for output in input_ids[:, 1:]]

    self.train()

    if batch_size==1:
        return outputs[0]
    else:
        return outputs

In [61]:
torch.manual_seed(12345678)
output = generate(my_model, tokenizer, prompt, num_tokens=10, top_k=None, sample=False)
print(output)

def min(arr):
    """
    Returns the minimum value from the list `arr`.
    
    Parameters:
    - arr (list): A list of numerical values.
    
    Returns:
    - The minimum value in `arr`.
    """
    return min(arr)


<|endoftext|>


In [73]:
torch.manual_seed(12345678)
output = generate_caching(my_model, tokenizer, prompt, num_tokens=10, top_k=None, sample=True)
print(output)

def min(arr):
    """
    Returns the minimum value from the list `arr`.
    
    Parameters:
    - arr (list): A list of numerical values.
    
    Returns:
    - The minimum value in `arr`.
    """
    return min(arr)


<|endoftext|>


In [48]:
torch.manual_seed(12345678)
inputs = tokenizer.encode("a=", return_tensors="pt")
outputs = model.generate(input_ids=inputs, max_new_tokens=10, do_sample=True)
tokenizer.decode(outputs[0], skip_special_tokens=True)

'a=ו cooperate<|reserved_890|>oungeutations rreditemsi\n'