In [1]:
from pprint import pprint
from parsers import ModelParser

import torch
from transformers import AutoTokenizer

model_parser = ModelParser([
    "./Meta-Llama-3-8B/model-00001-of-00004.safetensors",
    "./Meta-Llama-3-8B/model-00002-of-00004.safetensors",
    "./Meta-Llama-3-8B/model-00003-of-00004.safetensors",
    "./Meta-Llama-3-8B/model-00004-of-00004.safetensors",
])

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_parser._parsers_dict['./Meta-Llama-3-8B/model-00002-of-00004.safetensors'].tensor_names

['model.layers.10.input_layernorm.weight',
 'model.layers.10.mlp.down_proj.weight',
 'model.layers.10.mlp.gate_proj.weight',
 'model.layers.10.mlp.up_proj.weight',
 'model.layers.10.post_attention_layernorm.weight',
 'model.layers.10.self_attn.k_proj.weight',
 'model.layers.10.self_attn.o_proj.weight',
 'model.layers.10.self_attn.q_proj.weight',
 'model.layers.10.self_attn.v_proj.weight',
 'model.layers.11.input_layernorm.weight',
 'model.layers.11.mlp.down_proj.weight',
 'model.layers.11.mlp.gate_proj.weight',
 'model.layers.11.mlp.up_proj.weight',
 'model.layers.11.post_attention_layernorm.weight',
 'model.layers.11.self_attn.k_proj.weight',
 'model.layers.11.self_attn.o_proj.weight',
 'model.layers.11.self_attn.q_proj.weight',
 'model.layers.11.self_attn.v_proj.weight',
 'model.layers.12.input_layernorm.weight',
 'model.layers.12.mlp.down_proj.weight',
 'model.layers.12.mlp.gate_proj.weight',
 'model.layers.12.mlp.up_proj.weight',
 'model.layers.12.post_attention_layernorm.weight',


In [3]:
pprint(model_parser.tensor_names)

['model.embed_tokens.weight',
 'model.layers.0.input_layernorm.weight',
 'model.layers.0.mlp.down_proj.weight',
 'model.layers.0.mlp.gate_proj.weight',
 'model.layers.0.mlp.up_proj.weight',
 'model.layers.0.post_attention_layernorm.weight',
 'model.layers.0.self_attn.k_proj.weight',
 'model.layers.0.self_attn.o_proj.weight',
 'model.layers.0.self_attn.q_proj.weight',
 'model.layers.0.self_attn.v_proj.weight',
 'model.layers.1.input_layernorm.weight',
 'model.layers.1.mlp.down_proj.weight',
 'model.layers.1.mlp.gate_proj.weight',
 'model.layers.1.mlp.up_proj.weight',
 'model.layers.1.post_attention_layernorm.weight',
 'model.layers.1.self_attn.k_proj.weight',
 'model.layers.1.self_attn.o_proj.weight',
 'model.layers.1.self_attn.q_proj.weight',
 'model.layers.1.self_attn.v_proj.weight',
 'model.layers.2.input_layernorm.weight',
 'model.layers.2.mlp.down_proj.weight',
 'model.layers.2.mlp.gate_proj.weight',
 'model.layers.2.mlp.up_proj.weight',
 'model.layers.2.post_attention_layernorm.we

## Prepare text and embeddings

In [4]:
device="cpu"

In [5]:
tokenizer = AutoTokenizer.from_pretrained("./Meta-Llama-3-8B/")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [6]:
"""
messages = [
    {"role": "system", "content": "You are a programmer chatbot who always responds clearly and concisely!"},
    {"role": "user", "content": "Who are you?"},
]

input_ids = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    return_tensors="pt"
).to(device)#.to(model.device)
"""
input_ids = tokenizer(
    ["I believe the meaning of life is"],
    return_tensors="pt"
)["input_ids"].to(device)

terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

In [7]:
input_ids

tensor([[128000,     40,   4510,    279,   7438,    315,   2324,    374]])

## Forward passes

In [8]:
config = {
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": False,
  "attention_dropout": 0.0,
  "bos_token_id": 128000,
  "eos_token_id": 128001,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 14336,
  #"max_position_embeddings": 8192,
  "max_position_embeddings": 2048,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 8,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": None,
  "rope_theta": 500000.0,
  "tie_word_embeddings": False,
  "torch_dtype": torch.bfloat16,
  "transformers_version": "4.40.0.dev0",
  "use_cache": True,
  "vocab_size": 128256
}

In [9]:
from transformer_ops import Transformer

In [10]:
model = Transformer(config, model_parser, device=device)

In [11]:
max_seq_len = 2048
max_gen_len = 32
min_prompt_len = min(len(t) for t in input_ids)
max_prompt_len = max(len(t) for t in input_ids)
assert max_prompt_len <= max_seq_len
total_len = min(max_seq_len, max_gen_len + max_prompt_len)

In [12]:
pad_id = tokenizer.eos_token_id
batch_size = 1
prev_pos = 0

tokens = torch.full((batch_size, total_len), pad_id, dtype=torch.long, device=device)
for k, t in enumerate(input_ids):
    tokens[k, : len(t)] = t

eos_reached = torch.tensor([False] * batch_size, device=device)
input_text_mask = tokens != pad_id

In [13]:
stop_tokens = torch.tensor([13], device=device) # 13=.

In [14]:
from tqdm import tqdm

In [15]:
for cur_pos in tqdm(range(min_prompt_len, total_len)):
    logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
    # assume temperature 0
    next_token = torch.argmax(logits[:, -1], dim=-1)
    next_token = next_token.reshape(-1)
    print(f"Decoded token={tokenizer.decode(next_token)}")
    # only replace token if prompt has already been generated
    next_token = torch.where(
        input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
    )
    tokens[:, cur_pos] = next_token
    eos_reached |= (~input_text_mask[:, cur_pos]) & (
        torch.isin(next_token, stop_tokens)
    )
    prev_pos = cur_pos
    if all(eos_reached):
        break

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
  3%|█████▎                                                                                                                                                                     | 1/32 [00:21<10:55, 21.16s/it]

Decoded token= to


  6%|██████████▋                                                                                                                                                                | 2/32 [00:29<06:52, 13.76s/it]

Decoded token= be


  9%|████████████████                                                                                                                                                           | 3/32 [00:38<05:27, 11.30s/it]

Decoded token= happy


  9%|████████████████                                                                                                                                                           | 3/32 [00:46<07:31, 15.55s/it]

Decoded token=.





In [60]:
tokens[:, prev_pos:cur_pos]

tensor([[2303]])

AttributeError: 'PreTrainedTokenizerFast' object has no attribute 'stop_tokens'

In [None]:
@torch.inference_mode()
    def generate(
        self,
        prompt_tokens: List[List[int]],
        max_gen_len: int,
        temperature: float = 0.6,
        top_p: float = 0.9,
        logprobs: bool = False,
        echo: bool = False,
    ) -> Tuple[List[List[int]], Optional[List[List[float]]]]:
        """
        Generate text sequences based on provided prompts using the language generation model.

        Args:
            prompt_tokens (List[List[int]]): List of tokenized prompts, where each prompt is represented as a list of integers.
            max_gen_len (int): Maximum length of the generated text sequence.
            temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
            top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
            logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False.
            echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.

        Returns:
            Tuple[List[List[int]], Optional[List[List[float]]]]: A tuple containing generated token sequences and, if logprobs is True, corresponding token log probabilities.

        Note:
            This method uses the provided prompts as a basis for generating text. It employs nucleus sampling to produce text with controlled randomness.
            If logprobs is True, token log probabilities are computed for each generated token.

        """
        params = self.model.params
        bsz = len(prompt_tokens)
        assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)

        min_prompt_len = min(len(t) for t in prompt_tokens)
        max_prompt_len = max(len(t) for t in prompt_tokens)
        assert max_prompt_len <= params.max_seq_len
        total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)

        pad_id = self.tokenizer.pad_id
        tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
        for k, t in enumerate(prompt_tokens):
            tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
        if logprobs:
            token_logprobs = torch.zeros_like(tokens, dtype=torch.float)

        prev_pos = 0
        eos_reached = torch.tensor([False] * bsz, device="cuda")
        input_text_mask = tokens != pad_id
        if min_prompt_len == total_len:
            logits = self.model.forward(tokens, prev_pos)
            token_logprobs = -F.cross_entropy(
                input=logits.transpose(1, 2),
                target=tokens,
                reduction="none",
                ignore_index=pad_id,
            )

        stop_tokens = torch.tensor(list(self.tokenizer.stop_tokens))

        for cur_pos in range(min_prompt_len, total_len):
            logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
            if temperature > 0:
                probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
                next_token = sample_top_p(probs, top_p)
            else:
                next_token = torch.argmax(logits[:, -1], dim=-1)

            next_token = next_token.reshape(-1)
            # only replace token if prompt has already been generated
            next_token = torch.where(
                input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
            )
            tokens[:, cur_pos] = next_token
            if logprobs:
                token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
                    input=logits.transpose(1, 2),
                    target=tokens[:, prev_pos + 1 : cur_pos + 1],
                    reduction="none",
                    ignore_index=pad_id,
                )
            eos_reached |= (~input_text_mask[:, cur_pos]) & (
                torch.isin(next_token, stop_tokens)
            )
            prev_pos = cur_pos
            if all(eos_reached):
                break

        if logprobs:
            token_logprobs = token_logprobs.tolist()
        out_tokens, out_logprobs = [], []
        for i, toks in enumerate(tokens.tolist()):
            # cut to max gen len
            start = 0 if echo else len(prompt_tokens[i])
            toks = toks[start : len(prompt_tokens[i]) + max_gen_len]
            probs = None
            if logprobs:
                probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len]
            # cut to after eos tok if any
            for stop_token in self.tokenizer.stop_tokens:
                try:
                    eos_idx = toks.index(stop_token)
                    toks = toks[:eos_idx]
                    probs = probs[:eos_idx] if logprobs else None
                except ValueError:
                    pass
            out_tokens.append(toks)
            out_logprobs.append(probs)
        return (out_tokens, out_logprobs if logprobs else None)