<a href="https://colab.research.google.com/github/Nikita-Gz/mamba-meet-and-greet/blob/main/Mamba.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer
import torch
import pickle

model_name = "state-spaces/mamba-130m-hf"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = MambaForCausalLM.from_pretrained(model_name)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/4.79k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


config.json:   0%|          | 0.00/895 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/517M [00:00<?, ?B/s]

The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)` is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d


generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

In [None]:
def get_token_ids(text: str) -> torch.Tensor:
    input_ids = tokenizer(text, return_tensors= "pt")["input_ids"]
    return input_ids

def get_logits(input_ids: torch.Tensor) -> torch.Tensor:
    with torch.no_grad():
        outputs = model(input_ids, labels=input_ids, max_new_tokens=1)
        logits = outputs.logits
    return logits

def get_probabilities(logits: torch.Tensor) -> torch.Tensor:
    probabilities = torch.nn.functional.softmax(logits, dim=-1)
    return probabilities

def clean_up_token(token: str) -> str:
    """ Removes special characters that the tokoens might have """
    clean_token = token.replace("Ġ", " ").replace("Ċ", "\n")
    return clean_token

def get_top_p_token_strings(
        probabilities: torch.Tensor,
        top_p: float,
        top_n: int,
        substrings_to_skip=(
            '<|endoftext|>',
            '\n'
        )) -> list[torch.Tensor]:
    sorted_probs, sorted_indices = torch.sort(probabilities, descending=True)
    token_strings = []
    cumulative_p = 0
    for i, (prob, token_id) in enumerate(zip(sorted_probs, sorted_indices)):
        token_string = tokenizer.convert_ids_to_tokens([token_id])[0]
        token_string = clean_up_token(token_string)
        if any([substring in token_string for substring in substrings_to_skip]):
            continue

        i += 1
        cumulative_p += prob
        token_strings.append(token_string)
        if i >= top_n or cumulative_p >= top_p:
            break
    return token_strings

def explore_text_tree(tree: dict) -> tuple[dict, list[str]]:
    """ Yields leaf's dict (always empty as per definition of a leaf) and its path """
    stack = [(tree, [])]

    while len(stack) > 0:
        current_tree, path = stack.pop()
        for subtree_text, subtree in current_tree.items():
            current_path = path + [subtree_text]
            if len(subtree) == 0:
                yield subtree, current_path
            else:
                stack.append((subtree, current_path))

def print_text_tree(tree: dict, previous_text: str=''):
    """ Prints a dict of word branches into a hierarchical output iteratively """
    stack = [(tree, previous_text)]

    total_lines_printed = 0
    while len(stack) > 0:
        current_tree, current_text = stack.pop()
        for node_text, nodes in current_tree.items():
            new_text = current_text + node_text
            if nodes:
                stack.append((nodes, new_text))
            else:
                print(f"{total_lines_printed+1}) {new_text}")
                total_lines_printed += 1

def generate_text_branches(
        max_depth,
        max_leafs_to_explore_per_level,
        max_branching_paths_to_create,
        top_p,
        top_n,
        input_text):
    """
    Generates the continuations from input text, splitting off the continuation branches
    if the probabilities pass the top_p, top_n, and the branch count is within limit.
    Perhaps it could be improved by not rerunning the generation from the ground-up on each new token?
    """
    text_tree = {input_text: {}}
    branching_paths_created = 0
    for level in range(max_depth):
        print(f"Generating level {level+1} out of {max_depth}")
        leafs_explored_per_current_level = 0
        for leaf, words_on_branch_path in explore_text_tree(text_tree):

            # runs the text on this branch through the model and gets token probabilities
            text_on_branch_path = "".join(words_on_branch_path)
            input_ids = get_token_ids(text_on_branch_path)
            logits = get_logits(input_ids)
            probabilities = get_probabilities(logits)[0, -1]

            # adds the top token to the tree, and adds extra branches if possible
            top_tokens = get_top_p_token_strings(probabilities, top_p, top_n)
            leaf[top_tokens.pop()] = {} # this creates a new leaf to explore later
            for new_branch_i, token in enumerate(top_tokens):
                if branching_paths_created+1 > max_branching_paths_to_create:
                    break
                leaf[token] = {}
                branching_paths_created += 1

            leafs_explored_per_current_level += 1
            if leafs_explored_per_current_level >= max_leafs_to_explore_per_level:
                break
    return text_tree

In [None]:
input_text = 'This new "Mamba" AI architecture is...'
top_p = 0.5
top_n = 3
text_tree = generate_text_branches(
    max_depth=30,
    max_leafs_to_explore_per_level=15,
    max_branching_paths_to_create=14,
    top_p=top_p,
    top_n=top_n,
    input_text=input_text)

print(f"\nGenerated continuations from the text '{input_text}' with top_p={top_p} and top_n={top_n}:")
print_text_tree(text_tree)

Generating level 1 out of 30
Generating level 2 out of 30
Generating level 3 out of 30
Generating level 4 out of 30
Generating level 5 out of 30
Generating level 6 out of 30
Generating level 7 out of 30
Generating level 8 out of 30
Generating level 9 out of 30
Generating level 10 out of 30
Generating level 11 out of 30
Generating level 12 out of 30
Generating level 13 out of 30
Generating level 14 out of 30
Generating level 15 out of 30
Generating level 16 out of 30
Generating level 17 out of 30
Generating level 18 out of 30
Generating level 19 out of 30
Generating level 20 out of 30
Generating level 21 out of 30
Generating level 22 out of 30
Generating level 23 out of 30
Generating level 24 out of 30
Generating level 25 out of 30
Generating level 26 out of 30
Generating level 27 out of 30
Generating level 28 out of 30
Generating level 29 out of 30
Generating level 30 out of 30

Generated continuations from the text 'This new "Mamba" AI architecture is...' with top_p=0.5 and top_n=3:
1