# Instantiate small random model

In [1]:
from mamba_model import MambaModel
from mamba_config import MambaConfig
import torch

# Random init model

config = MambaConfig(
    num_layers = 3, # 5
    hidden_size = 256, # 1024
    mamba_headdim = 64,
    mamba_ngroups = 1,
    state_size = 16,
    conv_dimension = 4,
    expansion_factor = 2,
    rms_norm = True,
    bias = False,
    use_mem_mlp = False, #True,
    use_mem_rope =  False, #True,
    num_attention_heads = 16,
    num_mem_heads = 0, #16,
    num_mem_blocks = 0, #2,
    vocab_size = 32000,
    layer_mapping = ["m", "m", "m"]
)

model = MambaModel(config = config, max_sequence_length = 4096)
model = model.cuda().half()
# Tokenizer
from megatron.tokenizer.tokenizer import _HFAutoTokenizer
from megatron.text_generation.tokenization import tokenize_prompts
tokenizer = _HFAutoTokenizer("mistralai/Mistral-7B-v0.1")


Zarr-based strategies will not be registered because of missing packages


# Load 3B checkpoint

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
from mamba_model import MambaModel
from mamba_config import MambaConfig
import torch
import copy

# checkpoint_name = "/workspace/zamba2_tiny_test/iter_0000015/mp_rank_00/model_optim_rng.pt"
checkpoint_name = "/checkpoints/zamba_3B_attempt_2/iter_1235000/mp_rank_00/model_optim_rng.pt"
state_dict = torch.load(checkpoint_name, map_location = "cpu")
# config_path = "/workspace/zamba2_tiny_test/config.sh"
config_path = "/workspace/Mamba-MoE/examples/mamba_memory_block/zamba_3B_attempt3.sh"

def extract_keyword_args(filestr, keyword):
    gpt_split = filestr.split(keyword)
    if len(gpt_split) <=1:
        raise ValueError("Config provided does not have a GPT_ARGS variable provided")
    arg_splits = gpt_split[1].split("\"")
    gpt_args = arg_splits[1]
    gpt_args = gpt_args.replace("\n","").replace("\\","").replace("\t","")
    gpt_args = ' '.join(gpt_args.split())
    return gpt_args.strip().split(" ")
    
with open(config_path,"r") as f:
        filestr = f.read()
config_args = extract_keyword_args(filestr, "GPT_ARGS")
print(config_args)

config = MambaConfig(
    num_layers = 54,
    hidden_size = 2560,
    mamba_headdim = 64,
    mamba_ngroups = 1,
    state_size = 64,
    conv_dimension = 4,
    expansion_factor = 2,
    rms_norm = True,
    bias = False,
    use_mem_mlp = True,
    num_attention_heads = 32,
    num_mem_heads = 32,
    num_mem_blocks = 2,
    use_shared_block_lora = True,
    lora_rank = 128,
    vocab_size = 32000,
    layer_mapping = ['m', 'm', 'm', 'm', 'm', 'm', 'g', 'm', 'm', 'm', 'm', 'm', 'g', 'm', 'm', 'm', 'm', 'm', 'g', 'm', 'm', 'm', 'm', 'm', 'g', 'm', 'm', 'm', 'm', 'm', 'g', 'm', 'm', 'm', 'm', 'm', 'g', 'm', 'm', 'm', 'm', 'm', 'g', 'm', 'm', 'm', 'm', 'g', 'm', 'm', 'm', 'g', 'm', 'm']
    
)
#print(config["num_layers"])
# #layer_mapping = ["r", "r", "g", "r", "r"]
model = MambaModel(config = config, max_sequence_length = 4096)

checkpoint_state_dict = dict(copy.deepcopy(state_dict["model"]))
model_state_dict = model.state_dict()
del checkpoint_state_dict["embedding.position_embeddings.weight"] # delete pos embeddings from the state dict
checkpoint_state_dict["embedding.weight"] = checkpoint_state_dict["embedding.word_embeddings.weight"].clone()
del checkpoint_state_dict["embedding.word_embeddings.weight"]
# delete the extra state
for k in list(checkpoint_state_dict.keys()):
    if "extra_state" in k:
        del checkpoint_state_dict[k]
    if "buffer_params" in k:
        del checkpoint_state_dict[k]
        # okay let's leave this out. I was wanting to do key matching
