# Load Model

In [4]:
from model import Mamba, ModelArgs
from transformers import AutoTokenizer

# One of:
#     'state-spaces/mamba-2.8b-slimpj'
#     'state-spaces/mamba-2.8b'
#     'state-spaces/mamba-1.4b'
#     'state-spaces/mamba-790m'
#     'state-spaces/mamba-370m'
#     'state-spaces/mamba-130m'
pretrained_model_name = 'state-spaces/mamba-130m'

model = Mamba.from_pretrained(pretrained_model_name)
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


# Generate Text

In [5]:
import torch
import torch.nn.functional as F


def generate(model,
             tokenizer,
             prompt: str,
             n_tokens_to_gen: int = 50,
             sample: bool = True,
             top_k: int = 40):
    model.eval()
    
    input_ids = tokenizer(prompt, return_tensors='pt').input_ids
    
    for token_n in range(n_tokens_to_gen):
        with torch.no_grad():
            indices_to_input = input_ids
            next_token_logits = model(indices_to_input)[:, -1]
        
        probs = F.softmax(next_token_logits, dim=-1)
        (batch, vocab_size) = probs.shape
        
        if top_k is not None:
            (values, indices) = torch.topk(probs, k=top_k)
            probs[probs < values[:, -1, None]] = 0
            probs = probs / probs.sum(axis=1, keepdims=True)
        
        if sample:
            next_indices = torch.multinomial(probs, num_samples=1)
        else:
            next_indices = torch.argmax(probs, dim=-1)[:, None]
        
        input_ids = torch.cat([input_ids, next_indices], dim=1)

    output_completions = [tokenizer.decode(output.tolist()) for output in input_ids][0]
    
    return output_completions

In [6]:
print(generate(model, tokenizer, 'Mamba is the'))

Mamba is the most successful Mamba that has ever entered the game. After his amazing win over the Brazilian Mamba in Mombasa, he had only two more goals to his name the following week.

Mamba got his first call up to the M


In [7]:
print(generate(model, tokenizer, 'John: Hi!\nSally:'))

John: Hi!
Sally: Hey!
John: This is Sally! (He's smiling a little at the show!)
(I also think we made up. Not sure!)
Sally: Hey there, John. I think it’s nice of us to introduce


In [8]:
print(generate(model, tokenizer, 'The meaning of life is '))

The meaning of life is 
not defined 
by some simple definition 
or by some one-on-one interaction. 
If we are discussing philosophy, 
we are talking about the meaning of life. If we are discussing 
some other philosophical concept


In [9]:
print(generate(model, tokenizer, 'def reverse_string('))

def reverse_string(s):
    return "{0}{1}{2} - {:3d}".format(s[0], s[1], s[2])

class BaseCommandRunner(CommandRunnerBase):
    """ An instance of Base
