In [1]:
from pprint import pprint
from parsers import ModelParser

import torch
from transformers import AutoTokenizer
from utils.utils import load_json

model_parser = ModelParser([
    "./Phi-3-mini-4k-instruct/model-00001-of-00002.safetensors",
    "./Phi-3-mini-4k-instruct/model-00002-of-00002.safetensors",
])

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_parser.tensor_names

['model.embed_tokens.weight',
 'model.layers.0.input_layernorm.weight',
 'model.layers.0.mlp.down_proj.weight',
 'model.layers.0.mlp.gate_up_proj.weight',
 'model.layers.0.post_attention_layernorm.weight',
 'model.layers.0.self_attn.o_proj.weight',
 'model.layers.0.self_attn.qkv_proj.weight',
 'model.layers.1.input_layernorm.weight',
 'model.layers.1.mlp.down_proj.weight',
 'model.layers.1.mlp.gate_up_proj.weight',
 'model.layers.1.post_attention_layernorm.weight',
 'model.layers.1.self_attn.o_proj.weight',
 'model.layers.1.self_attn.qkv_proj.weight',
 'model.layers.10.input_layernorm.weight',
 'model.layers.10.mlp.down_proj.weight',
 'model.layers.10.mlp.gate_up_proj.weight',
 'model.layers.10.post_attention_layernorm.weight',
 'model.layers.10.self_attn.o_proj.weight',
 'model.layers.10.self_attn.qkv_proj.weight',
 'model.layers.11.input_layernorm.weight',
 'model.layers.11.mlp.down_proj.weight',
 'model.layers.11.mlp.gate_up_proj.weight',
 'model.layers.11.post_attention_layernorm.w

## Prepare text and embeddings

In [3]:
device="mps"
base_model_dir = "./Phi-3-mini-4k-instruct"
config = load_json(f"./{base_model_dir}/config.json")

In [4]:
tokenizer = AutoTokenizer.from_pretrained(base_model_dir)
terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>"),
]
tokenizer.pad_token = tokenizer.eos_token

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [5]:
def text_to_ids(text, tokenizer):
    if type(text) != list: text = [text]
    input_ids = tokenizer(
        text,
        #return_tensors="pt",
        #padding=True
    )["input_ids"]
    return input_ids

## Forward passes

In [6]:
from ops.transformer_ops import Transformer

In [7]:
model_dir = "PHI3-MINI-4K-PKL-int8"
model = Transformer(model_dir, config, device=device)

In [16]:
def sample_top_p(probs, p):
    """
    Taken from: https://github.com/meta-llama/llama3/blob/main/llama/generation.py
    
    Perform top-p (nucleus) sampling on a probability distribution.

    Args:
        probs (torch.Tensor): Probability distribution tensor.
        p (float): Probability threshold for top-p sampling.

    Returns:
        torch.Tensor: Sampled token indices.

    Note:
        Top-p sampling selects the smallest set of tokens whose cumulative probability mass
        exceeds the threshold p. The distribution is renormalized based on the selected tokens.
    """
    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
    probs_sum = torch.cumsum(probs_sort, dim=-1)
    mask = probs_sum - probs_sort > p
    probs_sort[mask] = 0.0
    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
    next_token = torch.multinomial(probs_sort, num_samples=1)
    next_token = torch.gather(probs_idx, -1, next_token)
    return next_token

def generate_text(model, tokenizer, input_ids, max_gen_len, temperature=0.6, top_p=0.9, stop_tokens_ids=None, streaming=False, echo=False):
    """
    If temperature > 0, then top_p is used for sampling.

    echo: whether to output also input tokens or not.
    """
    device=model.device
    max_seq_len = model.max_seq_len
    min_prompt_len = min(len(t) for t in input_ids)
    max_prompt_len = max(len(t) for t in input_ids)
    assert max_prompt_len <= max_seq_len
    total_len = min(max_seq_len, max_gen_len + max_prompt_len)

    pad_id = tokenizer.eos_token_id
    batch_size = len(input_ids)
    prev_pos = 0
    
    tokens = torch.full((batch_size, total_len), pad_id, dtype=torch.long, device=device)
    for k, t in enumerate(input_ids):
        tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device=device)
    
    eos_reached = torch.tensor([False] * batch_size, device=device)
    input_text_mask = tokens != pad_id

    if stop_tokens_ids == None:
        stop_tokens = torch.tensor([13], device="cpu") # 13
        #stop_tokens = torch.tensor(list(tokenizer.stop_tokens))
    else:
        stop_tokens = torch.tensor(stop_tokens_ids, device="cpu")

    tokens_output = []

    for cur_pos in range(min_prompt_len, total_len):
        logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
        if temperature > 0:
            probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
            next_token = sample_top_p(probs, top_p)
        else:
            next_token = torch.argmax(logits[:, -1], dim=-1)
        next_token = next_token.reshape(-1)
        decoded_token = tokenizer.decode(next_token)
        # only replace token if prompt has already been generated
        next_token = torch.where(
            input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
        )
        tokens[:, cur_pos] = next_token
        """
        Needs to be on CPU:
        NotImplementedError: The operator 'aten::isin.Tensor_Tensor_out' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
        """
        is_in = torch.isin(next_token.cpu(), stop_tokens).to(device)
        eos_reached |= (~input_text_mask[:, cur_pos]) & (
            is_in
        )
        prev_pos = cur_pos
        if all(eos_reached):
            break

    tokens_output = []
    for idx, generated_tokens in enumerate(tokens.tolist()):
        current_prompt_len = len(input_ids[idx])
        start_pos = 0 if echo else current_prompt_len
        generated_tokens = generated_tokens[start_pos: current_prompt_len + max_gen_len]
        for stop_token in stop_tokens_ids:
            try:
                idx_of_stop_token = generated_tokens.index(stop_token)
                generated_tokens = generated_tokens[:idx_of_stop_token]
            except ValueError:
                pass
        tokens_output.append(generated_tokens)
    
    return [tokenizer.decode(generated_tokens) for generated_tokens in tokens_output]

In [17]:
import time

In [18]:
texts = ["<|user|>I am going to Paris, what should I see?<|end|><|assistant|>"]
input_ids = text_to_ids(texts, tokenizer)

In [None]:
start_time = time.time()
outputs = generate_text(model, tokenizer, input_ids, max_gen_len=128, stop_tokens_ids=terminators)
delta_time = time.time() - start_time
print([text+output for text, output in zip(texts, outputs)])
outputs_total_tokens = sum([len(output) for output in outputs])
print(f"Generation took {delta_time} seconds, {delta_time/outputs_total_tokens} tokens/s.")