# for (k_c, k_m) in zip(checkpoint_state_dict.keys(), model_state_dict.keys()):
#     model_state_dict[k_m].data = checkpoint_state_dict[k_c].clone().data
#     #print(checkpoint_state_dict[k_c].clone().data.dtype,checkpoint_state_dict[k_c].clone().data.device)
print(model_state_dict.keys())
model.load_state_dict(checkpoint_state_dict)
model = model.cuda()#.float()
# Tokenizer
from megatron.tokenizer.tokenizer import _HFAutoTokenizer
from megatron.text_generation.tokenization import tokenize_prompts
tokenizer = _HFAutoTokenizer("mistralai/Mistral-7B-v0.1")

['--base-model-type', 'mamba', '--num-layers', '54', '--mamba-headdim', '64', '--mamba-ngroups', '1', '--hidden-size', '2560', '--state-size', '64', '--conv-dimension', '4', '--expansion-factor', '2', '--seq-length', '4096', '--max-position-embeddings', '4096', '--micro-batch-size', '3', '--global-batch-size', '384', '--lr', '2.0e-4', '--train-iters', '1_700_000', '--lr-decay-iters', '1_700_000', '--lr-decay-style', 'cosine', '--min-lr', '5.0e-5', '--lr-warmup-init', '0.0', '--weight-decay', '0.1', '--adam-beta2', '0.95', '--lr-warmup-iters', '5000', '--clip-grad', '1.0', '--bf16', '--recompute-granularity', 'selective', '--accumulate-allreduce-grads-in-fp32', '--attention-dropout', '0.0', '--hidden-dropout', '0.0', '--rms-norm', 'True', '--disable-bias-linear', '--num-mem-heads', '32', '--use-mem-mlp', '--num-attention-heads', '32', '--num-mem-blocks', '2', '--use-shared-block-lora', '--lora-rank', '128', '--use-distributed-optimizer', '--mamba-moe-layers', 'm', 'm', 'm', 'm', 'm', 'm

Zarr-based strategies will not be registered because of missing packages


# Generation

In [None]:
# Greedy generation without cache
prompt = 'Hello, how are you?'
tokens_to_generate = 10

prompts_tokens = [tokenizer.tokenize(prompt)]
text_ids = torch.cuda.LongTensor(prompts_tokens).transpose(0, 1)
print(text_ids.shape)
model.eval()
with torch.no_grad():
    for _ in range(tokens_to_generate):
        out = model(text_ids)
        out_last = out[:, -1]
        id = torch.argmax(out_last)[None, None]
        text_ids = torch.cat((text_ids, id), dim=0)
print(text_ids.shape)
text_ids = text_ids.transpose(0, 1)[0]
tokenizer.detokenize(text_ids.cpu().numpy().tolist())

In [30]:
# Adapted from megatron/text_generation/forward_step.py
# and megatron/core/inference_params.py

from collections.abc import Iterable

class InferenceParams:
    """Inference parameters that are passed to the main model in order
    to efficienly calculate and store the context during inference."""

    def __init__(self, max_batch_size, max_sequence_length):
        self.max_sequence_length = max_sequence_length
        self.max_batch_size = max_batch_size
        self.sequence_len_offset = 0
        self.batch_size_offset = 0
        self.key_value_memory_dict = {}
        self.key_value_memory_dict_mamba = {}

    def swap_key_value_dict(self, batch_idx):
        "swap between batches"
        if len(self.key_value_memory_dict) == 0:
            raise ValueError("should not swap when dict in empty")

        for layer_number in self.key_value_memory_dict.keys():
            inference_key_memory, inference_value_memory = self.key_value_memory_dict[layer_number]
            assert (
                len(batch_idx) == inference_key_memory.shape[1]
            )  # make sure batch size is the same
            new_inference_key_memory = inference_key_memory[:, batch_idx]
            new_inference_value_memory = inference_value_memory[:, batch_idx]
            self.key_value_memory_dict[layer_number] = (
                new_inference_key_memory,
                new_inference_value_memory,
            )

            


class ForwardStep:
    """Forward step function.
    We use a class here to hide the inference parameters
    from the outside caller."""

    def __init__(self, model, max_batch_size, max_sequence_length):
        """Set values so we don't need to do it multiple times."""
        # Make sure model is in eval mode.
        assert not isinstance(model, Iterable), \
            'interleaving schedule is not supported for inference'
        model.eval()
        self.model = model
        # Initialize inference parameters.
        self.inference_params = InferenceParams(max_batch_size,
                                                max_sequence_length)


    def __call__(self, tokens):
        """Invocation of the forward methods. Note that self.inference_params
        is being modified by the forward step."""
        # Run a simple forward pass.
        logits = model(tokens, inference_params=self.inference_params)
        
        # if self.inference_params.sequence_len_offset > 0:
        #     for k in range(1, 50):
        #         temp_list = list(self.inference_params.key_value_memory_dict_mamba[k])
        #         temp_list[0] = temp_list[0] * 0
        #         temp_list[1] = temp_list[1] * 0
        #         self.inference_params.key_value_memory_dict_mamba[k] = tuple(temp_list)
        
        # Update the sequence length offset.
        self.inference_params.sequence_len_offset += tokens.size(0)


        return logits

In [31]:
# greedy generation with cache

prompt = 'Hello, how are you?'
tokens_to_generate = 10
batch_size = 1

prompts_tokens = [tokenizer.tokenize(prompt)]
text_ids = torch.cuda.LongTensor(prompts_tokens).transpose(0, 1)
prompt_length = text_ids.shape[0]
max_sequence_length = prompt_length + tokens_to_generate

# allocate tensor for full sequence
tokens = torch.zeros((max_sequence_length, batch_size), dtype=torch.int64, device=torch.cuda.current_device())
tokens[:prompt_length].copy_(text_ids)

forward_step = ForwardStep(model, batch_size, max_sequence_length)

model.eval()
with torch.no_grad():
    prev_context_length = 0
    for context_length in range(prompt_length, max_sequence_length):
        # Pick the slice that we need to pass through the network.
        tokens2use = tokens[prev_context_length:context_length]
        logits = forward_step(tokens2use)
        last_token_logits = logits[:, -1]
        new_id = torch.argmax(last_token_logits)
        tokens[context_length, 0] = new_id
        prev_context_length = context_length

print(tokens.shape)
tokens = tokens.transpose(0, 1)[0]
tokenizer.detokenize(tokens.cpu().numpy().tolist())

torch.Size([17, 1])


'<s> Hello, how are you?\nI am fine, thank you.\nHow'

# Optimize rope

In [39]:
# Forward pass with old rope
prompt = 'This is a prompt that I am writing just to check that the output logits of the new rope implementation agree with those of the old implementation'

prompts_tokens = [tokenizer.tokenize(prompt)]
text_ids = torch.cuda.LongTensor(prompts_tokens).transpose(0, 1)
print(text_ids.shape)
model.eval()
with torch.no_grad():
    out = model(text_ids)
print(out.shape)
print(out[0])

torch.Size([29, 1])
torch.Size([1, 29, 32000])
tensor([[-4.7043,  4.1115, -1.6878,  ..., -4.9308, -4.6628, -3.5261],
        [-7.4544,  1.9833, -5.7091,  ..., -6.7184, -5.4518, -6.2321],
        [-7.2265,  0.9190, -5.1137,  ..., -5.7556, -4.5535, -5.0144],
        ...,
        [-6.3463,  2.1663, -5.2459,  ..., -7.2303, -5.1301, -6.2166],
        [-7.8079,  2.8875, -7.0855,  ..., -8.5312, -6.3525, -6.8940],
        [-8.7074,  5.8500, -5.4719,  ..., -8.3026, -6.2768, -6.7983]],
       device='cuda:0')


In [2]:
# Forward pass with new rope
prompt = 'This is a prompt that I am writing just to check that the output logits of the new rope implementation agree with those of the old implementation'

prompts_tokens = [tokenizer.tokenize(prompt)]
text_ids = torch.cuda.LongTensor(prompts_tokens).transpose(0, 1)
print(text_ids.shape)
model.eval()
with torch.no_grad():
    out = model(text_ids)
print(out.shape)
print(out[0])

  text_ids = torch.cuda.LongTensor(prompts_tokens).transpose(0, 1)


torch.Size([29, 1])
torch.Size([1, 29, 32000])
tensor([[-4.7043,  4.1115, -1.6878,  ..., -4.9308, -4.6628, -3.5261],
        [-7.4544,  1.9833, -5.7091,  ..., -6.7184, -5.4518, -6.2321],
        [-7.2265,  0.9190, -5.1137,  ..., -5.7556, -4.5535, -5.0144],
        ...,
        [-6.3463,  2.1663, -5.2459,  ..., -7.2303, -5.1301, -6.2166],
        [-7.8079,  2.8875, -7.0855,  ..., -8.5312, -6.3525, -6.8940],
        [-8.7074,  5.8500, -5.4719,  ..., -8.3026, -6.2768, -6.7983]],
       device='cuda:0')


# Scratchwork

In [25]:
prompt = 'Hello, how are you?'
tokens_to_generate = 10
return_output_log_probs = False
top_k_sampling = 1
top_p_sampling=0.0,
top_p_decay=0.0,
top_p_bound=0.0,
temperature=1.0,
use_eod_token_for_early_termination=True,
stop_on_double_eol=False,
stop_on_eol=False,
prevent_newline_after_colon=False,
random_seed=-1,
return_logits=False

# some lines below are adapted from Mamba-MoE/tasks/msdp/prompt.py and api.py


# Tokenize
context_tokens_tensor = [tokenizer.tokenize(prompt)]
context_tokens_tensor = torch.cuda.LongTensor(context_tokens_tensor)
context_length_tensor = context_tokens_tensor.shape[1]
print(context_tokens_tensor.shape, context_length_tensor)

torch.Size([1, 7]) 7


In [27]:
# simplified generation

import torch.nn.functional as F

def generate_tokens_probs_and_return_on_first_stage(
        model, tokens, lengths,
        return_output_log_probs=False,
        top_k=0, top_p=0.0, top_p_decay=0.0, top_p_bound=0.0,
        temperature=1.0,
        use_eod_token_for_early_termination=True,
        stop_on_double_eol=False,
        stop_on_eol=False,
        prevent_newline_after_colon=True
        ):
    """Main token generation function.
    Arguments:
        model: no interleaving is supported.
        tokens: prompt tokens extended to be of size [b, max-sequence-length]
        lengths: original prompt length, size: [b]
        return_output_log_probs: flag to calculate the log probability of
            the generated tokens. Note that the log probability is the one
            from the original logit.
        top_k, top_p: top-k and top-p sampling parameters.
            Note that top-k = 1 is gready. Also, these paramters are
            exclusive meaning that:
                if top-k > 0 then we expect top-p=0.
                if top-p > 0 then we check for top-k=0.
        temperature: sampling temperature.
        use_eod_token_for_early_termination: if True, do early termination if
            all the sequences have reached this token.
        prevent_newline_after_colon: if True, it will disable generating new line \n after :
    Outputs: Note that is size is adjusted to a lower value than
             max-sequence-length if generation is terminated early.
        tokens: prompt and generated tokens. size: [b, :]
        generated_sequence_lengths: total length (including prompt) of
            the generated sequence. size: [b]
        output_log_probs: log probability of the selected tokens. size: [b, s]
    """

    batch_size = tokens.size(0)
    min_prompt_length = lengths
    max_sequence_length = tokens.size(1)

    # if max_sequence_length > args.max_position_embeddings:
    #     raise ValueError("Length of prompt + tokens_to_generate longer than allowed")

    forward_step = ForwardStep(model, batch_size, max_sequence_length)

    # =============
    # Run infernece
    # =============

    with torch.no_grad():
        prev_context_length = 0
        for context_length in range(min_prompt_length, max_sequence_length):
            # Pick the slice that we need to pass through the network.
            tokens2use = tokens[:, prev_context_length:context_length]
            # logits will be meanigful only in the last pipeline stage.
            logits = forward_step(tokens2use)
            # Sample.
            last_token_logits = logits[:, -1, :]
            new_sample = torch.argmax(last_token_logits)
            ############# isn't tokens shorter than this? shouldn't do concat? can we actually allocate the total tokens tensor to memory, or unnecessary?
            tokens[0, context_length] = new_sample
            # Update the context length for the next token generation.
            prev_context_length = context_length

    return tokens

In [28]:
# Generate
model.eval()
with torch.no_grad():
    tokens, lengths, output_log_probs, logits = generate_tokens_probs_and_return_on_first_stage(
        model, context_tokens_tensor, context_length_tensor,
        return_output_log_probs=return_output_log_probs,
        top_k=top_k_sampling,
        top_p=top_p_sampling,
        top_p_decay=top_p_decay,
        top_p_bound=top_p_bound,
        temperature=temperature,
        use_eod_token_for_early_termination=use_eod_token_for_early_termination,
        stop_on_double_eol=stop_on_double_eol,
        stop_on_eol=stop_on_eol,
        prevent_newline_after_colon=prevent_newline_after_colon)


# tokens, prompts_plus_generations, prompts_plus_generations_segments = \
#     detokenize_generations(tokens, lengths, True)
prompts_plus_generations = tokenizer.detokenize(tokens.transpose(0, 1)[0].cpu().numpy().tolist())
print(prompts_plus_generations)

ValueError: not enough values to unpack (expected 4, got 1)

In [20]:
# main generation function taken from megatron/text_generation/generation.py

from megatron.text_generation.sampling import sample
import torch.nn.functional as F

def generate_tokens_probs_and_return_on_first_stage(
        model, tokens, lengths,
        return_output_log_probs=False,
        top_k=0, top_p=0.0, top_p_decay=0.0, top_p_bound=0.0,
        temperature=1.0,
        use_eod_token_for_early_termination=True,
        stop_on_double_eol=False,
        stop_on_eol=False,
        prevent_newline_after_colon=True
        ):
    """Main token generation function.
    Arguments:
        model: no interleaving is supported.
        tokens: prompt tokens extended to be of size [b, max-sequence-length]
        lengths: original prompt length, size: [b]
        return_output_log_probs: flag to calculate the log probability of
            the generated tokens. Note that the log probability is the one
            from the original logit.
        top_k, top_p: top-k and top-p sampling parameters.
            Note that top-k = 1 is gready. Also, these paramters are
            exclusive meaning that:
                if top-k > 0 then we expect top-p=0.
                if top-p > 0 then we check for top-k=0.
        temperature: sampling temperature.
        use_eod_token_for_early_termination: if True, do early termination if
            all the sequences have reached this token.
        prevent_newline_after_colon: if True, it will disable generating new line \n after :
    Outputs: Note that is size is adjusted to a lower value than
             max-sequence-length if generation is terminated early.
        tokens: prompt and generated tokens. size: [b, :]
        generated_sequence_lengths: total length (including prompt) of
            the generated sequence. size: [b]
        output_log_probs: log probability of the selected tokens. size: [b, s]
    """

    batch_size = tokens.size(0)
    min_prompt_length = lengths
    max_sequence_length = tokens.size(1)

    # if max_sequence_length > args.max_position_embeddings:
    #     raise ValueError("Length of prompt + tokens_to_generate longer than allowed")

    forward_step = ForwardStep(model, batch_size, max_sequence_length)

    # if hasattr(args, 'eos_id'):
    #     termination_id = args.eos_id
    # else:
    termination_id = tokenizer.eod

    # ===================
    # Pre-allocate memory
    # ===================

    # Log probability of the sequence (prompt + generated tokens).
    output_log_probs = None
    output_log_probs_size = (batch_size, max_sequence_length - 1)
    # Lengths of generated seuquence including including prompts.
    generated_sequence_lengths = None

    if return_output_log_probs:
        output_log_probs = torch.empty(output_log_probs_size,
                                        dtype=torch.float32,
                                        device=torch.cuda.current_device())
    generated_sequence_lengths = torch.ones(
            batch_size, dtype=torch.int64,
            device=torch.cuda.current_device()) * max_sequence_length
    
    # Whether we have reached a termination id.
    is_generation_done = torch.zeros(batch_size, dtype=torch.uint8,
                                     device=torch.cuda.current_device())

    # =============
    # Run infernece
    # =============

    with torch.no_grad():
        prev_context_length = 0
        if not lengths == min_prompt_length:
            print("We want all prompt lengths to be the same, otherwise it seems from the algo below that prompts longer than `min_prompt_length` will be cut.")
        for context_length in range(min_prompt_length, max_sequence_length):

            # Pick the slice that we need to pass through the network.
            tokens2use = tokens[:, prev_context_length:context_length]

            # logits will be meanigful only in the last pipeline stage.
            logits = forward_step(tokens2use)

            
            if prevent_newline_after_colon:
                logits[tokens2use[:, -1] == tokenizer.tokenize(':')[0], -1, tokenizer.tokenize('\n')[0]] = -1e10 # disable "\n" after ":"
            assert logits is not None

            # Sample.
            last_token_logits = logits[:, -1, :]
            new_sample = sample(last_token_logits,
                                top_k=top_k,
                                top_p=top_p,
                                temperature=temperature,
                                vocab_size=tokenizer.vocab_size)
            if top_p > 0.0 and top_p_decay > 0.0:
                top_p = top_p * top_p_decay
                if top_p_bound > 0.0:
                    top_p = max(top_p, top_p_bound)

            # If a prompt length is smaller or equal th current context
            # length, it means we have started generating tokens
            started = lengths <= context_length
            # Update the tokens.
            tokens[started, context_length] = new_sample[started]

            # Calculate the log probabilities.
            if return_output_log_probs:
                log_probs = F.log_softmax(logits, dim=2)
                if return_output_log_probs:
                    # Pick the tokens that we need to get the log
                    # probabilities for. Note that next input token is
                    # the token which we selected in the current logits,
                    # so shift by 1.
                    indices = torch.unsqueeze(
                        tokens[
                            :,
                            (prev_context_length + 1):(context_length + 1)],
                        2)
                    output_log_probs[:,
                                        prev_context_length:context_length] = \
                        torch.gather(log_probs, 2, indices).squeeze(2)

            # Update the context length for the next token generation.
            prev_context_length = context_length

            # Check if all the sequences have hit the termination_id.
            done = None
            
            # TODO(rprenger) These stopping methods are tokenizer dependent
            # instead tokenization should be in the inference loop so stop sequences can be used
            if stop_on_double_eol:
                hit_double_eol = (new_sample == 628).byte() & started.byte()
                hit_two_eols = (new_sample == 198).byte() & (tokens[:, context_length-1] == 198).byte() & started.byte()
                done_token = hit_double_eol | hit_two_eols
            elif stop_on_eol:
                hit_double_eol = (new_sample == 628).byte() & started.byte()
                hit_eol = (new_sample == 198).byte() & started.byte()
                done_token = hit_double_eol | hit_eol
            else: 
                done_token = (new_sample == termination_id).byte() & \
                    started.byte()
            
            just_finished = (done_token & ~is_generation_done).bool()
            generated_sequence_lengths[just_finished.view(-1)] = \
                context_length + 1
            is_generation_done = is_generation_done | done_token
            done = torch.all(is_generation_done)
            if use_eod_token_for_early_termination and done:
                break
            

    tokens = tokens[:, :(max_sequence_length + 1)]

    if return_output_log_probs:
        output_log_probs = output_log_probs[:, :context_length]

    return tokens, generated_sequence_lengths, output_log_probs, None