## (1) Load model

In [11]:
import torch
from model import Mamba, ModelArgs
from transformers import AutoTokenizer

# One of:
#     'state-spaces/mamba-2.8b-slimpj'
#     'state-spaces/mamba-2.8b'
#     'state-spaces/mamba-1.4b'
#     'state-spaces/mamba-790m'
#     'state-spaces/mamba-370m'
#     'state-spaces/mamba-130m'
pretrained_model_name = 'state-spaces/mamba-370m'

# 1. 定义设备
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# 2. 将模型加载并移至 GPU
model = Mamba.from_pretrained(pretrained_model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')

Using device: cuda


## (2) Generate Text

In [12]:
import torch
import torch.nn.functional as F


def generate(model,
             tokenizer,
             prompt: str,
             n_tokens_to_gen: int = 50,
             sample: bool = True,
             top_k: int = 40):
    model.eval()

    # 3. 将输入的 input_ids 移至 GPU
    input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device)
    
    for token_n in range(n_tokens_to_gen):
        with torch.no_grad():
            # 这里 input_ids 已经在 GPU 上了，所以 model 的计算会在 GPU 上进行
            indices_to_input = input_ids
            next_token_logits = model(indices_to_input)[:, -1]
        
        probs = F.softmax(next_token_logits, dim=-1)
        (batch, vocab_size) = probs.shape
        
        if top_k is not None:
            (values, indices) = torch.topk(probs, k=top_k)
            probs[probs < values[:, -1, None]] = 0
            probs = probs / probs.sum(axis=1, keepdims=True)

        # 在进行采样或取极大值时，生成的 next_indices 会自动继承 input_ids 的设备
        if sample:
            next_indices = torch.multinomial(probs, num_samples=1)
        else:
            next_indices = torch.argmax(probs, dim=-1)[:, None]
        
        input_ids = torch.cat([input_ids, next_indices], dim=1)

    output_completions = [tokenizer.decode(output.tolist()) for output in input_ids][0]
    
    return output_completions

In [13]:
print(generate(model, tokenizer, 'Mamba is the'))

Mamba is the first full-length album by British rock songwriter Dave Eringa (Eddie Troutman). Mamba was released in 1983. Two singles were released from the album. Mamba was used for The Last Temptation of Christ.


In [14]:
print(generate(model, tokenizer, 'John: Hi!\nSally:'))

John: Hi!
Sally: Hi John.
John: Good to see you.
Sally: Oh my God, I know.
John: This is a problem.
Sally: I really miss you, though, baby.
John: I know, we


In [15]:
print(generate(model, tokenizer, 'The meaning of life is '))

The meaning of life is 
To be a king is 
to know the sky is blue.
To have your mind is 
to know that you are alive.
To have the strength is 
to know that you are strong.
As a king,


In [16]:
print(generate(model, tokenizer, 'def reverse_string('))

def reverse_string(str):
        # Reverse a string that has a '.' at the beginning
        return str.replace('..', '_').join([str.replace(':', '').replace('.', '')
            for i in range(len(
